Skip to content

Commit d2bf034

Browse files
[Mosaic GPU] Test the wgmma_op lowering when a is in registers.
I had to add support for wgmma layout in vector_load. Not sure if this is useful outside the test. PiperOrigin-RevId: 735384104
1 parent 5a7ef40 commit d2bf034

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

jax/experimental/mosaic/gpu/dialect_lowering.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,6 @@ def _vector_load_op_lowering_rule(
235235
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
236236
)
237237

238-
if not layouts.is_strided_fragmented_layout(out_layout_attr):
239-
raise ValueError(
240-
f"{vector_load_op} has an unsupported layout: {out_layout_attr}"
241-
)
242-
243238
for i in vector_load_op.indices:
244239
index_defining_op = i.owner.opview
245240
if (
@@ -254,10 +249,29 @@ def _vector_load_op_lowering_rule(
254249

255250
element_type = vector_load_op.result.type.element_type
256251
is_signed = False if ir.IntegerType.isinstance(element_type) else None
257-
strided_layout = layouts.from_strided_fragmented_layout_attr(out_layout_attr)
258-
fragmented_array = fa.FragmentedArray.load_strided(
259-
vector_load_op.base, is_signed=is_signed, vec_size=strided_layout.vec_size
260-
)
252+
253+
if layouts.is_strided_fragmented_layout(out_layout_attr):
254+
strided_layout = layouts.from_strided_fragmented_layout_attr(
255+
out_layout_attr
256+
)
257+
fragmented_array = fa.FragmentedArray.load_strided(
258+
vector_load_op.base,
259+
is_signed=is_signed,
260+
vec_size=strided_layout.vec_size,
261+
)
262+
elif layouts.is_wgmma_fragmented_layout(out_layout_attr):
263+
layout = ir.MemRefType(vector_load_op.base.type).layout
264+
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
265+
transformed_ref = transform_memref(vector_load_op.base, transforms)
266+
fragmented_array = fa.FragmentedArray.load_tiled(
267+
transformed_ref,
268+
swizzle=swizzle,
269+
is_signed=is_signed
270+
)
271+
else:
272+
raise ValueError(
273+
f"{vector_load_op} has an unsupported layout: {out_layout_attr}"
274+
)
261275
return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)]
262276

263277

tests/mosaic/gpu_test.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -2755,6 +2755,7 @@ class TestCaseInput:
27552755
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
27562756
transpose_a: bool = False
27572757
transpose_b: bool = False
2758+
load_a_in_registers: bool = False
27582759

27592760
result = []
27602761
for swizzle in [
@@ -2786,6 +2787,13 @@ class TestCaseInput:
27862787
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
27872788
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
27882789
),
2790+
TestCaseInput(
2791+
shape_a=[groups_m * 64, groups_k * k],
2792+
shape_b=[groups_k * k, groups_n * k],
2793+
shape_res=[groups_m * 64, groups_n * k],
2794+
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
2795+
load_a_in_registers=True,
2796+
),
27892797
])
27902798
# The below only works for 128-byte swizzling. Regardless of transposing,
27912799
# TMA needs the size of the last dimension to be compatible with the
@@ -2849,6 +2857,14 @@ def matmul(
28492857
parity, _ = tma_barrier.update_parities(parities)
28502858
mgpu_dialect.wait(dialect_barrier, parity)
28512859

2860+
# SMEM -> Registers
2861+
a_operand = a_smem_ref
2862+
zero_index = arith.constant(ir.IndexType.get(), 0)
2863+
if test_case.load_a_in_registers:
2864+
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
2865+
zero_vector_indices = [zero_index] * len(test_case.shape_a)
2866+
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
2867+
28522868
# Computation
28532869
shape_result = ir.MemRefType(result_gmem_ref.type).shape
28542870
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
@@ -2860,7 +2876,7 @@ def matmul(
28602876
)
28612877
result = mgpu_dialect.wgmma(
28622878
accumulator,
2863-
a_smem_ref,
2879+
a_operand,
28642880
b_smem_ref,
28652881
transpose_a=test_case.transpose_a,
28662882
transpose_b=test_case.transpose_b,
@@ -2870,8 +2886,7 @@ def matmul(
28702886
nvvm.wgmma_wait_group_sync_aligned(0)
28712887

28722888
# Registers -> SMEM
2873-
zero_index = arith.constant(ir.IndexType.get(), 0)
2874-
vector.store(result, result_smem_ref, [zero_index, zero_index])
2889+
vector.store(result, result_smem_ref, [zero_index] * len(shape_result))
28752890

28762891
# SMEM -> GMEM
28772892
mgpu_dialect.async_store(

0 commit comments

Comments
 (0)