Skip to content

Commit

Permalink
add static_argnums to nn.checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 12, 2022
1 parent e320e11 commit 9522bc0
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
11 changes: 9 additions & 2 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,7 @@ def checkpoint(fn: Callable[..., Any],
rngs: PRNGSequenceFilter = True,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable[..., Any]:
"""Lifted version of ``jax.checkpoint``.
Expand All @@ -1164,15 +1165,21 @@ def checkpoint(fn: Callable[..., Any],
``pmap``, CSE can defeat the purpose of this decorator. But in some
settings, like when used inside a ``scan``, this CSE prevention mechanism
is unnecessary, in which case ``prevent_cse`` can be set to False.
static_argnums: Optional, int or sequence of ints, indicates which argument
values on which to specialize for tracing and caching purposes. Specifying
arguments as static can avoid ConcretizationTypeErrors when tracing, but
at the cost of more retracing overheads.
policy: Experimental checkpoint policy, see ``jax.checkpoint``.
Returns:
A wrapped version of ``fn``. When computing gradients intermediate
computations will be re-computed when computing gradients.
"""
def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs):
# add 2 to each static_argnums because we add two initial arguments to rematted
static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums)
@functools.partial(jax.remat,
concrete=concrete, prevent_cse=prevent_cse,
policy=policy)
concrete=concrete, static_argnums=static_argnums_,
prevent_cse=prevent_cse, policy=policy)
@functools.wraps(fn)
def rematted(variable_groups, rng_groups, *args, **kwargs):
scope = scope_fn(variable_groups, rng_groups)
Expand Down
6 changes: 6 additions & 0 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def checkpoint(target: Target,
rngs: lift.PRNGSequenceFilter = True,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
methods=None) -> Target:
"""Lifted version of ``jax.checkpoint``.
Expand Down Expand Up @@ -599,6 +600,10 @@ def checkpoint(target: Target,
``pmap``, CSE can defeat the purpose of this decorator. But in some
settings, like when used inside a ``scan``, this CSE prevention mechanism
is unnecessary, in which case ``prevent_cse`` should be set to False.
static_argnums: Optional, int or sequence of ints, indicates which argument
values on which to specialize for tracing and caching purposes. Specifying
arguments as static can avoid ConcretizationTypeErrors when tracing, but
at the cost of more retracing overheads.
policy: Experimental checkpoint policy, see ``jax.checkpoint``.
methods: If `target` is a `Module`, the methods of `Module` to checkpoint.
Expand All @@ -609,6 +614,7 @@ def checkpoint(target: Target,
return lift_transform(
lift.checkpoint, target,
variables=variables, rngs=rngs, concrete=concrete,
static_argnums=static_argnums,
prevent_cse=prevent_cse, policy=policy,
methods=methods)

Expand Down
33 changes: 33 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,39 @@ def __call__(self, input, apply_relu : bool = False):
# This next line crashes with a concretization error
_ = jax.grad(lambda x: remat_model.apply(p, x, apply_relu=True))(x)

def test_remat_static_argnums(self):

test = self
class Foo(nn.Module):
train_is_static: bool

@nn.compact
def __call__(self, inputs, train: bool):
if self.train_is_static:
test.assertTrue(isinstance(train, bool))
else:
test.assertTrue(isinstance(train, jnp.ndarray))

return nn.Dense(3, use_bias=False)(inputs)

# set train as a static argument
FooRemat = nn.remat(Foo, static_argnums=(1,))
foo = FooRemat(train_is_static=True)

x = jnp.empty((1, 2))
variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))

# set train as a non-static arguments
FooRemat = nn.remat(Foo, static_argnums=())
foo = FooRemat(train_is_static=False)

variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))


def test_vmap(self):
key1, key2 = random.split(random.PRNGKey(3), 2)
x = random.uniform(key1, (4, 4))
Expand Down

0 comments on commit 9522bc0

Please sign in to comment.