11from tilelang import tvm as tvm
22from tvm import tir
33from tilelang .utils .target import (
4- target_is_cuda ,
5- target_is_hip ,
6- )
7- from tilelang import _ffi_api
4+ target_is_cuda ,)
85from tilelang .intrinsics .mma_macro_generator import (
96 TensorCoreIntrinEmitter ,)
107from tilelang .layout import make_swizzled_layout
1411from tvm .runtime import Scriptable
1512import tvm .ffi
1613from tilelang .ir import GemmWarpPolicy
14+ from tilelang .transform .simplify import _Simplify
1715
1816
1917@tvm .ffi .register_func ("tl.gemm_py.infer_layout" )
2018def gemm_py_infer_layout (gemm_py , target , thread_bounds ):
2119 thread_nums = thread_bounds .extent
2220 return gemm_py .infer_layout (target , thread_nums )
2321
22+
2423@tvm .ffi .register_func ("tl.gemm_py.lower" )
2524def gemm_py_lower (gemm_py , target , thread_bounds , thread_var ):
2625 thread_nums = thread_bounds .extent
2726 stmt = gemm_py .lower (target , thread_nums , thread_var )
2827 return stmt
2928
29+
3030@tvm .ffi .register_object ("tl.GemmPy" )
3131class GemmPy (Node , Scriptable ):
3232 A : tir .Buffer
3333 B : tir .Buffer
3434 C : tir .Buffer
35-
35+
3636 APtr : tir .PrimExpr
3737 BPtr : tir .PrimExpr
3838 CPtr : tir .PrimExpr
@@ -52,23 +52,23 @@ class GemmPy(Node, Scriptable):
5252 k_pack : int
5353 wg_wait : int
5454 policy : GemmWarpPolicy
55-
5655
5756 def infer_layout (self , target : Target , thread_nums : int ):
5857 if target_is_cuda (target ):
5958 # TODO(lei): Support more cuda architectures, now mma only
6059 # Now only implement ssr layout
61- m_warp , n_warp = self .policy .compute_warp_partition (self .M , self .N , thread_nums , target , False )
62- warp_row_tiles = m_warp * 16
63- warp_col_tiles = n_warp * 16
60+ m_warp , n_warp = self .policy .compute_warp_partition (self .M , self .N , thread_nums , target ,
61+ False )
62+ warp_row_tiles = int (self .M // m_warp )
63+ warp_col_tiles = int (self .N // n_warp )
6464 mma_emitter = TensorCoreIntrinEmitter (
6565 a_dtype = self .in_dtype ,
6666 b_dtype = self .in_dtype ,
6767 accum_dtype = self .accum_dtype ,
6868 a_transposed = self .trans_A ,
6969 b_transposed = self .trans_B ,
70- block_row_warps = self . M ,
71- block_col_warps = self . N ,
70+ block_row_warps = m_warp ,
71+ block_col_warps = n_warp ,
7272 warp_row_tiles = warp_row_tiles ,
7373 warp_col_tiles = warp_col_tiles ,
7474 chunk = self .chunk ,
@@ -81,23 +81,23 @@ def infer_layout(self, target: Target, thread_nums: int):
8181 return layout_map
8282 else :
8383 raise ValueError (f"Unsupported target: { target } " )
84-
8584
8685 def lower (self , target : Target , thread_nums : int , thread_var : tir .Var ):
8786 if target_is_cuda (target ):
8887 # TODO(lei): Support more cuda architectures, now mma only
8988 # Now only implement ssr layout
90- m_warp , n_warp = self .policy .compute_warp_partition (self .M , self .N , thread_nums , target , False )
91- warp_row_tiles = m_warp * 16
92- warp_col_tiles = n_warp * 16
89+ m_warp , n_warp = self .policy .compute_warp_partition (self .M , self .N , thread_nums , target ,
90+ False )
91+ warp_row_tiles = int (self .M // m_warp )
92+ warp_col_tiles = int (self .N // n_warp )
9393 mma_emitter = TensorCoreIntrinEmitter (
9494 a_dtype = self .in_dtype ,
9595 b_dtype = self .in_dtype ,
9696 accum_dtype = self .accum_dtype ,
9797 a_transposed = self .trans_A ,
9898 b_transposed = self .trans_B ,
99- block_row_warps = self . M ,
100- block_col_warps = self . N ,
99+ block_row_warps = m_warp ,
100+ block_col_warps = n_warp ,
101101 warp_row_tiles = warp_row_tiles ,
102102 warp_col_tiles = warp_col_tiles ,
103103 chunk = self .chunk ,
@@ -125,7 +125,6 @@ def _gemm_ssr() -> None:
125125 A_local = T .alloc_local ((warp_rows * local_size_a ), in_dtype )
126126 B_local = T .alloc_local ((warp_cols * local_size_b ), in_dtype )
127127
128-
129128 for ki in T .serial (0 , (block_K // micro_size_k )):
130129 # Load A into fragment
131130 mma_emitter .ldmatrix_a (
@@ -143,10 +142,12 @@ def _gemm_ssr() -> None:
143142
144143 # Perform Matrix Multiplication
145144 mma_emitter .mma (A_local , B_local , C_local )
146- return _gemm_ssr .body
145+
146+ # Simplify to optimize the index computing
147+ # Must inline let statements to simplify the analysis
148+ return _Simplify (_gemm_ssr , inline_let = True ).body
147149 else :
148150 raise ValueError (f"Unsupported target: { target } " )
149-
150151
151152 @property
152153 def in_dtype (self ) -> str :
@@ -156,7 +157,7 @@ def in_dtype(self) -> str:
156157 @property
157158 def accum_dtype (self ) -> str :
158159 return self .C .dtype
159-
160+
160161 @property
161162 def chunk (self ) -> int :
162- return self .A .shape [- 2 ] if self .trans_A else self .A .shape [- 1 ]
163+ return self .A .shape [- 2 ] if self .trans_A else self .A .shape [- 1 ]
0 commit comments