Skip to content

Commit 0bbd063

Browse files
authored
[Dev][Doc] Add DeepSeek MLA Decode Example with Documentation and Performance Benchmarks (#134)
* [Dev] Add RetNet Linear Attention example * [Dev] Add WgmmaSync rewriter for pipelined WGMMA operations and add MHA WGMMA pipelined example (FA3-like scheduling) This commit introduces a new transformation pass `RewriteWgmmaSync` to optimize warp group matrix multiply accumulate (WGMMA) operations in the TileLang compiler: - Implemented `WgmmaSyncRewriter` in `src/transform/wgmma_sync_rewriter.cc` - Added pass registration for `RewriteWgmmaSync` - Updated `tilelang/engine/phase.py` to include the new transformation pass - Updated `tilelang/transform/__init__.py` to expose the new pass The rewriter intelligently manages synchronization and dependencies between WGMMA operations, improving pipeline efficiency for complex matrix multiplication kernels. * [Bugfix] Fix bug in ThreadTagChecker for warp specialization Improve thread tag validation in warp specialized rewriter to prevent unintended transformations: - Add more precise checks for threadIdx.y and threadIdx.z - Validate thread extent to ensure only single-extent thread bindings are allowed - Prevent warp specialization for multi-extent thread bindings in y and z dimensions * lint * [CI] Add TMA descriptor attribute to transformed module in test case * [Dev] Refactor DeepSeek MLA Decode Example with Non-Split and Split Flash Attention Implementations - Add new `flash_attn` macro for non-split flash attention implementation - Add swizzled layout for tile in shared memory - Use threadblock swizzle to imporve L2 cache hit rate * [Dev] Add DeepSeek MLA Decode Example with Documentation and Performance Benchmarks - Add detailed README.md explaining MLA (Multi-Head Latent Attention) implementation - Include performance benchmark images for batch sizes 64 and 128 - Add layout visualization images for QK and PV operations - Implement torch reference implementations in torch_refs.py - Update example_mla_decode.py with command-line argument support and flexible configuration - Add performance benchmarking and comparison with other implementations
1 parent 38d13c2 commit 0bbd063

File tree

7 files changed

+339
-87
lines changed

7 files changed

+339
-87
lines changed

examples/deepseek_mla/README.md

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# 🚀 How to write high-performance kernel with TileLang: take MLA as an example
2+
3+
TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang.
4+
5+
## Introduction to MLA
6+
7+
DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance. Subsequently, various deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have released their own MLA implementations.
8+
9+
## Benchmark Results
10+
11+
We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below.
12+
13+
<figure style="text-align: center">
14+
<a href="./bs64_float16.png">
15+
<img src="./bs64_float16.png" alt="bs64_float16">
16+
</a>
17+
<figcaption>Figure 1:Performance under batch size=64</figcaption>
18+
</figure>
19+
20+
<figure style="text-align: center">
21+
<a href="./bs128_float16.png">
22+
<img src="./bs128_float16.png" alt="bs128_float16">
23+
</a>
24+
<figcaption>Figure 2:Performance under batch size=128</figcaption>
25+
</figure>
26+
27+
As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton.
28+
Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this.
29+
30+
## Implementation
31+
32+
First, let's review the core computation logic of traditional FlashAttention:
33+
34+
```python
35+
# acc_s: [block_M, block_N]
36+
# scores_max: [block_M]
37+
# scores_scale: [block_M]
38+
# acc_o: [block_M, dim]
39+
40+
for i in range(loop_range):
41+
acc_s = Q @ K[i]
42+
scores_max_prev = scores_max
43+
scores_max = max(acc_s, dim=1)
44+
scores_scale = exp(scores_max_prev - scores_max)
45+
acc_o *= scores_scale
46+
acc_s = exp(acc_s - scores_max)
47+
acc_o = acc_s @ V[i]
48+
...
49+
```
50+
51+
Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency.
52+
53+
Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance.
54+
55+
This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling.
56+
57+
Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input.
58+
59+
Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory.
60+
61+
### Layout Inference
62+
63+
While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you.
64+
65+
Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations.
66+
67+
<figure style="text-align: center">
68+
<a href="./qk_layout.jpg">
69+
<img src="./qk_layout.jpg" alt="QK Layout">
70+
</a>
71+
<figcaption>Figure 3:Buffer shapes in Q @ K</figcaption>
72+
</figure>
73+
74+
<figure style="text-align: center">
75+
<a href="./qk_layout.jpg">
76+
<img src="./pv_layout.jpg" alt="PV Layout">
77+
</a>
78+
<figcaption>Figure 4:Buffer shapes in acc_s @ V</figcaption>
79+
</figure>
80+
81+
The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA.
82+
83+
For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`.
84+
85+
It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance.
86+
87+
### Threadblock Swizzling
88+
89+
Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions.
90+
91+
In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code:
92+
93+
```python
94+
T.use_swizzle(panel_size: int, order: str = "row")
95+
```
96+
97+
Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col".
98+
99+
100+
### Shared Memory Swizzling
101+
102+
In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance.
103+
104+
One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency.
105+
106+
Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code:
107+
108+
```python
109+
T.annotate_layout({
110+
S_shared: TileLang.layout.make_swizzled_layout(S_shared),
111+
})
112+
```
113+
114+
Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout.
115+
116+
117+
### Warp-Specialization
118+
119+
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Access), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
120+
121+
In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation.
122+
123+
124+
### Pipeline
125+
126+
127+
Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation:
128+
129+
```python
130+
T.pipelined(range: int, stage: int)
131+
```
132+
133+
Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases.
134+
135+
136+
### Split-KV
137+
138+
We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results.
139+
140+
In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter.
153 KB
Loading
154 KB
Loading

0 commit comments

Comments
 (0)