Skip to content

Commit

Permalink
fix psum transpose rule
Browse files Browse the repository at this point in the history
  • Loading branch information
jekbradbury committed May 1, 2020
1 parent 564e4a2 commit 268e0d1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/lax/lax_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _translate(val):
partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)
ad.deflinear(psum_p, lambda *ts, axis_name, axis_index_groups: psum(
ad.deflinear(psum_p, lambda ts, axis_name, axis_index_groups: psum_p.bind(
*ts, axis_name=axis_name, axis_index_groups=axis_index_groups))
pxla.multi_host_supported_collectives.add(psum_p)

Expand Down

0 comments on commit 268e0d1

Please sign in to comment.