diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index c46e89a90f..7954cdd728 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -971,12 +971,12 @@ def checkpoint( >>> import flax.linen as nn ... >>> class CheckpointedMLP(nn.Module): + ... @nn.checkpoint ... @nn.compact ... def __call__(self, x): - ... CheckpointDense = nn.checkpoint(nn.Dense) - ... x = CheckpointDense(128)(x) + ... x = nn.Dense(128)(x) ... x = nn.relu(x) - ... x = CheckpointDense(1)(x) + ... x = nn.Dense(1)(x) ... return x ... >>> model = CheckpointedMLP()