Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Flax NNX Model Surgery #4135

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions docs/nnx/surgery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
"source": [
"# Model surgery\n",
"\n",
"This guide will demostrate how to do model surgery in NNX with a few real-scenario use cases.\n",
"In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases:\n",
"\n",
"* __Module manipulation__: Pythonic ways to manipulate submodules given a model.\n",
"* __Python module manipulation__: Pythonic ways to manipulate sub-modules given a model.\n",
"\n",
"* __Abstact model__: A key trick to play with NNX modules and states without memory allocation.\n",
"* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation.\n",
"\n",
"* __From raw state to model__: How to manipulate parameter state when they are incompatible with existing model code.\n",
"* __Checkpoint surgery: From a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n",
"\n",
"* __Partial initialization__: Initializing only part of the model from scratch."
"* __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method."
]
},
{
Expand Down Expand Up @@ -61,11 +61,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pythonic module manipulations\n",
"## Pythonic module manipulation\n",
"\n",
"Model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code.\n",
"Doing model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code.\n",
"\n",
"You can make a variety of pythonic operations on its submodules, like swapping in/out, sharing modules/weights, monkeypatching, etc. See a few code examples below."
"You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching:"
]
},
{
Expand All @@ -78,7 +78,7 @@
"x = jax.random.normal(jax.random.key(42), (3, 4))\n",
"np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))\n",
"\n",
"# Submodule swapping\n",
"# Sub-module swapping\n",
"original1, original2 = model.linear1, model.linear2\n",
"model.linear1, model.linear2 = model.linear2, model.linear1\n",
"np.testing.assert_allclose(model(x), original1(original2(x)))\n",
Expand All @@ -89,14 +89,14 @@
"assert not hasattr(nnx.state(model), 'linear2')\n",
"np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))\n",
"\n",
"# Variable sharing (weight tying)\n",
"# Variable sharing (weight-tying)\n",
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n",
"assert hasattr(nnx.state(model), 'linear2')\n",
"assert hasattr(nnx.state(model)['linear2'], 'bias')\n",
"assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n",
"\n",
"# Monkey patching\n",
"# Monkey-patching\n",
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"def awesome_layer(x): return x\n",
"model.linear2 = awesome_layer\n",
Expand All @@ -107,15 +107,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create model and state without memory allocation\n",
"## Creating an abstract model or state without memory allocation\n",
"\n",
"For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n",
"\n",
"To create an abstract model,\n",
"* Create a function that returns a valid NNX model; and\n",
"* Run `nnx.eval_shape` (not `jax.eval_shape`) upon it.\n",
"\n",
"Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model is now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information."
"Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information."
]
},
{
Expand Down Expand Up @@ -177,7 +177,7 @@
"abs_state['linear2']['kernel'].value = model.linear2.kernel\n",
"abs_state['linear2']['bias'].value = model.linear2.bias\n",
"nnx.update(abs_model, abs_state)\n",
"np.testing.assert_allclose(abs_model(x), model(x)) # they are equivalent now!"
"np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now!"
]
},
{
Expand All @@ -186,9 +186,9 @@
"source": [
"## Checkpoint surgery\n",
"\n",
"With the abstract state technique in hand, we can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with our given model code, then call `nnx.update` to merge them.\n",
"With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them.\n",
"\n",
"This is helpful when you are to change model code significantly (e.g., migrating from Linen to NNX) so that old weights are no longer naturally compatible. Let's run a simple example here."
"This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX) and old weights are no longer naturally compatible. Let's run a simple example here:"
]
},
{
Expand All @@ -197,7 +197,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Save a version of model into a checkpoint\n",
"# Save a version of a model into a checkpoint.\n",
"checkpointer = orbax.PyTreeCheckpointer()\n",
"old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)"
Expand All @@ -207,7 +207,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this new model, the submodules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure."
"In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure:"
]
},
{
Expand Down Expand Up @@ -293,7 +293,7 @@
" state = nnx.State.from_flat_path(state)\n",
" return nnx.merge(graph_def, state)\n",
"\n",
"# Make your local change on the checkpoint\n",
"# Make your local change on the checkpoint.\n",
"raw = checkpointer.restore('/tmp/nnx-surgery-state')\n",
"pprint(raw)\n",
"raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']\n",
Expand All @@ -314,7 +314,7 @@
"source": [
"## Partial initialization\n",
"\n",
"In some cases (e.g., LoRA), you might want to randomly-initialize only *part of* your model parameters."
"In some cases (such as with Low-Rank Adapation (LoRA)), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization."
]
},
{
Expand All @@ -323,9 +323,9 @@
"source": [
"### Naive partial initialization\n",
"\n",
"You can simply initialize the whole model, then swap pre-trained params in. But this approach could allocate additional memory midway, if your modification requires re-creating module params that you will later discard. See this example below.\n",
"You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below.\n",
"\n",
"> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting kernel & running from scratch will always yield same output."
"> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output."
]
},
{
Expand All @@ -350,7 +350,7 @@
"simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))\n",
"print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n",
"# On this line, extra kernel and bias is created inside the new LoRALinear!\n",
"# They are wasted b/c we are to use the kernel and bias in `old_state` anyway.\n",
"# They are wasted since you are going to use the kernel and bias in `old_state` anyway.\n",
"simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))\n",
"print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'\n",
" ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')\n",
Expand All @@ -365,7 +365,7 @@
"source": [
"### Memory-efficient partial initialization\n",
"\n",
"Use `nnx.jit`'s efficiently compiled code to make sure only the state params you need are initialized."
"Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:"
]
},
{
Expand All @@ -383,16 +383,16 @@
}
],
"source": [
"# Some pretrained model state\n",
"# Some pretrained model state.\n",
"old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"\n",
"# Use nnx.jit (which wraps jax.jit) to automatically skip unused arrays - memory efficient!\n",
"# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!\n",
"@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)\n",
"def partial_init(old_state, rngs):\n",
" model = TwoLayerMLP(4, rngs=rngs)\n",
" # create new state\n",
" # Create a new state.\n",
" model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)\n",
" # add existing create\n",
" # Add the existing state.\n",
" nnx.update(model, old_state)\n",
" return model\n",
"\n",
Expand Down
Loading
Loading