diff --git a/docs_nnx/guides/surgery.ipynb b/docs_nnx/guides/surgery.ipynb index 8213abea2a..bebe6e2526 100644 --- a/docs_nnx/guides/surgery.ipynb +++ b/docs_nnx/guides/surgery.ipynb @@ -6,15 +6,13 @@ "source": [ "# Model surgery\n", "\n", - "> **Attention**: This page relates to the new Flax NNX API.\n", + "In this guide, you will learn how to perform model surgery in Flax NNX using several real-world scenarios:\n", "\n", - "In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases:\n", + "* __Pythonic `nnx.Module` manipulation__: Using Pythonic ways to manipulate sub-`Module`s given a model.\n", "\n", - "* __Pythonic module manipulation__: Pythonic ways to manipulate sub-modules given a model.\n", + "* __Manipulation of an abstract model or state__: A key trick for playing with `flax.nnx.Module`s and states without memory allocation.\n", "\n", - "* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation.\n", - "\n", - "* __Checkpoint surgery: From a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n", + "* __Checkpoint surgery from a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n", "\n", "* __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method." ] @@ -63,11 +61,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Pythonic module manipulation\n", + "## Pythonic `nnx.Module` manipulation\n", + "\n", + "It is easier to perform model surgery when:\n", "\n", - "Doing model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code.\n", + "1) You already have a fully fleshed-out model loaded with correct parameters; and\n", + "2) You don't intend to change your model definition code.\n", "\n", - "You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching:" + "You can perform a variety of Pythonic operations on its sub-`Module`s, such as sub-`Module` swapping, `Module` sharing, variable sharing, and monkey-patching:" ] }, { @@ -80,25 +81,25 @@ "x = jax.random.normal(jax.random.key(42), (3, 4))\n", "np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))\n", "\n", - "# Sub-module swapping\n", + "# Sub-`Module` swapping.\n", "original1, original2 = model.linear1, model.linear2\n", "model.linear1, model.linear2 = model.linear2, model.linear1\n", "np.testing.assert_allclose(model(x), original1(original2(x)))\n", "\n", - "# Module sharing (tying all weights)\n", + "# `Module` sharing (tying all weights together).\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "model.linear2 = model.linear1\n", "assert not hasattr(nnx.state(model), 'linear2')\n", "np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))\n", "\n", - "# Variable sharing (weight-tying)\n", + "# Variable sharing (weight-tying).\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n", "assert hasattr(nnx.state(model), 'linear2')\n", "assert hasattr(nnx.state(model)['linear2'], 'bias')\n", "assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n", "\n", - "# Monkey-patching\n", + "# Monkey-patching.\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "def awesome_layer(x): return x\n", "model.linear2 = awesome_layer\n", @@ -111,13 +112,14 @@ "source": [ "## Creating an abstract model or state without memory allocation\n", "\n", - "For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n", + "To do more complex model surgery, the key technique you can use is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n", + "\n", + "To create an abstract model:\n", "\n", - "To create an abstract model,\n", "* Create a function that returns a valid Flax NNX model; and\n", "* Run `nnx.eval_shape` (not `jax.eval_shape`) upon it.\n", "\n", - "Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information." + "Now you can use `nnx.split` as usual to get its abstract state. Note that all fields that should be `jax.Array`s in a real model are now of an abstract `jax.ShapeDtypeStruct` type with only shape/dtype/sharding information." ] }, { @@ -164,7 +166,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "When you fill every `VariableState` leaf's `value`s with real jax arrays, the abstract model becomes equivalent to a real model." + "When you fill every `nnx.VariableState` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model." ] }, { @@ -188,9 +190,11 @@ "source": [ "## Checkpoint surgery\n", "\n", - "With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them.\n", + "With the abstract state technique in hand, you can perform arbitrary manipulation on any checkpoint - or runtime parameter pytree - to make them fit with your given model code, and then call `nnx.update` to merge them.\n", + "\n", + "This can be helpful if you are trying to significantly change the model code - for example, when migrating from Flax Linen to Flax NNX - and old weights are no longer naturally compatible.\n", "\n", - "This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX), and old weights are no longer naturally compatible. Let's run a simple example here:" + "Let's run a simple example here:" ] }, { @@ -209,7 +213,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure:" + "In this new model, the sub-`Module`s are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure has changed, it is impossible to directly load the old checkpoint with the new model state structure:" ] }, { @@ -247,7 +251,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition." + "However, you can load the parameter pytree as a raw dictionary, perform the renames, and generate a new state that is guaranteed to be compatible with your new model definition." ] }, { @@ -283,7 +287,7 @@ "source": [ "def process_raw_dict(raw_state_dict):\n", " flattened = nnx.traversals.flatten_mapping(raw_state_dict)\n", - " # Cut off the '.value' postfix on every leaf path.\n", + " # Cut the '.value' postfix on every leaf path.\n", " flattened = {(path[:-1] if path[-1] == 'value' else path): value\n", " for path, value in flattened.items()}\n", " return nnx.traversals.unflatten_mapping(flattened)\n", @@ -309,7 +313,10 @@ "source": [ "## Partial initialization\n", "\n", - "In some cases (such as with LoRA), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization." + "In some cases - such as with LoRA (Low-Rank Adaption) - you may want to randomly-initialize only *part of* your model parameters. This can be achieved through:\n", + "\n", + "- Naive partial initialization; or\n", + "- Memory-efficient partial initialization." ] }, { @@ -318,9 +325,9 @@ "source": [ "### Naive partial initialization\n", "\n", - "You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below.\n", + "To do naive partial initialization, you can just initialize the whole model, then swap the pre-trained parameters in. However, this approach may allocate additional memory midway if your modification requires re-creating module parameters that you will later discard. Below is an example of this.\n", "\n", - "> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output." + "> **Note:** You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be “messed up” when you run a single Jupyter notebook cell multiple times (due to garbage-collection of old Python variables). However, restarting the Python kernel in the notebook and running the code from scratch will always yield the same output." ] }, { @@ -344,8 +351,8 @@ "\n", "simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))\n", "print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n", - "# On this line, extra kernel and bias is created inside the new LoRALinear!\n", - "# They are wasted since you are going to use the kernel and bias in `old_state` anyway.\n", + "# In this line, extra kernel and bias is created inside the new LoRALinear!\n", + "# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.\n", "simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))\n", "print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'\n", " ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')\n", @@ -360,7 +367,7 @@ "source": [ "### Memory-efficient partial initialization\n", "\n", - "Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:" + "To do memory-efficient partial initialization, use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:" ] }, { @@ -391,10 +398,10 @@ " nnx.update(model, old_state)\n", " return model\n", "\n", - "print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n", + "print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}')\n", "# Note that `old_state` will be deleted after this `partial_init` call.\n", "good_model = partial_init(old_state, nnx.Rngs(42))\n", - "print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'\n", + "print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'\n", " ' (2 new created - lora_a and lora_b)')" ] }, diff --git a/docs_nnx/guides/surgery.md b/docs_nnx/guides/surgery.md index 3c1aa786ae..5442e0eb69 100644 --- a/docs_nnx/guides/surgery.md +++ b/docs_nnx/guides/surgery.md @@ -10,15 +10,13 @@ jupytext: # Model surgery -> **Attention**: This page relates to the new Flax NNX API. +In this guide, you will learn how to perform model surgery in Flax NNX using several real-world scenarios: -In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases: +* __Pythonic `nnx.Module` manipulation__: Using Pythonic ways to manipulate sub-`Module`s given a model. -* __Pythonic module manipulation__: Pythonic ways to manipulate sub-modules given a model. +* __Manipulation of an abstract model or state__: A key trick for playing with `flax.nnx.Module`s and states without memory allocation. -* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation. - -* __Checkpoint surgery: From a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code. +* __Checkpoint surgery from a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code. * __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method. @@ -52,36 +50,39 @@ class TwoLayerMLP(nnx.Module): return self.linear2(x) ``` -## Pythonic module manipulation +## Pythonic `nnx.Module` manipulation + +It is easier to perform model surgery when: -Doing model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code. +1) You already have a fully fleshed-out model loaded with correct parameters; and +2) You don't intend to change your model definition code. -You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching: +You can perform a variety of Pythonic operations on its sub-`Module`s, such as sub-`Module` swapping, `Module` sharing, variable sharing, and monkey-patching: ```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(42), (3, 4)) np.testing.assert_allclose(model(x), model.linear2(model.linear1(x))) -# Sub-module swapping +# Sub-`Module` swapping. original1, original2 = model.linear1, model.linear2 model.linear1, model.linear2 = model.linear2, model.linear1 np.testing.assert_allclose(model(x), original1(original2(x))) -# Module sharing (tying all weights) +# `Module` sharing (tying all weights together). model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) model.linear2 = model.linear1 assert not hasattr(nnx.state(model), 'linear2') np.testing.assert_allclose(model(x), model.linear1(model.linear1(x))) -# Variable sharing (weight-tying) +# Variable sharing (weight-tying). model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate assert hasattr(nnx.state(model), 'linear2') assert hasattr(nnx.state(model)['linear2'], 'bias') assert not hasattr(nnx.state(model)['linear2'], 'kernel') -# Monkey-patching +# Monkey-patching. model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) def awesome_layer(x): return x model.linear2 = awesome_layer @@ -90,13 +91,14 @@ np.testing.assert_allclose(model(x), model.linear1(x)) ## Creating an abstract model or state without memory allocation -For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints. +To do more complex model surgery, the key technique you can use is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints. + +To create an abstract model: -To create an abstract model, * Create a function that returns a valid Flax NNX model; and * Run `nnx.eval_shape` (not `jax.eval_shape`) upon it. -Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information. +Now you can use `nnx.split` as usual to get its abstract state. Note that all fields that should be `jax.Array`s in a real model are now of an abstract `jax.ShapeDtypeStruct` type with only shape/dtype/sharding information. ```{code-cell} ipython3 abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) @@ -104,7 +106,7 @@ gdef, abs_state = nnx.split(abs_model) pprint(abs_state) ``` -When you fill every `VariableState` leaf's `value`s with real jax arrays, the abstract model becomes equivalent to a real model. +When you fill every `nnx.VariableState` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model. ```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) @@ -118,9 +120,11 @@ np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now! ## Checkpoint surgery -With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them. +With the abstract state technique in hand, you can perform arbitrary manipulation on any checkpoint - or runtime parameter pytree - to make them fit with your given model code, and then call `nnx.update` to merge them. + +This can be helpful if you are trying to significantly change the model code - for example, when migrating from Flax Linen to Flax NNX - and old weights are no longer naturally compatible. -This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX), and old weights are no longer naturally compatible. Let's run a simple example here: +Let's run a simple example here: ```{code-cell} ipython3 # Save a version of model into a checkpoint @@ -129,7 +133,7 @@ old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True) ``` -In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure: +In this new model, the sub-`Module`s are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure has changed, it is impossible to directly load the old checkpoint with the new model state structure: ```{code-cell} ipython3 class ModifiedTwoLayerMLP(nnx.Module): @@ -149,12 +153,12 @@ except Exception as e: print(f'This will throw error: {type(e)}: {e}') ``` -But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition. +However, you can load the parameter pytree as a raw dictionary, perform the renames, and generate a new state that is guaranteed to be compatible with your new model definition. ```{code-cell} ipython3 def process_raw_dict(raw_state_dict): flattened = nnx.traversals.flatten_mapping(raw_state_dict) - # Cut off the '.value' postfix on every leaf path. + # Cut the '.value' postfix on every leaf path. flattened = {(path[:-1] if path[-1] == 'value' else path): value for path, value in flattened.items()} return nnx.traversals.unflatten_mapping(flattened) @@ -176,15 +180,18 @@ np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones( ## Partial initialization -In some cases (such as with LoRA), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization. +In some cases - such as with LoRA (Low-Rank Adaption) - you may want to randomly-initialize only *part of* your model parameters. This can be achieved through: + +- Naive partial initialization; or +- Memory-efficient partial initialization. +++ ### Naive partial initialization -You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below. +To do naive partial initialization, you can just initialize the whole model, then swap the pre-trained parameters in. However, this approach may allocate additional memory midway if your modification requires re-creating module parameters that you will later discard. Below is an example of this. -> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output. +> **Note:** You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be “messed up” when you run a single Jupyter notebook cell multiple times (due to garbage-collection of old Python variables). However, restarting the Python kernel in the notebook and running the code from scratch will always yield the same output. ```{code-cell} ipython3 # Some pretrained model state @@ -192,8 +199,8 @@ old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42))) print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}') -# On this line, extra kernel and bias is created inside the new LoRALinear! -# They are wasted since you are going to use the kernel and bias in `old_state` anyway. +# In this line, extra kernel and bias is created inside the new LoRALinear! +# They are wasted, because you are going to use the kernel and bias in `old_state` anyway. simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42)) print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}' ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)') @@ -204,7 +211,7 @@ print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' ### Memory-efficient partial initialization -Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized: +To do memory-efficient partial initialization, use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized: ```{code-cell} ipython3 # Some pretrained model state @@ -220,10 +227,10 @@ def partial_init(old_state, rngs): nnx.update(model, old_state) return model -print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}') +print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}') # Note that `old_state` will be deleted after this `partial_init` call. good_model = partial_init(old_state, nnx.Rngs(42)) -print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' +print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}' ' (2 new created - lora_a and lora_b)') ```