Skip to content

Commit

Permalink
Add utilities for working with a mixture of Haiku and Flax modules.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553733615
  • Loading branch information
tomhennigan authored and copybara-github committed Aug 4, 2023
1 parent 1377a0d commit 6bd4675
Show file tree
Hide file tree
Showing 15 changed files with 1,352 additions and 0 deletions.
28 changes: 28 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,37 @@ SupportsCall

.. autoclass:: SupportsCall

Flax Interop
============

.. automodule:: haiku.experimental.flax

Haiku inside Flax
-----------------

Module
~~~~~~

.. autoclass:: Module

flatten_flax_to_haiku
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: flatten_flax_to_haiku

Flax inside Haiku
-----------------

lift
~~~~

.. autofunction:: lift

Advanced State Management
=========================

.. automodule:: haiku

Lifting
-------

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Alternatively, you can install via PyPI::
:caption: Advanced
:maxdepth: 1

notebooks/flax
notebooks/jax2tf
notebooks/build_your_own_haiku
notebooks/visualization
Expand Down
285 changes: 285 additions & 0 deletions docs/notebooks/flax.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"# Haiku and Flax interop 🥂\n",
"\n",
"Utilities to move seamlessly between Haiku and Flax.\n",
"\n",
"## Flax inside Haiku\n",
"\n",
"Using a Flax module inside a `hk.transform` (or `hk.transform_with_state`) is\n",
"straight forward.\n",
"\n",
"First construct an instance of your module, and then use `hkflax.lift` to\n",
"\"lift\" (see [`hk.lift`]) the parameters and state from the flax module into the\n",
"Haiku transform.\n",
"\n",
"Example:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "_"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"data": {
"text/plain": [
"Array([[ 0.33030465, -1.3496182 , 0.02847686, -1.6579462 , -0.9166192 ,\n",
" 0.2883583 , -0.046898 , 0.6414894 , -0.404975 , -2.1162813 ]], dtype=float32)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import haiku as hk\n",
"import haiku.experimental.flax as hkflax\n",
"import flax.linen as flax_nn\n",
"\n",
"def f(x):\n",
" mod = hkflax.lift(flax_nn.Dense(10), name='my_flax_module')\n",
" x = mod(x)\n",
" return x\n",
"\n",
"f = hk.transform(f)\n",
"x = jnp.ones([1, 1])\n",
"rng = jax.random.PRNGKey(42)\n",
"params = f.init(rng, x) # params contains the parameters for MyFlaxModule.\n",
"f.apply(params, None, x) # MyFlaxModule will be passed parameters from params."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"To use a stateful module simply swap `hk.transform` for\n",
"`hk.transform_with_state`.\n",
"\n",
"## Haiku inside Flax\n",
"\n",
"There are two supported approaches for converting `Haiku` code to `Flax`. Both\n",
"produce a Flax linen `nn.Module` which encapsulates the Haiku code and provides\n",
"`init` and `apply` methods to create and use parameters and state.\n",
"\n",
"- [Convert an `hk.Module` to `nn.Module`](#hk-Module).\n",
"- [Convert an `hk.transform` to `nn.Module`](#hk-transform).\n",
"- [Convert an `hk.transform_with_state` to `nn.Module`](#hk-transform).\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"### Converting `hk.Module` {#hk-Module}\n",
"\n",
"For stateless modules you simply need to construct the Flax module via\n",
"`hkflax.Module.create`:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "_"
},
"outputs": [],
"source": [
"mod = hkflax.Module.create(hk.Linear, 1) # hk.Linear(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"You can use this like a regular Flax `nn.Module` (because it is one!):"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "_"
},
"outputs": [],
"source": [
"rng = jax.random.PRNGKey(42)\n",
"x = jnp.ones([1, 1])\n",
"variables = mod.init(rng, x)\n",
"out = mod.apply(variables, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"For a stateful module like ResNet, you need to also handle output state, again\n",
"this is the same as stateful Flax modules:\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "_"
},
"outputs": [],
"source": [
"mod = hkflax.Module.create(hk.nets.ResNet50, 10)\n",
"\n",
"# Regular flax code from here on:\n",
"rng = jax.random.PRNGKey(42)\n",
"x = jnp.ones([1, 224, 224, 3])\n",
"variables = mod.init(rng, x, is_training=True)\n",
"for _ in range(10):\n",
" out, state_out = mod.apply(variables, x, is_training=True,\n",
" mutable=['state'])\n",
" variables = {**variables, **state_out}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"### Converting `hk.transform` or `hk.transform_with_state` {#hk-transform}\n",
"\n",
"`hkflax.Module` can be created from the result of `hk.transform` or\n",
"`hk.transform_with_state` if you prefer:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "_"
},
"outputs": [],
"source": [
"def mlp(x):\n",
" x = hk.Linear(300)(x)\n",
" x = hk.Linear(100)(x)\n",
" x = hk.Linear(10)(x)\n",
" return x\n",
"\n",
"mlp = hk.transform(mlp)\n",
"mlp = hkflax.Module(mlp)\n",
"\n",
"rng = jax.random.PRNGKey(42)\n",
"x = jnp.ones([1, 28 * 28])\n",
"variables = mlp.init(rng, x)\n",
"out = mlp.apply(variables, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"### Gotchas\n",
"\n",
"#### Initialization is different\n",
"\n",
"Flax and Haiku take different approaches to RNG key splitting. As such at the\n",
"moment the parameters returned from `hkflax.Module(f).init` will differ from\n",
"`hk.transform(f).init`.\n",
"\n",
"We have a route to support making Haiku transform match initialization of the\n",
"Flax module, but there is not a path for the opposite direction at the moment.\n",
"\n",
"If aligning initialization across Haiku and Flax is important to you, we\n",
"recommend using one of the libraries to create parameters, and then manipulate\n",
"the params/state dictionary to match the other library as needed:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"```python\n",
"# Utilities.\n",
"import haiku.data_structures as hkds\n",
"\n",
"make_flat = {f'{m}/{n}': w for m, n, w in hkds.traverse(d)}\n",
"\n",
"def make_nested(d):\n",
" out = {}\n",
" for k, w in d.items():\n",
" m, n = k.rsplit('/', 1)\n",
" out.setdefault(m, {})\n",
" out[m][n] = w\n",
" return out\n",
"\n",
"# The two modules here should be equivalent when run with Flax or Haiku.\n",
"f = hk.transform_with_state(..)\n",
"flax_mod = hkflax.Module(f)\n",
"\n",
"# Option 1: Convert Haiku initialized params/state to Flax.\n",
"params, state = f.init(..)\n",
"variables = {'params': make_flat(params), 'state': make_flat(state)}\n",
"\n",
"# Option 2: Convert Flax initialized variables to Haiku.\n",
"variables = flax_mod.init(..)\n",
"params = make_nested(variables.get('params', {}))\n",
"state = make_nested(variables.get('state', {}))\n",
"\n",
"# The output of the Haiku transformed function or the Flax function should be\n",
"# equivalent with either init.\n",
"out, state = f.apply(params, state, ..)\n",
"out, variables_out = flax_mod.apply(variables, .., mutable=['state'])\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_"
},
"source": [
"#### Multiple forward methods\n",
"\n",
"`hkflax.Module` only support `__call__` at the moment, please let tomhennigan@\n",
"know if this is a blocker for you.\n",
"\n",
"[`hk.lift`]: https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.lift:"
]
}
],
"metadata": {
"colab": {}
},
"nbformat": 4,
"nbformat_minor": 0
}
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ defusedxml==0.7.1
docutils==0.17.1
entrypoints==0.3
flatbuffers==2.0
flax==0.7.1
idna==3.3
imagesize==1.2.0
importlib-resources==5.9.0
Expand Down
3 changes: 3 additions & 0 deletions haiku/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ py_library(
"config.py",
"data_structures.py",
"experimental/__init__.py",
"experimental/flax.py",
"experimental/jaxpr_info.py",
"initializers.py",
"mixed_precision.py",
Expand Down Expand Up @@ -73,6 +74,8 @@ py_library(
"//haiku/_src:transform",
"//haiku/_src:typing",
"//haiku/_src:utils",
"//haiku/_src/flax:flax_module",
"//haiku/_src/flax:transform_flax",
"//haiku/_src/nets:mlp",
"//haiku/_src/nets:mobilenetv1",
"//haiku/_src/nets:resnet",
Expand Down
Loading

0 comments on commit 6bd4675

Please sign in to comment.