Skip to content

Commit e93c7c4

Browse files
authored
[Example] Add GQA Example (#118)
* Add DeepSeek MLA decode example with Flash Attention implementation * Add GEMM SplitK and StreamK example implementations This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques: - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations. * Refactor GEMM SplitK and StreamK example implementations Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts: - Remove unused import (Profiler) in splitk example - Simplify line breaks and improve code readability - Standardize indentation and remove unnecessary whitespace - Optimize atomic add and copy operations for better clarity * Add block sparse attention benchmarks for multiple libraries This commit introduces comprehensive block sparse attention benchmarks for different libraries: - TileLang block sparse FMHA implementation - Triton block sparse FMHA implementation - PyTorch reference block sparse FMHA implementation - FlashAttention dense FMHA reference implementation The benchmarks include: - Configurable benchmark parameters (batch size, heads, sequence length, etc.) - Sparse mask generation using top-k and threshold methods - Performance measurement for different sparse attention configurations - Utility functions for mask generation and benchmarking * Refactor block sparse attention benchmarks with code style improvements - Add Ruff linter ignore comments to benchmark files - Improve code formatting and line breaks - Remove unused imports - Standardize print statement formatting - Enhance code readability across multiple library benchmarks * lint fix * Add CUDA atomic operations for BFLOAT16 and update function naming - Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header - Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd) - Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values - Update kernel and language customization to use new function names - Add return type annotations in profiler module * lint fix * Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates: - Group Query Attention (GQA) implementation - Flash Attention forward pass - Performance benchmarking - Configurable parameters for batch, heads, sequence length, and dimension - Autotuning support - Reference implementation comparison * Refactor IR lowering pipeline into modular phases This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases: - `LowerAndLegalize`: Handles initial IR legalization and transformation - `OptimizeForTarget`: Applies target-specific optimizations The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability. * lintfix
1 parent 540aef4 commit e93c7c4

File tree

5 files changed

+348
-168
lines changed

5 files changed

+348
-168
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import torch
5+
import torch.nn.functional as F
6+
import tilelang
7+
from tilelang import Profiler
8+
from tilelang.autotuner import *
9+
import tilelang.language as T
10+
import itertools
11+
import argparse
12+
from functools import partial
13+
14+
15+
def get_configs():
16+
block_M = [128]
17+
block_N = [128]
18+
num_stages = [2]
19+
threads = [256]
20+
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
21+
22+
configs = [{
23+
'block_M': c[0],
24+
'block_N': c[1],
25+
'num_stages': c[2],
26+
'threads': c[3]
27+
} for c in _configs]
28+
return configs
29+
30+
31+
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
32+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
33+
head_kv = heads // groups
34+
q_shape = [batch, seq_len, heads, dim]
35+
kv_shape = [batch, seq_len, head_kv, dim]
36+
dtype = "float16"
37+
accum_dtype = "float"
38+
39+
def kernel_func(block_M, block_N, num_stages, threads):
40+
41+
@T.macro
42+
def MMA0(
43+
K: T.Buffer(kv_shape, dtype),
44+
Q_shared: T.Buffer([block_M, dim], dtype),
45+
K_shared: T.Buffer([block_N, dim], dtype),
46+
acc_s: T.Buffer([block_M, block_N], accum_dtype),
47+
k: T.int32,
48+
bx: T.int32,
49+
by: T.int32,
50+
bz: T.int32,
51+
):
52+
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
53+
if is_causal:
54+
for i, j in T.Parallel(block_M, block_N):
55+
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
56+
-T.infinity(acc_s.dtype))
57+
else:
58+
T.clear(acc_s)
59+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
60+
61+
@T.macro
62+
def MMA1(
63+
V: T.Buffer(kv_shape, dtype),
64+
V_shared: T.Buffer([block_M, dim], dtype),
65+
acc_s_cast: T.Buffer([block_M, block_N], dtype),
66+
acc_o: T.Buffer([block_M, dim], accum_dtype),
67+
k: T.int32,
68+
by: T.int32,
69+
bz: T.int32,
70+
):
71+
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
72+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
73+
74+
@T.macro
75+
def Softmax(
76+
acc_s: T.Buffer([block_M, block_N], accum_dtype),
77+
acc_s_cast: T.Buffer([block_M, block_N], dtype),
78+
scores_max: T.Buffer([block_M], accum_dtype),
79+
scores_max_prev: T.Buffer([block_M], accum_dtype),
80+
scores_scale: T.Buffer([block_M], accum_dtype),
81+
scores_sum: T.Buffer([block_M], accum_dtype),
82+
logsum: T.Buffer([block_M], accum_dtype),
83+
):
84+
T.copy(scores_max, scores_max_prev)
85+
T.fill(scores_max, -T.infinity(accum_dtype))
86+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
87+
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
88+
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
89+
# in the first ceil_div(kBlockM, kBlockN) steps.
90+
# for i in T.Parallel(block_M):
91+
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
92+
for i in T.Parallel(block_M):
93+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
94+
for i, j in T.Parallel(block_M, block_N):
95+
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
96+
# max * log_2(e)) This allows the compiler to use the ffma
97+
# instruction instead of fadd and fmul separately.
98+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
99+
T.reduce_sum(acc_s, scores_sum, dim=1)
100+
for i in T.Parallel(block_M):
101+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
102+
T.copy(acc_s, acc_s_cast)
103+
104+
@T.macro
105+
def Rescale(
106+
acc_o: T.Buffer([block_M, dim], accum_dtype),
107+
scores_scale: T.Buffer([block_M], accum_dtype),
108+
):
109+
for i, j in T.Parallel(block_M, dim):
110+
acc_o[i, j] *= scores_scale[i]
111+
112+
@T.prim_func
113+
def main(
114+
Q: T.Buffer(q_shape, dtype),
115+
K: T.Buffer(kv_shape, dtype),
116+
V: T.Buffer(kv_shape, dtype),
117+
Output: T.Buffer(q_shape, dtype),
118+
):
119+
with T.Kernel(
120+
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
121+
Q_shared = T.alloc_shared([block_M, dim], dtype)
122+
K_shared = T.alloc_shared([block_N, dim], dtype)
123+
V_shared = T.alloc_shared([block_N, dim], dtype)
124+
O_shared = T.alloc_shared([block_M, dim], dtype)
125+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
126+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
127+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
128+
scores_max = T.alloc_fragment([block_M], accum_dtype)
129+
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
130+
scores_scale = T.alloc_fragment([block_M], accum_dtype)
131+
scores_sum = T.alloc_fragment([block_M], accum_dtype)
132+
logsum = T.alloc_fragment([block_M], accum_dtype)
133+
134+
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
135+
T.fill(acc_o, 0)
136+
T.fill(logsum, 0)
137+
T.fill(scores_max, -T.infinity(accum_dtype))
138+
139+
loop_range = (
140+
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
141+
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
142+
143+
for k in T.Pipelined(loop_range, num_stages=num_stages):
144+
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
145+
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
146+
scores_sum, logsum)
147+
Rescale(acc_o, scores_scale)
148+
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
149+
for i, j in T.Parallel(block_M, dim):
150+
acc_o[i, j] /= logsum[i]
151+
T.copy(acc_o, O_shared)
152+
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
153+
154+
return main
155+
156+
if tune:
157+
158+
@autotune(
159+
configs=get_configs(),
160+
keys=["block_M", "block_N", "num_stages", "threads"],
161+
warmup=10,
162+
rep=10)
163+
@jit(
164+
out_idx=[3],
165+
supply_type=tilelang.TensorSupplyType.Integer,
166+
ref_prog=None,
167+
profiler="auto")
168+
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
169+
return kernel_func(block_M, block_N, num_stages, threads)
170+
171+
return kernel()
172+
else:
173+
174+
def kernel(block_M, block_N, num_stages, threads):
175+
return kernel_func(block_M, block_N, num_stages, threads)
176+
177+
return kernel
178+
179+
180+
def ref_program(Q, K, V, is_causal, groups=1):
181+
# Q: [B, T, HQ, D]
182+
# K: [B, T, HK, D]
183+
# V: [B, T, HV, D]
184+
# HQ = HKV * groups
185+
assert Q.size(2) == K.size(
186+
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
187+
assert Q.size(2) == V.size(
188+
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
189+
190+
dim = Q.size(-1)
191+
K = K.repeat_interleave(groups, dim=2)
192+
V = V.repeat_interleave(groups, dim=2)
193+
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
194+
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
195+
if is_causal:
196+
seq_len = Q.size(1)
197+
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
198+
mask = mask.unsqueeze(0).unsqueeze(0)
199+
scores = scores.masked_fill(mask == 0, float('-inf'))
200+
attention_weights = F.softmax(scores, dim=-1)
201+
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
202+
return output
203+
204+
205+
if __name__ == "__main__":
206+
parser = argparse.ArgumentParser()
207+
parser.add_argument('--batch', type=int, default=8, help='batch size')
208+
parser.add_argument('--heads', type=int, default=32, help='heads')
209+
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
210+
parser.add_argument('--dim', type=int, default=128, help='dim')
211+
parser.add_argument('--is_causal', action='store_true', help='causal')
212+
parser.add_argument('--tune', action='store_true', help='tune configs')
213+
parser.add_argument('--groups', type=int, default=8, help='groups')
214+
args = parser.parse_args()
215+
batch, heads, seq_len, dim, is_causal, groups = args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups
216+
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
217+
total_flops = 2 * flops_per_matmul
218+
if is_causal:
219+
total_flops *= 0.5
220+
221+
if (not args.tune):
222+
program = flashattn(
223+
batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)(
224+
block_M=128, block_N=128, num_stages=1, threads=128)
225+
ref_program = partial(ref_program, is_causal=is_causal, groups=groups)
226+
mod, params = tilelang.lower(program)
227+
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
228+
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
229+
print("All checks pass.")
230+
latency = mod.do_bench(ref_program, warmup=500)
231+
print("Ref: {:.2f} ms".format(latency))
232+
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
233+
latency = mod.do_bench(mod.func, warmup=500)
234+
print("Tile-lang: {:.2f} ms".format(latency))
235+
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
236+
else:
237+
best_latency, best_config, _ = flashattn(
238+
batch, heads, seq_len, dim, is_causal, tune=args.tune)
239+
print(f"Best latency: {best_latency}")
240+
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
241+
print(f"Best config: {best_config}")

tilelang/engine/lower.py

Lines changed: 10 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Licensed under the MIT License.
33
"""The compiler for TL programs."""
44

5-
import tilelang as tl
65
import os
76
import os.path as osp
87
from typing import Union, Optional, Callable
@@ -12,6 +11,10 @@
1211
from tvm.target import Target
1312
from tilelang.contrib import hipcc, nvcc
1413
from tilelang.utils.target import determine_target
14+
from tilelang.engine.phase import (
15+
LowerAndLegalize,
16+
OptimizeForTarget,
17+
)
1518

1619

1720
def is_cpu_device_backend(target: Target):
@@ -152,68 +155,12 @@ def lower(
152155
_is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
153156
_is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))
154157

155-
mod = tir.transform.BindTarget(target)(mod)
156-
157-
mod = tl.transform.FrontendLegalize()(mod)
158-
mod = tir.transform.Simplify()(mod)
159-
mod = tl.transform.LayoutInference()(mod)
160-
mod = tl.transform.LowerTileOp()(mod)
161-
mod = tl.transform.LegalizeVectorizedLoop()(mod)
162-
mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
163-
# Inject Simplify to remove the duplicated conditions
164-
mod = tir.transform.Simplify()(mod)
165-
166-
# which may be introduced by the LegalizeSafeMemoryAccess
167-
if target.arch == "sm_90":
168-
mod = tl.transform.MultiVersionBuffer()(mod)
169-
mod = tl.transform.WarpSpecialized()(mod)
170-
mod = tl.transform.InjectSoftwarePipeline()(mod)
171-
mod = tir.transform.LowerOpaqueBlock()(mod)
172-
# mod = tl.transform.WarpSpecializedPipeline()(mod)
173-
mod = tl.transform.InjectFenceProxy()(mod)
174-
else:
175-
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
176-
mod = tl.transform.PipelinePlanning()(mod)
177-
mod = tl.transform.InjectSoftwarePipeline()(mod)
178-
179-
mod = tir.transform.LowerOpaqueBlock()(mod)
180-
mod = tir.transform.FlattenBuffer()(mod)
181-
mod = tir.transform.NarrowDataType(32)(mod)
182-
mod = tir.transform.Simplify()(mod)
183-
mod = tl.transform.VectorizeLoop()(mod)
184-
mod = tir.transform.StorageRewrite()(mod)
185-
mod = tir.transform.UnrollLoop()(mod)
186-
mod = tir.transform.RenormalizeSplitPattern()(mod)
187-
mod = tir.transform.Simplify()(mod)
188-
mod = tir.transform.RemoveNoOp()(mod)
189-
mod = tir.transform.RewriteUnsafeSelect()(mod)
190-
mod = tir.transform.HoistIfThenElse()(mod)
191-
192-
mod = tir.transform.VerifyMemory()(mod)
193-
mod = tir.transform.AnnotateEntryFunc()(mod)
194-
# TODO(lei): This is a hack to make sure the
195-
# thread level allreduce pass can be applied
196-
# in TL. As Tl only use one thread dimension
197-
# the var binding information will be lost
198-
# in the lowering process with Legalization
199-
# and Simplify pass.
200-
# We can find a way better to create var instead
201-
# of putting the LowerThreadAllreduce before
202-
# the Legalization.
203-
mod = tl.transform.ThreadPartialSync("shared.dyn")(mod)
204-
mod = tir.transform.InferFragment()(mod)
205-
mod = tir.transform.LowerThreadAllreduce()(mod)
206-
mod = tl.transform.LowerHopperIntrin()(mod)
207-
mod = tl.transform.ThreadSync("shared")(mod)
208-
mod = tl.transform.ThreadSync("shared.dyn")(mod)
209-
mod = tir.transform.InjectPTXAsyncCopy()(mod)
210-
211-
mod = tl.transform.AnnotateDeviceRegions()(mod)
212-
mod = tir.transform.SplitHostDevice()(mod)
213-
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
214-
215-
mod = tl.transform.MakePackedAPI()(mod)
216-
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
158+
# Phase 1: Lower and legalize the IR
159+
mod = LowerAndLegalize(mod, target)
160+
161+
# Phase 2: Optimize the IR for the target
162+
mod = OptimizeForTarget(mod, target)
163+
217164
host_mod = tir.transform.Filter(_is_host_call)(mod)
218165
host_mod = tir.transform.BindTarget(target_host)(host_mod)
219166
host_mod = tir.transform.FP8StorageLegalize()(host_mod)

0 commit comments

Comments
 (0)