-
Notifications
You must be signed in to change notification settings - Fork 95
8000
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of servic 8000 e and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@WhenWen can you fix FAILED tests/test_train_lm.py::test_train_lm - AttributeError: 'AdamConfig' object has no attribute 'nesterov' |
Sorry for this, forgot to merge the config file for Nesterov AdamW. It is done now. |
import chex | ||
|
||
class ScaleByCautiousState(NamedTuple): | ||
"""State for the Mars algorithm.""" |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
"""State for the Mars algorithm.""" | |
"""State for the Cautious algorithm.""" |
@dataclass | ||
class ScionConfig(OptimizerConfig): | ||
""" | ||
Scion optimizer configuration: Momentum Orthogonalized by Newton-Schulz. |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
is this right?
|
||
def partition(self, tensor): | ||
"""Partition tensor into blocks.""" | ||
print('difference') |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
rm?
def partition(self, tensor): | ||
"""Partition tensor into blocks.""" | ||
print('difference') | ||
print(tensor.shape, self._shape) |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
rm
|
||
def create_mask(self, params): | ||
""" | ||
Creates a mask that labels parameters as 'mini' or 'adamw' based on their |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
doc comment isn't quite accurate
|
||
|
||
class ScaleByCautiousState(NamedTuple): | ||
"""State for the Mars algorithm.""" |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
"""State for the Mars algorithm.""" | |
"""State for the Cautious Adam algorithm.""" |
**kwargs, | ||
) -> base.GradientTransformation: | ||
""" | ||
Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
maybe put this link in the config class too
|
||
def init_fn(params, return_partition_specs_only=False): |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
i'm not checking this logic carefully
params_sharding_ = jax.tree.map(lambda x: x.spec, params_sharding_) | ||
updates, updates_struct = jax.tree.flatten(updates) | ||
scanned_layers_ = jax.tree.leaves(scanned_layers_) | ||
print(f"kron scanned_layers_: {scanned_layers_}") |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
rm
scanned_layers_ = jax.tree.leaves(scanned_layers_) | ||
print(f"kron scanned_layers_: {scanned_layers_}") | ||
params_sharding_ = jax.tree.leaves(params_sharding_) | ||
print(f"kron params_sharding_: {params_sharding_}") |
There was a problem hiding this comment.
The reason will be displayed to describe this comment to others. Learn more.
rm
Implemented a list of modern optimizers.
ADOPT
adopt.py
Muon
muon.py
SCION
scion.py
MARS
mars.py
Cautious
cautious.py
Kron (Variant of PSGD)
kron.py
RMSProp with Momentum
rmsprop.py
SOAP
soap.py