Skip to content

Commit

Permalink
[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch …
Browse files Browse the repository at this point in the history
…selection driven by constant prop in mosaic lowering.

This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.

This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.

And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.

PiperOrigin-RevId: 693594640
  • Loading branch information
Google-ML-Automation committed Dec 17, 2024
1 parent 7fe2579 commit d6e7194
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 9 deletions.
1 change: 1 addition & 0 deletions jax/_src/lax/control_flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
cond_p as cond_p,
switch as switch,
platform_dependent as platform_dependent,
platform_index_p as platform_index_p,
)
from jax._src.lax.control_flow.solves import (
custom_linear_solve as custom_linear_solve,
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,7 @@ def other_platforms_code(*args): ...
platform_index = platform_index_p.bind(
platforms=tuple(tuple(ps) for ps in platforms_lists),
has_default=(default is not None))

if default is not None:
branches = branches + (default,)
# Use a switch, to get the proper transformation rules for free. Since
Expand All @@ -946,6 +947,8 @@ def other_platforms_code(*args): ...
# recognized on the compilation platform. Detect eager mode and keep only the
# needed branch.
try:
# Note/TODO(mvoz): This actually rarely seems to concretize - we could look into
# core.ensure_compile_time_eval to get better single-branch selection.
platform_index_concrete = core.concrete_or_error(operator.index, platform_index)
except core.ConcretizationTypeError:
return switch(platform_index, branches, *args)
Expand Down
44 changes: 37 additions & 7 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8342,18 +8342,48 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
Array([4, 8], dtype=int32)
"""
util.check_arraylike("diagonal", a)
a_shape = shape(a)

if ndim(a) < 2:
raise ValueError("diagonal requires an array of at least two dimensions.")
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")

a = moveaxis(a, (axis1, axis2), (-2, -1))
def _default_diag(a):
a_shape = shape(a)

a = moveaxis(a, (axis1, axis2), (-2, -1))

diag_size = max(
0, min(a_shape[axis1] + min(offset, 0), a_shape[axis2] - max(offset, 0))
)
i = arange(diag_size)
j = arange(abs(offset), abs(offset) + diag_size)
return a[..., i, j] if offset >= 0 else a[..., j, i]

def _mosaic_diag(a):
def _sum(x, axis):
return lax.reduce(
x,
np.array(0, x.dtype),
lax.add if x.dtype != bool_ else lax.bitwise_or,
(axis,),
)

diag_size = max(0, min(a_shape[axis1] + min(offset, 0),
a_shape[axis2] - max(offset, 0)))
i = arange(diag_size)
j = arange(abs(offset), abs(offset) + diag_size)
return a[..., i, j] if offset >= 0 else a[..., j, i]
if a.shape[0] != a.shape[1]:
# This is a hack, because there are cases where we cannot determine
# if we are in mosaic or not - so we cannot skip tracing all potential
# paths when we make the jax cond beneath platform_dependent.
#
# On non mosaic - this will be a no-op.
#
# On mosaic - this will be a failure with unimplemented op error.
return _default_diag(a)

a_shape_eye = eye(a.shape[0])
original_a_dtype = a.dtype
a_shape_eye, a = util.promote_dtypes(a_shape_eye, a)
return _sum(lax.mul(a_shape_eye, a), axis=0).astype(original_a_dtype)

return lax.platform_dependent(a, mosaic=_mosaic_diag, default=_default_diag)


@export
Expand Down
66 changes: 64 additions & 2 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class LoweringContext:
def grid_rank(self):
return len(self.grid_sizes)


@contextlib.contextmanager
def grid_name_context(self):
# TODO(b/355036977): generalize this across other platforms
Expand Down Expand Up @@ -547,9 +548,13 @@ def lower_jaxpr_to_module(
module_name = name_and_src_info.name
attrs["sym_name"] = ir.StringAttr.get(module_name)
sym_tab = ir.SymbolTable(m.operation)

func_op = lower_jaxpr_to_func(
ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
name="main", for_verification=for_verification,
ctx,
jaxpr,
mosaic_grid_mapping=mosaic_grid_mapping,
name="main",
for_verification=for_verification,
)
m.body.append(func_op)
sym_tab.insert(func_op)
Expand All @@ -568,6 +573,7 @@ def lower_jaxpr_to_module(
# We checked above that the block does not require windowing.
window_params.append(ir.DictAttr.get())
continue

mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
Expand Down Expand Up @@ -1986,6 +1992,36 @@ def _add_lowering_rule(ctx: LoweringRuleContext, x, y):
skip_mlir_conversions.add(ad_util.add_any_p)


class FoldingError(Exception):
pass


def _fold_and_get_constant_value(x):
def _fold(x, fuel):
if fuel <= 0:
raise FoldingError("Folding depth exceeded")
op_name = getattr(x.owner, "name", None)
binop_folds = {
"arith.maxsi": max,
"arith.minsi": min,
}
if op_name == "arith.constant":
if ir.IntegerType.isinstance(x.type):
return ir.IntegerAttr(x.owner.attributes["value"]).value
elif ir.FloatType.isinstance(x.type):
return ir.FloatAttr(x.owner.attributes["value"]).value
else:
raise ValueError(f"Unsupported constant type: {x.type}")
if op_name in binop_folds:
return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands)
raise NotImplementedError(f"Folding not supported for {x.owner}")

try:
return _fold(x, 10)
except FoldingError:
return None


def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
(aval_out,) = ctx.avals_out
Expand Down Expand Up @@ -2702,6 +2738,11 @@ def _while_lowering_rule(

def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
index, *args = args
constant_index = _fold_and_get_constant_value(index)
if constant_index is not None:
return jaxpr_subcomp(
ctx.lowering_context, branches[constant_index].jaxpr, *args
)
out_types = map(aval_to_ir_type, ctx.avals_out)
pred = arith.cmpi(
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
Expand Down Expand Up @@ -3369,3 +3410,24 @@ def _pad(val):


lowering_rules[lax.pad_p] = _pad_lowering_rule


def _platform_index_lowering(
ctx: mlir.LoweringRuleContext,
*,
platforms: Sequence[Sequence[str]],
has_default: bool,
):
for i, ps in enumerate(platforms):
# note - slightly odd structure here, as platforms is a seq[seq[str]]
if ps == ("mosaic",):
return ir_constant(i)

if has_default:
return ir_constant(len(platforms))

raise NotImplementedError(
"No mosaic or default platform indexing rule found."
)

lowering_rules[lax.platform_index_p] = _platform_index_lowering
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@
map as map,
scan as scan,
scan_p as scan_p,
platform_index_p as platform_index_p,
switch as switch,
while_loop as while_loop,
while_p as while_p,
Expand Down
12 changes: 12 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,6 +2101,18 @@ def kernel(x_ref, o_ref):
)
self.assertTrue(acceptable_errors, "Failed with error: " + str(e))

@parameterized.parameters((128, 128), (256, 256))
def test_jnp_diagonal_pallas(self, n, m):
x = jnp.arange(n * m, dtype=jnp.float32).reshape((n, m))

def kernel(x_ref, out_ref):
out_ref[...] = jnp.diagonal(x_ref[...])

out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((n,), jnp.float32)
)(x)
np.testing.assert_array_equal(out, np.diagonal(x))


class OpsInterpretTest(OpsTest):
INTERPRET = True
Expand Down

0 comments on commit d6e7194

Please sign in to comment.