From 522a8fd792913377e34fb39e34a31cb23cbe783a Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Thu, 19 Dec 2024 11:35:24 -0800 Subject: [PATCH] consolidate the code example --- docs/gradient-checkpointing.md | 53 ++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 0358e02f5441..0938a5da944f 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -360,40 +360,55 @@ You may consider offloading to CPU memory instead of recomputing when checkpoint ```{code-cell} from jax.ad_checkpoint import checkpoint - policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( - "device", "pinned_host") - @functools.partial(checkpoint, policy=policy) - def f(x): - x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) - x = jnp.sin(x) - x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) - x = jnp.sin(x) - x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) - x = jnp.sin(x) - x = jnp.sum(x) - return x +def checkpoint_offload_dot_with_no_batch_dims(self): + policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( + "device", "pinned_host") + + @functools.partial(checkpoint, policy=policy) + def f(x): + x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) + x = jnp.sin(x) + x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) + x = jnp.sin(x) + x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) + x = jnp.sin(x) + x = jnp.sum(x) + return x ``` One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed. ```{code-cell} from jax.ad_checkpoint import checkpoint, checkpoint_name +from jax._src import test_util as jtu -def g(self): +def checkpoint_names_saved_offloaded_recomputed(self): + mesh = jtu.create_mesh((2,), ("x",)) + shape = (256, 128) + np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + s = NamedSharding(mesh, P("x")) + inp = jax.device_put(np_inp, s) - policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"], - offload_src='device', offload_dst='pinned_host') + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"], + offload_src='device', offload_dst='pinned_host') - @functools.partial(checkpoint, policy=policy) - def f(x): + @functools.partial(checkpoint, policy=policy) + def f(x): + def g(ys, _): + y, _ = ys y = checkpoint_name(jnp.sin(y), "y") z = checkpoint_name(jnp.sin(y), "z") + z = z.T w = checkpoint_name(jnp.sin(z), "w") - return jnp.sum(w) + return (w.T, jnp.sum(w)), None + _, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0] + return scan_out ``` +The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data. + #### List of policies The policies are: