Skip to content

Commit

Permalink
Merge pull request #1126 from leloykun:fc--add-muon
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712491634
  • Loading branch information
OptaxDev committed Jan 6, 2025
2 parents 1e08bcc + 6a67850 commit ee25534
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 1 deletion.
8 changes: 8 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Experimental features and algorithms that don't meet the
MomoState
momo_adam
MomoAdamState
muon
MuonState
prodigy
ProdigyState
sam
Expand Down Expand Up @@ -84,6 +86,12 @@ Momo
.. autofunction:: momo_adam
.. autoclass:: MomoAdamState

Muon
~~~~
.. autofunction:: muon
.. autofunction:: scale_by_muon
.. autoclass:: MuonState

Prodigy
~~~~~~~
.. autofunction:: prodigy
Expand Down
3 changes: 3 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from optax.contrib._momo import momo_adam
from optax.contrib._momo import MomoAdamState
from optax.contrib._momo import MomoState
from optax.contrib._muon import muon
from optax.contrib._muon import MuonState
from optax.contrib._muon import scale_by_muon
from optax.contrib._privacy import differentially_private_aggregate
from optax.contrib._privacy import DifferentiallyPrivateAggregateState
from optax.contrib._privacy import dpsgd
Expand Down
40 changes: 39 additions & 1 deletion optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
{'opt_name': 'dowg', 'opt_kwargs': {'learning_rate': 1.0}},
{'opt_name': 'momo', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'momo_adam', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'muon', 'opt_kwargs': {'learning_rate': 1e-2}},
{'opt_name': 'prodigy', 'opt_kwargs': {'learning_rate': 1e-1}},
{
'opt_name': 'schedule_free_sgd',
Expand Down Expand Up @@ -177,11 +178,48 @@ def obj_fn(params):
return initial_params, final_params, obj_fn


def _setup_matrix_parabola(dtype):
"""Quadratic function as an optimization target with 2D tensor parameters."""
initial_params = jnp.zeros((2, 2), dtype=dtype)
final_params = jnp.array([[3.0, -2.0], [1.0, 4.0]], dtype=dtype)

def obj_fn(params):
return jnp.sum(numerics.abs_sq(params - final_params))

return initial_params, final_params, obj_fn


def _setup_mixed_tensor_target(dtype):
"""Optimization target combining 1D and 2D tensor parameters."""
initial_1d_params = jnp.zeros((3,), dtype=dtype)
final_1d_params = jnp.array([1.0, -1.0, 2.0], dtype=dtype)

initial_2d_params = jnp.zeros((2, 2), dtype=dtype)
final_2d_params = jnp.array([[1.0, 0.0], [-1.0, 1.0]], dtype=dtype)

def obj_fn(params):
"""Objective function combining 1D and 2D parameters."""
params_1d, params_2d = params
loss_1d = jnp.sum(numerics.abs_sq(params_1d - final_1d_params))
loss_2d = jnp.sum(numerics.abs_sq(params_2d - final_2d_params))
return loss_1d + loss_2d

initial_params = (initial_1d_params, initial_2d_params)
final_params = (final_1d_params, final_2d_params)

return initial_params, final_params, obj_fn


class ContribTest(chex.TestCase):

@parameterized.product(
_ALL_OPTIMIZERS_UNDER_TEST,
target=(_setup_parabola, _setup_rosenbrock),
target=(
_setup_parabola,
_setup_rosenbrock,
_setup_matrix_parabola,
_setup_mixed_tensor_target,
),
dtype=('float32',),
)
def test_optimizers(
Expand Down
270 changes: 270 additions & 0 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Muon.
Implementation of the
[Muon optimizer](https://github.com/KellerJordan/modded-nanogpt)
by Keller Jordan
"""


from typing import NamedTuple, Optional, Union

import chex
import jax
import jax.numpy as jnp

from optax import tree_utils as otu
from optax._src import alias
from optax._src import base
from optax._src import combine
from optax._src import numerics
from optax._src import transform
from optax._src import utils


def orthogonalize_via_newton_schulz(
x: jax.Array,
ns_coeffs: jax.Array,
ns_steps: int = 5,
eps: float = 1e-8,
) -> jax.Array:
r"""Orthogonalize via Newton-Schulz iteration.
We opt to use a quintic iteration whose coefficients are selected to maximize
the slope at zero. For the purpose of minimizing steps, it turns out to be
empirically effective to keep increasing the slope at zero even beyond the
point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather
something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5),
which turns out not to hurt model performance at all relative to UV^T, where
USV^T = G is the SVD.
Args:
x: A matrix to orthogonalize.
ns_coeffs: Coefficients for the Newton-schulz iterators.
Must have shape (n, 3) where n is the number of iterations.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a 2D array.
eps: Term added to denominators to improve numerical stability.
Returns:
The orthogonalized matrix.
"""
if x.ndim != 2:
raise ValueError(f'Input must have shape (m, n), got {x.shape}')
if ns_coeffs.ndim > 2 or ns_coeffs.shape[-1] != 3:
raise ValueError(
'Newton-Schulz coefficients must have shape (3,) or (n, 3), '
f'got {ns_coeffs.shape}'
)
def newton_schulz_iterator(x: jax.Array, coeffs: jax.Array) -> jax.Array:
a = x @ x.T
b = coeffs[1] * a + coeffs[2] * a @ a
return coeffs[0] * x + b @ x
transposed = False
if x.shape[0] > x.shape[1]:
x = x.T
transposed = True
x /= jnp.linalg.norm(x) + eps # Ensure spectral norm is at most 1
ns_coeffs = ns_coeffs.astype(x.dtype)
if ns_coeffs.ndim == 1:
x = jax.lax.fori_loop(
0, ns_steps, lambda _, x: newton_schulz_iterator(x, ns_coeffs), x
)
else:
x, _ = jax.lax.scan(
lambda x, abc: (newton_schulz_iterator(x, abc), None), x, ns_coeffs
)
if transposed:
x = x.T
return x


class MuonState(NamedTuple):
"""State for the Adam algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates


def scale_by_muon(
ns_coeffs: Union[
tuple[float, float, float],
tuple[tuple[float, float, float], ...],
] = (3.4445, -4.7750, 2.0315),
ns_steps: int = 5,
beta: float = 0.95,
eps: float = 1e-8,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = True,
adaptive: bool = False,
) -> base.GradientTransformation:
r"""Rescale updates according to the Muon algorithm.
Muon is a variant of Shampoo that uses the Newton-schulz method to
orthogonalize the momentum accumulated by the optimizer. Mathematically, it
does steepest descent under the Schatten-p norm, for some large p. With
p=infty, it is equivalent to Shampoo without accumulation, or steepest
descent under the Spectral norm.
Args:
ns_coeffs: Coefficients for the Newton-schulz method.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a tuple of tuples.
beta: Decay rate for the exponentially weighted average of grads.
eps: Term added to denominators to improve numerical stability.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
adaptive: Whether to scale the updates by the dual norm of the
original updates. See https://arxiv.org/abs/2409.20325
Returns:
A `GradientTransformation` object.
References:
Jordan, `modded-nanogpt: Speedrunning the NanoGPT baseline
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Bernstein et al., `Old Optimizer, New Norm: An Anthology
https://arxiv.org/abs/2409.20325`_, 2024
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
ns_coeffs_ = jnp.asarray(ns_coeffs)
if ns_coeffs_.ndim > 2 or ns_coeffs_.shape[-1] != 3:
raise ValueError(
f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}'
)

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
return MuonState(count=jnp.zeros([], jnp.int32), mu=mu)

def update_fn(updates, state, params=None):
del params
mu = otu.tree_update_moment(updates, state.mu, beta, 1)
count_inc = numerics.safe_increment(state.count)
if nesterov:
mu_hat = jax.tree.map(
lambda m, g: beta * m + (1 - beta) * g,
otu.tree_bias_correction(
mu, beta, numerics.safe_increment(count_inc)
),
otu.tree_bias_correction(updates, beta, count_inc),
)
else:
mu_hat = otu.tree_bias_correction(mu, beta, count_inc)
# Apply Newton-schulz orthogonalization.
updates = jax.tree.map(
lambda x: orthogonalize_via_newton_schulz(x, ns_coeffs_, ns_steps, eps),
updates,
)
if adaptive:
# Scale the orthogonalized updates by the dual norm of the original
# updates. See https://arxiv.org/abs/2409.20325 for the derivation.
updates = jax.tree.map(
lambda x, y: jnp.einsum('ij,ij,ab->ab', x, y, y), mu_hat, updates
)
mu = otu.tree_cast(mu, mu_dtype)
return updates, MuonState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)


def muon(
learning_rate: base.ScalarOrSchedule,
ns_coeffs: Union[
tuple[float, float, float],
tuple[tuple[float, float, float], ...],
] = (3.4445, -4.7750, 2.0315),
ns_steps: int = 5,
beta: float = 0.95,
eps: float = 1e-8,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = True,
adaptive: bool = False,
adam_b1: float = 0.9,
adam_b2: float = 0.999,
adam_eps_root: float = 0.0,
adam_weight_decay: float = 0.0,
) -> base.GradientTransformation:
r"""Muon: Momentum Orthogonalized by Newton-schulz.
Muon is a variant of Shampoo that uses the Newton-schulz method to
orthogonalize the momentum accumulated by the optimizer. Mathematically, it
does steepest descent under the Schatten-p norm, for some large p. With
p=infty, it is equivalent to Shampoo without accumulation, or steepest
descent under the Spectral norm.
Note that Muon is currently only defined for 2D parameters, i.e. matrices.
This is because the Newton-Schulz iterator expects a matrix as input.
The non-2D parameters are instead passed through an Adam optimizer.
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
ns_coeffs: Coefficients for the Newton-schulz method.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a tuple of tuples.
beta: Decay rate for the exponentially weighted average of grads.
eps: Term added to the denominator to improve numerical stability.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
adaptive: Whether to scale the updates by the dual norm of the
original updates. See https://arxiv.org/abs/2409.20325
adam_b1: Exponential decay rate for Adam's first moment estimates.
adam_b2: Exponential decay rate for Adam's second moment estimates.
adam_eps_root: Epsilon to stabilize division in Adam, square root version.
adam_weight_decay: Weight decay factor for Adam.
Returns:
The corresponding `GradientTransformation`.
References:
Jordan, `modded-nanogpt: Speedrunning the NanoGPT baseline
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Bernstein et al., `Old Optimizer, New Norm: An Anthology
https://arxiv.org/abs/2409.20325`_, 2024
"""
return combine.multi_transform(
transforms={
'muon': combine.chain(
scale_by_muon(
ns_coeffs=ns_coeffs,
ns_steps=ns_steps,
beta=beta,
eps=eps,
mu_dtype=mu_dtype,
nesterov=nesterov,
adaptive=adaptive,
),
transform.scale_by_learning_rate(learning_rate),
),
'adam': alias.adamw(
learning_rate=learning_rate,
b1=adam_b1,
b2=adam_b2,
eps=eps,
eps_root=adam_eps_root,
weight_decay=adam_weight_decay,
mu_dtype=mu_dtype,
nesterov=nesterov,
),
},
param_labels=lambda params: jax.tree.map(
lambda x: 'muon' if x.ndim == 2 else 'adam', params
),
)

0 comments on commit ee25534

Please sign in to comment.