Skip to content

Commit 52800a5

Browse files
committed
Refactor GEMM and frontend legalize operations for improved clarity and functionality
- Updated `gemm_py.h` to include the correct header for GEMM operations. - Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity. - Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose. - Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity. - Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite. - Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process.
1 parent 1b5dde9 commit 52800a5

File tree

10 files changed

+83
-60
lines changed

10 files changed

+83
-60
lines changed

src/op/gemm_py.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@
77
#ifndef TVM_TL_OP_GEMM_PY_H_
88
#define TVM_TL_OP_GEMM_PY_H_
99

10+
#include "gemm.h"
1011
#include "operator.h"
11-
#include "gemm_py.h"
1212

1313
namespace tvm {
1414

1515
namespace tl {
1616

1717
using namespace tir;
1818

19-
2019
class GemmPyNode : public TileOperatorNode {
2120
public:
2221
bool CheckWGMMA() const;

src/transform/frontend_legalize.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ namespace tl {
3434

3535
using namespace tir;
3636

37-
class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
37+
class LetInliner : public arith::IRMutatorWithAnalyzer {
3838
public:
3939
static PrimFunc Substitute(PrimFunc f) {
4040
arith::Analyzer analyzer;
41-
FrontendLegalizer substituter(&analyzer);
41+
LetInliner substituter(&analyzer);
4242
PrimFuncNode *fptr = f.CopyOnWrite();
4343
fptr->body = substituter.VisitStmt(f->body);
4444
return f;
@@ -82,16 +82,16 @@ class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
8282

8383
using namespace tir::transform;
8484

85-
Pass FrontendLegalize() {
85+
Pass LetInline() {
8686
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
87-
return FrontendLegalizer::Substitute(std::move(f));
87+
return LetInliner::Substitute(std::move(f));
8888
};
89-
return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {});
89+
return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {});
9090
}
9191

9292
TVM_FFI_STATIC_INIT_BLOCK({
9393
namespace refl = tvm::ffi::reflection;
94-
refl::GlobalDef().def("tl.transform.FrontendLegalize", FrontendLegalize);
94+
refl::GlobalDef().def("tl.transform.LetInline", LetInline);
9595
});
9696

9797
} // namespace tl

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def main(
4444
T.copy(B[bx * block_N, k * block_K], B_shared)
4545
else:
4646
T.copy(B[k * block_K, bx * block_N], B_shared)
47-
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
47+
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
4848
T.copy(C_local, C[by * block_M, bx * block_N])
4949

5050
return main
@@ -88,6 +88,7 @@ def run_gemm(
8888
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
8989
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
9090
})
91+
print(kernel.get_kernel_source())
9192
profiler = kernel.get_profiler()
9293

9394
def ref_program(A, B):
@@ -157,7 +158,7 @@ def main(
157158
T.copy(B[bx * block_N, k * block_K], B_shared)
158159
else:
159160
T.copy(B[k * block_K, bx * block_N], B_shared)
160-
T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
161+
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
161162
T.copy(C_local, C[by * block_M, bx * block_N])
162163

163164
return main
@@ -224,4 +225,9 @@ def test_gemm_rs():
224225

225226

226227
if __name__ == "__main__":
227-
tilelang.testing.main()
228+
# tilelang.testing.main()
229+
tilelang.disable_cache()
230+
# run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
231+
# print("gemm fp16 nt ss done")
232+
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
233+
print("gemm fp16 nn ss done")

testing/python/transform/test_tilelang_transform_frontend_legalize.py renamed to testing/python/transform/test_tilelang_transform_let_inline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
def _check(original, transformed):
88
func = original
99
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
10-
mod = tl.transform.FrontendLegalize()(mod)
10+
mod = tl.transform.LetInline()(mod)
1111
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
1212
True)
1313

tilelang/engine/phase.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
8585
"""
8686
mod = tir.transform.BindTarget(target)(mod)
8787

88-
# Legalize the frontend IR to make it compatible with TVM
89-
mod = tilelang.transform.FrontendLegalize()(mod)
88+
# Inline let expressions and statements
89+
mod = tilelang.transform.LetInline()(mod)
9090
# Inject assumes to speedup tvm prover
9191
mod = tilelang.transform.InjectAssumes()(mod)
9292
# Simplify the IR expressions

tilelang/intrinsics/mma_macro_generator.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,12 @@ def _warp_ldmatrix_a(
188188
):
189189
stride = A_shared_buf.shape[-1]
190190
tx, _, warp_m = self.extract_thread_binding(thread_binding)
191+
trans = self.a_transposed
192+
191193
for i in T.serial(warp_rows):
192194
T.ptx_ldmatrix(
193195
a_dtype,
194-
T.bool(False),
196+
T.bool(trans),
195197
4,
196198
".b16",
197199
A_local_buf.data,
@@ -230,22 +232,25 @@ def _warp_ldmatrix_b(
230232
):
231233
stride = B_shared_buf.shape[-1]
232234
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
235+
trans = not self.b_transposed
233236

234237
for j in T.serial(warp_cols):
235238
# Assign B_shared_elem
236-
ri, rj = (
239+
wi, wk = (
237240
warp_n * warp_col_tiles + j * micro_size_y,
238241
rk * chunk + ki * micro_size_k,
239242
)
243+
B_shared_buf_elem = B_shared_buf[wi, wk] if self.b_transposed else B_shared_buf[wk,
244+
wi]
240245

241246
T.ptx_ldmatrix(
242247
b_dtype,
243-
T.bool(False), # TODO(lei): should be optimized
248+
T.bool(trans),
244249
4,
245250
".b16",
246251
B_local_buf.data,
247252
j * local_size_b,
248-
T.address_of(B_shared_buf[ri, rj]),
253+
T.address_of(B_shared_buf_elem),
249254
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
250255
)
251256

@@ -289,7 +294,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
289294
b_local_stride + j * local_size_b,
290295
C_local_buf.data,
291296
i * warp_cols * local_size_out + j * local_size_out,
292-
T.bool(False),
297+
T.bool(False), # saturate
293298
)
294299

295300
T.ptx_mma(
@@ -306,7 +311,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
306311
b_local_stride + j * local_size_b + lift(local_size_b) // 2,
307312
C_local_buf.data,
308313
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
309-
T.bool(False),
314+
T.bool(False), # saturate
310315
)
311316

312317
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)

tilelang/tileop/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .gemm import GemmPy
1+
from .gemm import GemmPy # noqa: F401

tilelang/tileop/gemm/__init__.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from tilelang import tvm as tvm
22
from tvm import tir
33
from tilelang.utils.target import (
4-
target_is_cuda,
5-
target_is_hip,
6-
)
7-
from tilelang import _ffi_api
4+
target_is_cuda,)
85
from tilelang.intrinsics.mma_macro_generator import (
96
TensorCoreIntrinEmitter,)
107
from tilelang.layout import make_swizzled_layout
@@ -14,25 +11,28 @@
1411
from tvm.runtime import Scriptable
1512
import tvm.ffi
1613
from tilelang.ir import GemmWarpPolicy
14+
from tilelang.transform.simplify import _Simplify
1715

1816

1917
@tvm.ffi.register_func("tl.gemm_py.infer_layout")
2018
def gemm_py_infer_layout(gemm_py, target, thread_bounds):
2119
thread_nums = thread_bounds.extent
2220
return gemm_py.infer_layout(target, thread_nums)
2321

22+
2423
@tvm.ffi.register_func("tl.gemm_py.lower")
2524
def gemm_py_lower(gemm_py, target, thread_bounds, thread_var):
2625
thread_nums = thread_bounds.extent
2726
stmt = gemm_py.lower(target, thread_nums, thread_var)
2827
return stmt
2928

29+
3030
@tvm.ffi.register_object("tl.GemmPy")
3131
class GemmPy(Node, Scriptable):
3232
A: tir.Buffer
3333
B: tir.Buffer
3434
C: tir.Buffer
35-
35+
3636
APtr: tir.PrimExpr
3737
BPtr: tir.PrimExpr
3838
CPtr: tir.PrimExpr
@@ -52,23 +52,23 @@ class GemmPy(Node, Scriptable):
5252
k_pack: int
5353
wg_wait: int
5454
policy: GemmWarpPolicy
55-
5655

5756
def infer_layout(self, target: Target, thread_nums: int):
5857
if target_is_cuda(target):
5958
# TODO(lei): Support more cuda architectures, now mma only
6059
# Now only implement ssr layout
61-
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
62-
warp_row_tiles = m_warp * 16
63-
warp_col_tiles = n_warp * 16
60+
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
61+
False)
62+
warp_row_tiles = int(self.M // m_warp)
63+
warp_col_tiles = int(self.N // n_warp)
6464
mma_emitter = TensorCoreIntrinEmitter(
6565
a_dtype=self.in_dtype,
6666
b_dtype=self.in_dtype,
6767
accum_dtype=self.accum_dtype,
6868
a_transposed=self.trans_A,
6969
b_transposed=self.trans_B,
70-
block_row_warps=self.M,
71-
block_col_warps=self.N,
70+
block_row_warps=m_warp,
71+
block_col_warps=n_warp,
7272
warp_row_tiles=warp_row_tiles,
7373
warp_col_tiles=warp_col_tiles,
7474
chunk=self.chunk,
@@ -81,23 +81,23 @@ def infer_layout(self, target: Target, thread_nums: int):
8181
return layout_map
8282
else:
8383
raise ValueError(f"Unsupported target: {target}")
84-
8584

8685
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
8786
if target_is_cuda(target):
8887
# TODO(lei): Support more cuda architectures, now mma only
8988
# Now only implement ssr layout
90-
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
91-
warp_row_tiles = m_warp * 16
92-
warp_col_tiles = n_warp * 16
89+
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
90+
False)
91+
warp_row_tiles = int(self.M // m_warp)
92+
warp_col_tiles = int(self.N // n_warp)
9393
mma_emitter = TensorCoreIntrinEmitter(
9494
a_dtype=self.in_dtype,
9595
b_dtype=self.in_dtype,
9696
accum_dtype=self.accum_dtype,
9797
a_transposed=self.trans_A,
9898
b_transposed=self.trans_B,
99-
block_row_warps=self.M,
100-
block_col_warps=self.N,
99+
block_row_warps=m_warp,
100+
block_col_warps=n_warp,
101101
warp_row_tiles=warp_row_tiles,
102102
warp_col_tiles=warp_col_tiles,
103103
chunk=self.chunk,
@@ -125,7 +125,6 @@ def _gemm_ssr() -> None:
125125
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
126126
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
127127

128-
129128
for ki in T.serial(0, (block_K // micro_size_k)):
130129
# Load A into fragment
131130
mma_emitter.ldmatrix_a(
@@ -143,10 +142,12 @@ def _gemm_ssr() -> None:
143142

144143
# Perform Matrix Multiplication
145144
mma_emitter.mma(A_local, B_local, C_local)
146-
return _gemm_ssr.body
145+
146+
# Simplify to optimize the index computing
147+
# Must inline let statements to simplify the analysis
148+
return _Simplify(_gemm_ssr, inline_let=True).body
147149
else:
148150
raise ValueError(f"Unsupported target: {target}")
149-
150151

151152
@property
152153
def in_dtype(self) -> str:
@@ -156,7 +157,7 @@ def in_dtype(self) -> str:
156157
@property
157158
def accum_dtype(self) -> str:
158159
return self.C.dtype
159-
160+
160161
@property
161162
def chunk(self) -> int:
162-
return self.A.shape[-2] if self.trans_A else self.A.shape[-1]
163+
return self.A.shape[-2] if self.trans_A else self.A.shape[-1]

tilelang/transform/__init__.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# pylint: disable=invalid-name, unsupported-binary-operation
33

44
from . import _ffi_api
5-
from .simplify import Simplify, simplify_prim_func # noqa: F401
5+
from .simplify import Simplify, simplify_prim_func, LetInline # noqa: F401
66
from .pass_config import PassConfigKey # noqa: F401
77
from tilelang import tvm as tvm # noqa: F401
88
from tvm.ir.transform import PassContext # noqa: F401
@@ -68,17 +68,6 @@ def InjectSoftwarePipeline():
6868
return _ffi_api.InjectSoftwarePipeline() # type: ignore
6969

7070

71-
def FrontendLegalize():
72-
"""FrontendLegalize
73-
74-
Returns
75-
-------
76-
fpass : tvm.transform.Pass
77-
The result pass
78-
"""
79-
return _ffi_api.FrontendLegalize() # type: ignore
80-
81-
8271
def InjectAssumes():
8372
"""Inject Assumes
8473

tilelang/transform/simplify.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55
from . import _ffi_api
66

77

8+
def LetInline():
9+
"""LetInline
10+
11+
Returns
12+
-------
13+
fpass : tvm.transform.Pass
14+
The result pass
15+
"""
16+
return _ffi_api.LetInline() # type: ignore
17+
18+
819
def Simplify(simplify_arguments: bool = False):
920
"""Simplify
1021
@@ -16,13 +27,24 @@ def Simplify(simplify_arguments: bool = False):
1627
return _ffi_api.Simplify(simplify_arguments) # type: ignore
1728

1829

19-
def _Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
30+
def _Simplify(stmt: Union[PrimFunc, IRModule],
31+
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
2032
if isinstance(stmt, PrimFunc):
21-
mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
33+
if inline_let:
34+
mod = LetInline()(IRModule.from_expr(stmt))
35+
mod = Simplify(simplify_arguments=True)(mod)
36+
else:
37+
mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
2238
assert len(mod.functions) == 1, "Simplify should return a single function"
2339
return list(mod.functions.values()).pop()
2440
elif isinstance(stmt, IRModule):
25-
return Simplify(simplify_arguments=True)(stmt)
41+
if inline_let:
42+
mod = LetInline()(stmt)
43+
mod = Simplify(simplify_arguments=True)(mod)
44+
else:
45+
mod = Simplify(simplify_arguments=True)(stmt)
46+
assert len(mod.functions) == 1, "Simplify should return a single function"
47+
return list(mod.functions.values()).pop()
2648
else:
2749
raise ValueError(f"Unsupported type: {type(stmt)}")
2850

@@ -37,6 +59,7 @@ def wrapper(*args, **kwargs):
3759
return wrapper
3860

3961

40-
def apply_simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
62+
def apply_simplify(stmt: Union[PrimFunc, IRModule],
63+
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
4164
"""Apply Simplify pass to a PrimFunc or IRModule."""
42-
return _Simplify(stmt)
65+
return _Simplify(stmt, inline_let)

0 commit comments

Comments
 (0)