@@ -2755,6 +2755,7 @@ class TestCaseInput:
2755
2755
transforms_b : tuple [Tile | Transpose | Swizzle , ...] = ()
2756
2756
transpose_a : bool = False
2757
2757
transpose_b : bool = False
2758
+ load_a_in_registers : bool = False
2758
2759
2759
2760
result = []
2760
2761
for swizzle in [
@@ -2786,6 +2787,13 @@ class TestCaseInput:
2786
2787
transforms_a = [Tile ([64 , k ]), Swizzle (swizzle )],
2787
2788
transforms_b = [Tile ([k , k ]), Swizzle (swizzle )],
2788
2789
),
2790
+ TestCaseInput (
2791
+ shape_a = [groups_m * 64 , groups_k * k ],
2792
+ shape_b = [groups_k * k , groups_n * k ],
2793
+ shape_res = [groups_m * 64 , groups_n * k ],
2794
+ transforms_a = [Tile ([64 , k ]), Swizzle (swizzle )],
2795
+ load_a_in_registers = True ,
2796
+ ),
2789
2797
])
2790
2798
# The below only works for 128-byte swizzling. Regardless of transposing,
2791
2799
# TMA needs the size of the last dimension to be compatible with the
@@ -2849,6 +2857,14 @@ def matmul(
2849
2857
parity , _ = tma_barrier .update_parities (parities )
2850
2858
mgpu_dialect .wait (dialect_barrier , parity )
2851
2859
2860
+ # SMEM -> Registers
2861
+ a_operand = a_smem_ref
2862
+ zero_index = arith .constant (ir .IndexType .get (), 0 )
2863
+ if test_case .load_a_in_registers :
2864
+ a_vector_type = ir .VectorType .get (test_case .shape_a , ab_elt_type )
2865
+ zero_vector_indices = [zero_index ] * len (test_case .shape_a )
2866
+ a_operand = vector .load (a_vector_type , a_smem_ref , zero_vector_indices )
2867
+
2852
2868
# Computation
2853
2869
shape_result = ir .MemRefType (result_gmem_ref .type ).shape
2854
2870
result_elt_type = ir .MemRefType (result_gmem_ref .type ).element_type
@@ -2860,7 +2876,7 @@ def matmul(
2860
2876
)
2861
2877
result = mgpu_dialect .wgmma (
2862
2878
accumulator ,
2863
- a_smem_ref ,
2879
+ a_operand ,
2864
2880
b_smem_ref ,
2865
2881
transpose_a = test_case .transpose_a ,
2866
2882
transpose_b = test_case .transpose_b ,
@@ -2870,8 +2886,7 @@ def matmul(
2870
2886
nvvm .wgmma_wait_group_sync_aligned (0 )
2871
2887
2872
2888
# Registers -> SMEM
2873
- zero_index = arith .constant (ir .IndexType .get (), 0 )
2874
- vector .store (result , result_smem_ref , [zero_index , zero_index ])
2889
+ vector .store (result , result_smem_ref , [zero_index ] * len (shape_result ))
2875
2890
2876
2891
# SMEM -> GMEM
2877
2892
mgpu_dialect .async_store (
0 commit comments