@@ -17,53 +17,66 @@ def make_mma_load_base_layout(dtype: str = "float16",
1717 ----------
1818 dtype : str
1919 The data type of the matrix.
20- local_buf : tir.Buffer
21- The local buffer representing a fragment of a matrix.
20+ matrix : Literal["A", "B"]
21+ The mma operand to be loaded.
22+ transposed : bool
23+ Whether the matrix is transposed, by default False.
2224
2325 Returns
2426 -------
2527 T.Fragment
26- A fragment object that describes how threads and indices
27- in `local_buf` are laid out.
28+ Describes how threads and indices in fragment are laid out.
2829
29- Raises
30- ------
31- AssertionError
32- If `local_buf` is not detected to be a fragment buffer.
3330 """
3431 from tilelang .intrinsics .mma_layout import (
35- shared_16x16_to_mma_32x8_layout_sr ,
36- shared_16x16_to_mma_32x8_layout_rs ,
37- shared_16x32_to_mma_32x16_layout ,
38- shared_32x16_to_mma_32x16_layout ,
32+ shared_16x8_to_mma_32x4_layout_sr_a ,
33+ shared_16x16_to_mma_32x8_layout_sr_a ,
34+ shared_16x32_to_mma_32x16_layout_sr_a ,
35+ shared_16x8_to_mma_32x4_layout_sr_b ,
36+ shared_16x16_to_mma_32x8_layout_sr_b ,
37+ shared_16x32_to_mma_32x16_layout_sr_b ,
3938 )
4039 assert matrix in ["A" , "B" ], "matrix should be either A or B"
4140 dtype_bits = DataType (dtype ).bits
42- assert transposed is False , "transposed is not supported yet"
4341 # s represents spatial axis
4442 # r represents reduction axis
4543 # sr represents the two dims are spatial + reduction
4644 # rs represents the two dims are reduction + spatial
47- transform_func_sr : Callable = None
48- transform_func_rs : Callable = None
49- if dtype_bits == 16 :
50- transform_func_sr = shared_16x16_to_mma_32x8_layout_sr
51- transform_func_rs = shared_16x16_to_mma_32x8_layout_rs
45+ transform_func_sr_a : Callable = None
46+ transform_func_sr_b : Callable = None
47+ if dtype_bits == 32 :
48+ transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
49+ transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
50+ elif dtype_bits == 16 :
51+ transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a
52+ transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b
5253 elif dtype_bits == 8 :
53- transform_func_sr = shared_16x32_to_mma_32x16_layout
54- transform_func_rs = shared_32x16_to_mma_32x16_layout
54+ transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a
55+ transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b
5556 else :
5657 raise ValueError (f"Unsupported dtype { dtype } " )
58+
5759 is_sr_conditions = [False ]
5860 is_sr_conditions .append (matrix == "A" and not transposed )
5961 is_sr_conditions .append (matrix == "B" and transposed )
6062 is_sr_axis_order = any (is_sr_conditions )
6163
62- transform_func : Callable = transform_func_sr if is_sr_axis_order else transform_func_rs
63-
64- micro_size_s , _ , micro_size_r = get_mma_micro_size (dtype )
64+ micro_size_x , micro_size_y , micro_size_k = get_mma_micro_size (dtype )
65+
66+ # the layout of mma.sync is row.col.
67+ # so the b matrix expected a transposed basic layout
68+ transform_func : Callable = None
69+ if matrix == "A" :
70+ transform_func = transform_func_sr_a if is_sr_axis_order else lambda i , j : transform_func_sr_a (
71+ j , i )
72+ micro_size_s , micro_size_r = micro_size_x , micro_size_k
73+ elif matrix == "B" :
74+ transform_func = transform_func_sr_b if is_sr_axis_order else lambda i , j : transform_func_sr_b (
75+ j , i )
76+ micro_size_s , micro_size_r = micro_size_k , micro_size_y
77+ else :
78+ raise ValueError (f"Unsupported matrix { matrix } " )
6579
66- transform_func = transform_func
6780 inverse_mma_load_layout = IndexMap .from_func (transform_func , index_dtype = "int32" )
6881
6982 def forward_thread (i : int , j : int ) -> int :
@@ -81,7 +94,7 @@ def forward_index(i: int, j: int) -> int:
8194 return local_id
8295
8396 base_fragment = T .Fragment (
84- [micro_size_r , micro_size_s ],
97+ [micro_size_s , micro_size_r ] if is_sr_axis_order else [ micro_size_r , micro_size_s ],
8598 forward_thread_fn = forward_thread ,
8699 forward_index_fn = forward_index ,
87100 )
@@ -109,4 +122,4 @@ def forward_index(i: int, j: int) -> int:
109122# block layout 128x32
110123block_layout = warp_layout .repeat ([warp_rows , chunk ], repeat_on_thread = False , lower_dim_first = False )
111124print (block_layout )
112- # plot_layout(block_layout, name="block_layout")
125+ plot_layout (block_layout , name = "block_layout" )
0 commit comments