diff --git a/flax/linen/module.py b/flax/linen/module.py index 88692a807e..8fb4497d9c 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -2585,68 +2585,6 @@ def tabulate( ) return tabulate_fn(*args, **kwargs) - def module_paths( - self, - rngs: Union[KeyArray, RNGSequences], - *args, - show_repeated: bool = False, - mutable: CollectionFilter = DenyList('intermediates'), - **kwargs, - ) -> dict[str, 'Module']: - """Returns a dictionary mapping module paths to module instances. - - This method has the same signature and internally calls ``Module.init``, - but instead of returning the variables, it returns a dictionary mapping - module paths to unbounded copies of module instances that were used - at runtime. ``module_paths`` uses ``jax.eval_shape`` to run the forward - computation without consuming any FLOPs or allocating memory. - - Example:: - - >>> import flax.linen as nn - >>> import jax, jax.numpy as jnp - - >>> class Foo(nn.Module): - ... @nn.compact - ... def __call__(self, x): - ... h = nn.Dense(4)(x) - ... return nn.Dense(2)(h) - - >>> x = jnp.ones((16, 9)) - >>> modules = Foo().module_paths(jax.random.key(0), x) - >>> print({ - ... p: type(m).__name__ for p, m in modules.items() - ... }) - {'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'} - - `Args: - rngs: The rngs for the variable collections as passed to ``Module.init``. - *args: The arguments to the forward computation. - show_repeated: If ``True``, repeated calls to the same module will be - shown in the table, otherwise only the first call will be shown. - Default is ``False``. - mutable: Can be bool, str, or list. Specifies which collections should - be treated as mutable: ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. ``list``: A list of - names of mutable collections. By default, all collections except - 'intermediates' are mutable. - **kwargs: keyword arguments to pass to the forward computation. - - Returns: - A dict`ionary mapping module paths to module instances. - """ - from flax.linen import summary - - table = summary._get_module_table( - module=self, - depth=None, - show_repeated=show_repeated, - compute_flops=False, - compute_vjp_flops=False, - )(rngs, *args, **kwargs, mutable=mutable) - - return {'/'.join(row.path): row.module_copy for row in table} - _ParentType = Union[Type[Module], Scope, Type[_Sentinel], None] diff --git a/flax/linen/summary.py b/flax/linen/summary.py index 5692c7e9a8..fd90d3dcd4 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -30,6 +30,7 @@ Sequence, Set, Tuple, + Type, Union, ) @@ -129,7 +130,7 @@ class Row: """ path: Tuple[str, ...] - module_copy: 'module_lib.Module' + module_type: Type[module_lib.Module] method: str inputs: Any outputs: Any @@ -477,7 +478,7 @@ def _get_variables(): rows.append( Row( c.path, - c.module.copy(), + type(c.module), c.method, inputs, c.outputs, @@ -611,7 +612,7 @@ def _render_table( ) rich_table.add_row( path_repr, - type(row.module_copy).__name__ + method_repr, + row.module_type.__name__ + method_repr, *( _as_yaml_str( _summary_tree_map( diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 9150a93718..a5d0416ada 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2487,41 +2487,6 @@ def my_property(self): self.assertEqual(obj_loaded.b, 'ok') self.assertEqual(obj_loaded.my_property, 'okok') - def test_module_paths(self): - class Bar(nn.Module): - @nn.compact - def __call__(self, x): - x = nn.Dense(3)(x) - x = nn.Dense(4)(x) - return x - - class Foo(nn.Module): - @nn.compact - def __call__(self, x): - x = Bar()(x) - x = nn.Dense(5)(x) - return x - - x = jnp.ones((1, 2)) - m = Foo() - module_paths = m.module_paths(random.key(0), x) - - # assert all module are unbounded - for module in module_paths.values(): - self.assertIsNone(module.scope) - - # test paths - self.assertIn('', module_paths) - self.assertEqual(type(module_paths['']), Foo) - self.assertIn('Dense_0', module_paths) - self.assertEqual(type(module_paths['Dense_0']), nn.Dense) - self.assertIn('Bar_0', module_paths) - self.assertEqual(type(module_paths['Bar_0']), Bar) - self.assertIn('Bar_0/Dense_0', module_paths) - self.assertEqual(type(module_paths['Bar_0/Dense_0']), nn.Dense) - self.assertIn('Bar_0/Dense_1', module_paths) - self.assertEqual(type(module_paths['Bar_0/Dense_1']), nn.Dense) - class LeakTests(absltest.TestCase): def test_tracer_leaks(self):