Skip to content

Commit

Permalink
Switch NNX to use Treescope instead of Penzai.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 664936417
  • Loading branch information
danieldjohnson authored and Flax Authors committed Aug 19, 2024
1 parent 71b5a46 commit 7c96b0e
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 45 deletions.
4 changes: 2 additions & 2 deletions docs/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
}
],
"source": [
"! pip install -U flax penzai"
"! pip install -U flax treescope"
]
},
{
Expand Down Expand Up @@ -171,7 +171,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The above visualization by `nnx.display` is generated using the awesome [Penzai](https://penzai.readthedocs.io/en/stable/index.html#) library."
"The above visualization by `nnx.display` is generated using the awesome [Treescope](https://treescope.readthedocs.io/en/stable/index.html#) library."
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ that have allowed Linen to scale effectively to large codebases.
```{code-cell} ipython3
:tags: [skip-execution]
! pip install -U flax penzai
! pip install -U flax treescope
```

```{code-cell} ipython3
Expand Down Expand Up @@ -77,7 +77,7 @@ print(y)
nnx.display(model)
```

The above visualization by `nnx.display` is generated using the awesome [Penzai](https://penzai.readthedocs.io/en/stable/index.html#) library.
The above visualization by `nnx.display` is generated using the awesome [Treescope](https://treescope.readthedocs.io/en/stable/index.html#) library.

+++

Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ipython_genutils
sphinx-design
jupytext==1.13.8
dm-haiku
penzai>=0.1.2; python_version>='3.10'
treescope>=0.1.1; python_version>='3.10'

# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
Expand Down
12 changes: 6 additions & 6 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ def __nnx_repr__(self):
yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('metadata', self.metadata)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
Expand Down Expand Up @@ -317,9 +317,9 @@ def __nnx_repr__(self):
yield reprlib.Attr('nodedef', self.nodedef)
yield reprlib.Attr('index_mapping', self.index_mapping)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'nodedef': self.nodedef,
Expand Down
11 changes: 6 additions & 5 deletions flax/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,20 +403,21 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
flatten_func=partial(_module_flatten, with_keys=False),
)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
from penzai.treescope import formatting_util # type: ignore[import-not-found,import-untyped]
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
color=formatting_util.color_from_string(type(self).__qualname__)
color=treescope.formatting_util.color_from_string(
type(self).__qualname__
)
)

# -------------------------
Expand Down
12 changes: 6 additions & 6 deletions flax/nnx/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def __nnx_repr__(self):
yield reprlib.Object(type(self))
yield reprlib.Attr('trace_state', self._trace_state)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={'trace_state': self._trace_state},
path=path,
Expand Down Expand Up @@ -173,14 +173,14 @@ def to_shape_dtype(value):
if clear_seen:
CONTEXT.seen_modules_repr = None

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
Expand Down
8 changes: 4 additions & 4 deletions flax/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __nnx_repr__(self):
continue
yield r

def __penzai_repr__(self, path, subtree_renderer):
def __treescope_repr__(self, path, subtree_renderer):
children = {}
for k, v in self.state.items():
if isinstance(v, State):
Expand Down Expand Up @@ -141,15 +141,15 @@ def __nnx_repr__(self):
v = NestedStateRepr(v)
yield reprlib.Attr(repr(k), v)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]

children = {}
for k, v in self.items():
if isinstance(v, State):
v = NestedStateRepr(v)
children[k] = v
return pz_repr_lib.render_dictionary_wrapper(
return treescope.repr_lib.render_dictionary_wrapper(
object_type=type(self),
wrapped_dict=children,
path=path,
Expand Down
6 changes: 3 additions & 3 deletions flax/nnx/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
yield reprlib.Attr('jax_trace', self._jax_trace)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={'jax_trace': self._jax_trace},
path=path,
Expand Down
12 changes: 6 additions & 6 deletions flax/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,16 +429,16 @@ def __nnx_repr__(self):
continue
yield reprlib.Attr(name, repr(value))

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name == 'raw_value':
name = 'value'
if name.endswith('_hooks') or name == '_trace_state':
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
Expand Down Expand Up @@ -853,14 +853,14 @@ def __nnx_repr__(self):
continue
yield reprlib.Attr(name, repr(value))

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {'type': self.type}
for name, value in vars(self).items():
if name == 'type' or name.endswith('_hooks'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
Expand Down
17 changes: 8 additions & 9 deletions flax/nnx/nnx/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import importlib.util

penzai_installed = importlib.util.find_spec('penzai') is not None
treescope_installed = importlib.util.find_spec('treescope') is not None
try:
from IPython import get_ipython

Expand All @@ -24,18 +24,17 @@


def display(*args):
"""Display the given objects using a Penzai visualizer.
"""Display the given objects using the Treescope pretty-printer.
If Penzai is not installed or the code is not running in IPython, ``display``
will print the objects instead.
If treescope is not installed or the code is not running in IPython,
``display`` will print the objects instead.
"""
if not penzai_installed or not in_ipython:
if not treescope_installed or not in_ipython:
for x in args:
print(x)
return

from penzai import pz # type: ignore[import-not-found,import-untyped]
import treescope # type: ignore[import-not-found,import-untyped]

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
for x in args:
pz.ts.display(x, ignore_exceptions=True)
for x in args:
treescope.display(x, ignore_exceptions=True, autovisualize=True)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ testing = [
"nbstripout",
"black[jupyter]==23.7.0",
# "pyink==23.5.0", # disabling pyink fow now
"penzai>=0.1.2; python_version>='3.10'",
"treescope>=0.1.1; python_version>='3.10'",
]

[project.urls]
Expand Down

0 comments on commit 7c96b0e

Please sign in to comment.