Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
a638b16 by Cristian Garcia <cgarcia.e88@gmail.com>:

add Module.module_paths

PiperOrigin-RevId: 601493850
  • Loading branch information
Flax Team committed Jan 25, 2024
1 parent fc77713 commit 60bb31d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 100 deletions.
62 changes: 0 additions & 62 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,68 +2585,6 @@ def tabulate(
)
return tabulate_fn(*args, **kwargs)

def module_paths(
self,
rngs: Union[KeyArray, RNGSequences],
*args,
show_repeated: bool = False,
mutable: CollectionFilter = DenyList('intermediates'),
**kwargs,
) -> dict[str, 'Module']:
"""Returns a dictionary mapping module paths to module instances.
This method has the same signature and internally calls ``Module.init``,
but instead of returning the variables, it returns a dictionary mapping
module paths to unbounded copies of module instances that were used
at runtime. ``module_paths`` uses ``jax.eval_shape`` to run the forward
computation without consuming any FLOPs or allocating memory.
Example::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> class Foo(nn.Module):
... @nn.compact
... def __call__(self, x):
... h = nn.Dense(4)(x)
... return nn.Dense(2)(h)
>>> x = jnp.ones((16, 9))
>>> modules = Foo().module_paths(jax.random.key(0), x)
>>> print({
... p: type(m).__name__ for p, m in modules.items()
... })
{'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
`Args:
rngs: The rngs for the variable collections as passed to ``Module.init``.
*args: The arguments to the forward computation.
show_repeated: If ``True``, repeated calls to the same module will be
shown in the table, otherwise only the first call will be shown.
Default is ``False``.
mutable: Can be bool, str, or list. Specifies which collections should
be treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A list of
names of mutable collections. By default, all collections except
'intermediates' are mutable.
**kwargs: keyword arguments to pass to the forward computation.
Returns:
A dict`ionary mapping module paths to module instances.
"""
from flax.linen import summary

table = summary._get_module_table(
module=self,
depth=None,
show_repeated=show_repeated,
compute_flops=False,
compute_vjp_flops=False,
)(rngs, *args, **kwargs, mutable=mutable)

return {'/'.join(row.path): row.module_copy for row in table}


_ParentType = Union[Type[Module], Scope, Type[_Sentinel], None]

Expand Down
7 changes: 4 additions & 3 deletions flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Sequence,
Set,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -129,7 +130,7 @@ class Row:
"""

path: Tuple[str, ...]
module_copy: 'module_lib.Module'
module_type: Type[module_lib.Module]
method: str
inputs: Any
outputs: Any
Expand Down Expand Up @@ -477,7 +478,7 @@ def _get_variables():
rows.append(
Row(
c.path,
c.module.copy(),
type(c.module),
c.method,
inputs,
c.outputs,
Expand Down Expand Up @@ -611,7 +612,7 @@ def _render_table(
)
rich_table.add_row(
path_repr,
type(row.module_copy).__name__ + method_repr,
row.module_type.__name__ + method_repr,
*(
_as_yaml_str(
_summary_tree_map(
Expand Down
35 changes: 0 additions & 35 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2487,41 +2487,6 @@ def my_property(self):
self.assertEqual(obj_loaded.b, 'ok')
self.assertEqual(obj_loaded.my_property, 'okok')

def test_module_paths(self):
class Bar(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(3)(x)
x = nn.Dense(4)(x)
return x

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = Bar()(x)
x = nn.Dense(5)(x)
return x

x = jnp.ones((1, 2))
m = Foo()
module_paths = m.module_paths(random.key(0), x)

# assert all module are unbounded
for module in module_paths.values():
self.assertIsNone(module.scope)

# test paths
self.assertIn('', module_paths)
self.assertEqual(type(module_paths['']), Foo)
self.assertIn('Dense_0', module_paths)
self.assertEqual(type(module_paths['Dense_0']), nn.Dense)
self.assertIn('Bar_0', module_paths)
self.assertEqual(type(module_paths['Bar_0']), Bar)
self.assertIn('Bar_0/Dense_0', module_paths)
self.assertEqual(type(module_paths['Bar_0/Dense_0']), nn.Dense)
self.assertIn('Bar_0/Dense_1', module_paths)
self.assertEqual(type(module_paths['Bar_0/Dense_1']), nn.Dense)


class LeakTests(absltest.TestCase):
def test_tracer_leaks(self):
Expand Down

0 comments on commit 60bb31d

Please sign in to comment.