8000 Add Modern Optimizers in Levanter by WhenWen · Pull Request #955 · stanford-crfm/levanter · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add Modern Optimizers in Levanter #955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of servic 8000 e and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Add Modern Optimizers in Levanter #955

wants to merge 5 commits into from

Conversation

WhenWen
Copy link
Contributor
@WhenWen WhenWen commented Apr 28, 2025

Implemented a list of modern optimizers.

  1. ADOPT

    • Implemented in adopt.py
    • Reference: Nagahara et al. ADOPT: Modified Adam Can Converge with Any β₂ with the Optimal Rate. arXiv:2411.02853
  2. Muon

  3. SCION

    • Implemented in scion.py
    • Reference: Pethick et al. Training Deep Learning Models with Norm-Constrained LMOs. arXiv:2502.07529
  4. MARS

    • Implemented in mars.py
    • Reference: Li, S., Zhou, Y., & Wang, P. (2024). MARS: Unleashing the Power of Variance Reduction for Training Large Models. arXiv:2411.10438
  5. Cautious

    • Implemented in cautious.py
    • Reference: Liang, K., Chen, L., Liu, B., & Liu, Q. (2024). Cautious Optimizers: Improving Training with One Line of Code. arXiv:2411.16085
  6. Kron (Variant of PSGD)

  7. RMSProp with Momentum

    • Implemented in rmsprop.py
    • Reference: Tieleman, T., & Hinton, G. (2012). RMSProp: Divide the gradient by a running average of its recent magnitude
  8. SOAP

    • Implemented in soap.py
    • Reference: Vyas, N., Morwani, D., Zhao, R., Kwun, M., Shapira, I., Brandfonbrener, D., … & Kakade, S. (2024). SOAP: Improving and Stabilizing Shampoo using Adam. arXiv:2409.11321

@WhenWen WhenWen requested a review from dlwh April 28, 2025 17:00
@dlwh
Copy link
Member
dlwh commented Apr 29, 2025

@WhenWen can you fix

FAILED tests/test_train_lm.py::test_train_lm - AttributeError: 'AdamConfig' object has no attribute 'nesterov'
FAILED tests/test_train_lm.py::test_train_lm_fp8 - AttributeError: 'AdamConfig' object has no attribute 'nesterov'

@WhenWen
Copy link
Contributor Author
WhenWen commented Apr 30, 2025

@WhenWen can you fix

FAILED tests/test_train_lm.py::test_train_lm - AttributeError: 'AdamConfig' object has no attribute 'nesterov' FAILED tests/test_train_lm.py::test_train_lm_fp8 - AttributeError: 'AdamConfig' object has no attribute 'nesterov'

Sorry for this, forgot to merge the config file for Nesterov AdamW. It is done now.

import chex

class ScaleByCautiousState(NamedTuple):
"""State for the Mars algorithm."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""State for the Mars algorithm."""
"""State for the Cautious algorithm."""

@dataclass
class ScionConfig(OptimizerConfig):
"""
Scion optimizer configuration: Momentum Orthogonalized by Newton-Schulz.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this right?


def partition(self, tensor):
"""Partition tensor into blocks."""
print('difference')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm?

def partition(self, tensor):
"""Partition tensor into blocks."""
print('difference')
print(tensor.shape, self._shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm


def create_mask(self, params):
"""
Creates a mask that labels parameters as 'mini' or 'adamw' based on their
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc comment isn't quite accurate



class ScaleByCautiousState(NamedTuple):
"""State for the Mars algorithm."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""State for the Mars algorithm."""
"""State for the Cautious Adam algorithm."""

**kwargs,
) -> base.GradientTransformation:
"""
Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe put this link in the config class too

Comment on lines +268 to +269

def init_fn(params, return_partition_specs_only=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not checking this logic carefully

params_sharding_ = jax.tree.map(lambda x: x.spec, params_sharding_)
updates, updates_struct = jax.tree.flatten(updates)
scanned_layers_ = jax.tree.leaves(scanned_layers_)
print(f"kron scanned_layers_: {scanned_layers_}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm

scanned_layers_ = jax.tree.leaves(scanned_layers_)
print(f"kron scanned_layers_: {scanned_layers_}")
params_sharding_ = jax.tree.leaves(params_sharding_)
print(f"kron params_sharding_: {params_sharding_}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0