Skip to content

Commit

Permalink
rewrite axis_index implementation, use custom bind
Browse files Browse the repository at this point in the history
fixes #2716

Co-authored-by: Trevor Cai <tycai@google.com>
  • Loading branch information
mattjj and trevorcai committed Apr 23, 2020
1 parent 903010b commit 4729013
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,26 +383,27 @@ def axis_index(axis_name):
[0 1]
[0 1]]
"""
return axis_index_p.bind(axis_name=axis_name)

def _axis_index_bind(*, axis_name):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
frame = dynamic_axis_env[axis_name]
sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
nreps = dynamic_axis_env.nreps
dummy_arg = frame.pmap_trace.pure(core.unit)
if frame.soft_trace:
dummy_arg = frame.soft_trace.pure(dummy_arg)

return axis_index_p.bind(dummy_arg, nreps=nreps, sizes=sizes,
soft_size=frame.soft_size, axis_name=axis_name)
trace = frame.pmap_trace

def _axis_index_partial_eval(trace, _, **params):
# This partial_eval rule adds the axis_index primitive into the jaxpr formed
# during pmap lowering. It is like the standard JaxprTrace.process_primitive
# rule except that we don't attempt to lower out of the trace.
out_aval = ShapedArray((), onp.int32)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, params)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
dict(nreps=nreps, sizes=sizes,
soft_size=frame.soft_size, axis_name=axis_name))
out_tracer.recipe = eqn
return out_tracer

if not frame.soft_trace:
return out_tracer
else:
val_out = out_tracer * frame.soft_size + onp.arange(frame.soft_size)
return SplitAxisTracer(frame.soft_trace, axis_name, val_out)

def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):
div = c.Constant(onp.array(nreps // prod(sizes), dtype=onp.uint32))
Expand All @@ -411,8 +412,8 @@ def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):
return c.ConvertElementType(unsigned_index, xb.dtype_to_etype(onp.int32))

axis_index_p = core.Primitive('axis_index')
axis_index_p.def_custom_bind(_axis_index_bind)
xla.translations[axis_index_p] = _axis_index_translation_rule
pe.custom_partial_eval_rules[axis_index_p] = _axis_index_partial_eval


### lazy device-memory persistence and result handling
Expand Down

0 comments on commit 4729013

Please sign in to comment.