From aadb50905c9f0bd7668e1d67c625b762234f9cb6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 4 Oct 2024 04:53:24 -0700 Subject: [PATCH] [pallas:mosaic_gpu] Allowed indexing refs with scalars The transforms do not yet handle this case, so only the basic indexing works. PiperOrigin-RevId: 682273046 --- jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/core.py | 23 +++++++---- jax/_src/pallas/mosaic_gpu/lowering.py | 51 ++++++++++++++---------- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- tests/pallas/mosaic_gpu_test.py | 41 +++++++++++++++++++ 5 files changed, 88 insertions(+), 30 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 668a5960ea28..91616948be49 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -77,6 +77,7 @@ pytype_strict_library( "//jax:mosaic_gpu", "//jax:tree_util", "//jax/_src/pallas", + "//jaxlib/mlir:ir", ] + py_deps("numpy"), ) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 30f28044e861..9156811312f4 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -23,10 +23,11 @@ from jax._src import core as jax_core from jax._src import dtypes from jax._src import tree_util -from jax._src.state.types import Transform from jax._src.pallas import core as pallas_core +from jax._src.state.types import Transform import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp +from jaxlib.mlir import ir AbstractMemoryRef = pallas_core.AbstractMemoryRef @@ -75,6 +76,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: shape=self.to_gpu_transform().transform_shape(aval.shape) ) +Index = slice | int | ir.Value @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -114,11 +116,14 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype - def untransform_index(self, idxs: tuple[slice, ...]) -> tuple[slice, ...]: + def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]: + if not all(isinstance(idx, slice) for idx in idxs): + raise NotImplementedError("Non-slice indices are not supported") untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] idxs_after_tiling = [] for idx, tile in zip(tiled_idxs, self.tiling): + assert isinstance(idx, slice) if idx.step is not None and idx.step != 1: raise NotImplementedError("Strided slices unsupported") if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): @@ -177,7 +182,7 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype - def untransform_index(self, idxs: tuple[slice, ...]) -> tuple[slice, ...]: + def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]: return tuple(idxs[i] for i in _perm_inverse(self.permutation)) def tree_flatten(self): @@ -223,13 +228,17 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: class UnswizzleRef(Transform): swizzle: int - def untransform_index(self, idxs: tuple[slice, ...]) -> tuple[slice, ...]: + def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]: if not idxs: return idxs - if idxs[-1].step is not None and idxs[-1].step != 1: + if not all(isinstance(idx, slice) for idx in idxs): + raise NotImplementedError("Non-slice indices are not supported") + last_idx = idxs[-1] + assert isinstance(last_idx, slice) + if last_idx.step is not None and last_idx.step != 1: raise NotImplementedError("Swizzled dims cannot be sliced") - if (idxs[-1].start is not None and idxs[-1].start != 0) or ( - idxs[-1].stop is not None and idxs[-1].stop != self.swizzle + if (last_idx.start is not None and last_idx.start != 0) or ( + last_idx.stop is not None and last_idx.stop != self.swizzle ): raise ValueError("Swizzled dims cannot be sliced") return idxs diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ed48d3fd9f2e..20ba3793181d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -732,21 +732,22 @@ def _handle_indexing( indexer = cast(indexing.NDIndexer, transforms[-1]) if indexer.int_indexer_shape: raise NotImplementedError("int_indexer_shape non-empty") - slices = _ndindexer_slices(indexer) + indices = _ndindexer_indices(indexer) for t in reversed(transforms[:-1]): - slices = t.untransform_index(slices) - return mgpu.memref_slice(ref, slices) + indices = t.untransform_index(indices) + return mgpu.memref_slice(ref, indices) -def _ndindexer_slices(indexer: indexing.NDIndexer) -> tuple[slice, ...]: - slices = [] - for s in indexer.indices: - if not isinstance(s, indexing.Slice): - raise NotImplementedError(f"Unsupported dimension index: {s}") - if s.is_dynamic_start or s.is_dynamic_size: - raise NotImplementedError(f"Unsupported slice: {s}") - slices.append(slice(s.start, s.start + s.size, s.stride)) - return tuple(slices) +def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]: + indices = [] + for idx in indexer.indices: + if isinstance(idx, indexing.Slice): + if idx.is_dynamic_start or idx.is_dynamic_size: + raise NotImplementedError(f"Unsupported slice: {idx}") + indices.append(slice(idx.start, idx.start + idx.size, idx.stride)) + else: + indices.append(_as_index(idx)) + return tuple(indices) def _is_swizzled(transforms: tuple[gpu_core.Transform, ...]) -> int | None: @@ -787,10 +788,10 @@ def _swap_lowering_rule( raise TypeError(f"Can only store arrays (got {value}).") if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only store to references (got {x_smem}).") - transform = jax.tree.unflatten(tree, leaves) - swizzle = _is_swizzled(transform) - x_smem = _handle_indexing(x_smem, transform) - x_aval, _ = ctx.avals_in + transforms = jax.tree.unflatten(tree, leaves) + swizzle = _is_swizzled(transforms) + x_smem = _handle_indexing(x_smem, transforms) + x_aval = ctx.avals_in[0] if swizzle is None: old_value = mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) @@ -1178,9 +1179,15 @@ def _i64_constant(v: int) -> ir.Value: return arith_dialect.constant(ir.IntegerType.get_signless(64), v) -def _as_index(v: int | ir.Value) -> ir.Value: - if isinstance(v, int): - return arith_dialect.constant(ir.IndexType.get(), v) - if ir.IndexType.isinstance(v.type): - return v - return arith_dialect.index_cast(ir.IndexType.get(), v) +def _as_index(v: object) -> ir.Value: + match v: + case int(): + return arith_dialect.constant(ir.IndexType.get(), v) + case ir.Value() if ir.IndexType.isinstance(v.type): + return v + case ir.Value() if ir.IntegerType.isinstance(v.type): + return arith_dialect.index_cast(ir.IndexType.get(), v) + case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()): + return _as_index(v.registers.item()) + case _: + raise ValueError(f"Unsupported index: {v}") diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 1a9995ee1d91..7232a74d4cb6 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -77,7 +77,7 @@ def _extract_copy_params(transforms): transforms = transforms[1:] gpu_transforms = [t.to_gpu_transform() for t in transforms] return dict( - gmem_slice=lowering._ndindexer_slices(indexer), + gmem_slice=lowering._ndindexer_indices(indexer), gmem_transform=tuple(gpu_transforms), swizzle=swizzle, ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 917529d3aff7..83d9af060006 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -60,6 +60,18 @@ def kernel(x_ref, o_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), unary(x)) + def test_add_first(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[0] + + x = jnp.arange(256).astype(jnp.float32) + y = jnp.flip(x).reshape(1, 256) + np.testing.assert_array_equal(kernel(x, y), x + y[0]) + def test_add_xy(self): @functools.partial( pl.pallas_call, @@ -72,6 +84,19 @@ def kernel(x_ref, y_ref, o_ref): y = x + 1 np.testing.assert_array_equal(kernel(x, y), x + y) + def test_add_xy_indexed(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + idx = jnp.sum(y_ref[...]) + o_ref[...] = x_ref[idx] + + x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32) + y = jnp.zeros(128, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)]) + def test_add_one_grid(self): @functools.partial( pl.pallas_call, @@ -469,6 +494,22 @@ def kernel(o_ref): kernel(), jnp.full([256], 5.0, dtype=jnp.float32) ) + def test_fori_loop_indexed_store(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + def body(idx, _): + o_ref[idx] = x_ref[idx] + y_ref[idx] + return () + + jax.lax.fori_loop(0, 4, body, ()) + + x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32) + y = x + 1 + np.testing.assert_array_equal(kernel(x, y), x + y) + def test_cond(self): @functools.partial(