@@ -83,6 +83,7 @@ def mma(
83
83
accumulate : ir .Value | bool = True ,
84
84
collective : bool = False ,
85
85
):
86
+ i32 = ir .IntegerType .get_signless (32 )
86
87
i64 = ir .IntegerType .get_signless (64 )
87
88
if isinstance (accumulate , bool ):
88
89
accumulate = arith .constant (ir .IntegerType .get_signless (1 ), accumulate )
@@ -112,6 +113,10 @@ def mma(
112
113
raise ValueError (
113
114
f"Accumulator shape mismatch: expected { (m , n * num_cta )} , got { d .shape } "
114
115
)
116
+ if d .layout != (expected_layout := _infer_tmem_layout (d .shape , collective )):
117
+ raise ValueError (
118
+ f"Accumulator layout mismatch: expected { expected_layout } , got { d .layout } "
119
+ )
115
120
f32 = ir .F32Type .get ()
116
121
if element_type == f32 or element_type == ir .BF16Type .get ():
117
122
if d .dtype != f32 :
@@ -136,11 +141,7 @@ def mma(
136
141
raise ValueError (f"N must be a multiple of 8, got: { n } " )
137
142
elif n > 256 and n != 512 :
138
143
raise ValueError ("Only N below 256 or N=512 are supported" )
139
- if num_cta == 2 and n > 256 :
140
- raise NotImplementedError (
141
- "N is too big for collective MMA. Only up to 256 is supported."
142
- )
143
- n_group_elems = min (n , 256 )
144
+ n_group_elems = min (n , 256 // num_cta )
144
145
if m % m_group_elems :
145
146
raise ValueError (f"M must be a multiple of { m_group_elems } , got: { m } " )
146
147
if k % k_group_elems :
@@ -179,6 +180,7 @@ def mma(
179
180
180
181
# Step 4. Issue the instructions.
181
182
true = arith .constant (ir .IntegerType .get_signless (1 ), 1 )
183
+ n_collective_group_elems = n_group_elems * num_cta
182
184
for mi , ni , ki in np .ndindex (m_groups , n_groups , k_groups ):
183
185
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
184
186
a_mk = arith .addi (a_desc_base , utils .c (mma_utils .encode_addr (a_offset ), i64 ))
@@ -188,9 +190,9 @@ def mma(
188
190
raise NotImplementedError ("D needs to be sliced" )
189
191
acc = accumulate if ki == 0 else true
190
192
_do_mma (
191
- d . slice (
192
- slice ( None ), utils . ds ( ni * n_group_elems , n_group_elems )
193
- ). address ,
193
+ arith . addi (
194
+ d . address , arith . constant ( i32 , ni * n_collective_group_elems )
195
+ ),
194
196
a_mk ,
195
197
b_nk ,
196
198
d_type = ir .F32Type .get (),
@@ -377,8 +379,15 @@ class TMEMLayout:
377
379
+------------------+------------------+
378
380
| [0:64, 64:128] | [64:128, 64:128] |
379
381
+------------------+------------------+
382
+
383
+ The above is further complicated by column_tile_stride, which is used to
384
+ swizzle the ordering of column tiles. That is, if column_tile_stride is 2,
385
+ we will first lay out all tiles that have the column index 0, 2, 4, and so on
386
+ until we run out of tiles. Only then we lay out the tiles with column index
387
+ 1, 3, etc.
380
388
"""
381
389
elements_in_tile : tuple [int , int ]
390
+ column_tile_stride : int = 1
382
391
383
392
def __post_init__ (self ):
384
393
row_tiling = self .elements_in_tile [0 ]
@@ -405,7 +414,7 @@ def cols_in_shape(self, shape: tuple[int, int]):
405
414
return num_tiles // tiles_in_row * cols_in_tile
406
415
407
416
408
- def _infer_tmem_layout (shape : tuple [int , int ]) -> TMEMLayout :
417
+ def _infer_tmem_layout (shape : tuple [int , int ], collective : bool ) -> TMEMLayout :
409
418
if shape [0 ] > TMEM_ROWS :
410
419
raise ValueError (
411
420
"Can only infer TMEM layout for shapes with at most 128 rows, got:"
@@ -421,7 +430,15 @@ def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
421
430
"Can only infer TMEM layout for shapes with row count that's a power of"
422
431
f" 2, got: { shape [0 ]} "
423
432
)
424
- return TMEMLayout (elements_in_tile = (shape [0 ], 1 ))
433
+ if shape [1 ] % 8 :
434
+ raise ValueError (
435
+ "Can only infer TMEM layout for shapes with column count that's a"
436
+ f" multiple of 8, got: { shape [1 ]} "
437
+ )
438
+ if collective and shape [1 ] == 512 :
439
+ return TMEMLayout (elements_in_tile = (shape [0 ], 128 ), column_tile_stride = 2 )
440
+ else :
441
+ return TMEMLayout (elements_in_tile = (shape [0 ], 8 ))
425
442
426
443
427
444
@dataclasses .dataclass (frozen = True )
@@ -432,7 +449,14 @@ class TMEMRef:
432
449
layout : TMEMLayout
433
450
434
451
@classmethod
435
- def from_alloc (cls , tmem_addr_ref : ir .Value , shape : tuple [int , int ], dtype , layout : TMEMLayout | None = None ):
452
+ def from_alloc (
453
+ cls ,
454
+ tmem_addr_ref : ir .Value ,
455
+ shape : tuple [int , int ],
456
+ dtype ,
457
+ collective : bool | None = None ,
458
+ layout : TMEMLayout | None = None ,
459
+ ):
436
460
i32 = ir .IntegerType .get_signless (32 )
437
461
if not ir .MemRefType .isinstance (tmem_addr_ref .type ):
438
462
raise ValueError (f"tmem_addr_ref must be a memref or a pointer, got: { tmem_addr_ref .type } " )
@@ -449,7 +473,11 @@ def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layo
449
473
if shape [0 ] < 32 :
450
474
raise ValueError (f"TMEM refs must have at least 32 rows, got: { shape [0 ]} " )
451
475
if layout is None :
452
- layout = _infer_tmem_layout (shape )
476
+ if collective is None :
477
+ raise ValueError (
478
+ "collective argument must be provided when TMEM layout is inferred"
479
+ )
480
+ layout = _infer_tmem_layout (shape , collective )
453
481
else :
454
482
layout .check_shape (shape )
455
483
# TODO: Do we have to do this??
@@ -461,12 +489,17 @@ def slice(self, *idxs):
461
489
base_idx , slice_shape , is_squeezed = utils .parse_indices (idxs , self .shape )
462
490
if any (is_squeezed ):
463
491
raise ValueError ("TMEM can only be sliced, not indexed" )
464
- if self .layout . elements_in_tile [ 0 ] != TMEM_ROWS :
492
+ if self .layout != TMEMLayout ( elements_in_tile = ( TMEM_ROWS , 8 )) :
465
493
raise NotImplementedError (
466
- f"Slicing only implemented for refs with tiling of { TMEM_ROWS } rows"
494
+ "Slicing only implemented for refs with standard layout, got:"
495
+ f" { self .layout } "
467
496
)
468
497
if base_idx [0 ] != 0 or slice_shape [0 ] != TMEM_ROWS :
469
498
raise NotImplementedError ("TMEM cannot be sliced along rows" )
499
+ if slice_shape [1 ] % 8 :
500
+ raise NotImplementedError (
501
+ "TMEM column slice length must be a multiple of 8"
502
+ )
470
503
col_idx = base_idx [1 ]
471
504
if not isinstance (col_idx , ir .Value ):
472
505
col_idx = arith .constant (ir .IntegerType .get_signless (32 ), col_idx )
@@ -484,48 +517,75 @@ def __getitem__(self, *idxs):
484
517
raise ValueError ("TMEM loads only support slicing" )
485
518
if any (idx != 0 for idx in base_idxs ) or tuple (slice_shape ) != self .shape :
486
519
raise NotImplementedError ("Slicing of TMEM not impelmented yet" )
487
- if self .layout .elements_in_tile [0 ] != TMEM_ROWS :
488
- raise NotImplementedError (
489
- f"Loads only implemented for refs with tiling of { TMEM_ROWS } rows"
490
- )
491
520
if self .shape [1 ] % 8 :
492
521
raise NotImplementedError
493
522
if self .dtype != ir .F32Type .get ():
494
523
raise NotImplementedError (self .dtype )
495
524
layout = _m128_256bit_32bit_layout (self .shape )
496
525
regs_shape = layout .registers_shape (self .shape )
497
- num = self .shape [1 ] // 8
498
- # TODO(apaszke): Make the tiling configurable through the args too.
499
- if num <= 32 :
500
- num_tiling = num
501
- elif num == 64 :
502
- num_tiling = 32
503
- else :
504
- raise NotImplementedError (num )
505
- registers = np .empty (regs_shape , dtype = object )
506
- # We load 16 lanes at a time, but need 32 in total.
507
- for row_group in range (2 ):
508
- addr_row = arith .addi (self .address , arith .constant (i32 , (row_group * 16 ) << 16 ))
509
- regs = []
510
- cols_per_num_tile = 8 # This depends on the 16x256b below.
511
- for num_group in range (num // num_tiling ):
512
- addr_row_col = arith .addi (
513
- addr_row ,
514
- arith .constant (i32 , num_tiling * num_group * cols_per_num_tile ),
526
+ if self .layout == TMEMLayout (elements_in_tile = (TMEM_ROWS , 8 )):
527
+ # load_32xcols returns a 4xN array, but the FA tiling we use here tiles
528
+ # columns before rows, and so it is Nx4 (after ignoring all 1 dims).
529
+ registers = _load_32xcols (
530
+ self .address , self .shape [1 ], self .dtype
531
+ ).T .reshape (regs_shape )
532
+ elif self .layout == TMEMLayout (elements_in_tile = (TMEM_ROWS , 128 ), column_tile_stride = 2 ):
533
+ if self .shape [1 ] % 128 != 0 :
534
+ raise ValueError (
535
+ f"TMEM layout { self .layout } is not compatible with shape { self .shape } "
515
536
)
516
- regs += tmem_load (addr_row_col , "16x256b" , num_tiling )
517
- regs = [llvm .bitcast (self .dtype , r ) for r in regs ]
518
- vector_regs = []
519
- undef = llvm .mlir_undef (ir .VectorType .get ((2 ,), self .dtype ))
520
- for r_low , r_high in zip (regs [::2 ], regs [1 ::2 ]):
521
- high_undef = llvm .insertelement (undef , r_low , utils .c (0 , i32 ))
522
- vreg = llvm .insertelement (high_undef , r_high , utils .c (1 , i32 ))
523
- vector_regs .append (vreg )
524
- # Dimension 4 is the one where we split 32 rows into tiles of 8.
525
- regs_slice = (slice (None ),) * 4 + (slice (row_group * 2 , (row_group + 1 ) * 2 ),)
526
- registers [regs_slice ] = np .asarray (vector_regs , dtype = object ).reshape (registers [regs_slice ].shape )
537
+ num_column_tiles = self .shape [1 ] // 128
538
+ column_tile_stride = self .layout .column_tile_stride
539
+ num_strided_col_groups = utils .ceil_div (num_column_tiles , column_tile_stride )
540
+ tiles = []
541
+ for col_tile_base in range (num_strided_col_groups ):
542
+ for col_tile in range (col_tile_base , num_column_tiles , column_tile_stride ):
543
+ tiles .append (
544
+ _load_32xcols (
545
+ arith .addi (self .address , arith .constant (i32 , col_tile * 128 )),
546
+ cols = 128 ,
547
+ dtype = self .dtype ,
548
+ )
549
+ )
550
+ registers = np .concatenate (tiles , axis = 1 ).T .reshape (regs_shape )
551
+ else :
552
+ raise NotImplementedError (
553
+ f"Loads only implemented for refs with standard layout, got: { self .layout } "
554
+ )
527
555
return fa .FragmentedArray (_registers = registers , _layout = layout , _is_signed = None )
528
556
557
+ def _load_32xcols (base_addr , cols , dtype ):
558
+ # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
559
+ i32 = ir .IntegerType .get_signless (32 )
560
+ assert cols % 8 == 0
561
+ cols_per_num_tile = 8
562
+ load_shape = "16x256b"
563
+ num = cols // 8
564
+ if num <= 32 :
565
+ num_tiling = num
566
+ elif num == 64 :
567
+ num_tiling = 32
568
+ else :
569
+ raise NotImplementedError (num )
570
+ vector_regs = np .ndarray ((4 , num ), dtype = object )
571
+ # We load 16 lanes at a time, but need 32 in total.
572
+ for row_group in range (2 ):
573
+ addr_row = arith .addi (base_addr , arith .constant (i32 , (row_group * 16 ) << 16 ))
574
+ regs = []
575
+ for num_group in range (num // num_tiling ):
576
+ addr_row_col = arith .addi (
577
+ addr_row ,
578
+ arith .constant (i32 , num_tiling * num_group * cols_per_num_tile ),
579
+ )
580
+ regs += tmem_load (addr_row_col , load_shape , num_tiling )
581
+ regs = [llvm .bitcast (dtype , r ) for r in regs ]
582
+ undef = llvm .mlir_undef (ir .VectorType .get ((2 ,), dtype ))
583
+ for r_low , r_high , idx in zip (regs [::2 ], regs [1 ::2 ], np .ndindex (num , 2 )):
584
+ high_undef = llvm .insertelement (undef , r_low , utils .c (0 , i32 ))
585
+ vreg = llvm .insertelement (high_undef , r_high , utils .c (1 , i32 ))
586
+ vector_regs [idx [1 ] + 2 * row_group , idx [0 ]] = vreg
587
+ return vector_regs
588
+
529
589
530
590
def _m128_256bit_32bit_layout (shape : tuple [int , ...]):
531
591
if len (shape ) != 2 :
0 commit comments