Skip to content

Commit

Permalink
Avoid index out of range error in carry structure check
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX authored and hawkinsp committed Dec 9, 2024
1 parent 259194a commit 65b6088
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.36
## jax 0.4.37

* Bug fixes
* Fix a bug that will throw `index out of range` error in
{func}`jax.lax.while_loop` if the user register pytree node class with
different aux data for the flatten and flatten_with_path.

## jax 0.4.36 (Dec 5, 2024)

* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'of the carry output is a {thing2}, so {explanation}'
for path, thing1, thing2, explanation
in equality_errors(in_carry, out_carry)]
if len(diffs) == 0:
# The trees may have different aux data but structures are the same.
return
if len(diffs) == 1:
differences = f'{diffs[0]}.\n'.capitalize()
else:
Expand All @@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval)]
if len(diffs) == 0:
# The trees may have different aux data but structures are the same.
return
if len(diffs) == 1:
differences = f'{diffs[0]}.\n'.capitalize()
else:
Expand Down
13 changes: 13 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,19 @@ def testWhileTypeErrors(self):
lax.while_loop(lambda c: True, lambda c: (True, True),
(np.bool_(True), np.float32(0.)))

def testWhileLoopCustomPytreeDiffAuxData(self):
class Node:
def __init__(self, x, y):
self.x = x
self.y = y
tree_util.register_pytree_with_keys(
Node,
lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys
lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved)
lambda o: ((o.x, o.y), 'without_keys'), # flatten
)
lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.))

def testNestedWhileWithDynamicUpdateSlice(self):
num = 5

Expand Down

0 comments on commit 65b6088

Please sign in to comment.