diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 5bba512bf0e4..3c5df91bcb6b 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -155,8 +155,13 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, operand=_no_operand_sentinel, linear=None): """Conditionally apply ``true_fun`` or ``false_fun``. + Wraps XLA's `Conditional + `_ + operator. + Provided arguments are correctly typed, ``cond()`` has equivalent - semantics to this Python implementation:: + semantics to this Python implementation, where ``pred`` must be a + scalar type:: def cond(pred, true_fun, false_fun, *operands): if pred: @@ -164,7 +169,11 @@ def cond(pred, true_fun, false_fun, *operands): else: return false_fun(*operands) - ``pred`` must be a scalar type. + + In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of + the two branches is executed (up to compiler rewrites and optimizations). + However, when transformed with :func:`~jax.vmap` to operate over a batch of + predicates, ``cond`` is converted to :func:`~jax.lax.select`. Args: pred: Boolean scalar type, indicating which branch function to apply. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d181efb8e404..ba12477e8a28 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -908,9 +908,25 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: return rev_p.bind(operand, dimensions=tuple(dimensions)) def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: - """Wraps XLA's `Select + """Selects between two branches based on a boolean predicate. + + Wraps XLA's `Select `_ operator. + + In general :func:`~jax.lax.select` leads to evaluation of both branches, although + the compiler may elide computations if possible. For a similar function that + usually evaluates only a single branch, see :func:`~jax.lax.cond`. + + Args: + pred: boolean array + on_true: array containing entries to return where ``pred`` is True. Must have + the same shape as ``pred``, and the same shape and dtype as ``on_false``. + on_false: array containing entries to return where ``pred`` is False. Must have + the same shape as ``pred``, and the same shape and dtype as ``on_true``. + + Returns: + result: array with same shape and dtype as ``on_true`` and ``on_false``. """ # Caution! The select_n_p primitive has the *opposite* order of arguments to # select(). This is because it implements `select_n`. diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f461f1c3876c..dceeb4df5530 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -450,7 +450,7 @@ def infer_params(*args, **kwargs): hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics, tuple(isinstance(a, GDA) for a in args_flat), resource_env) - jaxpr, canonicalized_out_shardings_flat = _pjit_jaxpr( + jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( flat_fun, hashable_pytree(out_shardings), global_in_avals, HashableFunction(out_tree, closure=())) @@ -458,6 +458,14 @@ def infer_params(*args, **kwargs): not config.jax_array): canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec( canonicalized_in_shardings_flat, args_flat) + + assert len(args_flat) == len(canonicalized_in_shardings_flat) + canonicalized_in_shardings_flat = ( + _UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat + donated_invars = (False,) * len(consts) + donated_invars + in_positional_semantics = ( + pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics + # in_shardings and out_shardings here are all OpShardingSharding. params = dict( jaxpr=jaxpr, @@ -471,7 +479,7 @@ def infer_params(*args, **kwargs): keep_unused=keep_unused, inline=inline, ) - return (args_flat, local_in_avals, params, in_tree, out_tree(), + return (consts + args_flat, local_in_avals, params, in_tree, out_tree(), donate_argnums) if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115: @@ -699,7 +707,9 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree): jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals) finally: pxla._positional_semantics.val = prev_positional_val - jaxpr = core.ClosedJaxpr(jaxpr, consts) + + jaxpr = pe.convert_constvars_jaxpr(jaxpr) + jaxpr = pe.close_jaxpr(jaxpr) out_shardings_flat = flatten_axis_resources( "pjit out_axis_resources", out_tree(), out_shardings_thunk(), tupled_args=False) @@ -713,7 +723,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree): ) # lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple - return _ListWithW([jaxpr, canonicalized_out_shardings_flat]) + return _ListWithW([jaxpr, consts, canonicalized_out_shardings_flat]) def pjit_check_aval_sharding( @@ -981,7 +991,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh): 'multiple devices is not supported.') else: if isinstance(arg, np.ndarray) and not pxla.is_op_sharding_replicated( - pjit_in_s._to_xla_op_sharding(arg.ndim)) and xb.process_count() > 1: + pjit_in_s._to_xla_op_sharding(arg.ndim)) and xb.process_count() > 1: # type: ignore raise ValueError( 'When jax.Array is enabled, passing non-trivial shardings for numpy ' 'inputs is not allowed. To fix this error, either specify a ' @@ -999,7 +1009,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh): if (committed and not isinstance(arg_s, PmapSharding) and not pxla.are_op_shardings_equal( - pjit_in_s._to_xla_op_sharding(arg.ndim), + pjit_in_s._to_xla_op_sharding(arg.ndim), # type: ignore arg_s._to_xla_op_sharding(arg.ndim))): op = getattr(pjit_in_s, '_original_sharding', pjit_in_s) raise ValueError('Sharding passed to pjit does not match the sharding ' @@ -1179,9 +1189,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, in_positional_semantics, out_positional_semantics, keep_unused, inline): - if not isinstance(ctx.module_context.axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)): - raise RuntimeError("Nesting pjit() inside jit() is not allowed.") + if not config.jax_array: + if not isinstance(ctx.module_context.axis_context, + (mlir.SPMDAxisContext, mlir.ShardingContext)): + raise RuntimeError("Nesting pjit() inside jit() is not allowed.") output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index eb956f3c2902..6feaefeb9816 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -15,6 +15,7 @@ from functools import partial import operator +import jax from jax import config from jax.tree_util import tree_map, tree_reduce from jax._src import api @@ -223,7 +224,13 @@ def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''): rtol = _merge_tolerance(rtol, default_gradient_tolerance) _rand_like = partial(rand_like, np.random.RandomState(0)) v_out, vjpfun = f_vjp(*args) - v_out_expected = f(*args) + try: + v_out_expected = f(*args) + except: + print(f.lower(*args).compile().as_text()) + print(len(args)) + print(jax.make_jaxpr(f)(*args)) + raise check_close(v_out, v_out_expected, atol=atol, rtol=rtol, err_msg=f'{err_msg} primal' if err_msg else 'primal') tangent = tree_map(_rand_like, args) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 5766a34e83d9..8a7c6b67077d 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2859,6 +2859,8 @@ def lower_sharding_computation( has_outfeed = core.jaxpr_uses_outfeed(jaxpr) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) + map(dispatch.prefetch, it.chain(consts, dispatch.jaxpr_literals(jaxpr))) + # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to evaluate their arguments. @@ -2927,6 +2929,7 @@ def lower_sharding_computation( if eff not in core.ordered_effects] ordered_effects = [eff for eff in closed_jaxpr.effects if eff in core.ordered_effects] + print('closed_jaxpr', closed_jaxpr) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 62795dc5a6db..cbda439d549d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2177,7 +2177,7 @@ def test_grad_of_pjit_single_device_sharding(self): def test_autodiff_with_single_device_sharding(self): # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4.) - f = pjit(lambda x: x.sum(1) * h.sum()) + f = lambda x: x.sum(1) * h.sum() g = pjit(lambda x: f(jnp.sin(x * 4 + 2))) jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2) @@ -3069,6 +3069,19 @@ def test_pmap_sharding_input_pjit_in_axis_resources(self): self.assertArraysEqual(out, pmap_out * 2) self.assertLen(out.devices(), 4) + def test_nested_pjit_closing_over_tracer(self): + @pjit + def f(x): + y = jnp.float32(2) * x + + @pjit + def g(z): + return jax.pmap(lambda x: x[jnp.newaxis] * y)(z) + + return g(x) + + f(np.arange(1., dtype='float32').reshape((1, 1))) # doesn't crash + class TempSharding(Sharding):