Skip to content

Commit

Permalink
[nnx] improve nnx.scan in_axes/out_axes
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 2, 2024
1 parent 839db8c commit 3591ec1
Show file tree
Hide file tree
Showing 3 changed files with 598 additions and 254 deletions.
2 changes: 1 addition & 1 deletion docs/nnx/haiku_linen_vs_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ in ``__init__`` to scan over the sequence.
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=1
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)

return y
Expand Down
Loading

0 comments on commit 3591ec1

Please sign in to comment.