Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] improve docs #4141

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions docs/nnx/index.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@

NNX
========
.. div:: sd-text-left sd-font-italic

**N**\ eural **N**\ etworks for JA\ **X**

NNX is a **N**\ eural **N**\ etwork library for JA\ **X** that focuses on providing the best
development experience, so building and experimenting with neural networks is easy and
intuitive. It achieves this by embracing Python’s object-oriented model and making it
compatible with JAX transforms, resulting in code that is easy to inspect, debug, and
analyze.

----

NNX is a new Flax API that is designed to make it easier to create, inspect, debug,
and analyze neural networks in JAX. It achieves this by adding first class support
for Python reference semantics, allowing users to express their models using regular
Python objects. NNX takes years of feedback from Linen and brings to Flax a simpler
and more user-friendly experience.

Features
^^^^^^^^^
Expand Down
17 changes: 7 additions & 10 deletions docs/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
"source": [
"# NNX Basics\n",
"\n",
"NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best \n",
"development experience, so building and experimenting with neural networks is easy and\n",
"intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees), \n",
"enabling reference sharing and mutability. This design allows your models to resemble \n",
"familiar Python object-oriented code, particularly appealing to users of frameworks\n",
"like PyTorch.\n",
"\n",
"Despite its simplified implementation, NNX supports the same powerful design patterns \n",
"that have allowed Linen to scale effectively to large codebases."
"NNX is a new Flax API that is designed to make it easier to create, inspect, debug,\n",
"and analyze neural networks in JAX. It achieves this by adding first class support\n",
"for Python reference semantics, allowing users to express their models using regular\n",
"Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference\n",
"sharing and mutability. This design should should make PyTorch or Keras users feel at\n",
"home."
]
},
{
Expand Down Expand Up @@ -68,7 +65,7 @@
}
],
"source": [
"! pip install -U flax treescope"
"# ! pip install -U flax treescope"
]
},
{
Expand Down
17 changes: 7 additions & 10 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,17 @@ jupytext:

# NNX Basics

NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best
development experience, so building and experimenting with neural networks is easy and
intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees),
enabling reference sharing and mutability. This design allows your models to resemble
familiar Python object-oriented code, particularly appealing to users of frameworks
like PyTorch.

Despite its simplified implementation, NNX supports the same powerful design patterns
that have allowed Linen to scale effectively to large codebases.
NNX is a new Flax API that is designed to make it easier to create, inspect, debug,
and analyze neural networks in JAX. It achieves this by adding first class support
for Python reference semantics, allowing users to express their models using regular
Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference
sharing and mutability. This design should should make PyTorch or Keras users feel at
home.

```{code-cell} ipython3
:tags: [skip-execution]

! pip install -U flax treescope
# ! pip install -U flax treescope
```

```{code-cell} ipython3
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
from .nnx.variables import (
Param as Param,
register_variable_name_type_pair as register_variable_name_type_pair,
)
)
# this needs to be imported before optimizer to prevent circular import
from .nnx.training import optimizer as optimizer
from .nnx.training.metrics import Metric as Metric
Expand Down
5 changes: 4 additions & 1 deletion flax/nnx/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def split_rngs_wrapper(*args, **kwargs):
key = stream()
backups.append((stream, stream.key.value, stream.count.value))
stream.key.value = jax.random.split(key, splits)
counts_shape = (splits, *stream.count.shape)
if isinstance(splits, int):
counts_shape = (splits, *stream.count.shape)
else:
counts_shape = (*splits, *stream.count.shape)
stream.count.value = jnp.zeros(counts_shape, dtype=jnp.uint32)

return SplitBackups(backups)
Expand Down
1 change: 0 additions & 1 deletion flax/nnx/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from flax.nnx.nnx.transforms.transforms import resolve_kwargs
from flax.typing import Leaf, MISSING, Missing, PytreeDeque
import jax
from jax._src.tree_util import broadcast_prefix
import jax.core
import jax.numpy as jnp
import jax.stages
Expand Down
1 change: 0 additions & 1 deletion flax/nnx/tests/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from absl.testing import absltest
import flax
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/tests/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ def __init__(self, dout: int, rngs: nnx.Rngs):
self.rngs = rngs

def __call__(self, x):

@partial(nnx.vmap, in_axes=(0, None), axis_size=5)
@nnx.split_rngs(splits=5)
@nnx.vmap(in_axes=(0, None), axis_size=5)
def vmap_fn(inner, x):
return inner(x)

Expand Down
Loading