diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 6233b35cfebc..242a207461a8 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -341,7 +341,7 @@ def _translate(val): partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) -ad.deflinear(psum_p, lambda *ts, axis_name, axis_index_groups: psum( +ad.deflinear(psum_p, lambda ts, axis_name, axis_index_groups: psum_p.bind( *ts, axis_name=axis_name, axis_index_groups=axis_index_groups)) pxla.multi_host_supported_collectives.add(psum_p)