Skip to content

Commit

Permalink
test codediff
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed May 3, 2024
1 parent d0e080d commit a4a1bb8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 16 deletions.
17 changes: 17 additions & 0 deletions docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,30 @@ def parse(
)
idx = lines.index(code_sep)
code_left = self._code_block(lines[0:idx])
# if title_left == 'NNX transforms123':
# # TODO: for some reason adding a space character after the separator character causes an error
# # find another example in another file where the space character is already added and it doesn't cause the error
# # and try to find the pattern
# test_code = lines[idx + 1+1 :]
# else:
# test_code = lines[idx + 1 :]
test_code = lines[idx + 1 :]
code_right = self._code_block(test_code)

output = self._tabs(
(title_left, code_left), (title_right, code_right), sync=sync
)

# # if title_left == 'NNX transforms123':
# if title_left == 'Haiku123':
# print('\nPRINT START')
# print('\n'+'='*20+'OUTPUT'+'='*20)
# print(output)
# print('\n'+'='*20+'TEST CODE'+'='*20)
# print(test_code)
# print('\nEND PRINT')
# assert False

return output, test_code

def _code_block(self, lines):
Expand Down
29 changes: 15 additions & 14 deletions docs/experimental/nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ whereas the function signature of JAX-transformed functions can only accept the
the transformed function.

.. codediff::
:title_left: NNX transforms
:title_left: NNX transforms123
:title_right: JAX transforms
:sync:

Expand All @@ -57,23 +57,24 @@ the transformed function.
train_step(model, x, y)

---
@jax.jit #!
def train_step(graphdef, state, x, y): #!
def loss_fn(graphdef, state): #!
model = nnx.merge(graphdef, state) #!

@jax.jit
def train_step(graphdef, state, x, y):
def loss_fn(graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, argnums=1)(graphdef, state) #!
grads = jax.grad(loss_fn, argnums=1)(graphdef, state)

model = nnx.merge(graphdef, state) #!
model = nnx.merge(graphdef, state)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
return nnx.split(model) #!
return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0))) #!
graphdef, state = train_step(graphdef, state, x, y) #!
graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)


Mixing NNX and JAX transformations
Expand All @@ -89,10 +90,10 @@ pure and has valid argument types that are recognized by JAX.

@nnx.jit
def train_step(model, x, y):
def loss_fn(graphdef, state): #!
def loss_fn(graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, 1)(*nnx.split(model)) #!
grads = jax.grad(loss_fn, 1)(*nnx.split(model))
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
Expand All @@ -103,8 +104,8 @@ pure and has valid argument types that are recognized by JAX.
train_step(model, x, y)

---
@jax.jit #!
def train_step(graphdef, state, x, y): #!
@jax.jit
def train_step(graphdef, state, x, y):
model = nnx.merge(graphdef, state)
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ whereas in Haiku ``name`` must be explicitly defined in the constructor
signature and passed to the superclass constructor.

.. codediff::
:title_left: Haiku
:title_left: Haiku123
:title_right: Flax
:sync:

Expand Down Expand Up @@ -358,6 +358,7 @@ return value.
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

return params, new_state

---

def train_step(params, batch_stats, inputs, labels):
Expand Down
3 changes: 2 additions & 1 deletion docs/guides/converting_and_upgrading/linen_upgrade_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ Using Flax Modules inside other Modules
z = nn.Dense(x, 500, name="latents")
return z
---
class Encoder(nn.Module):

class Encoder(nn.Module): #!
@nn.compact
def __call__(self, x):
x = nn.Dense(500)(x) # [1] #!
Expand Down

0 comments on commit a4a1bb8

Please sign in to comment.