Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix false positive tracer leaks in flax library. #4232

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 20 additions & 33 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _partial_pack(
out_variable_filters: Sequence[CollectionFilter],
rng_filters: Sequence[PRNGSequenceFilter],
name=None,
) -> tuple[Callable[..., Any], Callable[..., Any], Any, Any, Callable[..., Any], Callable[..., Any]]:
) -> tuple[Callable[..., Any], Callable[..., Any], Any, Any, Callable[..., Any]]:
"""Pack variables and rngs for functional transformations.

The _partial_pack function is the building block for all other lifted transformations.
Expand Down Expand Up @@ -175,16 +175,11 @@ def _partial_pack(
inner_rng_counters.append(rng_counters)
rng_groups_xs_t = _transpose(rng_groups_xs)

inner_scopes: list[Scope] = []

def scope_fn(
variable_groups_xs_t,
rng_groups_xs_t,
mutable_filter: CollectionFilter = True,
):
nonlocal inner_scopes
for inner_scope in inner_scopes:
inner_scope.invalidate()
inner_scopes = []
mutable: Filter = False
for out_filter in out_variable_filters:
Expand Down Expand Up @@ -260,10 +255,6 @@ def repack_fn(inner_scope_tree):

return _transpose(out_variable_groups_xs)

def invalidate_scopes_fn():
for inner_scope in inner_scopes:
inner_scope.invalidate()

def publish_results_fn(out_variable_groups_xs_t):
out_variable_groups_xs = _transpose(out_variable_groups_xs_t)
for scope, out_variable_groups, rng_counters in zip(
Expand All @@ -278,14 +269,14 @@ def publish_results_fn(out_variable_groups_xs_t):
scope.put_variable(col_name, var_name, value)

return (
scope_fn,
repack_fn,
variable_groups_xs_t,
rng_groups_xs_t,
publish_results_fn,
invalidate_scopes_fn,
scope_fn,
repack_fn,
variable_groups_xs_t,
rng_groups_xs_t,
publish_results_fn,
)


def pack(
fn: Callable[..., Any],
in_variable_filters: Sequence[CollectionFilter],
Expand Down Expand Up @@ -322,24 +313,20 @@ def wrapper(scope_tree: Scope, *args, **kwargs):
variable_groups_xs_t,
rng_groups_xs_t,
publish_results_fn,
invalidate_scopes_fn,
) = _partial_pack(scope_tree, in_variable_filters, out_variable_filters, rng_filters, name)
try:
if enable_kwargs:
y, out_variable_groups_xs_t = fn(
scope_fn,
repack_fn,
variable_groups_xs_t,
rng_groups_xs_t,
*args,
**kwargs,
)
else:
y, out_variable_groups_xs_t = fn(
scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, *args
)
finally:
invalidate_scopes_fn()
if enable_kwargs:
y, out_variable_groups_xs_t = fn(
scope_fn,
repack_fn,
variable_groups_xs_t,
rng_groups_xs_t,
*args,
**kwargs,
)
else:
y, out_variable_groups_xs_t = fn(
scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, *args
)
publish_results_fn(out_variable_groups_xs_t)
return y

Expand Down
5 changes: 3 additions & 2 deletions flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,9 @@ def _get_table_fn(*args, **kwargs):

def _get_variables():
return module.init(*args, **kwargs)

variables = jax.eval_shape(_get_variables)
# TODO(cgarciae): is it possible to avoid leaking tracers for summaries?
with jax.check_tracer_leaks(False):
variables = jax.eval_shape(_get_variables)
calls = module_lib._context.call_info_stack[-1].calls
calls.sort(key=lambda c: c.index)

Expand Down
9 changes: 7 additions & 2 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TypeVar,
Union,
)
import weakref
from collections.abc import Callable, Iterable, Mapping, Sequence

from flax import core
Expand Down Expand Up @@ -440,21 +441,25 @@ class _HashableProxy:
function should be retraced or not
"""

module: Module
module_ref: weakref.ref
hash_key: int

@classmethod
def from_module(cls, module: Module) -> '_HashableProxy':
fingerprint = _module_fingerprint(module)
hash_key = hash(fingerprint)
return cls(module, hash_key)
return cls(weakref.ref(module), hash_key)

def __hash__(self):
return self.hash_key

def __eq__(self, other):
return isinstance(other, _HashableProxy) and self.hash_key == other.hash_key

@property
def module(self):
return self.module_ref()


def _module_fingerprint(module: Module) -> tuple[type[Any], Any]:
return _fingerprint_recursive(module, (), {})
Expand Down
19 changes: 19 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,25 @@ def __call__(self, x):
y = Top().apply(vs, jnp.ones((2, 5)))
assert vs['aux']['vfoo']['v'].value.shape == ()

def test_vjp_tracer_leak(self):
class LearnScale(nn.Module):
@nn.compact
def __call__(self, x):
p = self.param('scale', nn.initializers.zeros, ())
return p * x
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
y, bwd = nn.vjp(lambda mdl, x: mdl(x), LearnScale(), x)
params_grad, x_grad = bwd(jnp.ones(y.shape))
return y, params_grad, x_grad
key = jax.random.PRNGKey(0)
x = jnp.ones((2, 3))
foo = Foo()
with jax.check_tracer_leaks():
params = foo.init(key, x)
foo.apply(params, x)


if __name__ == '__main__':
absltest.main()
Loading