diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 34395756f25a..f89e4d53a476 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -38,6 +38,7 @@ cond_p as cond_p, switch as switch, platform_dependent as platform_dependent, + platform_index_p as platform_index_p, ) from jax._src.lax.control_flow.solves import ( custom_linear_solve as custom_linear_solve, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index b618339fafc0..db0a1f4dc4ee 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -945,6 +945,7 @@ def other_platforms_code(*args): ... platform_index = platform_index_p.bind( platforms=tuple(tuple(ps) for ps in platforms_lists), has_default=(default is not None)) + if default is not None: branches = branches + (default,) # Use a switch, to get the proper transformation rules for free. Since @@ -957,6 +958,8 @@ def other_platforms_code(*args): ... # recognized on the compilation platform. Detect eager mode and keep only the # needed branch. try: + # Note/TODO(mvoz): This actually rarely seems to concretize - we could look into + # core.ensure_compile_time_eval to get better single-branch selection. platform_index_concrete = core.concrete_or_error(operator.index, platform_index) except core.ConcretizationTypeError: return switch(platform_index, branches, *args) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c907640a985..070f6bfda30a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8341,18 +8341,41 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, Array([4, 8], dtype=int32) """ util.check_arraylike("diagonal", a) - a_shape = shape(a) + if ndim(a) < 2: raise ValueError("diagonal requires an array of at least two dimensions.") offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()") - a = moveaxis(a, (axis1, axis2), (-2, -1)) + def _default_diag(a): + a_shape = shape(a) + + a = moveaxis(a, (axis1, axis2), (-2, -1)) - diag_size = max(0, min(a_shape[axis1] + min(offset, 0), - a_shape[axis2] - max(offset, 0))) - i = arange(diag_size) - j = arange(abs(offset), abs(offset) + diag_size) - return a[..., i, j] if offset >= 0 else a[..., j, i] + diag_size = max( + 0, min(a_shape[axis1] + min(offset, 0), a_shape[axis2] - max(offset, 0)) + ) + i = arange(diag_size) + j = arange(abs(offset), abs(offset) + diag_size) + return a[..., i, j] if offset >= 0 else a[..., j, i] + + + # The mosaic lowering rule for diag is only defined for square arrays. + # TODO(mvoz): Add support for offsets. + if shape(a)[0] != shape(a)[1] or ndim(a) != 2 or offset != 0: + return _default_diag(a) + else: + a_shape_eye = eye(shape(a)[0], dtype=_dtype(a)) + + def _mosaic_diag(a): + def _sum(x, axis): + return lax.reduce( + x, + np.array(0, _dtype(x)), + lax.add if _dtype(x) != bool_ else lax.bitwise_or, + (axis,), + ) + return _sum(lax.mul(a_shape_eye, a), axis=0) + return lax.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag) @export diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 70a1d71f5712..08a337d80f1a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -547,9 +547,13 @@ def lower_jaxpr_to_module( module_name = name_and_src_info.name attrs["sym_name"] = ir.StringAttr.get(module_name) sym_tab = ir.SymbolTable(m.operation) + func_op = lower_jaxpr_to_func( - ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping, - name="main", for_verification=for_verification, + ctx, + jaxpr, + mosaic_grid_mapping=mosaic_grid_mapping, + name="main", + for_verification=for_verification, ) m.body.append(func_op) sym_tab.insert(func_op) @@ -568,6 +572,7 @@ def lower_jaxpr_to_module( # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) continue + mlir_func = lower_jaxpr_to_transform_func( ctx, bm.index_map_jaxpr.jaxpr, @@ -1990,6 +1995,36 @@ def _add_lowering_rule(ctx: LoweringRuleContext, x, y): skip_mlir_conversions.add(ad_util.add_any_p) +class FoldingError(Exception): + pass + + +def _fold_and_get_constant_value(x): + def _fold(x, fuel): + if fuel <= 0: + raise FoldingError("Folding depth exceeded") + op_name = getattr(x.owner, "name", None) + binop_folds = { + "arith.maxsi": max, + "arith.minsi": min, + } + if op_name == "arith.constant": + if ir.IntegerType.isinstance(x.type): + return ir.IntegerAttr(x.owner.attributes["value"]).value + elif ir.FloatType.isinstance(x.type): + return ir.FloatAttr(x.owner.attributes["value"]).value + else: + raise ValueError(f"Unsupported constant type: {x.type}") + if op_name in binop_folds: + return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands) + raise FoldingError(f"Folding not supported for {x.owner}") + + try: + return _fold(x, 10) + except FoldingError: + return None + + def _max_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2708,6 +2743,12 @@ def _while_lowering_rule( def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): index, *args = args + constant_index = _fold_and_get_constant_value(index) + + if constant_index is not None: + return jaxpr_subcomp( + ctx.lowering_context.replace(block_shapes=ctx.block_shapes[1:]), branches[constant_index].jaxpr, *args + ) out_types = map(aval_to_ir_type, ctx.avals_out) pred = arith.cmpi( arith.CmpIPredicate.ne, index, ir_constant(0, index.type) @@ -3375,3 +3416,25 @@ def _pad(val): lowering_rules[lax.pad_p] = _pad_lowering_rule + + +def _platform_index_lowering( + ctx: mlir.LoweringRuleContext, + *, + platforms: Sequence[Sequence[str]], + has_default: bool, +): + for i, ps in enumerate(platforms): + # note - slightly odd structure here, as platforms is a seq[seq[str]] + if "mosaic" in ps: + return ir_constant(i) + + if has_default: + return ir_constant(len(platforms)) + + raise NotImplementedError( + "No mosaic or default platform indexing rule found." + ) + + +lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d848bc437df9..0210f75d7873 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -2127,6 +2127,18 @@ def kernel(x_ref, o_ref): ) self.assertTrue(acceptable_errors, "Failed with error: " + str(e)) + @parameterized.parameters((128, 128), (256, 256)) + def test_jnp_diagonal_pallas(self, n, m): + x = jnp.arange(n * m, dtype=jnp.float32).reshape((n, m)) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.diagonal(x_ref[...]) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((n,), jnp.float32) + )(x) + np.testing.assert_array_equal(out, np.diagonal(x)) + class OpsInterpretTest(OpsTest): INTERPRET = True