Skip to content

Commit

Permalink
Merge pull request #3703 from levskaya:rematdocfix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 608683576
  • Loading branch information
Flax Authors committed Feb 20, 2024
2 parents efbe705 + 5c7fb4c commit eca5dd6
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,12 +971,12 @@ def checkpoint(
>>> import flax.linen as nn
...
>>> class CheckpointedMLP(nn.Module):
... @nn.checkpoint
... @nn.compact
... def __call__(self, x):
... CheckpointDense = nn.checkpoint(nn.Dense)
... x = CheckpointDense(128)(x)
... x = nn.Dense(128)(x)
... x = nn.relu(x)
... x = CheckpointDense(1)(x)
... x = nn.Dense(1)(x)
... return x
...
>>> model = CheckpointedMLP()
Expand Down

0 comments on commit eca5dd6

Please sign in to comment.