Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Allowed indexing refs with scalars
Browse files Browse the repository at this point in the history
The transforms do not yet handle this case, so only the basic indexing works.

PiperOrigin-RevId: 682273046
  • Loading branch information
superbobry authored and Google-ML-Automation committed Oct 4, 2024
1 parent ad6604d commit aadb509
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 30 deletions.
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic_gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pytype_strict_library(
"//jax:mosaic_gpu",
"//jax:tree_util",
"//jax/_src/pallas",
"//jaxlib/mlir:ir",
] + py_deps("numpy"),
)

Expand Down
23 changes: 16 additions & 7 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import tree_util
from jax._src.state.types import Transform
from jax._src.pallas import core as pallas_core
from jax._src.state.types import Transform
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
from jaxlib.mlir import ir


AbstractMemoryRef = pallas_core.AbstractMemoryRef
Expand Down Expand Up @@ -75,6 +76,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
shape=self.to_gpu_transform().transform_shape(aval.shape)
)

Index = slice | int | ir.Value

@dataclasses.dataclass(frozen=True)
class TilingTransform(MemoryRefTransform):
Expand Down Expand Up @@ -114,11 +116,14 @@ def transform_shape(self, shape):
def transform_dtype(self, dtype):
return dtype

def untransform_index(self, idxs: tuple[slice, ...]) -> tuple[slice, ...]:
def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]:
if not all(isinstance(idx, slice) for idx in idxs):
raise NotImplementedError("Non-slice indices are not supported")
untiled_idxs = idxs[: -len(self.tiling)]
tiled_idxs = idxs[-len(self.tiling) :]
idxs_after_tiling = []
for idx, tile in zip(tiled_idxs, self.tiling):
assert isinstance(idx, slice)
if idx.step is not None and idx.step != 1:
raise NotImplementedError("Strided slices unsupported")
if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile):
Expand Down Expand Up @@ -177,7 +182,7 @@ def transform_shape(self, shape):
def transform_dtype(self, dtype):
return dtype

def untransform_index(self, idxs: tuple[slice, ...]) -> tuple[slice, ...]:
def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]:
return tuple(idxs[i] for i in _perm_inverse(self.permutation))

def tree_flatten(self):
Expand Down Expand Up @@ -223,13 +228,17 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
class UnswizzleRef(Transform):
swizzle: int

def untransform_index(self, idxs: tuple[slice, ...]) -> tuple[slice, ...]:
def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]:
if not idxs:
return idxs
if idxs[-1].step is not None and idxs[-1].step != 1:
if not all(isinstance(idx, slice) for idx in idxs):
raise NotImplementedError("Non-slice indices are not supported")
last_idx = idxs[-1]
assert isinstance(last_idx, slice)
if last_idx.step is not None and last_idx.step != 1:
raise NotImplementedError("Swizzled dims cannot be sliced")
if (idxs[-1].start is not None and idxs[-1].start != 0) or (
idxs[-1].stop is not None and idxs[-1].stop != self.swizzle
if (last_idx.start is not None and last_idx.start != 0) or (
last_idx.stop is not None and last_idx.stop != self.swizzle
):
raise ValueError("Swizzled dims cannot be sliced")
return idxs
Expand Down
51 changes: 29 additions & 22 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,21 +732,22 @@ def _handle_indexing(
indexer = cast(indexing.NDIndexer, transforms[-1])
if indexer.int_indexer_shape:
raise NotImplementedError("int_indexer_shape non-empty")
slices = _ndindexer_slices(indexer)
indices = _ndindexer_indices(indexer)
for t in reversed(transforms[:-1]):
slices = t.untransform_index(slices)
return mgpu.memref_slice(ref, slices)
indices = t.untransform_index(indices)
return mgpu.memref_slice(ref, indices)


def _ndindexer_slices(indexer: indexing.NDIndexer) -> tuple[slice, ...]:
slices = []
for s in indexer.indices:
if not isinstance(s, indexing.Slice):
raise NotImplementedError(f"Unsupported dimension index: {s}")
if s.is_dynamic_start or s.is_dynamic_size:
raise NotImplementedError(f"Unsupported slice: {s}")
slices.append(slice(s.start, s.start + s.size, s.stride))
return tuple(slices)
def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]:
indices = []
for idx in indexer.indices:
if isinstance(idx, indexing.Slice):
if idx.is_dynamic_start or idx.is_dynamic_size:
raise NotImplementedError(f"Unsupported slice: {idx}")
indices.append(slice(idx.start, idx.start + idx.size, idx.stride))
else:
indices.append(_as_index(idx))
return tuple(indices)


def _is_swizzled(transforms: tuple[gpu_core.Transform, ...]) -> int | None:
Expand Down Expand Up @@ -787,10 +788,10 @@ def _swap_lowering_rule(
raise TypeError(f"Can only store arrays (got {value}).")
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
raise TypeError(f"Can only store to references (got {x_smem}).")
transform = jax.tree.unflatten(tree, leaves)
swizzle = _is_swizzled(transform)
x_smem = _handle_indexing(x_smem, transform)
x_aval, _ = ctx.avals_in
transforms = jax.tree.unflatten(tree, leaves)
swizzle = _is_swizzled(transforms)
x_smem = _handle_indexing(x_smem, transforms)
x_aval = ctx.avals_in[0]
if swizzle is None:
old_value = mgpu.FragmentedArray.load_strided(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype)
Expand Down Expand Up @@ -1178,9 +1179,15 @@ def _i64_constant(v: int) -> ir.Value:
return arith_dialect.constant(ir.IntegerType.get_signless(64), v)


def _as_index(v: int | ir.Value) -> ir.Value:
if isinstance(v, int):
return arith_dialect.constant(ir.IndexType.get(), v)
if ir.IndexType.isinstance(v.type):
return v
return arith_dialect.index_cast(ir.IndexType.get(), v)
def _as_index(v: object) -> ir.Value:
match v:
case int():
return arith_dialect.constant(ir.IndexType.get(), v)
case ir.Value() if ir.IndexType.isinstance(v.type):
return v
case ir.Value() if ir.IntegerType.isinstance(v.type):
return arith_dialect.index_cast(ir.IndexType.get(), v)
case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()):
return _as_index(v.registers.item())
case _:
raise ValueError(f"Unsupported index: {v}")
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _extract_copy_params(transforms):
transforms = transforms[1:]
gpu_transforms = [t.to_gpu_transform() for t in transforms]
return dict(
gmem_slice=lowering._ndindexer_slices(indexer),
gmem_slice=lowering._ndindexer_indices(indexer),
gmem_transform=tuple(gpu_transforms),
swizzle=swizzle,
)
Expand Down
41 changes: 41 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ def kernel(x_ref, o_ref):
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), unary(x))

def test_add_first(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[0]

x = jnp.arange(256).astype(jnp.float32)
y = jnp.flip(x).reshape(1, 256)
np.testing.assert_array_equal(kernel(x, y), x + y[0])

def test_add_xy(self):
@functools.partial(
pl.pallas_call,
Expand All @@ -72,6 +84,19 @@ def kernel(x_ref, y_ref, o_ref):
y = x + 1
np.testing.assert_array_equal(kernel(x, y), x + y)

def test_add_xy_indexed(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
idx = jnp.sum(y_ref[...])
o_ref[...] = x_ref[idx]

x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
y = jnp.zeros(128, dtype=jnp.int32)
np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)])

def test_add_one_grid(self):
@functools.partial(
pl.pallas_call,
Expand Down Expand Up @@ -469,6 +494,22 @@ def kernel(o_ref):
kernel(), jnp.full([256], 5.0, dtype=jnp.float32)
)

def test_fori_loop_indexed_store(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
def body(idx, _):
o_ref[idx] = x_ref[idx] + y_ref[idx]
return ()

jax.lax.fori_loop(0, 4, body, ())

x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
y = x + 1
np.testing.assert_array_equal(kernel(x, y), x + y)

def test_cond(self):

@functools.partial(
Expand Down

0 comments on commit aadb509

Please sign in to comment.