diff --git a/flax/core/lift.py b/flax/core/lift.py index 0ec7e6abbc..ff19daee55 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -1138,6 +1138,7 @@ def checkpoint(fn: Callable[..., Any], rngs: PRNGSequenceFilter = True, concrete: bool = False, prevent_cse: bool = True, + static_argnums: Union[int, Tuple[int, ...]] = (), policy: Optional[Callable[..., bool]] = None, ) -> Callable[..., Any]: """Lifted version of ``jax.checkpoint``. @@ -1164,15 +1165,21 @@ def checkpoint(fn: Callable[..., Any], ``pmap``, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a ``scan``, this CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` can be set to False. + static_argnums: Optional, int or sequence of ints, indicates which argument + values on which to specialize for tracing and caching purposes. Specifying + arguments as static can avoid ConcretizationTypeErrors when tracing, but + at the cost of more retracing overheads. policy: Experimental checkpoint policy, see ``jax.checkpoint``. Returns: A wrapped version of ``fn``. When computing gradients intermediate computations will be re-computed when computing gradients. """ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs): + # add 2 to each static_argnums because we add two initial arguments to rematted + static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums) @functools.partial(jax.remat, - concrete=concrete, prevent_cse=prevent_cse, - policy=policy) + concrete=concrete, static_argnums=static_argnums_, + prevent_cse=prevent_cse, policy=policy) @functools.wraps(fn) def rematted(variable_groups, rng_groups, *args, **kwargs): scope = scope_fn(variable_groups, rng_groups) diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index d9f78e0a95..805b8db413 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -572,6 +572,7 @@ def checkpoint(target: Target, rngs: lift.PRNGSequenceFilter = True, concrete: bool = False, prevent_cse: bool = True, + static_argnums: Union[int, Tuple[int, ...]] = (), policy: Optional[Callable[..., bool]] = None, methods=None) -> Target: """Lifted version of ``jax.checkpoint``. @@ -599,6 +600,10 @@ def checkpoint(target: Target, ``pmap``, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a ``scan``, this CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` should be set to False. + static_argnums: Optional, int or sequence of ints, indicates which argument + values on which to specialize for tracing and caching purposes. Specifying + arguments as static can avoid ConcretizationTypeErrors when tracing, but + at the cost of more retracing overheads. policy: Experimental checkpoint policy, see ``jax.checkpoint``. methods: If `target` is a `Module`, the methods of `Module` to checkpoint. @@ -606,9 +611,13 @@ def checkpoint(target: Target, A wrapped version of ``target``. When computing gradients intermediate computations will be re-computed on the backward pass. """ + # subtract 1 from each static_argnums because 'self' is not passed to the + # lifted function + static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums) return lift_transform( lift.checkpoint, target, variables=variables, rngs=rngs, concrete=concrete, + static_argnums=static_argnums, prevent_cse=prevent_cse, policy=policy, methods=methods) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 24db263ba1..292a2bb776 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -145,6 +145,73 @@ def __call__(self, input, apply_relu : bool = False): # This next line crashes with a concretization error _ = jax.grad(lambda x: remat_model.apply(p, x, apply_relu=True))(x) + def test_remat_static_argnums(self): + test = self + + class Foo(nn.Module): + train_is_static: bool + + @nn.compact + def __call__(self, inputs, train: bool): + if self.train_is_static: + test.assertTrue(isinstance(train, bool)) + else: + test.assertTrue(isinstance(train, jnp.ndarray)) + + return nn.Dense(3, use_bias=False)(inputs) + + # set train as a static argument + FooRemat = nn.remat(Foo, static_argnums=(2,)) + foo = FooRemat(train_is_static=True) + + x = jnp.empty((1, 2)) + variables = foo.init(random.PRNGKey(0), x, True) + y = foo.apply(variables, x, False) + self.assertEqual(y.shape, (1, 3)) + + # set train as a non-static arguments + FooRemat = nn.remat(Foo, static_argnums=()) + foo = FooRemat(train_is_static=False) + + variables = foo.init(random.PRNGKey(0), x, True) + y = foo.apply(variables, x, False) + self.assertEqual(y.shape, (1, 3)) + + def test_remat_decorator_static_argnums(self): + test = self + + class FooTrainStatic(nn.Module): + @partial(nn.remat, static_argnums=(2,)) + @nn.compact + def __call__(self, inputs, train: bool): + test.assertTrue(isinstance(train, bool)) + + return nn.Dense(3, use_bias=False)(inputs) + + # set train as a static argument + foo = FooTrainStatic() + + x = jnp.empty((1, 2)) + variables = foo.init(random.PRNGKey(0), x, True) + y = foo.apply(variables, x, False) + self.assertEqual(y.shape, (1, 3)) + + class FooTrainDynamic(nn.Module): + @partial(nn.remat, static_argnums=()) + @nn.compact + def __call__(self, inputs, train: bool): + test.assertTrue(isinstance(train, jnp.ndarray)) + + return nn.Dense(3, use_bias=False)(inputs) + + # set train as a non-static arguments + foo = FooTrainDynamic() + + variables = foo.init(random.PRNGKey(0), x, True) + y = foo.apply(variables, x, False) + self.assertEqual(y.shape, (1, 3)) + + def test_vmap(self): key1, key2 = random.split(random.PRNGKey(3), 2) x = random.uniform(key1, (4, 4))