Skip to content

Commit

Permalink
test codediff
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed May 3, 2024
1 parent d0e080d commit 92dab2d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def parse(
(title_left, code_left), (title_right, code_right), sync=sync
)

print('\nPRINT START')
print('\n'+'='*20+'OUTPUT'+'='*20)
print(output)
print('\n'+'='*20+'TEST CODE'+'='*20)
print(test_code)
print('\nEND PRINT')

return output, test_code

def _code_block(self, lines):
Expand Down
1 change: 1 addition & 0 deletions docs/experimental/nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ the transformed function.
train_step(model, x, y)

---

@jax.jit #!
def train_step(graphdef, state, x, y): #!
def loss_fn(graphdef, state): #!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ return value.
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

return params, new_state

---

def train_step(params, batch_stats, inputs, labels):
Expand Down

0 comments on commit 92dab2d

Please sign in to comment.