Skip to content

Commit

Permalink
correct argnum indexes for transforms.checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 13, 2022
1 parent 9522bc0 commit f384b33
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
3 changes: 3 additions & 0 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,9 @@ def checkpoint(target: Target,
A wrapped version of ``target``. When computing gradients intermediate
computations will be re-computed on the backward pass.
"""
# subtract 1 from each static_argnums because 'self' is not passed to the
# lifted function
static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums)
return lift_transform(
lift.checkpoint, target,
variables=variables, rngs=rngs, concrete=concrete,
Expand Down
38 changes: 36 additions & 2 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def __call__(self, input, apply_relu : bool = False):
_ = jax.grad(lambda x: remat_model.apply(p, x, apply_relu=True))(x)

def test_remat_static_argnums(self):

test = self

class Foo(nn.Module):
train_is_static: bool

Expand All @@ -161,7 +161,7 @@ def __call__(self, inputs, train: bool):
return nn.Dense(3, use_bias=False)(inputs)

# set train as a static argument
FooRemat = nn.remat(Foo, static_argnums=(1,))
FooRemat = nn.remat(Foo, static_argnums=(2,))
foo = FooRemat(train_is_static=True)

x = jnp.empty((1, 2))
Expand All @@ -176,6 +176,40 @@ def __call__(self, inputs, train: bool):
variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))

def test_remat_decorator_static_argnums(self):
test = self

class FooTrainStatic(nn.Module):
@partial(nn.remat, static_argnums=(2,))
@nn.compact
def __call__(self, inputs, train: bool):
test.assertTrue(isinstance(train, bool))

return nn.Dense(3, use_bias=False)(inputs)

# set train as a static argument
foo = FooTrainStatic()

x = jnp.empty((1, 2))
variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))

class FooTrainDynamic(nn.Module):
@partial(nn.remat, static_argnums=())
@nn.compact
def __call__(self, inputs, train: bool):
test.assertTrue(isinstance(train, jnp.ndarray))

return nn.Dense(3, use_bias=False)(inputs)

# set train as a non-static arguments
foo = FooTrainDynamic()

variables = foo.init(random.PRNGKey(0), x, True)
y = foo.apply(variables, x, False)
self.assertEqual(y.shape, (1, 3))


def test_vmap(self):
Expand Down

0 comments on commit f384b33

Please sign in to comment.