Skip to content

Commit 30a9e1b

Browse files
apaszkeGoogle-ML-Automation
authored andcommittedMar 11, 2025·
[Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell
This one is particularly annoying, because we have to break up the MMA into two collective N=256 MMAs. However, TensorCore only updates a contiguous chunk of columns in TMEM and so after executing two of those we end up with a TMEM layout that looks like this: ``` Contributing CTA | 0 | 1 | 0 | 1 | N local | 0:128 | 0:128 | 128:256 | 128:256 | N | 0:128 | 256:384 | 128:256 | 384:512 | ``` You can see that the TMEM columns no longer monotonically go over all columns until N=512, but they include a number of jumps! We could fix this on the load side, by ensuring that each CTA in the group does a strided load along the tiled dimension, but that just seems more trouble than it's worth (and is not that well supported by TMA unless we increase the number of striding levels). Instead, we encode this weirdness in the TMEM layout we use and make sure to rearrange the data properly while loading the tiles into registers. PiperOrigin-RevId: 735791426
1 parent 1aca76f commit 30a9e1b

File tree

5 files changed

+115
-51
lines changed

5 files changed

+115
-51
lines changed
 

‎jax/experimental/mosaic/gpu/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def ref(member_thunks=member_thunks):
260260
dynamic_smem, c(dynamic_smem_offset, index), [],
261261
)
262262
if layout is None:
263-
layout = tcgen05._infer_tmem_layout(shape)
263+
layout = tcgen05._infer_tmem_layout(shape, collective)
264264
num_cols = layout.cols_in_shape(shape)
265265
delayed_warp_init.append(
266266
functools.partial(

‎jax/experimental/mosaic/gpu/examples/matmul_blackwell.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def main(unused_argv):
230230
tile_n *= 2
231231
if m < tile_m or n < tile_n:
232232
continue
233-
if kwargs["collective"] and tile_n >= 512:
234-
continue # TODO(apaszke): Support 512
233+
if tile_n > 512:
234+
continue
235235
if (m // tile_m) % kwargs["grid_tile_m"]:
236236
continue
237237
try:

‎jax/experimental/mosaic/gpu/tcgen05.py

+107-47
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def mma(
8383
accumulate: ir.Value | bool = True,
8484
collective: bool = False,
8585
):
86+
i32 = ir.IntegerType.get_signless(32)
8687
i64 = ir.IntegerType.get_signless(64)
8788
if isinstance(accumulate, bool):
8889
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
@@ -112,6 +113,10 @@ def mma(
112113
raise ValueError(
113114
f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}"
114115
)
116+
if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective)):
117+
raise ValueError(
118+
f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}"
119+
)
115120
f32 = ir.F32Type.get()
116121
if element_type == f32 or element_type == ir.BF16Type.get():
117122
if d.dtype != f32:
@@ -136,11 +141,7 @@ def mma(
136141
raise ValueError(f"N must be a multiple of 8, got: {n}")
137142
elif n > 256 and n != 512:
138143
raise ValueError("Only N below 256 or N=512 are supported")
139-
if num_cta == 2 and n > 256:
140-
raise NotImplementedError(
141-
"N is too big for collective MMA. Only up to 256 is supported."
142-
)
143-
n_group_elems = min(n, 256)
144+
n_group_elems = min(n, 256 // num_cta)
144145
if m % m_group_elems:
145146
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
146147
if k % k_group_elems:
@@ -179,6 +180,7 @@ def mma(
179180

180181
# Step 4. Issue the instructions.
181182
true = arith.constant(ir.IntegerType.get_signless(1), 1)
183+
n_collective_group_elems = n_group_elems * num_cta
182184
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
183185
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
184186
a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64))
@@ -188,9 +190,9 @@ def mma(
188190
raise NotImplementedError("D needs to be sliced")
189191
acc = accumulate if ki == 0 else true
190192
_do_mma(
191-
d.slice(
192-
slice(None), utils.ds(ni * n_group_elems, n_group_elems)
193-
).address,
193+
arith.addi(
194+
d.address, arith.constant(i32, ni * n_collective_group_elems)
195+
),
194196
a_mk,
195197
b_nk,
196198
d_type=ir.F32Type.get(),
@@ -377,8 +379,15 @@ class TMEMLayout:
377379
+------------------+------------------+
378380
| [0:64, 64:128] | [64:128, 64:128] |
379381
+------------------+------------------+
382+
383+
The above is further complicated by column_tile_stride, which is used to
384+
swizzle the ordering of column tiles. That is, if column_tile_stride is 2,
385+
we will first lay out all tiles that have the column index 0, 2, 4, and so on
386+
until we run out of tiles. Only then we lay out the tiles with column index
387+
1, 3, etc.
380388
"""
381389
elements_in_tile: tuple[int, int]
390+
column_tile_stride: int = 1
382391

383392
def __post_init__(self):
384393
row_tiling = self.elements_in_tile[0]
@@ -405,7 +414,7 @@ def cols_in_shape(self, shape: tuple[int, int]):
405414
return num_tiles // tiles_in_row * cols_in_tile
406415

407416

408-
def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
417+
def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout:
409418
if shape[0] > TMEM_ROWS:
410419
raise ValueError(
411420
"Can only infer TMEM layout for shapes with at most 128 rows, got:"
@@ -421,7 +430,15 @@ def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
421430
"Can only infer TMEM layout for shapes with row count that's a power of"
422431
f" 2, got: {shape[0]}"
423432
)
424-
return TMEMLayout(elements_in_tile=(shape[0], 1))
433+
if shape[1] % 8:
434+
raise ValueError(
435+
"Can only infer TMEM layout for shapes with column count that's a"
436+
f" multiple of 8, got: {shape[1]}"
437+
)
438+
if collective and shape[1] == 512:
439+
return TMEMLayout(elements_in_tile=(shape[0], 128), column_tile_stride=2)
440+
else:
441+
return TMEMLayout(elements_in_tile=(shape[0], 8))
425442

426443

427444
@dataclasses.dataclass(frozen=True)
@@ -432,7 +449,14 @@ class TMEMRef:
432449
layout: TMEMLayout
433450

434451
@classmethod
435-
def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layout: TMEMLayout | None = None):
452+
def from_alloc(
453+
cls,
454+
tmem_addr_ref: ir.Value,
455+
shape: tuple[int, int],
456+
dtype,
457+
collective: bool | None = None,
458+
layout: TMEMLayout | None = None,
459+
):
436460
i32 = ir.IntegerType.get_signless(32)
437461
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
438462
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
@@ -449,7 +473,11 @@ def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layo
449473
if shape[0] < 32:
450474
raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}")
451475
if layout is None:
452-
layout = _infer_tmem_layout(shape)
476+
if collective is None:
477+
raise ValueError(
478+
"collective argument must be provided when TMEM layout is inferred"
479+
)
480+
layout = _infer_tmem_layout(shape, collective)
453481
else:
454482
layout.check_shape(shape)
455483
# TODO: Do we have to do this??
@@ -461,12 +489,17 @@ def slice(self, *idxs):
461489
base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
462490
if any(is_squeezed):
463491
raise ValueError("TMEM can only be sliced, not indexed")
464-
if self.layout.elements_in_tile[0] != TMEM_ROWS:
492+
if self.layout != TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)):
465493
raise NotImplementedError(
466-
f"Slicing only implemented for refs with tiling of {TMEM_ROWS} rows"
494+
"Slicing only implemented for refs with standard layout, got:"
495+
f" {self.layout}"
467496
)
468497
if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS:
469498
raise NotImplementedError("TMEM cannot be sliced along rows")
499+
if slice_shape[1] % 8:
500+
raise NotImplementedError(
501+
"TMEM column slice length must be a multiple of 8"
502+
)
470503
col_idx = base_idx[1]
471504
if not isinstance(col_idx, ir.Value):
472505
col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx)
@@ -484,48 +517,75 @@ def __getitem__(self, *idxs):
484517
raise ValueError("TMEM loads only support slicing")
485518
if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape:
486519
raise NotImplementedError("Slicing of TMEM not impelmented yet")
487-
if self.layout.elements_in_tile[0] != TMEM_ROWS:
488-
raise NotImplementedError(
489-
f"Loads only implemented for refs with tiling of {TMEM_ROWS} rows"
490-
)
491520
if self.shape[1] % 8:
492521
raise NotImplementedError
493522
if self.dtype != ir.F32Type.get():
494523
raise NotImplementedError(self.dtype)
495524
layout = _m128_256bit_32bit_layout(self.shape)
496525
regs_shape = layout.registers_shape(self.shape)
497-
num = self.shape[1] // 8
498-
# TODO(apaszke): Make the tiling configurable through the args too.
499-
if num <= 32:
500-
num_tiling = num
501-
elif num == 64:
502-
num_tiling = 32
503-
else:
504-
raise NotImplementedError(num)
505-
registers = np.empty(regs_shape, dtype=object)
506-
# We load 16 lanes at a time, but need 32 in total.
507-
for row_group in range(2):
508-
addr_row = arith.addi(self.address, arith.constant(i32, (row_group * 16) << 16))
509-
regs = []
510-
cols_per_num_tile = 8 # This depends on the 16x256b below.
511-
for num_group in range(num // num_tiling):
512-
addr_row_col = arith.addi(
513-
addr_row,
514-
arith.constant(i32, num_tiling * num_group * cols_per_num_tile),
526+
if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)):
527+
# load_32xcols returns a 4xN array, but the FA tiling we use here tiles
528+
# columns before rows, and so it is Nx4 (after ignoring all 1 dims).
529+
registers = _load_32xcols(
530+
self.address, self.shape[1], self.dtype
531+
).T.reshape(regs_shape)
532+
elif self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2):
533+
if self.shape[1] % 128 != 0:
534+
raise ValueError(
535+
f"TMEM layout {self.layout} is not compatible with shape {self.shape}"
515536
)
516-
regs += tmem_load(addr_row_col, "16x256b", num_tiling)
517-
regs = [llvm.bitcast(self.dtype, r) for r in regs]
518-
vector_regs = []
519-
undef = llvm.mlir_undef(ir.VectorType.get((2,), self.dtype))
520-
for r_low, r_high in zip(regs[::2], regs[1::2]):
521-
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
522-
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
523-
vector_regs.append(vreg)
524-
# Dimension 4 is the one where we split 32 rows into tiles of 8.
525-
regs_slice = (slice(None),) * 4 + (slice(row_group * 2, (row_group + 1) * 2),)
526-
registers[regs_slice] = np.asarray(vector_regs, dtype=object).reshape(registers[regs_slice].shape)
537+
num_column_tiles = self.shape[1] // 128
538+
column_tile_stride = self.layout.column_tile_stride
539+
num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride)
540+
tiles = []
541+
for col_tile_base in range(num_strided_col_groups):
542+
for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride):
543+
tiles.append(
544+
_load_32xcols(
545+
arith.addi(self.address, arith.constant(i32, col_tile * 128)),
546+
cols=128,
547+
dtype=self.dtype,
548+
)
549+
)
550+
registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape)
551+
else:
552+
raise NotImplementedError(
553+
f"Loads only implemented for refs with standard layout, got: {self.layout}"
554+
)
527555
return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None)
528556

557+
def _load_32xcols(base_addr, cols, dtype):
558+
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
559+
i32 = ir.IntegerType.get_signless(32)
560+
assert cols % 8 == 0
561+
cols_per_num_tile = 8
562+
load_shape = "16x256b"
563+
num = cols // 8
564+
if num <= 32:
565+
num_tiling = num
566+
elif num == 64:
567+
num_tiling = 32
568+
else:
569+
raise NotImplementedError(num)
570+
vector_regs = np.ndarray((4, num), dtype=object)
571+
# We load 16 lanes at a time, but need 32 in total.
572+
for row_group in range(2):
573+
addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16))
574+
regs = []
575+
for num_group in range(num // num_tiling):
576+
addr_row_col = arith.addi(
577+
addr_row,
578+
arith.constant(i32, num_tiling * num_group * cols_per_num_tile),
579+
)
580+
regs += tmem_load(addr_row_col, load_shape, num_tiling)
581+
regs = [llvm.bitcast(dtype, r) for r in regs]
582+
undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
583+
for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)):
584+
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
585+
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
586+
vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg
587+
return vector_regs
588+
529589

530590
def _m128_256bit_32bit_layout(shape: tuple[int, ...]):
531591
if len(shape) != 2:

‎jax/experimental/mosaic/gpu/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -1201,3 +1201,7 @@ def bitcast(x: ir.Value, new_type: ir.Type):
12011201
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
12021202
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
12031203
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
1204+
1205+
1206+
def ceil_div(x: int, y: int):
1207+
return (x + y - 1) // y

‎tests/mosaic/gpu_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@ def quantize(x):
10261026
in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32
10271027
out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation
10281028
m=(256,), # TODO(apaszke): 64, 192, 256
1029-
n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2
1029+
n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
10301030
k_steps=(1, 2),
10311031
swizzle=(32, 64, 128,),
10321032
)

0 commit comments

Comments
 (0)
Please sign in to comment.