Skip to content

Commit

Permalink
Fix pjit's initial style usage of consts
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#13589 from jakevdp:cond-doc c9c6263
PiperOrigin-RevId: 499899435
  • Loading branch information
yashk2810 authored and jax authors committed Jan 6, 2023
1 parent fc04c71 commit 373b0d4
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 14 deletions.
13 changes: 11 additions & 2 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,25 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Wraps XLA's `Conditional
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
operator.
Provided arguments are correctly typed, ``cond()`` has equivalent
semantics to this Python implementation::
semantics to this Python implementation, where ``pred`` must be a
scalar type::
def cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
``pred`` must be a scalar type.
In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
the two branches is executed (up to compiler rewrites and optimizations).
However, when transformed with :func:`~jax.vmap` to operate over a batch of
predicates, ``cond`` is converted to :func:`~jax.lax.select`.
Args:
pred: Boolean scalar type, indicating which branch function to apply.
Expand Down
18 changes: 17 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,9 +908,25 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
return rev_p.bind(operand, dimensions=tuple(dimensions))

def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
"""Wraps XLA's `Select
"""Selects between two branches based on a boolean predicate.
Wraps XLA's `Select
<https://www.tensorflow.org/xla/operation_semantics#select>`_
operator.
In general :func:`~jax.lax.select` leads to evaluation of both branches, although
the compiler may elide computations if possible. For a similar function that
usually evaluates only a single branch, see :func:`~jax.lax.cond`.
Args:
pred: boolean array
on_true: array containing entries to return where ``pred`` is True. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_false``.
on_false: array containing entries to return where ``pred`` is False. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_true``.
Returns:
result: array with same shape and dtype as ``on_true`` and ``on_false``.
"""
# Caution! The select_n_p primitive has the *opposite* order of arguments to
# select(). This is because it implements `select_n`.
Expand Down
29 changes: 20 additions & 9 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,22 @@ def infer_params(*args, **kwargs):
hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
tuple(isinstance(a, GDA) for a in args_flat), resource_env)

jaxpr, canonicalized_out_shardings_flat = _pjit_jaxpr(
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), global_in_avals,
HashableFunction(out_tree, closure=()))

if (any(_is_from_gda(i) for i in canonicalized_in_shardings_flat) or
not config.jax_array):
canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec(
canonicalized_in_shardings_flat, args_flat)

assert len(args_flat) == len(canonicalized_in_shardings_flat)
canonicalized_in_shardings_flat = (
_UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat
donated_invars = (False,) * len(consts) + donated_invars
in_positional_semantics = (
pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics

# in_shardings and out_shardings here are all OpShardingSharding.
params = dict(
jaxpr=jaxpr,
Expand All @@ -471,7 +479,7 @@ def infer_params(*args, **kwargs):
keep_unused=keep_unused,
inline=inline,
)
return (args_flat, local_in_avals, params, in_tree, out_tree(),
return (consts + args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)

if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115:
Expand Down Expand Up @@ -699,7 +707,9 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
finally:
pxla._positional_semantics.val = prev_positional_val
jaxpr = core.ClosedJaxpr(jaxpr, consts)

jaxpr = pe.convert_constvars_jaxpr(jaxpr)
jaxpr = pe.close_jaxpr(jaxpr)

out_shardings_flat = flatten_axis_resources(
"pjit out_axis_resources", out_tree(), out_shardings_thunk(), tupled_args=False)
Expand All @@ -713,7 +723,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):
)

# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return _ListWithW([jaxpr, canonicalized_out_shardings_flat])
return _ListWithW([jaxpr, consts, canonicalized_out_shardings_flat])


def pjit_check_aval_sharding(
Expand Down Expand Up @@ -981,7 +991,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
'multiple devices is not supported.')
else:
if isinstance(arg, np.ndarray) and not pxla.is_op_sharding_replicated(
pjit_in_s._to_xla_op_sharding(arg.ndim)) and xb.process_count() > 1:
pjit_in_s._to_xla_op_sharding(arg.ndim)) and xb.process_count() > 1: # type: ignore
raise ValueError(
'When jax.Array is enabled, passing non-trivial shardings for numpy '
'inputs is not allowed. To fix this error, either specify a '
Expand All @@ -999,7 +1009,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
if (committed and
not isinstance(arg_s, PmapSharding) and
not pxla.are_op_shardings_equal(
pjit_in_s._to_xla_op_sharding(arg.ndim),
pjit_in_s._to_xla_op_sharding(arg.ndim), # type: ignore
arg_s._to_xla_op_sharding(arg.ndim))):
op = getattr(pjit_in_s, '_original_sharding', pjit_in_s)
raise ValueError('Sharding passed to pjit does not match the sharding '
Expand Down Expand Up @@ -1179,9 +1189,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
in_positional_semantics, out_positional_semantics,
keep_unused, inline):
if not isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
if not config.jax_array:
if not isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")

output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
flat_output_types = util.flatten(output_types)
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/public_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from functools import partial
import operator

import jax
from jax import config
from jax.tree_util import tree_map, tree_reduce
from jax._src import api
Expand Down Expand Up @@ -223,7 +224,13 @@ def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
_rand_like = partial(rand_like, np.random.RandomState(0))
v_out, vjpfun = f_vjp(*args)
v_out_expected = f(*args)
try:
v_out_expected = f(*args)
except:
print(f.lower(*args).compile().as_text())
print(len(args))
print(jax.make_jaxpr(f)(*args))
raise
check_close(v_out, v_out_expected, atol=atol, rtol=rtol,
err_msg=f'{err_msg} primal' if err_msg else 'primal')
tangent = tree_map(_rand_like, args)
Expand Down
3 changes: 3 additions & 0 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2859,6 +2859,8 @@ def lower_sharding_computation(
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

map(dispatch.prefetch, it.chain(consts, dispatch.jaxpr_literals(jaxpr)))

# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
Expand Down Expand Up @@ -2927,6 +2929,7 @@ def lower_sharding_computation(
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in closed_jaxpr.effects
if eff in core.ordered_effects]
print('closed_jaxpr', closed_jaxpr)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
Expand Down
15 changes: 14 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2177,7 +2177,7 @@ def test_grad_of_pjit_single_device_sharding(self):
def test_autodiff_with_single_device_sharding(self):
# Add a constant captured by the nested pjit to make things more complicated
h = jnp.arange(4.)
f = pjit(lambda x: x.sum(1) * h.sum())
f = lambda x: x.sum(1) * h.sum()
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)))
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)

Expand Down Expand Up @@ -3069,6 +3069,19 @@ def test_pmap_sharding_input_pjit_in_axis_resources(self):
self.assertArraysEqual(out, pmap_out * 2)
self.assertLen(out.devices(), 4)

def test_nested_pjit_closing_over_tracer(self):
@pjit
def f(x):
y = jnp.float32(2) * x

@pjit
def g(z):
return jax.pmap(lambda x: x[jnp.newaxis] * y)(z)

return g(x)

f(np.arange(1., dtype='float32').reshape((1, 1))) # doesn't crash


class TempSharding(Sharding):

Expand Down

0 comments on commit 373b0d4

Please sign in to comment.