Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added wrt option to nnx.Optimizer #3983

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@
from .nnx.spmd import with_sharding_constraint as with_sharding_constraint
from .nnx.state import State as State
from .nnx.training import metrics as metrics
from .nnx.variables import (
Param as Param,
) # this needs to be imported before optimizer to prevent circular import
from .nnx.training import optimizer as optimizer
from .nnx.training.metrics import Metric as Metric
from .nnx.training.metrics import MultiMetric as MultiMetric
Expand All @@ -127,7 +130,6 @@
from .nnx.variables import Cache as Cache
from .nnx.variables import Empty as Empty
from .nnx.variables import Intermediate as Intermediate
from .nnx.variables import Param as Param
from .nnx.variables import Variable as Variable
from .nnx.variables import VariableState as VariableState
from .nnx.variables import VariableMetadata as VariableMetadata
Expand Down
75 changes: 63 additions & 12 deletions flax/nnx/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,42 +113,93 @@ def __init__(
self,
model: nnx.Module,
tx: optax.GradientTransformation,
wrt: filterlib.Filter = nnx.Param,
):
"""
Instantiate the class and wrap the :class:`Module` and Optax gradient
transformation. Set the step count to 0.
transformation. Instantiate the optimizer state to keep track of
:class:`Variable` types specified in ``wrt``. Set the step count to 0.

Args:
model: An NNX Module.
tx: An Optax gradient transformation.
wrt: optional argument to filter for which :class:`Variable`'s to keep
track of in the optimizer state. These should be the :class:`Variable`'s
that you plan on updating; i.e. this argument value should match the
``wrt`` argument passed to the ``nnx.grad`` call that will generate the
gradients that will be passed into the ``grads`` argument of the
:func:`update` method.
"""
self.step = OptState(jnp.array(0, dtype=jnp.uint32))
self.model = model
self.tx = tx
self.opt_state = tx.init(nnx.state(model, nnx.Param))
self.opt_state = tx.init(nnx.state(model, wrt))
self.wrt = wrt

def split(self, *filters: filterlib.Filter):
return graph.split(self, *filters)

def update(self, grads):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the
gradients are with respect to the same :class:`Variable` types as defined in
``self.wrt`` during instantiation of this ``Optimizer``. For example::

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> import optax

>>> class CustomVariable(nnx.Variable):
... pass

>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... self.custom_variable = CustomVariable(jnp.ones((1, 3)))
... def __call__(self, x):
... return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(model))
State({
'custom_variable': VariableState(
type=CustomVariable,
value=(1, 3)
),
'linear': {
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
}
})

>>> # update:
>>> # - only Linear layer parameters
>>> # - only CustomVariable parameters
>>> # - both Linear layer and CustomVariable parameters
>>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
>>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)):
... # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad`
... state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable)
... grads = nnx.grad(loss_fn, wrt=variable)(
... state.model, jnp.ones((1, 2)), jnp.ones((1, 3))
... )
... state.update(grads=grads)

Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.

Args:
grads: Gradients that have the same pytree structure as ``.params``.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.

Returns:
An updated instance of ``self`` with ``step`` incremented by one, ``params``
and ``opt_state`` updated by applying ``grads``, and additional attributes
replaced as specified by ``kwargs``.
grads: the gradients derived from ``nnx.grad``.
"""
params = nnx.state(self.model, nnx.Param)
state = nnx.state(self.model, self.wrt)

updates, new_opt_state = self.tx.update(grads, self.opt_state, params)
new_params = optax.apply_updates(params, updates)
updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
new_params = optax.apply_updates(state, updates)
assert isinstance(new_params, nnx.State)

self.step.value += 1
Expand Down
45 changes: 45 additions & 0 deletions flax/nnx/tests/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,51 @@ def update(self, *, grads, **updates): # type: ignore[signature-mismatch]
state.update(grads=grads, values=loss_fn(state.model))
self.assertTrue(state.metrics.compute() < initial_loss)

@parameterized.parameters(
{'variable': nnx.Param},
{'variable': nnx.LoRAParam},
{'variable': (nnx.Param, nnx.LoRAParam)},
)
def test_wrt_update(self, variable):
in_features = 4
out_features = 10
model = nnx.LoRA(
in_features=in_features,
lora_rank=2,
out_features=out_features,
base_module=Model(
in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0)
),
rngs=nnx.Rngs(1),
)
state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable)
prev_variables, prev_other_variables = nnx.state(model, variable, ...)

x = jnp.ones((1, 4))
y = jnp.ones((1, 10))
loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()

grads = nnx.grad(loss_fn, wrt=variable)(state.model, x, y)
initial_loss = loss_fn(model, x, y)
state.update(grads=grads)
self.assertTrue(loss_fn(model, x, y) < initial_loss)

# make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged
variables, other_variables = nnx.state(model, variable, ...)
self.assertTrue(
jax.tree.all(
jax.tree.map(lambda x, y: (x != y).all(), prev_variables, variables)
)
)
if other_variables:
self.assertTrue(
jax.tree.all(
jax.tree.map(
lambda x, y: (x == y).all(), prev_other_variables, other_variables
)
)
)


if __name__ == '__main__':
absltest.main()
Loading