Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static_argnums to nn.checkpoint #2457

Merged
merged 2 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,7 @@ def checkpoint(fn: Callable[..., Any],
rngs: PRNGSequenceFilter = True,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable[..., Any]:
"""Lifted version of ``jax.checkpoint``.
Expand All @@ -1164,15 +1165,21 @@ def checkpoint(fn: Callable[..., Any],
``pmap``, CSE can defeat the purpose of this decorator. But in some
settings, like when used inside a ``scan``, this CSE prevention mechanism
is unnecessary, in which case ``prevent_cse`` can be set to False.
static_argnums: Optional, int or sequence of ints, indicates which argument
values on which to specialize for tracing and caching purposes. Specifying
arguments as static can avoid ConcretizationTypeErrors when tracing, but
at the cost of more retracing overheads.
policy: Experimental checkpoint policy, see ``jax.checkpoint``.
Returns:
A wrapped version of ``fn``. When computing gradients intermediate
computations will be re-computed when computing gradients.
"""
def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs):
# add 2 to each static_argnums because we add two initial arguments to rematted
static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums)
@functools.partial(jax.remat,
concrete=concrete, prevent_cse=prevent_cse,
policy=policy)
concrete=concrete, static_argnums=static_argnums_,
prevent_cse=prevent_cse, policy=policy)
@functools.wraps(fn)
def rematted(variable_groups, rng_groups, *args, **kwargs):
scope = scope_fn(variable_groups, rng_groups)
Expand Down
9 changes: 9 additions & 0 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def checkpoint(target: Target,
rngs: lift.PRNGSequenceFilter = True,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
methods=None) -> Target:
"""Lifted version of ``jax.checkpoint``.
Expand Down Expand Up @@ -599,16 +600,24 @@ def checkpoint(target: Target,
``pmap``, CSE can defeat the purpose of this decorator. But in some
settings, like when used inside a ``scan``, this CSE prevention mechanism
is unnecessary, in which case ``prevent_cse`` should be set to False.
static_argnums: Optional, int or sequence of ints, indicates which argument
values on which to specialize for tracing and caching purposes. Specifying
arguments as static can avoid ConcretizationTypeErrors when tracing, but
at the cost of more retracing overheads.
policy: Experimental checkpoint policy, see ``jax.checkpoint``.
methods: If `target` is a `Module`, the methods of `Module` to checkpoint.

Returns:
A wrapped version of ``target``. When computing gradients intermediate
computations will be re-computed on the backward pass.
"""
# subtract 1 from each static_argnums because 'self' is not passed to the
# lifted function
static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums)
return lift_transform(
lift.checkpoint, target,
variables=variables, rngs=rngs, concrete=concrete,
static_argnums=static_argnums,
prevent_cse=prevent_cse, policy=policy,
methods=methods)

Expand Down
67 changes: 67 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,73 @@ def __call__(self, input, apply_relu : bool = False):
# This next line crashes with a concretization error
_ = jax.grad(lambda x: remat_model.apply(p, x, apply_relu=True))(x)

def test_remat_static_argnums(self):
test = self

class Foo(nn.Module):
train_is_static: bool

@nn.compact
def __call__(self, inputs, train: bool):
if self.train_is_static:
test.assertTrue(isinstance(train, bool))
else:
test.assertTrue(isinstance(train, jnp.ndarray))

return nn.Dense(3, use_bias=False)(inputs)

# set train as a static argument
FooRemat = nn.remat(Foo, static_argnums=(2,))
foo = FooRemat(train_is_static=True)

x = jnp.empty((1, 2))
variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))

# set train as a non-static arguments
FooRemat = nn.remat(Foo, static_argnums=())
foo = FooRemat(train_is_static=False)

variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))

def test_remat_decorator_static_argnums(self):
test = self

class FooTrainStatic(nn.Module):
@partial(nn.remat, static_argnums=(2,))
@nn.compact
def __call__(self, inputs, train: bool):
test.assertTrue(isinstance(train, bool))

return nn.Dense(3, use_bias=False)(inputs)

# set train as a static argument
foo = FooTrainStatic()

x = jnp.empty((1, 2))
variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))

class FooTrainDynamic(nn.Module):
@partial(nn.remat, static_argnums=())
@nn.compact
def __call__(self, inputs, train: bool):
test.assertTrue(isinstance(train, jnp.ndarray))

return nn.Dense(3, use_bias=False)(inputs)

# set train as a non-static arguments
foo = FooTrainDynamic()

variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))


def test_vmap(self):
key1, key2 = random.split(random.PRNGKey(3), 2)
x = random.uniform(key1, (4, 4))
Expand Down