Skip to content

Commit

Permalink
[nnx] add demo.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Feb 7, 2024
1 parent d9585e0 commit 385e234
Show file tree
Hide file tree
Showing 2 changed files with 540 additions and 0 deletions.
370 changes: 370 additions & 0 deletions flax/experimental/nnx/docs/demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a1b37dff",
"metadata": {},
"source": [
"# NNX Demo"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e8099a6f",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"from jax import numpy as jnp\n",
"from flax.experimental import nnx"
]
},
{
"cell_type": "markdown",
"id": "bcc5cffe",
"metadata": {},
"source": [
"### [1] NNX is Pythonic"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d99b73af",
"metadata": {
"outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model = MLP(\n",
" blocks=[Block(\n",
" linear=Linear(\n",
" in_features=4,\n",
" out_features=4,\n",
" use_bias=True,\n",
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7f8aa7e24670>,\n",
" bias_init=<function zeros at 0x7f8b4a8d55a0>,\n",
" dot_general=<function dot_general at 0x7f8b4aed8f70>\n",
" ),\n",
" bn=BatchNorm(\n",
" num_features=4,\n",
" use_running_average=None,\n",
" \n",
"...\n"
]
}
],
"source": [
"\n",
"class Block(nnx.Module):\n",
" def __init__(self, din, dout, *, rngs):\n",
" self.linear = nnx.Linear(din, dout, rngs=rngs,\n",
" kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal() , ('data', 'mp')))\n",
" self.bn = nnx.BatchNorm(dout, rngs=rngs)\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
" x = self.linear(x)\n",
" x = self.bn(x, use_running_average=not train)\n",
" x = nnx.relu(x)\n",
" return x\n",
"\n",
"\n",
"class MLP(nnx.Module):\n",
" def __init__(self, nlayers, dim, *, rngs): # explicit RNG threading\n",
" self.blocks = [\n",
" Block(dim, dim, rngs=rngs) for _ in range(nlayers)\n",
" ]\n",
" self.count = Count(0) # stateful variables are defined as attributes\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
" self.count += 1 # in-place stateful updates\n",
" for block in self.blocks:\n",
" x = block(x, train=train)\n",
" return x\n",
"\n",
"class Count(nnx.Variable): # custom Variable types define the \"collections\"\n",
" pass\n",
"\n",
"model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n",
"y = model(jnp.ones((2, 4)), train=False) # call methods directly\n",
"\n",
"print(f'{model = }'[:500] + '\\n...')"
]
},
{
"cell_type": "markdown",
"id": "523aa27c",
"metadata": {},
"source": [
"Because NNX Modules contain their own state, they are very easily to inspect:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6f278ec4",
"metadata": {
"outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model.count = 1\n",
"model.blocks[0].linear.kernel = Array([[ 0.4541134 , -0.5264871 , -0.36505395, -0.57566494],\n",
" [ 0.3880299 , 0.56555384, 0.48706698, 0.22677685],\n",
" [-0.9015692 , 0.24465257, -0.58447087, 0.18421973],\n",
" [-0.06992681, -0.64693826, 0.20232539, 1.1200054 ]], dtype=float32)\n"
]
}
],
"source": [
"print(f'{model.count = }')\n",
"print(f'{model.blocks[0].linear.kernel = }')\n",
"# print(f'{model.blocks.sdf.kernel = }') # typesafe inspection"
]
},
{
"cell_type": "markdown",
"id": "95f389f2",
"metadata": {},
"source": [
"### [2] Model Surgery is Intuitive"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "96f61108",
"metadata": {
"outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n"
]
}
],
"source": [
"# Module sharing\n",
"model.blocks[1] = model.blocks[3]\n",
"# Weight tying\n",
"model.blocks[0].linear.variables.kernel = model.blocks[-1].linear.variables.kernel\n",
"# Monkey patching\n",
"def my_optimized_layer(x, *, train: bool): return x\n",
"model.blocks[2] = my_optimized_layer\n",
"\n",
"y = model(jnp.ones((2, 4)), train=False) # still works\n",
"print(f'{y.shape = }')"
]
},
{
"cell_type": "markdown",
"id": "aca5a6cd",
"metadata": {},
"source": [
"### [3] Interacting with JAX is easy"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c166dcc7",
"metadata": {
"outputId": "9a3f378b-739e-4f45-9968-574651200ede"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"state = State({\n",
" 'blocks': {\n",
" '0': {\n",
" 'linear': {\n",
" 'kernel': Array([[-0.33674937, 1.0543901 , -0.524824 , 0.16665861],\n",
" [ 0.6607222 , 0.07498633, -0.165967 , -0.36928803],\n",
" [-0.7086948 , -0.5809104 , 0.2939486 , -0.6660238 ],\n",
" [-0.13412867, 0.09832543, 0.77024055, -0.2405255 ]], dtype=float32),\n",
" 'bias': Array([0., 0., 0., 0.], dtype=float32)\n",
" },\n",
" 'bn': {\n",
" 'mean': Array([0., 0., 0., 0.], dtype=float32),\n",
"...\n",
"\n",
"static = GraphDef(\n",
" type=MLP,\n",
" index=0,\n",
" attributes=('blocks', 'count'),\n",
" subgraphs={\n",
" 'blocks': GraphDef(\n",
" type=list,\n",
" index=1,\n",
" attributes=('0', '1', '2', '3', '4'),\n",
" subgraphs={\n",
" '0': GraphDef(\n",
" type=Block,\n",
" index=2,\n",
" attributes=('line\n",
"...\n"
]
}
],
"source": [
"state, static = model.split()\n",
"\n",
"# state is a dictionary-like JAX pytree\n",
"print(f'{state = }'[:500] + '\\n...')\n",
"\n",
"# static is also a JAX pytree, but containing no data, just metadata\n",
"print(f'\\n{static = }'[:300] + '\\n...')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9f03e3af",
"metadata": {
"outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n",
"model.count = Array(3, dtype=int32, weak_type=True)\n"
]
}
],
"source": [
"state, static = model.split()\n",
"\n",
"@jax.jit\n",
"def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n",
" model = static.merge(state)\n",
" y = model(x, train=True)\n",
" state, _ = model.split()\n",
" return y, state\n",
"\n",
"x = jnp.ones((2, 4))\n",
"y, state = forward(static,state, x)\n",
"\n",
"model.update(state)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{model.count = }')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9e23dbb4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n",
"model.count = Array(4, dtype=int32, weak_type=True)\n"
]
}
],
"source": [
"params, batch_stats, counts, static = model.split(nnx.Param, nnx.BatchStat, Count)\n",
"\n",
"@jax.jit\n",
"def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n",
" model = static.merge(params, batch_stats, counts)\n",
" y = model(x, train=True)\n",
" params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n",
" return y, params, batch_stats, counts\n",
"\n",
"x = jnp.ones((2, 4))\n",
"y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n",
"\n",
"model.update(params, batch_stats, counts)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{model.count = }')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2461bfe8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n",
"parent.model.count = Array(5, dtype=int32, weak_type=True)\n"
]
}
],
"source": [
"class Parent(nnx.Module):\n",
"\n",
" def __init__(self, model: MLP):\n",
" self.model = model\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
"\n",
" params, batch_stats, counts, static = self.model.split(nnx.Param, nnx.BatchStat, Count)\n",
"\n",
" @jax.jit\n",
" def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n",
" model = static.merge(params, batch_stats, counts)\n",
" y = model(x, train=True)\n",
" params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n",
" return y, params, batch_stats, counts\n",
"\n",
" y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n",
"\n",
" self.model.update(params, batch_stats, counts)\n",
"\n",
" return y\n",
"\n",
"parent = Parent(model)\n",
"\n",
"y = parent(jnp.ones((2, 4)), train=False)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{parent.model.count = }')"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 385e234

Please sign in to comment.