diff --git a/flax/core/lift.py b/flax/core/lift.py index 08c77a1606..8daa718769 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -127,7 +127,7 @@ def _partial_pack( out_variable_filters: Sequence[CollectionFilter], rng_filters: Sequence[PRNGSequenceFilter], name=None, -) -> tuple[Callable[..., Any], Callable[..., Any], Any, Any, Callable[..., Any], Callable[..., Any]]: +) -> tuple[Callable[..., Any], Callable[..., Any], Any, Any, Callable[..., Any]]: """Pack variables and rngs for functional transformations. The _partial_pack function is the building block for all other lifted transformations. @@ -175,16 +175,11 @@ def _partial_pack( inner_rng_counters.append(rng_counters) rng_groups_xs_t = _transpose(rng_groups_xs) - inner_scopes: list[Scope] = [] - def scope_fn( variable_groups_xs_t, rng_groups_xs_t, mutable_filter: CollectionFilter = True, ): - nonlocal inner_scopes - for inner_scope in inner_scopes: - inner_scope.invalidate() inner_scopes = [] mutable: Filter = False for out_filter in out_variable_filters: @@ -260,10 +255,6 @@ def repack_fn(inner_scope_tree): return _transpose(out_variable_groups_xs) - def invalidate_scopes_fn(): - for inner_scope in inner_scopes: - inner_scope.invalidate() - def publish_results_fn(out_variable_groups_xs_t): out_variable_groups_xs = _transpose(out_variable_groups_xs_t) for scope, out_variable_groups, rng_counters in zip( @@ -278,14 +269,14 @@ def publish_results_fn(out_variable_groups_xs_t): scope.put_variable(col_name, var_name, value) return ( - scope_fn, - repack_fn, - variable_groups_xs_t, - rng_groups_xs_t, - publish_results_fn, - invalidate_scopes_fn, + scope_fn, + repack_fn, + variable_groups_xs_t, + rng_groups_xs_t, + publish_results_fn, ) + def pack( fn: Callable[..., Any], in_variable_filters: Sequence[CollectionFilter], @@ -322,24 +313,20 @@ def wrapper(scope_tree: Scope, *args, **kwargs): variable_groups_xs_t, rng_groups_xs_t, publish_results_fn, - invalidate_scopes_fn, ) = _partial_pack(scope_tree, in_variable_filters, out_variable_filters, rng_filters, name) - try: - if enable_kwargs: - y, out_variable_groups_xs_t = fn( - scope_fn, - repack_fn, - variable_groups_xs_t, - rng_groups_xs_t, - *args, - **kwargs, - ) - else: - y, out_variable_groups_xs_t = fn( - scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, *args - ) - finally: - invalidate_scopes_fn() + if enable_kwargs: + y, out_variable_groups_xs_t = fn( + scope_fn, + repack_fn, + variable_groups_xs_t, + rng_groups_xs_t, + *args, + **kwargs, + ) + else: + y, out_variable_groups_xs_t = fn( + scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, *args + ) publish_results_fn(out_variable_groups_xs_t) return y diff --git a/flax/linen/summary.py b/flax/linen/summary.py index badfa18178..d6676729f0 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -437,8 +437,9 @@ def _get_table_fn(*args, **kwargs): def _get_variables(): return module.init(*args, **kwargs) - - variables = jax.eval_shape(_get_variables) + # TODO(cgarciae): is it possible to avoid leaking tracers for summaries? + with jax.check_tracer_leaks(False): + variables = jax.eval_shape(_get_variables) calls = module_lib._context.call_info_stack[-1].calls calls.sort(key=lambda c: c.index) diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 1e23dfefd8..06ffd51968 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -31,6 +31,7 @@ TypeVar, Union, ) +import weakref from collections.abc import Callable, Iterable, Mapping, Sequence from flax import core @@ -440,14 +441,14 @@ class _HashableProxy: function should be retraced or not """ - module: Module + module_ref: weakref.ref hash_key: int @classmethod def from_module(cls, module: Module) -> '_HashableProxy': fingerprint = _module_fingerprint(module) hash_key = hash(fingerprint) - return cls(module, hash_key) + return cls(weakref.ref(module), hash_key) def __hash__(self): return self.hash_key @@ -455,6 +456,10 @@ def __hash__(self): def __eq__(self, other): return isinstance(other, _HashableProxy) and self.hash_key == other.hash_key + @property + def module(self): + return self.module_ref() + def _module_fingerprint(module: Module) -> tuple[type[Any], Any]: return _fingerprint_recursive(module, (), {}) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index ac6cb6dc8b..7720aa7ee3 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -2627,6 +2627,25 @@ def __call__(self, x): y = Top().apply(vs, jnp.ones((2, 5))) assert vs['aux']['vfoo']['v'].value.shape == () + def test_vjp_tracer_leak(self): + class LearnScale(nn.Module): + @nn.compact + def __call__(self, x): + p = self.param('scale', nn.initializers.zeros, ()) + return p * x + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + y, bwd = nn.vjp(lambda mdl, x: mdl(x), LearnScale(), x) + params_grad, x_grad = bwd(jnp.ones(y.shape)) + return y, params_grad, x_grad + key = jax.random.PRNGKey(0) + x = jnp.ones((2, 3)) + foo = Foo() + with jax.check_tracer_leaks(): + params = foo.init(key, x) + foo.apply(params, x) + if __name__ == '__main__': absltest.main()