From 0ee9531ef28c45046173c383ecd48d1de61b14ce Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 18 Oct 2024 04:54:59 -0700 Subject: [PATCH] [Pallas:MGPU] Add support for indexed refs to WGMMA PiperOrigin-RevId: 687258992 --- jax/_src/pallas/mosaic_gpu/primitives.py | 100 ++++++++++++----------- tests/pallas/mosaic_gpu_test.py | 37 ++++++++- 2 files changed, 88 insertions(+), 49 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index d6e4e01f1352..55e5b979d7c1 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -19,6 +19,7 @@ import enum from typing import Any, Literal +import jax from jax._src import core as jax_core from jax._src import effects from jax._src import state @@ -391,41 +392,25 @@ def wgmma( f" rhs={b.shape=}, acc={acc.shape}" ) - if (dtype := a.dtype) != b.dtype: + if a.dtype != b.dtype: raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}") - # Infer swizzle from a. - if not a.transforms or not isinstance( - (swizzle_transform := a.transforms[0]), gpu_core.UnswizzleRef - ): - raise ValueError("WGMMA lhs must be a tiled and swizzled reference.") + a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms) + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) - swizzle = swizzle_transform.swizzle - swizzle_elems = swizzle // dtype.itemsize - if a.transforms[1:] != (gpu_core.UntileRef((64, swizzle_elems)),): - raise ValueError( - f"WGMMA lhs must be tiled with 64x{swizzle_elems} tiles for element type" - f" {dtype}." - ) - - rhs_transpose_transform = gpu_core.TransposeRef((1, 0, 2, 3)) - rhs_tiling = gpu_core.UntileRef((swizzle_elems, swizzle_elems)) - if b.transforms == (swizzle_transform, rhs_tiling): - rhs_transpose = False - elif b.transforms == (swizzle_transform, rhs_transpose_transform, rhs_tiling): - rhs_transpose = True - else: - raise ValueError( - f"WGMMA rhs must have {swizzle=} and be tiled with" - f" {swizzle_elems}x{swizzle_elems} tiles for element type {dtype} (and" - " optionally transposed)." - ) - - wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose) + wgmma_ref_p.bind( + acc, + a.ref, + b.ref, + *a_transforms_leaves, + *b_transforms_leaves, + a_transforms_tree=a_transforms_tree, + b_transforms_tree=b_transforms_tree, + ) @wgmma_ref_p.def_effectful_abstract_eval -def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, **params): +def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params): del a_aval, b_aval, params if not isinstance(acc_aval, gpu_core.WGMMAAbstractAccumulatorRef): raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc_aval}") @@ -439,23 +424,9 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, **params): @discharge.register_discharge_rule(wgmma_ref_p) -def _wgmma_ref_discharge( - in_avals, - out_avals, - acc, - a, - b, - swizzle, - rhs_transpose, -): +def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): del in_avals, out_avals - return ( - wgmma_p.bind( - acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose - ), - None, - None, - ), [] + return (wgmma_p.bind(*args, **kwargs), *([None] * (len(args) - 1))), [] # Functional WGMMA, returns a shaped array. Internal. @@ -468,10 +439,43 @@ def _wgmma_lowering( acc, a, b, - swizzle, - rhs_transpose, + *transforms_leaves, + a_transforms_tree, + b_transforms_tree, ): - del ctx + _, a_aval, *_ = ctx.avals_in + a_transforms_leaves, b_transforms_leaves = util.split_list( + transforms_leaves, [a_transforms_tree.num_leaves] + ) + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + + a, a_transforms = lowering._handle_indexing(a, a_transforms) + b, b_transforms = lowering._handle_indexing(b, b_transforms) + + match a_transforms: + case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): + swizzle_elems = swizzle // a_aval.dtype.itemsize + if tiling != (64, swizzle_elems): + raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") + case _: + raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") + + match b_transforms: + case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): + rhs_transpose = False + # TODO(apaszke): Actually what we really want to test here is that we're + # only doing transposes within the tiles! + case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.TransposeRef((1, 0, 2, 3)), gpu_core.UntileRef(rhs_tiling)): + rhs_transpose = True + case _: + raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.") + + if rhs_swizzle != swizzle: + raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle") + if rhs_tiling != (swizzle_elems, swizzle_elems): + raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") + new_acc = mgpu.wgmma( acc, a, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index bb9789555477..6ab4bf8df145 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -677,7 +677,42 @@ def scope(acc_ref): res, a @ (b.T if rhs_transpose else b), rtol=1e-3 ) - def test_wgmma_sliced(self): + def test_wgmma_sliced_ref(self): + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0]) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32)) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16) + + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (2, 64, 128), lambda: (0, 0, 0), + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ), + ), + plgpu.GPUBlockSpec( + (2, 128, 192), lambda: (0, 0, 0), + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ), + ), + ], + out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), + )(a, b) + np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) + + def test_wgmma_sliced_acc(self): swizzle = 128 elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize def kernel(a_ref, b_ref, o_ref):