Skip to content

Commit

Permalink
Fix a faulty soft_pmap rule for axis_index
Browse files Browse the repository at this point in the history
The rule didn't specify the precision for the `np.arange` constant,
which caused an accidental dtype promotion in X64 mode. Previously the
error has luckicly been hidden behind a coerction that followed
`axis_index` in that test, but the new implementation has surfaced it.
  • Loading branch information
apaszke committed Sep 22, 2020
1 parent 332a9ba commit 8ac19c7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/lax/lax_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
assert not vals and not mapped
idx = axis_index(axis_name) # type: ignore
return idx * chunk_size + np.arange(chunk_size), True
return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True

axis_index_p = core.Primitive('axis_index')
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
Expand Down

0 comments on commit 8ac19c7

Please sign in to comment.