Skip to content

Commit a2ae88c

Browse files
authored
[Example] Update GEMM FP8 Example (#123)
* Add DeepSeek MLA decode example with Flash Attention implementation * Add GEMM SplitK and StreamK example implementations This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques: - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations. * Refactor GEMM SplitK and StreamK example implementations Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts: - Remove unused import (Profiler) in splitk example - Simplify line breaks and improve code readability - Standardize indentation and remove unnecessary whitespace - Optimize atomic add and copy operations for better clarity * Add block sparse attention benchmarks for multiple libraries This commit introduces comprehensive block sparse attention benchmarks for different libraries: - TileLang block sparse FMHA implementation - Triton block sparse FMHA implementation - PyTorch reference block sparse FMHA implementation - FlashAttention dense FMHA reference implementation The benchmarks include: - Configurable benchmark parameters (batch size, heads, sequence length, etc.) - Sparse mask generation using top-k and threshold methods - Performance measurement for different sparse attention configurations - Utility functions for mask generation and benchmarking * Refactor block sparse attention benchmarks with code style improvements - Add Ruff linter ignore comments to benchmark files - Improve code formatting and line breaks - Remove unused imports - Standardize print statement formatting - Enhance code readability across multiple library benchmarks * lint fix * Add CUDA atomic operations for BFLOAT16 and update function naming - Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header - Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd) - Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values - Update kernel and language customization to use new function names - Add return type annotations in profiler module * lint fix * Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates: - Group Query Attention (GQA) implementation - Flash Attention forward pass - Performance benchmarking - Configurable parameters for batch, heads, sequence length, and dimension - Autotuning support - Reference implementation comparison * Refactor IR lowering pipeline into modular phases This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases: - `LowerAndLegalize`: Handles initial IR legalization and transformation - `OptimizeForTarget`: Applies target-specific optimizations The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability. * lintfix * nas kernel * Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates - Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements - Increased default number of heads and selected blocks in TileLang NSA example - Added Ruff linter ignore comments to reference.py - Standardized function signatures and improved code readability across NSA implementations * Add utility math functions for integer operations - Implement `next_power_of_2()` to calculate the next power of 2 for an integer - Add `cdiv()` function for ceiling division of integers * Add utility math functions for integer operations - Implement `next_power_of_2()` to calculate the next power of 2 for an integer - Add `cdiv()` function for ceiling division of integers * Refactor DeepSeek MLA Decode Example with Enhanced Flash Attention Implementation - Update flash attention kernel to support positional embeddings (PE) - Modify reference implementation to handle PE and group query attention - Increase default batch size and adjust benchmarking parameters - Improve kernel performance and readability - Add einops and torch operations for more flexible tensor manipulation * Update README.md with corrected Flash MLA Decoding example path - Modify the example link for Flash MLA Decoding to point to the correct directory - Ensure accurate navigation to the DeepSeek MLA decoding example
1 parent 375423c commit a2ae88c

File tree

4 files changed

+324
-354
lines changed

4 files changed

+324
-354
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Although tile-lang aims to be portable across a range of Devices, it has been sp
2626
- [Dequantization GEMM](./examples/dequantize_gemm/)
2727
- [Flash Attention](./examples/flash_attention/)
2828
- [Flash Linear Attention](./examples/linear_attention/)
29-
- [Flash MLA Decoding](./examples/flash_decoding/example_mla_decode.py)
29+
- [Flash MLA Decoding](./examples/deepseek_mla/)
3030
- [Native Sparse Attention](./examples/native_sparse_attention/)
3131

3232
Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added.

examples/deepseek_mla/example_mla_decode.py

Lines changed: 84 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@
33
import tilelang
44
from tilelang.autotuner import *
55
import tilelang.language as T
6+
from einops import rearrange, einsum
67

7-
num_split = 4
8+
num_split = 1
89

910

1011
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
1112
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
12-
shape_q = [batch, heads, (dim + pe_dim)]
13-
shape_k = [batch, seqlen_kv, kv_head_num, (dim + pe_dim)]
14-
shape_v = [batch, seqlen_kv, kv_head_num, dim]
15-
shape_o = [batch, heads, dim]
16-
part_shape = [batch, heads, num_split, dim]
1713
dtype = "float16"
1814
accum_dtype = "float"
1915
kv_group_num = heads // kv_head_num
@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
2218

2319
@T.macro
2420
def flash_attn_split(
25-
Q: T.Buffer(shape_q, dtype),
26-
K: T.Buffer(shape_k, dtype),
27-
V: T.Buffer(shape_v, dtype),
21+
Q: T.Buffer([batch, heads, dim], dtype),
22+
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
23+
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
24+
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
2825
glse: T.Buffer([batch, heads, num_split], dtype),
29-
Output_partial: T.Buffer(part_shape, dtype),
26+
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
3027
):
3128
with T.Kernel(
32-
batch, heads // min(block_H, kv_group_num), num_split, threads=128) as (bx, by, bz):
33-
Q_shared = T.alloc_shared([block_H, (dim + pe_dim)], dtype)
34-
K_shared = T.alloc_shared([block_N, (dim + pe_dim)], dtype)
35-
V_shared = T.alloc_shared([block_N, dim], dtype)
29+
batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
30+
Q_shared = T.alloc_shared([block_H, dim], dtype)
31+
S_shared = T.alloc_shared([block_H, block_N], dtype)
32+
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
33+
KV_shared = T.alloc_shared([block_N, dim], dtype)
34+
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
3635
O_shared = T.alloc_shared([block_H, dim], dtype)
3736
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
37+
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
3838
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
3939
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
4040
scores_max = T.alloc_fragment([block_H], accum_dtype)
@@ -53,20 +53,32 @@ def flash_attn_split(
5353
})
5454

5555
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
56+
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
5657
T.fill(acc_o, 0)
5758
T.fill(logsum, 0)
5859
T.fill(scores_max, -T.infinity(accum_dtype))
5960

6061
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
61-
for k in T.Pipelined(loop_range, num_stages=1):
62-
T.copy(
63-
K[bid, (seqlen_kv // num_split) * sid +
64-
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
65-
cur_kv_head, :], K_shared)
66-
T.clear(acc_s)
67-
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
62+
for k in T.Pipelined(loop_range, num_stages=2):
63+
kv_start = (seqlen_kv // num_split) * sid + k * block_N
64+
kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N
65+
66+
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
67+
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
68+
69+
T.clear(acc_s_0)
70+
T.gemm(
71+
Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
72+
T.gemm(
73+
Q_pe_shared,
74+
K_pe_shared,
75+
acc_s_0,
76+
transpose_B=True,
77+
policy=T.GemmWarpPolicy.FullCol)
6878
T.copy(scores_max, scores_max_prev)
6979
T.fill(scores_max, -T.infinity(accum_dtype))
80+
T.copy(acc_s_0, S_shared)
81+
T.copy(S_shared, acc_s)
7082
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
7183
for i in T.Parallel(block_H):
7284
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
@@ -78,11 +90,7 @@ def flash_attn_split(
7890
T.copy(acc_s, acc_s_cast)
7991
for i, j in T.Parallel(block_H, dim):
8092
acc_o[i, j] *= scores_scale[i]
81-
T.copy(
82-
V[bid, (seqlen_kv // num_split) * sid +
83-
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
84-
cur_kv_head, :], V_shared)
85-
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
93+
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
8694
for i, j in T.Parallel(block_H, dim):
8795
acc_o[i, j] /= logsum[i]
8896
for i in T.Parallel(block_H):
@@ -96,8 +104,8 @@ def flash_attn_split(
96104
@T.macro
97105
def combine(
98106
glse: T.Buffer([batch, heads, num_split], dtype),
99-
Output_partial: T.Buffer(part_shape, dtype),
100-
Output: T.Buffer(shape_o, dtype),
107+
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
108+
Output: T.Buffer([batch, heads, dim], dtype),
101109
):
102110
with T.Kernel(heads, batch, threads=128) as (by, bz):
103111
po_local = T.alloc_fragment([dim], dtype)
@@ -133,50 +141,61 @@ def combine(
133141

134142
@T.prim_func
135143
def main(
136-
Q: T.Buffer(shape_q, dtype),
137-
K: T.Buffer(shape_k, dtype),
138-
V: T.Buffer(shape_v, dtype),
144+
Q: T.Buffer([batch, heads, dim], dtype),
145+
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
146+
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
147+
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
139148
glse: T.Buffer([batch, heads, num_split], dtype),
140-
Output_partial: T.Buffer(part_shape, dtype), # [batch, heads, num_split, dim]
141-
Output: T.Buffer(shape_o, dtype),
149+
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
150+
Output: T.Buffer([batch, heads, dim], dtype),
142151
):
143-
flash_attn_split(Q, K, V, glse, Output_partial)
152+
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
144153
combine(glse, Output_partial, Output)
145154

146155
return main
147156

148157

149-
def ref_program(query, key, value, glse, Output_partial):
158+
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
150159
# """
151160
# Inputs:
152-
# - query (Tensor): [batch, heads, dim]
153-
# - key (Tensor): [batch, seqlen_kv, kv_head_num, dim]
154-
# - value (Tensor): [batch, seqlen_kv, kv_head_num, dim]
155-
161+
# - q (Tensor): [batch, heads, dim]
162+
# - q_pe (Tensor): [batch, heads, pe_dim]
163+
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
164+
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
165+
# - glse (Tensor): [batch, heads, num_split]
166+
# - Output_partial (Tensor): [batch, heads, num_split, dim]
156167
# Outputs:
157168
# - output (Tensor): [batch, heads, dim]
158169
# """
159-
from einops import rearrange
160-
batch_size, query_heads, dim = query.shape # [batch_size, query_heads, dim]
161-
_, seqlen_kv, kv_heads, _ = key.shape # [batch_size, seqlen_kv, kv_heads, kv_dim]
162-
dim_v = value.shape[-1]
163-
assert kv_heads == 1, "kv_heads must be 1"
164-
165-
query_expanded = rearrange(query, 'b h d -> b h 1 d') # [batch_size, query_heads, 1, dim]
166-
key_expanded = key.expand(-1, -1, query_heads, -1) # [batch_size, query_heads, seqlen_kv, dim]
167-
value_expanded = value.expand(-1, -1, query_heads,
168-
-1) # [batch_size, query_heads, seqlen_kv, dim]
169-
key_expanded = rearrange(key_expanded,
170-
'b n h d -> b h n d') # [batch_size, kv_head_num, seqlen_kv, dim]
171-
value_expanded = rearrange(value_expanded,
172-
'b n h d -> b h n d') # [batch_size, query_heads, seqlen_kv, dim]
173-
174-
scores = torch.matmul(query_expanded,
175-
key_expanded.transpose(-1, -2)) # [batch_size, query_heads, 1, seqlen_kv]
176-
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
177-
attention_weights = F.softmax(scores, dim=-1) # [batch_size, query_heads, 1, seqlen_kv]
178-
output = torch.matmul(attention_weights, value_expanded) # [batch_size, query_heads, 1, dim]
179-
return output.view(batch_size, query_heads, dim_v)
170+
dim = q.shape[-1]
171+
pe_dim = q_pe.shape[-1]
172+
num_head_groups = q.shape[1] // kv.shape[2]
173+
scale = (dim + pe_dim)**0.5
174+
q = rearrange(
175+
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
176+
177+
q_pe = rearrange(
178+
q_pe, 'b (h g) d -> b g h d',
179+
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
180+
181+
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
182+
183+
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
184+
185+
query = torch.concat([q, q_pe], dim=-1)
186+
key = torch.concat([kv, k_pe], dim=-1)
187+
188+
scores = einsum(
189+
query, key,
190+
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
191+
192+
attention = F.softmax(
193+
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
194+
195+
out = einsum(attention, kv,
196+
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
197+
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
198+
return out
180199

181200

182201
def flash_split_ref(Q, K, V):
@@ -251,7 +270,7 @@ def reduce_ref(Q, K, V, glse, Output_partial):
251270

252271

253272
if __name__ == "__main__":
254-
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 64, 128, 1, 8192, 512, 64
273+
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 128, 128, 1, 8192, 512, 64
255274
qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE)
256275
pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
257276
total_flops = qk_flops + pv_flops
@@ -260,8 +279,9 @@ def reduce_ref(Q, K, V, glse, Output_partial):
260279

261280
program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H)
262281
mod, params = tilelang.lower(program)
263-
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
282+
mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Normal)
264283
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
265-
latency = mod.do_bench(mod.func, warmup=500)
284+
print("All close")
285+
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
266286
print("Tile-lang: {:.2f} ms".format(latency))
267-
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
287+
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))

0 commit comments

Comments
 (0)