Skip to content

Commit

Permalink
[nnx] use explicit Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Feb 29, 2024
1 parent d1f219f commit 6ad6c96
Show file tree
Hide file tree
Showing 73 changed files with 1,323 additions and 1,716 deletions.
2 changes: 1 addition & 1 deletion docs/experimental/nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Basic usage
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
return x @ self.w.value + self.b.value

rngs = nnx.Rngs(0) # explicit RNG handling
model = Linear(din=2, dout=3, rngs=rngs) # initialize the model
Expand Down
84 changes: 49 additions & 35 deletions docs/experimental/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
" self.din, self.dout = din, dout\n",
"\n",
" def __call__(self, x: jax.Array):\n",
" return x @ self.w + self.b"
" return x @ self.w.value + self.b.value"
]
},
{
Expand Down Expand Up @@ -80,26 +80,26 @@
" din=2,\n",
" dout=3\n",
")\n",
"model.w = Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
"model.w.value = Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
" [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n",
"model.b = Array([0., 0., 0.], dtype=float32)\n"
"model.b.value = Array([0., 0., 0.], dtype=float32)\n"
]
}
],
"source": [
"model = Linear(din=2, dout=3, rngs=nnx.Rngs(params=0))\n",
"\n",
"print(f'{model = }')\n",
"print(f'{model.w = }')\n",
"print(f'{model.b = }')"
"print(f'{model.w.value = }')\n",
"print(f'{model.b.value = }')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is very handy for debugging as it allows accessing the entire structure or\n",
"modify it. Similarly, computation can ran directly."
"modify it. Similarly, computation can be ran directly."
]
},
{
Expand Down Expand Up @@ -145,15 +145,15 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"counter.count = 0\n",
"counter.count = 1\n"
"counter.count.value = 0\n",
"counter.count.value = 1\n"
]
}
],
Expand All @@ -163,12 +163,12 @@
" self.count = nnx.Variable(0)\n",
"\n",
" def __call__(self):\n",
" self.count += 1\n",
" self.count.value += 1\n",
"\n",
"counter = Counter()\n",
"print(f'{counter.count = }')\n",
"print(f'{counter.count.value = }')\n",
"counter()\n",
"print(f'{counter.count = }')"
"print(f'{counter.count.value = }')"
]
},
{
Expand Down Expand Up @@ -199,7 +199,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -215,9 +215,9 @@
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x1307e2a60>,\n",
" bias_init=<function zeros at 0x11728bca0>,\n",
" dot_general=<function dot_general at 0x1169f6700>\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x169773f70>,\n",
" bias_init=<function zeros at 0x1353b8ca0>,\n",
" dot_general=<function dot_general at 0x126dc5700>\n",
" ),\n",
" bn=BatchNorm(\n",
" num_features=2,\n",
Expand Down Expand Up @@ -257,22 +257,22 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model.blocks[1].linear.kernel = Array([[0.992858 , 0.9711272],\n",
"model.blocks[1].linear.kernel.value = Array([[0.992858 , 0.9711272],\n",
" [1.4061186, 0.4704619]], dtype=float32)\n",
"model.blocks[0].bn.scale = Array([1., 1.], dtype=float32)\n"
"model.blocks[0].bn.scale.value = Array([1., 1.], dtype=float32)\n"
]
}
],
"source": [
"print(f'{model.blocks[1].linear.kernel = }')\n",
"print(f'{model.blocks[0].bn.scale = }')"
"print(f'{model.blocks[1].linear.kernel.value = }')\n",
"print(f'{model.blocks[0].bn.scale.value = }')"
]
},
{
Expand Down Expand Up @@ -316,7 +316,7 @@
"model.blocks[2] = awesome_layer\n",
"\n",
"# Variable sharing (weight tying)\n",
"model.blocks[-1].linear.variables.kernel = model.blocks[0].linear.variables.kernel\n",
"model.blocks[-1].linear.kernel = model.blocks[0].linear.kernel\n",
"\n",
"model(jnp.ones((1, 2)))"
]
Expand Down Expand Up @@ -353,8 +353,8 @@
" self.count = Count(0)\n",
"\n",
" def __call__(self, x: jax.Array):\n",
" self.count += 1\n",
" return x @ self.w + self.b\n",
" self.count.value += 1\n",
" return x @ self.w.value + self.b.value\n",
" \n",
"model = StatefulLinear(din=2, dout=3, rngs=nnx.Rngs(0))"
]
Expand Down Expand Up @@ -382,10 +382,16 @@
"output_type": "stream",
"text": [
"state = State({\n",
" 'w': Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
" [0.7862853 , 0.03352201, 0.50682676]], dtype=float32),\n",
" 'b': Array([0., 0., 0.], dtype=float32),\n",
" 'count': 0\n",
" 'w': Param(\n",
" raw_value=Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
" [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n",
" ),\n",
" 'b': Param(\n",
" raw_value=Array([0., 0., 0.], dtype=float32)\n",
" ),\n",
" 'count': Count(\n",
" raw_value=0\n",
" )\n",
"})\n",
"\n",
"static = GraphDef(\n",
Expand Down Expand Up @@ -431,8 +437,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"model.count = 0\n",
"model.count = Array(1, dtype=int32, weak_type=True)\n"
"model.count = Count(\n",
" raw_value=0\n",
")\n",
"model.count.value = Array(1, dtype=int32, weak_type=True)\n"
]
}
],
Expand All @@ -456,7 +464,7 @@
"# 5. Update the state of the original Module\n",
"model.update(state)\n",
"\n",
"print(f'{model.count = }')"
"print(f'{model.count.value = }')"
]
},
{
Expand Down Expand Up @@ -503,13 +511,19 @@
"output_type": "stream",
"text": [
"params = State({\n",
" 'w': Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
" [0.7862853 , 0.03352201, 0.50682676]], dtype=float32),\n",
" 'b': Array([0., 0., 0.], dtype=float32)\n",
" 'w': Param(\n",
" raw_value=Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
" [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n",
" ),\n",
" 'b': Param(\n",
" raw_value=Array([0., 0., 0.], dtype=float32)\n",
" )\n",
"})\n",
"\n",
"counts = State({\n",
" 'count': Array(1, dtype=int32, weak_type=True)\n",
" 'count': Count(\n",
" raw_value=Array(1, dtype=int32, weak_type=True)\n",
" )\n",
"})\n"
]
}
Expand Down
26 changes: 13 additions & 13 deletions docs/experimental/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Linear(nnx.Module):
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
return x @ self.w.value + self.b.value
```

As shown above dynamic state is stored in `nnx.Variable`s such as `nnx.Param`,
Expand All @@ -54,12 +54,12 @@ for inspection.
model = Linear(din=2, dout=3, rngs=nnx.Rngs(params=0))
print(f'{model = }')
print(f'{model.w = }')
print(f'{model.b = }')
print(f'{model.w.value = }')
print(f'{model.b.value = }')
```

This is very handy for debugging as it allows accessing the entire structure or
modify it. Similarly, computation can ran directly.
modify it. Similarly, computation can be ran directly.

```{code-cell} ipython3
x = jnp.ones((1, 2))
Expand All @@ -84,12 +84,12 @@ class Counter(nnx.Module):
self.count = nnx.Variable(0)
def __call__(self):
self.count += 1
self.count.value += 1
counter = Counter()
print(f'{counter.count = }')
print(f'{counter.count.value = }')
counter()
print(f'{counter.count = }')
print(f'{counter.count.value = }')
```

**This looks too easy, what is the catch?**
Expand Down Expand Up @@ -136,8 +136,8 @@ One of the benefits of NNX is that nested Modules as easy to inspect and
static analyzers can help you while doing so.

```{code-cell} ipython3
print(f'{model.blocks[1].linear.kernel = }')
print(f'{model.blocks[0].bn.scale = }')
print(f'{model.blocks[1].linear.kernel.value = }')
print(f'{model.blocks[0].bn.scale.value = }')
```

#### Model Surgery
Expand All @@ -160,7 +160,7 @@ def awesome_layer(x): return x
model.blocks[2] = awesome_layer
# Variable sharing (weight tying)
model.blocks[-1].linear.variables.kernel = model.blocks[0].linear.variables.kernel
model.blocks[-1].linear.kernel = model.blocks[0].linear.kernel
model(jnp.ones((1, 2)))
```
Expand All @@ -187,8 +187,8 @@ class StatefulLinear(nnx.Module):
self.count = Count(0)
def __call__(self, x: jax.Array):
self.count += 1
return x @ self.w + self.b
self.count.value += 1
return x @ self.w.value + self.b.value
model = StatefulLinear(din=2, dout=3, rngs=nnx.Rngs(0))
```
Expand Down Expand Up @@ -236,7 +236,7 @@ y, state = forward(static, state, x=jnp.ones((1, 2)))
# 5. Update the state of the original Module
model.update(state)
print(f'{model.count = }')
print(f'{model.count.value = }')
```

The key insight of this pattern is that using mutable references is
Expand Down
6 changes: 3 additions & 3 deletions examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,11 @@ def encode_strings(strs, max_len):
predict_step,
in_axes=(
0,
jax.tree_util.tree_map(lambda x: None, state.params),
jax.tree_map(lambda x: None, state.params),
0,
None,
None,
jax.tree_util.tree_map(lambda x: None, predict_config),
jax.tree_map(lambda x: None, predict_config),
None,
None,
),
Expand Down Expand Up @@ -558,7 +558,7 @@ def encode_strings(strs, max_len):
# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = next(train_iter)
batch = jax.tree_util.tree_map(lambda x: jnp.array(x), batch)
batch = jax.tree_map(lambda x: jnp.array(x), batch)
state, metrics = jit_train_step(
state, batch, train_config, learning_rate_fn, 0.0, dropout_rngs
)
Expand Down
8 changes: 2 additions & 6 deletions flax/core/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,7 @@ def copy(
if isinstance(x, FrozenDict):
return x.copy(add_or_replace)
elif isinstance(x, dict):
new_dict = jax.tree_util.tree_map(
lambda x: x, x
) # make a deep copy of dict x
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
new_dict.update(add_or_replace)
return new_dict
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
Expand Down Expand Up @@ -282,9 +280,7 @@ def pop(
if isinstance(x, FrozenDict):
return x.pop(key)
elif isinstance(x, dict):
new_dict = jax.tree_util.tree_map(
lambda x: x, x
) # make a deep copy of dict x
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
value = new_dict.pop(key)
return new_dict, value
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
Expand Down
Loading

0 comments on commit 6ad6c96

Please sign in to comment.