Skip to content

Commit

Permalink
test codediff
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed May 3, 2024
1 parent d0e080d commit 434e9ae
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 48 deletions.
17 changes: 17 additions & 0 deletions docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,30 @@ def parse(
)
idx = lines.index(code_sep)
code_left = self._code_block(lines[0:idx])
# if title_left == 'NNX transforms123':
# # TODO: for some reason adding a space character after the separator character causes an error
# # find another example in another file where the space character is already added and it doesn't cause the error
# # and try to find the pattern
# test_code = lines[idx + 1+1 :]
# else:
# test_code = lines[idx + 1 :]
test_code = lines[idx + 1 :]
code_right = self._code_block(test_code)

output = self._tabs(
(title_left, code_left), (title_right, code_right), sync=sync
)

# # if title_left == 'NNX transforms123':
# if title_left == 'Haiku123':
# print('\nPRINT START')
# print('\n'+'='*20+'OUTPUT'+'='*20)
# print(output)
# print('\n'+'='*20+'TEST CODE'+'='*20)
# print(test_code)
# print('\nEND PRINT')
# assert False

return output, test_code

def _code_block(self, lines):
Expand Down
49 changes: 2 additions & 47 deletions docs/experimental/nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ whereas the function signature of JAX-transformed functions can only accept the
the transformed function.

.. codediff::
:title_left: NNX transforms
:title_left: NNX transforms123
:title_right: JAX transforms
:sync:

Expand All @@ -57,6 +57,7 @@ the transformed function.
train_step(model, x, y)

---

@jax.jit #!
def train_step(graphdef, state, x, y): #!
def loss_fn(graphdef, state): #!
Expand All @@ -74,49 +75,3 @@ the transformed function.

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0))) #!
graphdef, state = train_step(graphdef, state, x, y) #!


Mixing NNX and JAX transformations
**********************************

NNX and JAX transformations can be mixed together, so long as the JAX-transformed function is
pure and has valid argument types that are recognized by JAX.

.. codediff::
:title_left: Using ``nnx.jit`` with ``jax.grad``
:title_right: Using ``jax.jit`` with ``nnx.grad``
:sync:

@nnx.jit
def train_step(model, x, y):
def loss_fn(graphdef, state): #!
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, 1)(*nnx.split(model)) #!
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)

---
@jax.jit #!
def train_step(graphdef, state, x, y): #!
model = nnx.merge(graphdef, state)
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)


Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ whereas in Haiku ``name`` must be explicitly defined in the constructor
signature and passed to the superclass constructor.

.. codediff::
:title_left: Haiku
:title_left: Haiku123
:title_right: Flax
:sync:

Expand Down Expand Up @@ -358,6 +358,7 @@ return value.
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

return params, new_state

---

def train_step(params, batch_stats, inputs, labels):
Expand Down

0 comments on commit 434e9ae

Please sign in to comment.