Skip to content

Commit 2b24293

Browse files
authored
perf: FlashAttention-3 style MLA PageAttention (#887)
This PR is the followup of #804 , we implemented a FlashAttention-3 version of warp specialization pattern (splitting on head-dimension) in #804 for faster attention on Hopper GPUs. Compared to the previous version (in FA2 style), this PR did the following changes: 1. use one warpgroup for producer, two warpgroup for consumer. 2. use async wgmma instead of mma. 3. use the software pipeline algorithm in FlashAttention-3, to overlap CUDA-Cores and Tensor-Cores operations. 4. Compared to original attention, MLA uses the same set of K and V (the ckv matrix), if we reuse the `CTA_TILE_KV=64` and `PIPE_STAGES=2`, the software pipeline algorithm would block the memory copy for next KV-Tile (both the pipe slots were be occupied), original attention do not have this issue because it has both `pipeline_k` and `pipeline_v`, doubling the stages. This PR changes `CTA_TILE_KV=32` and `PIPE_STAGES=4` to ensure we can compute the current KV-tile while loading the next KV-Tile, when using software pipeline. 5. Unlike original attention, we can't reuse V shared memory space for O. This PR designed a circular buffer for `o_smem` that reuses the KV slots, one KV-slot is not large enough for `o_smem` so we use two KV shared memory slot for one `o_smem`, a barrier is required to guarantee the memory order. ## Pipeline This figures explains our pipeline design: ![pipeline-design-mla](https://github.com/user-attachments/assets/178e465e-e671-459f-a4ea-02e2eaf17343) ## Results Benchmark result on H100 SXM3 (80GB). This PR (fa3 template), `page_size=1`: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1305.40 GB/s Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2228.56 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2759.33 GB/s Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1766.33 GB/s Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2498.08 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2768.37 GB/s ``` #804 + #863 (fa2 template), `page_size=1`: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1067.74 GB/s Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 1761.25 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2065.78 GB/s Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1384.35 GB/s Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 1892.64 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2075.97 GB/s ``` Using TMA and multicast could further improve performance for `page_size` larger than 1, we leave them for future work.
1 parent 26c0296 commit 2b24293

File tree

13 files changed

+1679
-94
lines changed

13 files changed

+1679
-94
lines changed

benchmarks/bench_deepseek_mla.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import flashinfer
2121

2222

23-
def bench_deepseek_mla_decode(batch_size, seq_len, num_heads):
23+
def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend):
2424
head_dim_ckv = 512
2525
head_dim_kpe = 64
2626
page_size = 1
@@ -39,7 +39,7 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads):
3939
sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5)
4040
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
4141
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
42-
workspace_buffer, backend="fa2"
42+
workspace_buffer, backend=backend
4343
)
4444
q_indptr = torch.arange(0, batch_size + 1).to(0).int()
4545
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * seq_len
@@ -74,6 +74,6 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads):
7474

7575

7676
if __name__ == "__main__":
77-
for seq_len in [1024, 2048, 4096, 8192, 16384, 32768]:
78-
for batch_size in [1, 16, 32, 64]:
79-
bench_deepseek_mla_decode(batch_size, seq_len, 16)
77+
for seq_len in [1024, 2048]:
78+
for batch_size in [64, 128, 768]:
79+
bench_deepseek_mla_decode(batch_size, seq_len, 64, "auto")

csrc/batch_mla_run.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#include <flashinfer/attention/mla_fa2.cuh>
16+
#include <flashinfer/attention/mla.cuh>
1717
#include <flashinfer/attention/scheduler.cuh>
1818
#include <flashinfer/fastdiv.cuh>
1919
#include <optional>

csrc/batch_mla_sm90_plan.cu

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/attention/scheduler.cuh>
17+
#include <optional>
18+
19+
#include "batch_mla_sm90_config.inc"
20+
#include "pytorch_conversion_utils.h"
21+
#include "pytorch_extension_utils.h"
22+
23+
using namespace flashinfer;
24+
25+
at::Tensor BatchMLAPagedAttentionSM90Plan(at::Tensor float_workspace_buffer,
26+
at::Tensor int_workspace_buffer,
27+
at::Tensor page_locked_int_workspace_buffer,
28+
at::Tensor qo_indptr, at::Tensor kv_indptr,
29+
at::Tensor kv_len, int64_t num_heads, int64_t head_dim_o,
30+
bool causal, int64_t cuda_stream) {
31+
size_t float_workspace_size_in_bytes =
32+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
33+
size_t int_workspace_size_in_bytes =
34+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
35+
36+
MLAPlanInfo plan_info;
37+
38+
int batch_size = kv_len.size(0);
39+
40+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
41+
cudaError_t status =
42+
MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
43+
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
44+
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(qo_indptr.data_ptr()),
45+
static_cast<IdType*>(kv_indptr.data_ptr()), static_cast<IdType*>(kv_len.data_ptr()),
46+
batch_size, num_heads, head_dim_o, causal, stream);
47+
48+
TORCH_CHECK(status == cudaSuccess, "Failed to plan MLA, error: ", cudaGetErrorString(status));
49+
50+
return vec_to_tensor(plan_info.ToVector());
51+
}

csrc/batch_mla_sm90_pybind.cu

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "batch_mla_sm90_config.inc"
17+
#include "pytorch_extension_utils.h"
18+
19+
at::Tensor BatchMLAPagedAttentionSM90Plan(at::Tensor float_workspace_buffer,
20+
at::Tensor int_workspace_buffer,
21+
at::Tensor page_locked_int_workspace_buffer,
22+
at::Tensor qo_indptr, at::Tensor kv_indptr,
23+
at::Tensor kv_len, int64_t num_heads, int64_t head_dim_o,
24+
bool causal, int64_t cuda_stream);
25+
26+
void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer,
27+
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
28+
at::Tensor q_nope, at::Tensor q_pe, at::Tensor ckv_cache,
29+
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor o,
30+
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
31+
int64_t num_heads, int64_t page_size, double sm_scale,
32+
int64_t cuda_stream);
33+
34+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
35+
m.def("plan", &BatchMLAPagedAttentionSM90Plan);
36+
m.def("run", &BatchMLAPagedAttentionSM90Run);
37+
}

csrc/batch_mla_sm90_run.cu

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/attention/mla_hopper.cuh>
17+
#include <flashinfer/attention/scheduler.cuh>
18+
#include <flashinfer/fastdiv.cuh>
19+
#include <optional>
20+
21+
#include "batch_mla_sm90_config.inc"
22+
#include "pytorch_conversion_utils.h"
23+
#include "pytorch_extension_utils.h"
24+
25+
using namespace flashinfer;
26+
27+
void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer,
28+
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
29+
at::Tensor q_nope, at::Tensor q_pe, at::Tensor ckv_cache,
30+
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor o,
31+
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
32+
int64_t num_heads, int64_t page_size, double sm_scale,
33+
int64_t cuda_stream) {
34+
// q_nope: [n, num_heads, head_dim_ckv]
35+
// q_pe: [n, num_heads, head_dim_kpe]
36+
// ckv_cache: [num_pages, page_size, head_dim_ckv]
37+
// kpe_cache: [num_pages, page_size, head_dim_kpe]
38+
MLAPlanInfo plan_info;
39+
plan_info.FromVector(tensor_to_vec(plan_info_vec));
40+
41+
auto device = q_nope.device();
42+
43+
void* float_buffer_ptr = float_workspace_buffer.data_ptr();
44+
void* int_buffer_ptr = int_workspace_buffer.data_ptr();
45+
46+
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
47+
48+
auto q_scalar_type = q_nope.scalar_type();
49+
auto kv_scalar_type = ckv_cache.scalar_type();
50+
51+
unsigned int q_nope_stride_n = q_nope.stride(0);
52+
unsigned int q_nope_stride_h = q_nope.stride(1);
53+
unsigned int q_pe_stride_n = q_pe.stride(0);
54+
unsigned int q_pe_stride_h = q_pe.stride(1);
55+
unsigned int ckv_stride_page = ckv_cache.stride(0);
56+
unsigned int ckv_stride_n = ckv_cache.stride(1);
57+
unsigned int kpe_stride_page = kpe_cache.stride(0);
58+
unsigned int kpe_stride_n = kpe_cache.stride(1);
59+
unsigned int o_stride_n = o.stride(0);
60+
unsigned int o_stride_h = o.stride(1);
61+
62+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
63+
64+
DISPATCH_context(
65+
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] {
66+
Params params;
67+
68+
params.q_nope = static_cast<DTypeQ*>(q_nope.data_ptr());
69+
params.q_pe = static_cast<DTypeQ*>(q_pe.data_ptr());
70+
params.ckv = static_cast<DTypeKV*>(ckv_cache.data_ptr());
71+
params.kpe = static_cast<DTypeKV*>(kpe_cache.data_ptr());
72+
73+
params.q_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_indptr_offset);
74+
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
75+
params.partial_indptr =
76+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.partial_indptr_offset);
77+
params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr());
78+
params.q_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_len_offset);
79+
params.kv_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
80+
params.q_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_start_offset);
81+
params.kv_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_start_offset);
82+
params.kv_end = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_end_offset);
83+
params.work_indptr =
84+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
85+
params.merge_packed_offset_start = GetPtrFromBaseOffset<IdType>(
86+
int_buffer_ptr, plan_info.merge_packed_offset_start_offset);
87+
params.merge_packed_offset_end =
88+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_packed_offset_end_offset);
89+
params.merge_indptr =
90+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_indptr_offset);
91+
params.final_o = static_cast<DTypeO*>(o.data_ptr());
92+
params.final_lse =
93+
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
94+
params.partial_o =
95+
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_o_offset);
96+
params.partial_lse =
97+
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);
98+
99+
params.num_heads = uint_fastdiv(num_heads);
100+
params.block_size = uint_fastdiv(page_size);
101+
102+
params.q_nope_stride_n = q_nope_stride_n;
103+
params.q_nope_stride_h = q_nope_stride_h;
104+
params.q_pe_stride_n = q_pe_stride_n;
105+
params.q_pe_stride_h = q_pe_stride_h;
106+
params.ckv_stride_page = ckv_stride_page;
107+
params.ckv_stride_n = ckv_stride_n;
108+
params.kpe_stride_page = kpe_stride_page;
109+
params.kpe_stride_n = kpe_stride_n;
110+
params.o_stride_n = o_stride_n;
111+
params.o_stride_h = o_stride_h;
112+
113+
params.sm_scale = sm_scale;
114+
115+
cudaError_t status =
116+
mla::BatchMLAPageAttentionHopper<MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE>(
117+
params, plan_info.num_blks_x, plan_info.num_blks_y, stream);
118+
119+
TORCH_CHECK(status == cudaSuccess,
120+
"Failed to run MLA, error: ", cudaGetErrorString(status));
121+
});
122+
}

flashinfer/jit/attention/pytorch.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jinja2
2121
import torch
2222

23-
from ..core import load_cuda_ops, logger
23+
from ..core import load_cuda_ops, logger, sm90a_nvcc_flags
2424
from ..env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR
2525
from ..utils import (
2626
dtype_map,
@@ -79,6 +79,7 @@ def get_batch_decode_uri(
7979

8080

8181
def get_batch_mla_uri(
82+
backend: str,
8283
dtype_q: torch.dtype,
8384
dtype_kv: torch.dtype,
8485
dtype_o: torch.dtype,
@@ -93,18 +94,22 @@ def get_batch_mla_uri(
9394
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
9495
f"head_dim_ckv_{head_dim_ckv}_"
9596
f"head_dim_kpe_{head_dim_kpe}"
96-
)
97+
) + ("_sm90" if backend == "fa3" else "")
9798

9899

99100
def gen_batch_mla_module(
101+
backend: str,
100102
dtype_q: torch.dtype,
101103
dtype_kv: torch.dtype,
102104
dtype_o: torch.dtype,
103105
dtype_idx: torch.dtype,
104106
head_dim_ckv: int,
105107
head_dim_kpe: int,
106108
):
109+
if backend == "auto":
110+
raise ValueError("backend should not be auto when jit_args is provided")
107111
uri = get_batch_mla_uri(
112+
backend,
108113
dtype_q,
109114
dtype_kv,
110115
dtype_o,
@@ -115,35 +120,71 @@ def gen_batch_mla_module(
115120
gen_directory = FLASHINFER_GEN_SRC_DIR / uri
116121
os.makedirs(gen_directory, exist_ok=True)
117122

118-
with open(FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f:
119-
config_templ = jinja2.Template(f.read())
120-
generated_config_path = gen_directory / "batch_mla_config.inc"
121-
write_if_different(
122-
generated_config_path,
123-
config_templ.render(
124-
dtype_q=dtype_map[dtype_q],
125-
dtype_kv=dtype_map[dtype_kv],
126-
dtype_o=dtype_map[dtype_o],
127-
dtype_idx=dtype_map[dtype_idx],
128-
head_dim_ckv=head_dim_ckv,
129-
head_dim_kpe=head_dim_kpe,
130-
),
131-
)
123+
if backend == "fa2":
124+
with open(FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f:
125+
config_templ = jinja2.Template(f.read())
126+
generated_config_path = gen_directory / "batch_mla_config.inc"
127+
write_if_different(
128+
generated_config_path,
129+
config_templ.render(
130+
dtype_q=dtype_map[dtype_q],
131+
dtype_kv=dtype_map[dtype_kv],
132+
dtype_o=dtype_map[dtype_o],
133+
dtype_idx=dtype_map[dtype_idx],
134+
head_dim_ckv=head_dim_ckv,
135+
head_dim_kpe=head_dim_kpe,
136+
),
137+
)
132138

133-
source_paths = []
134-
for filename in [
135-
"batch_mla_plan.cu",
136-
"batch_mla_run.cu",
137-
"batch_mla_pybind.cu",
138-
]:
139-
src_path = FLASHINFER_CSRC_DIR / filename
140-
dest_path = gen_directory / filename
141-
source_paths.append(dest_path)
142-
with open(src_path, "r") as f:
143-
source = f.read()
144-
write_if_different(dest_path, source)
139+
source_paths = []
140+
for filename in [
141+
"batch_mla_plan.cu",
142+
"batch_mla_run.cu",
143+
"batch_mla_pybind.cu",
144+
]:
145+
src_path = FLASHINFER_CSRC_DIR / filename
146+
dest_path = gen_directory / filename
147+
source_paths.append(dest_path)
148+
with open(src_path, "r") as f:
149+
source = f.read()
150+
write_if_different(dest_path, source)
151+
elif backend == "fa3":
152+
with open(FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f:
153+
config_templ = jinja2.Template(f.read())
154+
generated_config_path = gen_directory / "batch_mla_sm90_config.inc"
155+
write_if_different(
156+
generated_config_path,
157+
config_templ.render(
158+
dtype_q=dtype_map[dtype_q],
159+
dtype_kv=dtype_map[dtype_kv],
160+
dtype_o=dtype_map[dtype_o],
161+
dtype_idx=dtype_map[dtype_idx],
162+
head_dim_ckv=head_dim_ckv,
163+
head_dim_kpe=head_dim_kpe,
164+
),
165+
)
166+
source_paths = []
167+
for filename in [
168+
"batch_mla_sm90_plan.cu",
169+
"batch_mla_sm90_run.cu",
170+
"batch_mla_sm90_pybind.cu",
171+
]:
172+
src_path = FLASHINFER_CSRC_DIR / filename
173+
dest_path = gen_directory / filename
174+
source_paths.append(dest_path)
175+
with open(src_path, "r") as f:
176+
source = f.read()
177+
write_if_different(dest_path, source)
178+
else:
179+
raise ValueError(f"Unsupported backend: {backend}")
145180

146-
return load_cuda_ops(uri, source_paths)
181+
return load_cuda_ops(
182+
uri,
183+
source_paths,
184+
extra_cuda_cflags=(
185+
["-gencode=arch=compute_90a,code=sm_90a"] if backend == "fa3" else []
186+
),
187+
)
147188

148189

149190
def get_batch_decode_mla_uri(

0 commit comments

Comments
 (0)