Skip to content

Commit

Permalink
[Pallas] Implement tiled and swizzled Memref loads for Mosaic GPU via…
Browse files Browse the repository at this point in the history
… "GPUBlockSpec"

PiperOrigin-RevId: 673165201
  • Loading branch information
justinjfu authored and jax authors committed Sep 11, 2024
1 parent c659dc9 commit e3c4b20
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 86 deletions.
205 changes: 128 additions & 77 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,105 @@ def __init__(
self.memory_space = memory_space
self.indexing_mode = indexing_mode

def to_block_mapping(
self,
origin: OriginStr,
array_aval: jax_core.ShapedArray,
*,
# Inputs for the index_map
index_map_avals: Sequence[jax_core.AbstractValue],
index_map_tree: tree_util.PyTreeDef,
grid: GridMappingGrid,
mapped_dims: tuple[int, ...],
) -> BlockMapping:
if self.index_map is None:
index_map_func = lambda *args: (0,) * len(array_aval.shape)
else:
index_map_func = self.index_map
if self.block_shape is None:
block_shape = array_aval.shape
else:
block_shape = self.block_shape
if len(array_aval.shape) != len(block_shape):
raise ValueError(
f"Block shape for {origin} (= {block_shape}) "
"must have the same number of dimensions as the "
f"array shape {array_aval.shape}."
)

unmapped_block_shape = tuple(s for s in block_shape if s is not None)
block_array_aval = array_aval.update(shape=unmapped_block_shape)
if isinstance(array_aval, jax_core.DShapedArray):
# Get the "max" shape for the ragged array.
block_array_aval = jax_core.ShapedArray(
block_array_aval.shape,
block_array_aval.dtype,
block_array_aval.weak_type,
)
block_aval = AbstractMemoryRef(block_array_aval, self.memory_space)

if not jax_core.is_constant_shape(block_aval.shape):
raise ValueError(
"shape polymorphism for Pallas does not support "
"dynamically-shaped blocks. "
f"Block spec for {origin} has block_shape: {block_aval.shape}"
)

flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func), index_map_tree
)
debug = pe.debug_info(
index_map_func,
index_map_tree,
index_map_out_tree_thunk,
False,
"pallas_call index_map",
)
index_map_src_info = NameAndSrcInfo.from_pallas_call(
None, debug.func_src_info
)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
flat_index_map_fun, index_map_avals, debug_info=debug
)
mapped_block_shape = tuple(mapped if s is None else s for s in block_shape)
if len(out_avals) != len(block_shape):
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must return "
f"{len(block_shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
)
for i, ov in enumerate(out_avals):
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must return integer scalars. Output[{i}] has type "
f"{ov}."
)

if consts:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must not capture constants: {consts}"
)

array_aval_shape = _max_shape_from_aval(array_aval)

mapping = BlockMapping(
block_shape=mapped_block_shape,
block_aval=block_aval,
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=self.indexing_mode,
array_shape_dtype=jax.ShapeDtypeStruct(
array_aval_shape, array_aval.dtype
),
origin=origin,
)
mapping.check_invariants()
return mapping


class NoBlockSpec:
def __repr__(self):
Expand All @@ -329,6 +428,15 @@ def __repr__(self):
# BlockSpecTree = Sequence[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
BlockSpecTree = Any


class MemrefTransform(Protocol):
"""Represents a transformation applied to a Memref on load or store."""

def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef:
"""Returns the transformed aval given an input aval."""
raise NotImplementedError("Abstract evaluation not implemented.")


@dataclasses.dataclass(frozen=True)
class BlockMapping:
"""An internal canonicalized version of BlockSpec.
Expand All @@ -342,6 +450,9 @@ class BlockMapping:
indexing_mode: IndexingMode
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: OriginStr
transforms: Sequence[MemrefTransform] = dataclasses.field(
default_factory=tuple
)

def check_invariants(self) -> None:
if not config.enable_checks.value: return
Expand All @@ -368,6 +479,14 @@ def replace(self, **kwargs):
new_self.check_invariants()
return new_self

@property
def ref_aval(self) -> AbstractMemoryRef:
"""Returns the abstract value of the Ref after transformations."""
block_aval = self.block_aval
for transform in self.transforms:
block_aval = transform(block_aval)
return block_aval

def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
Expand Down Expand Up @@ -603,82 +722,14 @@ def _convert_block_spec_to_block_mapping(
) -> BlockMapping:
if block_spec is no_block_spec:
block_spec = BlockSpec(None, None)
if block_spec.index_map is None:
index_map_func = lambda *args: (0,) * len(array_aval.shape)
else:
index_map_func = block_spec.index_map
if block_spec.block_shape is None:
block_shape = array_aval.shape
else:
block_shape = block_spec.block_shape
if len(array_aval.shape) != len(block_shape):
raise ValueError(
f"Block shape for {origin} (= {block_shape}) "
"must have the same number of dimensions as the "
f"array shape {array_aval.shape}.")

unmapped_block_shape = tuple(s for s in block_shape if s is not None)
block_array_aval = array_aval.update(shape=unmapped_block_shape)
if isinstance(array_aval, jax_core.DShapedArray):
# Get the "max" shape for the ragged array.
block_array_aval = jax_core.ShapedArray(
block_array_aval.shape,
block_array_aval.dtype,
block_array_aval.weak_type,
)
block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space)

if not jax_core.is_constant_shape(block_aval.shape):
raise ValueError(
"shape polymorphism for Pallas does not support "
"dynamically-shaped blocks. "
f"Block spec for {origin} has block_shape: {block_aval.shape}")

flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func), index_map_tree)
debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk,
False, "pallas_call index_map")
index_map_src_info = NameAndSrcInfo.from_pallas_call(None,
debug.func_src_info)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
index_map_avals,
debug_info=debug)
mapped_block_shape = tuple(
mapped if s is None else s for s in block_shape)
if len(out_avals) != len(block_shape):
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must return "
f"{len(block_shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values.")
for i, ov in enumerate(out_avals):
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must return integer scalars. Output[{i}] has type "
f"{ov}.")

if consts:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must not capture constants: {consts}")

array_aval_shape = _max_shape_from_aval(array_aval)

mapping = BlockMapping(
block_shape=mapped_block_shape,
block_aval=block_aval,
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=block_spec.indexing_mode,
array_shape_dtype=jax.ShapeDtypeStruct(
array_aval_shape, array_aval.dtype
),
origin=origin,
return block_spec.to_block_mapping(
origin,
array_aval,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid,
mapped_dims=mapped_dims,
)
mapping.check_invariants()
return mapping

index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)

Expand Down Expand Up @@ -846,11 +897,11 @@ def get_grid_mapping(
num_scratch_operands=num_flat_scratch_operands,
)
grid_mapping.check_invariants()
in_ref_avals = [bm.block_aval for bm in in_block_mappings]
in_ref_avals = [bm.ref_aval for bm in in_block_mappings]
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals,
*jaxpr_in_ref_avals)
out_ref_avals = [bm.block_aval for bm in out_block_mappings]
out_ref_avals = [bm.ref_aval for bm in out_block_mappings]
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/pallas/mosaic_gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ pytype_strict_library(
srcs = ["core.py"],
deps = [
"//jax",
"//jax:core",
"//jax:mosaic_gpu",
"//jax:tree_util",
"//jax/_src/pallas",
],
] + py_deps("numpy"),
)
86 changes: 86 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
import enum
from typing import ClassVar, Literal
from jax import core as jax_core
from jax._src import core
from jax._src import tree_util
from jax._src.pallas import core as pallas_core
from jax.experimental.mosaic import gpu as mosaic_gpu
import jax.numpy as jnp


AbstractMemoryRef = pallas_core.AbstractMemoryRef


Expand Down Expand Up @@ -55,6 +59,88 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
return MemoryRef(shape, dtype, self)


class TilingTransform(pallas_core.MemrefTransform):
"""Represents a tiling transformation for Memrefs.
A tiling of (X, Y) on an array of shape (M, N) will result in a transformed
shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a
tiling of (64, 32) will be tiled as (4, 8, 64, 32).
"""

def __init__(self, tiling: tuple[int, ...]):
self.tiling = tiling

def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
block_shape = block_aval.inner_aval.shape # pytype: disable=attribute-error
old_tiled_dims = block_shape[-len(self.tiling) :]
num_tiles = tuple(
block_dim // tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
rem = (
block_dim % tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
if any(rem):
raise ValueError(
f"Block shape {block_shape} is not divisible by tiling {self.tiling}"
)
new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling
return block_aval.update(
inner_aval=block_aval.inner_aval.update(shape=new_block_shape)
)

def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform:
return mosaic_gpu.TileTransform(self.tiling)


@dataclasses.dataclass(frozen=True)
class GPUBlockMapping(pallas_core.BlockMapping):
swizzle: int | None = None


@dataclasses.dataclass
class GPUBlockSpec(pallas_core.BlockSpec):
# TODO(justinfu): Replace tiling a list of transforms.
tiling: tuple[int, ...] | None = None
swizzle: int | None = None

def to_block_mapping(
self,
origin: pallas_core.OriginStr,
array_aval: core.ShapedArray,
*,
index_map_avals: Sequence[core.AbstractValue],
index_map_tree: tree_util.PyTreeDef,
grid: pallas_core.GridMappingGrid,
mapped_dims: tuple[int, ...],
) -> GPUBlockMapping:
bm = super().to_block_mapping(
origin,
array_aval,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid,
mapped_dims=mapped_dims,
)
transforms: tuple[pallas_core.MemrefTransform, ...] = ()
if self.tiling is not None:
transforms += (TilingTransform(self.tiling),)
return GPUBlockMapping(
block_shape=bm.block_shape,
block_aval=bm.block_aval,
origin=bm.origin,
index_map_jaxpr=bm.index_map_jaxpr,
index_map_src_info=bm.index_map_src_info,
indexing_mode=bm.indexing_mode,
array_shape_dtype=bm.array_shape_dtype,
transforms=transforms,
swizzle=self.swizzle,
)


# TODO(b/354568887): Cosolidate this with TPU's MemoryRef.
@dataclasses.dataclass(frozen=True)
class MemoryRef:
Expand Down
Loading

0 comments on commit e3c4b20

Please sign in to comment.