-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add trimmed Linen to NNX guide #4209
Conversation
docs_nnx/guides/linen_to_nnx.rst
Outdated
|
||
* **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input. | ||
|
||
* This means Linen can ``@nn.compact`` decorator to define a model with only one method, wheras NNX modules must have both ``__init__`` and ``__call__`` defined. This also means that the input shape must be explicitly passed during module creation because the parameter shapes cannot be inferred from the input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some rewording to improve accuracy.
* This means Linen can ``@nn.compact`` decorator to define a model with only one method, wheras NNX modules must have both ``__init__`` and ``__call__`` defined. This also means that the input shape must be explicitly passed during module creation because the parameter shapes cannot be inferred from the input. | |
* Linen can use the ``@nn.compact`` decorator to define the model in a single method and use shape inference from the input sample, whereas NNX modules generally requests additional shape information to create all parameters during ``__init__`` and separately define the computation in ``__call__``. |
docs_nnx/guides/linen_to_nnx.rst
Outdated
|
||
* Linen uses ``@jax.jit`` to compile the training step, whereas NNX uses ``@nnx.jit``. ``jax.jit`` only accepts pure stateless arguments, but ``nnx.jit`` allows the arguments to be stateful NNX modules. This greatly reduced the number of lines needed for a train step. | ||
|
||
* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return an NNX model state of gradients. In NNX, here you need to use the split/merge API |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return an NNX model state of gradients. In NNX, here you need to use the split/merge API | |
* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return the gradients of Modules as NNX ``State`` dictionaries. To use regular ``jax.grad`` with NNX you need to use the split/merge API |
docs_nnx/guides/linen_to_nnx.rst
Outdated
|
||
* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return an NNX model state of gradients. In NNX, here you need to use the split/merge API | ||
|
||
* If you are already using Optax optimizer classes like ``optax.adam(1e-3)`` and use its ``update()`` method to update your model params (instead of the raw ``jax.tree.map`` computation like here), check out `nnx.Optimizer example <https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ for a much more concise way of training and updating your model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* If you are already using Optax optimizer classes like ``optax.adam(1e-3)`` and use its ``update()`` method to update your model params (instead of the raw ``jax.tree.map`` computation like here), check out `nnx.Optimizer example <https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ for a much more concise way of training and updating your model. | |
* If you are already using Optax optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here), check out `nnx.Optimizer example <https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ for a much more concise way of training and updating your model. |
docs_nnx/guides/linen_to_nnx.rst
Outdated
def loss_fn(model): | ||
logits = model( | ||
|
||
inputs, # <== inputs | ||
|
||
) | ||
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be more appealing to show it like this to contrast how much simpler NNX is in this case.
def loss_fn(model): | |
logits = model( | |
inputs, # <== inputs | |
) | |
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() | |
def loss_fn(model): | |
logits = model(inputs) | |
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() |
docs_nnx/guides/linen_to_nnx.rst
Outdated
* ``nn.BatchNorm`` creates ``batch_stats`` -> ``nnx.BatchNorm`` creates ``nnx.BatchStats``. | ||
|
||
* ``linen.Module.sow()`` creates ``intermediates`` -> ``nnx.Module.sow()`` creates ``nnx.Intermediates``. You can also simply get the intermediates by assigning it to a module attribute, like ``self.sowed = nnx.Intermediates(x)``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* ``nn.BatchNorm`` creates ``batch_stats`` -> ``nnx.BatchNorm`` creates ``nnx.BatchStats``. | |
* ``linen.Module.sow()`` creates ``intermediates`` -> ``nnx.Module.sow()`` creates ``nnx.Intermediates``. You can also simply get the intermediates by assigning it to a module attribute, like ``self.sowed = nnx.Intermediates(x)``. | |
* ``nn.Dense`` creates ``params`` -> ``nnx.Linear`` creates ``nnx.Param``. | |
* ``nn.BatchNorm`` creates ``batch_stats`` -> ``nnx.BatchNorm`` creates ``nnx.BatchStat``. | |
* ``linen.Module.sow()`` creates ``intermediates`` -> ``nnx.Module.sow()`` creates ``nnx.Intermediate``. You can also simply get the intermediates by assigning it to a module attribute, like ``self.sowed = nnx.Intermediate(x)``. |
BTW: in Linen this is also true, you can simply add an intermediate via self.variable('intermediates' 'sowed', lambda: x)
|
||
@nn.compact |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add new line to align __call__
s
@nn.compact | |
@nn.compact |
docs_nnx/guides/linen_to_nnx.rst
Outdated
x = self.batchnorm(x, use_running_average=not training) | ||
x = jax.nn.relu(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Increment count
x = self.batchnorm(x, use_running_average=not training) | |
x = jax.nn.relu(x) | |
x = self.batchnorm(x, use_running_average=not training) | |
self.count.value += 1 | |
x = jax.nn.relu(x) |
docs_nnx/guides/linen_to_nnx.rst
Outdated
Using Multiple Methods | ||
========== | ||
|
||
In this section we will take a look at how to use multiple methods in all three |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this section we will take a look at how to use multiple methods in all three | |
In this section we will take a look at how to use multiple methods in both |
docs_nnx/guides/linen_to_nnx.rst
Outdated
In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined | ||
in ``__init__`` to scan over the sequence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some explanation for nnx.scan
In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined | |
in ``__init__`` to scan over the sequence. | |
In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined | |
in ``__init__`` to scan over the sequence, and explicitly set ``in_axes=(nnx.Carry, None, 1)``, | |
``Carry`` means that the ``carry`` argument will be the carry, ``None`` means that ``cell`` will | |
be broadcasted to all steps, and ``1`` means ``x`` will be scanned across axis 1. |
docs_nnx/guides/linen_to_nnx.rst
Outdated
Scan over Layers | ||
========== | ||
|
||
In general, lifted transforms of Linen and NNX should look the same. However, NNX lifted transforms is designed to be closer to their lower level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, lifted transforms of Linen and NNX should look the same. However, NNX lifted transforms is designed to be closer to their lower level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it. | |
In general, lifted transforms of Linen and NNX should look the same. However, NNX lifted transforms are designed to be closer to their lower level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it. |
docs_nnx/guides/linen_to_nnx.rst
Outdated
|
||
In Linen, we apply a ``nn.scan`` upon the module ``Block`` to create a larger module ``ScanBlock`` that contains 5 ``Block``. It will automatically create a large parameter of shape ``(5, 64, 64)`` at initialization time, and at call time iterate over every ``(64, 64)`` slice for a total of 5 times, like a ``jax.lax.scan`` would. | ||
|
||
But if you think closely, there actually wans't any ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if you think closely, there actually wans't any ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array. | |
But if you think closely, there actually isn't the need for ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array. |
docs_nnx/guides/linen_to_nnx.rst
Outdated
|
||
But if you think closely, there actually wans't any ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array. | ||
|
||
This means that in NNX, since model initialization and running code are completely decoupled, we need to use ``nnx.vmap`` to initialize the underlying blocks, and then use ``nnx.scan`` to run the model input through them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that in NNX, since model initialization and running code are completely decoupled, we need to use ``nnx.vmap`` to initialize the underlying blocks, and then use ``nnx.scan`` to run the model input through them. | |
In NNX we take advantage of the fact that model initialization and running code are completely decoupled, and instead use ``nnx.vmap`` to initialize the underlying blocks, and ``nnx.scan`` to run the model input through them. |
docs_nnx/guides/linen_to_nnx.rst
Outdated
|
||
There are a few other details to explain in this example: | ||
|
||
* **What is that `nnx.split_rngs` decorator?** This is because ``jax.vmap`` and ``jax.lax.scan`` requires a list of RNG keys if each of its internal operations needs its own key. So for the 5 layers inside ``MLP``, it will split and provide 5 different RNG keys from its arguments before going into the JAX transform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added explanation of why nnx.split_rngs
is necessary.
* **What is that `nnx.split_rngs` decorator?** This is because ``jax.vmap`` and ``jax.lax.scan`` requires a list of RNG keys if each of its internal operations needs its own key. So for the 5 layers inside ``MLP``, it will split and provide 5 different RNG keys from its arguments before going into the JAX transform. | |
* **What is that ``nnx.split_rngs`` decorator?** NNX transforms are completely agnostic of RNG state, this makes them behave more like JAX transforms but diverge from the Linen transforms which do handle RNG state. To regain this functionality, the ``nnx.split_rngs`` decorator allows you to split the ``Rngs`` before passing them | |
to the decorated function and 'lower' them afterwards so they can be used outside. | |
* This is needed because ``jax.vmap`` and ``jax.lax.scan`` requires a list of RNG keys if each of its internal operations needs its own key. So for the 5 layers inside ``MLP``, it will split and provide 5 different RNG keys from its arguments before going into the JAX transform. |
138633f
to
fe2e586
Compare
All comments adopted - thanks @cgarciae ! |
Extracted the Linen/NNX part from the original Haiku/Linen/NNX guide, removed a bunch of Haiku-targeted examples, and added a few structured explanations and usage examples targeting existing Linen users, including:
The original Haiku/Linen/NNX guide will later be tailored to a Haiku/Flax NNX guide.