diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index ad955f8d90..575409bd8b 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -103,6 +103,9 @@ from .nnx.spmd import with_sharding_constraint as with_sharding_constraint from .nnx.state import State as State from .nnx.training import metrics as metrics +from .nnx.variables import ( + Param as Param, +) # this needs to be imported before optimizer to prevent circular import from .nnx.training import optimizer as optimizer from .nnx.training.metrics import Metric as Metric from .nnx.training.metrics import MultiMetric as MultiMetric @@ -127,7 +130,6 @@ from .nnx.variables import Cache as Cache from .nnx.variables import Empty as Empty from .nnx.variables import Intermediate as Intermediate -from .nnx.variables import Param as Param from .nnx.variables import Variable as Variable from .nnx.variables import VariableState as VariableState from .nnx.variables import VariableMetadata as VariableMetadata diff --git a/flax/nnx/nnx/training/optimizer.py b/flax/nnx/nnx/training/optimizer.py index 47cd27ad60..88cd8cd774 100644 --- a/flax/nnx/nnx/training/optimizer.py +++ b/flax/nnx/nnx/training/optimizer.py @@ -113,42 +113,93 @@ def __init__( self, model: nnx.Module, tx: optax.GradientTransformation, + wrt: filterlib.Filter = nnx.Param, ): """ Instantiate the class and wrap the :class:`Module` and Optax gradient - transformation. Set the step count to 0. + transformation. Instantiate the optimizer state to keep track of + :class:`Variable` types specified in ``wrt``. Set the step count to 0. Args: model: An NNX Module. tx: An Optax gradient transformation. + wrt: optional argument to filter for which :class:`Variable`'s to keep + track of in the optimizer state. These should be the :class:`Variable`'s + that you plan on updating; i.e. this argument value should match the + ``wrt`` argument passed to the ``nnx.grad`` call that will generate the + gradients that will be passed into the ``grads`` argument of the + :func:`update` method. """ self.step = OptState(jnp.array(0, dtype=jnp.uint32)) self.model = model self.tx = tx - self.opt_state = tx.init(nnx.state(model, nnx.Param)) + self.opt_state = tx.init(nnx.state(model, wrt)) + self.wrt = wrt def split(self, *filters: filterlib.Filter): return graph.split(self, *filters) def update(self, grads): """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value. + The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the + gradients are with respect to the same :class:`Variable` types as defined in + ``self.wrt`` during instantiation of this ``Optimizer``. For example:: + + >>> from flax import nnx + >>> import jax, jax.numpy as jnp + >>> import optax + + >>> class CustomVariable(nnx.Variable): + ... pass + + >>> class Model(nnx.Module): + ... def __init__(self, rngs): + ... self.linear = nnx.Linear(2, 3, rngs=rngs) + ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) + ... def __call__(self, x): + ... return self.linear(x) + self.custom_variable + >>> model = Model(rngs=nnx.Rngs(0)) + >>> jax.tree.map(jnp.shape, nnx.state(model)) + State({ + 'custom_variable': VariableState( + type=CustomVariable, + value=(1, 3) + ), + 'linear': { + 'bias': VariableState( + type=Param, + value=(3,) + ), + 'kernel': VariableState( + type=Param, + value=(2, 3) + ) + } + }) + + >>> # update: + >>> # - only Linear layer parameters + >>> # - only CustomVariable parameters + >>> # - both Linear layer and CustomVariable parameters + >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() + >>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)): + ... # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad` + ... state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) + ... grads = nnx.grad(loss_fn, wrt=variable)( + ... state.model, jnp.ones((1, 2)), jnp.ones((1, 3)) + ... ) + ... state.update(grads=grads) Note that internally this function calls ``.tx.update()`` followed by a call to ``optax.apply_updates()`` to update ``params`` and ``opt_state``. Args: - grads: Gradients that have the same pytree structure as ``.params``. - **kwargs: Additional dataclass attributes that should be ``.replace()``-ed. - - Returns: - An updated instance of ``self`` with ``step`` incremented by one, ``params`` - and ``opt_state`` updated by applying ``grads``, and additional attributes - replaced as specified by ``kwargs``. + grads: the gradients derived from ``nnx.grad``. """ - params = nnx.state(self.model, nnx.Param) + state = nnx.state(self.model, self.wrt) - updates, new_opt_state = self.tx.update(grads, self.opt_state, params) - new_params = optax.apply_updates(params, updates) + updates, new_opt_state = self.tx.update(grads, self.opt_state, state) + new_params = optax.apply_updates(state, updates) assert isinstance(new_params, nnx.State) self.step.value += 1 diff --git a/flax/nnx/tests/optimizer_test.py b/flax/nnx/tests/optimizer_test.py index b612ca3b34..d2bcaf609f 100644 --- a/flax/nnx/tests/optimizer_test.py +++ b/flax/nnx/tests/optimizer_test.py @@ -120,6 +120,51 @@ def update(self, *, grads, **updates): # type: ignore[signature-mismatch] state.update(grads=grads, values=loss_fn(state.model)) self.assertTrue(state.metrics.compute() < initial_loss) + @parameterized.parameters( + {'variable': nnx.Param}, + {'variable': nnx.LoRAParam}, + {'variable': (nnx.Param, nnx.LoRAParam)}, + ) + def test_wrt_update(self, variable): + in_features = 4 + out_features = 10 + model = nnx.LoRA( + in_features=in_features, + lora_rank=2, + out_features=out_features, + base_module=Model( + in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0) + ), + rngs=nnx.Rngs(1), + ) + state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) + prev_variables, prev_other_variables = nnx.state(model, variable, ...) + + x = jnp.ones((1, 4)) + y = jnp.ones((1, 10)) + loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() + + grads = nnx.grad(loss_fn, wrt=variable)(state.model, x, y) + initial_loss = loss_fn(model, x, y) + state.update(grads=grads) + self.assertTrue(loss_fn(model, x, y) < initial_loss) + + # make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged + variables, other_variables = nnx.state(model, variable, ...) + self.assertTrue( + jax.tree.all( + jax.tree.map(lambda x, y: (x != y).all(), prev_variables, variables) + ) + ) + if other_variables: + self.assertTrue( + jax.tree.all( + jax.tree.map( + lambda x, y: (x == y).all(), prev_other_variables, other_variables + ) + ) + ) + if __name__ == '__main__': absltest.main()