Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Added some runtime type checking to copy_* and …
Browse files Browse the repository at this point in the history
…`barrier_*` primitives

PiperOrigin-RevId: 710302436
  • Loading branch information
superbobry authored and Google-ML-Automation committed Dec 28, 2024
1 parent 7ab61b7 commit 76ccb19
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import lowering
from jax._src.pallas.mosaic_gpu.core import state_types
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
Expand All @@ -44,13 +45,30 @@
WARPGROUP_SIZE = 128


_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef


def _check_ref(
aval: object, name: str, memory_space: gpu_core.GPUMemorySpace
) -> None:
if not isinstance(aval, state_types.AbstractRef):
raise TypeError(f"{name} must be a reference, got {aval}")
aval_memory_space = getattr(aval, "memory_space", None) or gpu_core.GMEM
if aval_memory_space is not memory_space:
raise ValueError(
f"{name} must be a {memory_space.name.upper()} reference, got {aval}"
)


copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem")
copy_smem_to_gmem_p.multiple_results = True


@copy_smem_to_gmem_p.def_effectful_abstract_eval
def _copy_smem_to_gmem_abstract_eval(*avals, **params):
del avals, params # Unused.
def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
_check_ref(src, "src", gpu_core.SMEM)
_check_ref(dst, "dst", gpu_core.GMEM)
del args, params # Unused.
return (), {state.ReadEffect(0), state.WriteEffect(1)}


Expand Down Expand Up @@ -115,9 +133,7 @@ def _extract_smem_copy_params(transforms):


def copy_smem_to_gmem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
predicate: jax.Array | None = None,
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
) -> None:
"""Asynchronously copies a SMEM reference to a GMEM reference.
Expand All @@ -131,10 +147,6 @@ def copy_smem_to_gmem(
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
:func:`jax.experimental.mosaic.gpu.commit_smem`
"""
if src.memory_space is not gpu_core.SMEM:
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
src, src_transforms = state_primitives.get_ref_and_transforms(
src, None, "copy_smem_to_gmem", force_trailing_indexer=False,
)
Expand Down Expand Up @@ -165,8 +177,11 @@ def copy_smem_to_gmem(


@copy_gmem_to_smem_p.def_effectful_abstract_eval
def _copy_gmem_to_smem_abstract_eval(*avals, **params):
del avals, params # Unused.
def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
del args, params # Unused.
_check_ref(src, "src", gpu_core.GMEM)
_check_ref(dst, "dst", gpu_core.SMEM)
_check_ref(barrier, "barrier", gpu_core.SMEM)
return (), {state.ReadEffect(0), state.WriteEffect(1)}


Expand Down Expand Up @@ -218,21 +233,13 @@ def _copy_gmem_to_smem_lowering(
return ()


def copy_gmem_to_smem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
barrier: pallas_core.AbstractMemoryRef,
) -> None:
def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None:
"""Asynchronously copies a GMEM reference to a SMEM reference.
See also:
:func:`jax.experimental.mosaic.gpu.barrier_arrive`
:func:`jax.experimental.mosaic.gpu.barrier_wait`
"""
if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
if dst.memory_space is not gpu_core.SMEM:
raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
src, src_transforms = state_primitives.get_ref_and_transforms(
src, None, "copy_gmem_to_smem", force_trailing_indexer=False,
)
Expand Down Expand Up @@ -292,8 +299,9 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None:


@barrier_arrive_p.def_effectful_abstract_eval
def _barrier_arrive_abstract_eval(*avals, **params):
del avals, params # Unused.
def _barrier_arrive_abstract_eval(barrier, *args, **params):
del args, params # Unused.
_check_ref(barrier, "barrier", gpu_core.SMEM)
return (), {gpu_core._memory_effect}


Expand Down Expand Up @@ -329,8 +337,9 @@ def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None:


@barrier_wait_p.def_effectful_abstract_eval
def _barrier_wait_abstract_eval(*avals, **params):
del avals, params # Unused.
def _barrier_wait_abstract_eval(barrier, *args, **params):
_check_ref(barrier, "barrier", gpu_core.SMEM)
del args, params # Unused.
return (), {gpu_core._memory_effect}


Expand Down

0 comments on commit 76ccb19

Please sign in to comment.