Skip to content

Commit 0a1d513

Browse files
authored
[Layout] fix plot layout (tile-ai#890)
1 parent d020aaa commit 0a1d513

File tree

2 files changed

+39
-27
lines changed

2 files changed

+39
-27
lines changed

examples/plot_layout/fragment_mma_load_a.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,66 @@ def make_mma_load_base_layout(dtype: str = "float16",
1717
----------
1818
dtype : str
1919
The data type of the matrix.
20-
local_buf : tir.Buffer
21-
The local buffer representing a fragment of a matrix.
20+
matrix : Literal["A", "B"]
21+
The mma operand to be loaded.
22+
transposed : bool
23+
Whether the matrix is transposed, by default False.
2224
2325
Returns
2426
-------
2527
T.Fragment
26-
A fragment object that describes how threads and indices
27-
in `local_buf` are laid out.
28+
Describes how threads and indices in fragment are laid out.
2829
29-
Raises
30-
------
31-
AssertionError
32-
If `local_buf` is not detected to be a fragment buffer.
3330
"""
3431
from tilelang.intrinsics.mma_layout import (
35-
shared_16x16_to_mma_32x8_layout_sr,
36-
shared_16x16_to_mma_32x8_layout_rs,
37-
shared_16x32_to_mma_32x16_layout,
38-
shared_32x16_to_mma_32x16_layout,
32+
shared_16x8_to_mma_32x4_layout_sr_a,
33+
shared_16x16_to_mma_32x8_layout_sr_a,
34+
shared_16x32_to_mma_32x16_layout_sr_a,
35+
shared_16x8_to_mma_32x4_layout_sr_b,
36+
shared_16x16_to_mma_32x8_layout_sr_b,
37+
shared_16x32_to_mma_32x16_layout_sr_b,
3938
)
4039
assert matrix in ["A", "B"], "matrix should be either A or B"
4140
dtype_bits = DataType(dtype).bits
42-
assert transposed is False, "transposed is not supported yet"
4341
# s represents spatial axis
4442
# r represents reduction axis
4543
# sr represents the two dims are spatial + reduction
4644
# rs represents the two dims are reduction + spatial
47-
transform_func_sr: Callable = None
48-
transform_func_rs: Callable = None
49-
if dtype_bits == 16:
50-
transform_func_sr = shared_16x16_to_mma_32x8_layout_sr
51-
transform_func_rs = shared_16x16_to_mma_32x8_layout_rs
45+
transform_func_sr_a: Callable = None
46+
transform_func_sr_b: Callable = None
47+
if dtype_bits == 32:
48+
transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
49+
transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
50+
elif dtype_bits == 16:
51+
transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a
52+
transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b
5253
elif dtype_bits == 8:
53-
transform_func_sr = shared_16x32_to_mma_32x16_layout
54-
transform_func_rs = shared_32x16_to_mma_32x16_layout
54+
transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a
55+
transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b
5556
else:
5657
raise ValueError(f"Unsupported dtype {dtype}")
58+
5759
is_sr_conditions = [False]
5860
is_sr_conditions.append(matrix == "A" and not transposed)
5961
is_sr_conditions.append(matrix == "B" and transposed)
6062
is_sr_axis_order = any(is_sr_conditions)
6163

62-
transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs
63-
64-
micro_size_s, _, micro_size_r = get_mma_micro_size(dtype)
64+
micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)
65+
66+
# the layout of mma.sync is row.col.
67+
# so the b matrix expected a transposed basic layout
68+
transform_func: Callable = None
69+
if matrix == "A":
70+
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
71+
j, i)
72+
micro_size_s, micro_size_r = micro_size_x, micro_size_k
73+
elif matrix == "B":
74+
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
75+
j, i)
76+
micro_size_s, micro_size_r = micro_size_k, micro_size_y
77+
else:
78+
raise ValueError(f"Unsupported matrix {matrix}")
6579

66-
transform_func = transform_func
6780
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
6881

6982
def forward_thread(i: int, j: int) -> int:
@@ -81,7 +94,7 @@ def forward_index(i: int, j: int) -> int:
8194
return local_id
8295

8396
base_fragment = T.Fragment(
84-
[micro_size_r, micro_size_s],
97+
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
8598
forward_thread_fn=forward_thread,
8699
forward_index_fn=forward_index,
87100
)
@@ -109,4 +122,4 @@ def forward_index(i: int, j: int) -> int:
109122
# block layout 128x32
110123
block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
111124
print(block_layout)
112-
# plot_layout(block_layout, name="block_layout")
125+
plot_layout(block_layout, name="block_layout")

tilelang/intrinsics/mma_macro_generator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ def make_mma_load_layout(self,
490490
transform_func_sr_a: Callable = None
491491
transform_func_sr_b: Callable = None
492492
if dtype_bits == 32:
493-
...
494493
transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
495494
transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
496495
elif dtype_bits == 16:

0 commit comments

Comments
 (0)