-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add utilities for working with a mixture of Haiku and Flax modules.
PiperOrigin-RevId: 553733615
- Loading branch information
1 parent
1377a0d
commit 6bd4675
Showing
15 changed files
with
1,352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"# Haiku and Flax interop 🥂\n", | ||
"\n", | ||
"Utilities to move seamlessly between Haiku and Flax.\n", | ||
"\n", | ||
"## Flax inside Haiku\n", | ||
"\n", | ||
"Using a Flax module inside a `hk.transform` (or `hk.transform_with_state`) is\n", | ||
"straight forward.\n", | ||
"\n", | ||
"First construct an instance of your module, and then use `hkflax.lift` to\n", | ||
"\"lift\" (see [`hk.lift`]) the parameters and state from the flax module into the\n", | ||
"Haiku transform.\n", | ||
"\n", | ||
"Example:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Array([[ 0.33030465, -1.3496182 , 0.02847686, -1.6579462 , -0.9166192 ,\n", | ||
" 0.2883583 , -0.046898 , 0.6414894 , -0.404975 , -2.1162813 ]], dtype=float32)" | ||
] | ||
}, | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import haiku as hk\n", | ||
"import haiku.experimental.flax as hkflax\n", | ||
"import flax.linen as flax_nn\n", | ||
"\n", | ||
"def f(x):\n", | ||
" mod = hkflax.lift(flax_nn.Dense(10), name='my_flax_module')\n", | ||
" x = mod(x)\n", | ||
" return x\n", | ||
"\n", | ||
"f = hk.transform(f)\n", | ||
"x = jnp.ones([1, 1])\n", | ||
"rng = jax.random.PRNGKey(42)\n", | ||
"params = f.init(rng, x) # params contains the parameters for MyFlaxModule.\n", | ||
"f.apply(params, None, x) # MyFlaxModule will be passed parameters from params." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"To use a stateful module simply swap `hk.transform` for\n", | ||
"`hk.transform_with_state`.\n", | ||
"\n", | ||
"## Haiku inside Flax\n", | ||
"\n", | ||
"There are two supported approaches for converting `Haiku` code to `Flax`. Both\n", | ||
"produce a Flax linen `nn.Module` which encapsulates the Haiku code and provides\n", | ||
"`init` and `apply` methods to create and use parameters and state.\n", | ||
"\n", | ||
"- [Convert an `hk.Module` to `nn.Module`](#hk-Module).\n", | ||
"- [Convert an `hk.transform` to `nn.Module`](#hk-transform).\n", | ||
"- [Convert an `hk.transform_with_state` to `nn.Module`](#hk-transform).\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"### Converting `hk.Module` {#hk-Module}\n", | ||
"\n", | ||
"For stateless modules you simply need to construct the Flax module via\n", | ||
"`hkflax.Module.create`:\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"mod = hkflax.Module.create(hk.Linear, 1) # hk.Linear(1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"You can use this like a regular Flax `nn.Module` (because it is one!):" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"rng = jax.random.PRNGKey(42)\n", | ||
"x = jnp.ones([1, 1])\n", | ||
"variables = mod.init(rng, x)\n", | ||
"out = mod.apply(variables, x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"For a stateful module like ResNet, you need to also handle output state, again\n", | ||
"this is the same as stateful Flax modules:\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"mod = hkflax.Module.create(hk.nets.ResNet50, 10)\n", | ||
"\n", | ||
"# Regular flax code from here on:\n", | ||
"rng = jax.random.PRNGKey(42)\n", | ||
"x = jnp.ones([1, 224, 224, 3])\n", | ||
"variables = mod.init(rng, x, is_training=True)\n", | ||
"for _ in range(10):\n", | ||
" out, state_out = mod.apply(variables, x, is_training=True,\n", | ||
" mutable=['state'])\n", | ||
" variables = {**variables, **state_out}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"### Converting `hk.transform` or `hk.transform_with_state` {#hk-transform}\n", | ||
"\n", | ||
"`hkflax.Module` can be created from the result of `hk.transform` or\n", | ||
"`hk.transform_with_state` if you prefer:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def mlp(x):\n", | ||
" x = hk.Linear(300)(x)\n", | ||
" x = hk.Linear(100)(x)\n", | ||
" x = hk.Linear(10)(x)\n", | ||
" return x\n", | ||
"\n", | ||
"mlp = hk.transform(mlp)\n", | ||
"mlp = hkflax.Module(mlp)\n", | ||
"\n", | ||
"rng = jax.random.PRNGKey(42)\n", | ||
"x = jnp.ones([1, 28 * 28])\n", | ||
"variables = mlp.init(rng, x)\n", | ||
"out = mlp.apply(variables, x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"### Gotchas\n", | ||
"\n", | ||
"#### Initialization is different\n", | ||
"\n", | ||
"Flax and Haiku take different approaches to RNG key splitting. As such at the\n", | ||
"moment the parameters returned from `hkflax.Module(f).init` will differ from\n", | ||
"`hk.transform(f).init`.\n", | ||
"\n", | ||
"We have a route to support making Haiku transform match initialization of the\n", | ||
"Flax module, but there is not a path for the opposite direction at the moment.\n", | ||
"\n", | ||
"If aligning initialization across Haiku and Flax is important to you, we\n", | ||
"recommend using one of the libraries to create parameters, and then manipulate\n", | ||
"the params/state dictionary to match the other library as needed:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"```python\n", | ||
"# Utilities.\n", | ||
"import haiku.data_structures as hkds\n", | ||
"\n", | ||
"make_flat = {f'{m}/{n}': w for m, n, w in hkds.traverse(d)}\n", | ||
"\n", | ||
"def make_nested(d):\n", | ||
" out = {}\n", | ||
" for k, w in d.items():\n", | ||
" m, n = k.rsplit('/', 1)\n", | ||
" out.setdefault(m, {})\n", | ||
" out[m][n] = w\n", | ||
" return out\n", | ||
"\n", | ||
"# The two modules here should be equivalent when run with Flax or Haiku.\n", | ||
"f = hk.transform_with_state(..)\n", | ||
"flax_mod = hkflax.Module(f)\n", | ||
"\n", | ||
"# Option 1: Convert Haiku initialized params/state to Flax.\n", | ||
"params, state = f.init(..)\n", | ||
"variables = {'params': make_flat(params), 'state': make_flat(state)}\n", | ||
"\n", | ||
"# Option 2: Convert Flax initialized variables to Haiku.\n", | ||
"variables = flax_mod.init(..)\n", | ||
"params = make_nested(variables.get('params', {}))\n", | ||
"state = make_nested(variables.get('state', {}))\n", | ||
"\n", | ||
"# The output of the Haiku transformed function or the Flax function should be\n", | ||
"# equivalent with either init.\n", | ||
"out, state = f.apply(params, state, ..)\n", | ||
"out, variables_out = flax_mod.apply(variables, .., mutable=['state'])\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_" | ||
}, | ||
"source": [ | ||
"#### Multiple forward methods\n", | ||
"\n", | ||
"`hkflax.Module` only support `__call__` at the moment, please let tomhennigan@\n", | ||
"know if this is a blocker for you.\n", | ||
"\n", | ||
"[`hk.lift`]: https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.lift:" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": {} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.