Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
602 changes: 0 additions & 602 deletions docs_nnx/guides/array_ref.ipynb

This file was deleted.

572 changes: 572 additions & 0 deletions docs_nnx/guides/hijax.ipynb

Large diffs are not rendered by default.

82 changes: 36 additions & 46 deletions docs_nnx/guides/array_ref.md → docs_nnx/guides/hijax.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,31 @@ import optax

+++

### Array Refs 101
### Variables Refs

```{code-cell} ipython3
a_ref = jax.new_ref(jnp.array([1, 2, 3]))
variable = nnx.Variable(jnp.array([1, 2, 3]), is_hijax=True)
print(f"{variable.is_hijax = }\n")

@jax.jit
def increment(a_ref: jax.Ref): # no return!
array: jax.Array = a_ref[...] # access
a_ref[...] = array + 1 # update
def increment(variable: nnx.Variable[jax.Array]): # no return!
new_value = variable + 1 # Array-like operations
variable[...] = new_value # in-place updates

print("[1] =", a_ref); increment(a_ref); print("[2] =", a_ref)
print("Before =", variable); increment(variable); print("After =", variable)
```

```{code-cell} ipython3
@jax.jit
def inc(x):
x[...] += 1

print(increment.lower(a_ref).as_text())
# TODO: enable once as_text is fixed
# print(increment.lower(variable).as_text())
```

### Variables Refs

```{code-cell} ipython3
variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True)
print(f"{variable.has_ref = }\n")

print("[1] =", variable); increment(variable); print("[2] =", variable)
```
nnx.use_hijax(True)

```{code-cell} ipython3
with nnx.use_refs(True):
variable = nnx.Variable(jnp.array([1, 2, 3]))
variable = nnx.Variable(jnp.array([1, 2, 3]))

print(f"{variable.has_ref = }")
print(f"{variable.is_hijax = }")
```

Mention `nnx.use_refs` can be used as global flag
Expand All @@ -73,12 +63,14 @@ class Linear(nnx.Module):
def __call__(self, x):
return x @ self.kernel + self.bias[None]

model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs
refs_model = nnx.to_refs(model) # convert to array refs
arrays_model = nnx.to_arrays(refs_model) # convert to regular arrays
with nnx.use_hijax(False): # use lojax Variables
model = Linear(1, 3, rngs=nnx.Rngs(0))

hijax_model = nnx.to_hijax(model) # convert hijax Variables
arrays_model = nnx.to_lojax(hijax_model) # convert to lojax Variables

print("nnx.to_refs(model) =", refs_model)
print("nnx.to_arrays(refs_model) =", arrays_model)
print("nnx.to_hijax(model) =", hijax_model)
print("nnx.to_lojax(refs_model) =", arrays_model)
```

## Examples
Expand All @@ -99,9 +91,9 @@ class Block(nnx.Module):
### Training Loop

```{code-cell} ipython3
with nnx.use_refs(True):
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
# hijax Variables by default
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@jax.jit
def train_step(model, optimizer, x, y):
Expand All @@ -110,7 +102,7 @@ def train_step(model, optimizer, x, y):
model = nnx.merge(graphdef, params, nondiff)
return ((model(x) - y) ** 2).mean()

loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) # freeze ArrayRefs for jax.grad
loss, grads = jax.value_and_grad(loss_fn)(nnx.to_lojax(params)) # lojax Variables for jax.grad
optimizer.update(model, grads)

return loss
Expand All @@ -121,12 +113,11 @@ train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))
### Scan Over Layers

```{code-cell} ipython3
@nnx.vmap
@jax.vmap
def create_stack(rngs):
return Block(2, 64, 2, rngs=rngs)
return nnx.to_lojax(Block(2, 64, 2, rngs=rngs))

with nnx.use_refs(True):
block_stack = create_stack(nnx.Rngs(0).fork(split=8))
block_stack = nnx.to_hijax(create_stack(nnx.Rngs(0).fork(split=8)))

def scan_fn(x, block):
x = block(x)
Expand All @@ -150,42 +141,42 @@ def create_model(rngs):
return Block(2, 64, 3, rngs=rngs)

try:
with nnx.use_refs(True):
model = create_model(nnx.Rngs(0))
model = create_model(nnx.Rngs(0))
except Exception as e:
print(f"Error:", e)
```

```{code-cell} ipython3
with nnx.use_refs(False): # <-- disable array refs
with nnx.use_hijax(False): # <-- disable hijax Variables
model = create_model(nnx.Rngs(0))

model = nnx.to_refs(model) # convert to mutable after creation
model = nnx.to_hijax(model) # convert to mutable after creation

print("model.linear =", model.linear)
```

```{code-cell} ipython3
# TODO: why does this work?
@nnx.jit
def create_model(rngs):
return Block(2, 64, 3, rngs=rngs)

with nnx.use_refs(True):
model = create_model(nnx.Rngs(0))
model = create_model(nnx.Rngs(0))

print("model.linear =", model.linear)
```

### Reference Sharing (aliasing)

```{code-cell} ipython3
# TODO: why does this not fail?
def get_error(f, *args):
try:
return f(*args)
except Exception as e:
return f"{type(e).__name__}: {e}"
x = jax.new_ref(jnp.array(0))

x = nnx.Variable(jnp.array(0))

@jax.jit
def f(a, b):
Expand All @@ -211,9 +202,8 @@ class SharedModules(nnx.Pytree):
def g(pytree):
...

with nnx.use_refs(True):
shared_variables = SharedVariables()
shared_modules = SharedModules()
shared_variables = SharedVariables()
shared_modules = SharedModules()

print("SharedVariables", get_error(g, shared_variables))
print("SharedModules", get_error(g, shared_modules))
Expand Down
8 changes: 2 additions & 6 deletions docs_nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,11 @@ Basic usage
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@nnx.jit # automatic state management for JAX transforms
@nnx.jit # automatic state propagation
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()

loss_fn = lambda model: ((model(x) - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads) # in-place updates

return loss


Expand Down
75 changes: 27 additions & 48 deletions docs_nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions docs_nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Linear(nnx.Module):
self.din, self.dout = din, dout

def __call__(self, x: jax.Array):
return x @ self.w + self.b
return x @ self.w + self.b[None]
```

Also note that the inner values of `Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above).
Expand Down Expand Up @@ -73,12 +73,12 @@ class Counter(nnx.Module):
self.count = Count(jnp.array(0))

def __call__(self):
self.count.value += 1
self.count[...] += 1

counter = Counter()
print(f'{counter.count.value = }')
print(f'{counter.count[...] = }')
counter()
print(f'{counter.count.value = }')
print(f'{counter.count[...] = }')
```

Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms
Expand Down
2 changes: 1 addition & 1 deletion examples/gemma/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def assign_val_fn(
mapped_path: tuple[str | int, ...],
val: Any,
) -> dict[tuple[str, ...], Any]:
state[mapped_path].value = val
state[mapped_path].set_value(val)
return state

mdl: M = nnx.eval_shape(module_factory)
Expand Down
4 changes: 2 additions & 2 deletions examples/gemma/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def _map_key_fn(key: tuple[str, ...]) -> tuple[str | int, ...]:
np.testing.assert_array_equal(output, linen_output)
for i in range(len(num_features)):
np.testing.assert_array_equal(
mdl.layers[i].layers[0].mean.value,
mdl.layers[i].layers[0].mean[...],
linen_vars['batch_stats'][f'layers_{i}']['layers_0']['mean'],
)
np.testing.assert_array_equal(
mdl.layers[i].layers[0].var.value,
mdl.layers[i].layers[0].var[...],
linen_vars['batch_stats'][f'layers_{i}']['layers_0']['var'],
)

Expand Down
8 changes: 4 additions & 4 deletions examples/gemma/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def __init__(
self.w = nnx.Param(kernel_init(rngs.params(), shape, dtype))

def __call__(self, x: ArrayLike) -> Array:
return jnp.einsum(self.einsum_str, x, self.w.value)
return jnp.einsum(self.einsum_str, x, self.w[...])

@property
def shape(self) -> Shape:
return self.w.value.shape
return self.w.shape


class RMSNorm(nnx.Module):
Expand All @@ -65,12 +65,12 @@ def __init__(
self.scale = nnx.Param(scale_init(rngs.params(), dim, dtype))

def __call__(self, x: Array) -> Array:
dtype = self.scale.value.dtype
dtype = self.scale.dtype
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
normed_inputs = jnp.asarray(x * jax.lax.rsqrt(var + 1e-06), dtype=dtype)
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
scale = jnp.expand_dims(self.scale.value, axis=range(len(x.shape) - 1))
scale = jnp.expand_dims(self.scale, axis=range(len(x.shape) - 1))
normed_inputs = normed_inputs * (1 + scale)
return normed_inputs
6 changes: 3 additions & 3 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def encode(self, x: ArrayLike) -> Array:
return x

def decode(self, x: ArrayLike) -> Array:
return jnp.dot(x, self.input_embedding.value.T)
return jnp.dot(x, self.input_embedding.T)

@property
def embed_dim(self):
return self.input_embedding.value.shape[1]
return self.input_embedding.shape[1]

@property
def num_embed(self):
return self.input_embedding.value.shape[0]
return self.input_embedding.shape[0]


class Attention(nnx.Module):
Expand Down
4 changes: 2 additions & 2 deletions examples/gemma/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def test_forbidden_tokens(self):
transformer_config, rngs=nnx.Rngs(params=0)
)
# Pre-cook the embedding matrix so that the output is deterministic.
transformer.embedder.input_embedding.value = jnp.eye(
transformer.embedder.input_embedding.set_value(jnp.eye(
vocab.GetPieceSize(), 32
)
))
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
Expand Down
10 changes: 4 additions & 6 deletions examples/gemma/sow_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,11 @@ def merge(self, decoding_step, layer: nnx.Module):
if field.name.startswith('attn_'):
step_value = getattr(
layer.attn, field.name.replace('attn_', '')
).value[0]
)[0]
elif field.name.startswith('mlp_'):
step_value = getattr(layer.mlp, field.name.replace('mlp_', '')).value[
0
]
step_value = getattr(layer.mlp, field.name.replace('mlp_', ''))[0]
else:
step_value = getattr(layer, field.name).value[0]
step_value = getattr(layer, field.name)[0]
except AttributeError as exc:
raise ValueError(
f'Intermediate {field.name} is not in the step intermediates.'
Expand Down Expand Up @@ -93,7 +91,7 @@ def merge(self, decoding_step, transformer: nnx.Module):
if self.embeddings is not None:
try:
self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set(
transformer.embeddings.value[0][:, 0, ...]
transformer.embeddings[0][:, 0, ...]
)
except AttributeError as exc:
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,10 @@ def _assign_linen_params_to_nnx_state(
if 'gate_proj' in mapped_path:
if transpose_gating_einsum:
val = jnp.swapaxes(val, 1, 2)
state[mapped_path].value = val[0]
state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1]
state[mapped_path].set_value(val[0])
state[mapped_path[:-2] + ('up_proj', 'kernel')].set_value(val[1])
else:
state[mapped_path].value = val
state[mapped_path].set_value(val)
return state


Expand Down
Loading
Loading