Skip to content

Commit a9935ea

Browse files
authored
bugfix: Parameterize prefix mask call (needed by POD-Attention) (#1059)
Passes the threadId to the prefix mask call manually. Shouldn't change existing code, but is necessary since POD remaps threadIds and blockIds.
1 parent 9a811ee commit a9935ea

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

benchmarks/bench_mixed_attention.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,27 @@ def run_bench(
6868
q_data_type=torch.bfloat16,
6969
kv_data_type=torch.bfloat16,
7070
)
71+
o = wrapper_old.run(q, kv_data)
7172
ms_old = do_bench(lambda: wrapper_old.run(q, kv_data))
7273

73-
if len(p_kv_lens) > 0:
74+
if len(p_kv_lens) == 1:
7475
q_d = q[: d_q_indptr[-1]]
7576
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
7677
q_p = q[d_q_indptr[-1] :]
7778
k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(1)
7879
k_p, v_p = k_p.squeeze(1), v_p.squeeze(1)
80+
kv_indices_d = torch.arange(
81+
0, d_kv_indptr[-1], device=device, dtype=torch.int32
82+
)
7983

8084
last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
8185
wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper(
8286
workspace_buffer,
8387
kv_layout=kv_layout,
8488
)
8589
wrapper_pod.plan(
86-
d_q_indptr.to(device),
8790
d_kv_indptr.to(device),
91+
kv_indices_d.to(device),
8892
last_page_len=last_page_len_d,
8993
num_qo_heads=num_qo_heads,
9094
num_kv_heads=num_kv_heads,
@@ -93,6 +97,19 @@ def run_bench(
9397
q_data_type=torch.bfloat16,
9498
kv_data_type=torch.bfloat16,
9599
)
100+
o_p, o_d = wrapper_pod.run(
101+
q_p,
102+
k_p,
103+
v_p,
104+
q_d,
105+
kv_data,
106+
causal_p=causal,
107+
)
108+
o_pod = torch.cat([o_d, o_p], dim=0)
109+
# Verify output matches
110+
torch.testing.assert_close(
111+
o, o_pod, rtol=1e-3, atol=1e-3, msg="POD-Attention output mismatch!"
112+
)
96113
ms_pod = do_bench(
97114
lambda: wrapper_pod.run(
98115
q_p,
@@ -106,7 +123,7 @@ def run_bench(
106123
)
107124

108125
print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
109-
if len(p_kv_lens) > 0:
126+
if len(p_kv_lens) == 1:
110127
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
111128
total_bytes = (
112129
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
@@ -116,7 +133,7 @@ def run_bench(
116133
bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3)
117134

118135
print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s")
119-
if len(p_kv_lens) > 0:
136+
if len(p_kv_lens) == 1:
120137
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
121138
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
122139

@@ -128,8 +145,8 @@ def run_bench(
128145
# Irregular sequence lengths for prefill and decode
129146
d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256]
130147
d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256]
131-
p_q_configs = [[17] * 8, [], [17] * 16, []]
132-
p_kv_configs = [[10000] * 8, [], [8192] * 16, []]
148+
p_q_configs = [[17] * 1, [10000], [17] * 1, []]
149+
p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []]
133150

134151
# construct random length testcases
135152
for _ in range(1):

include/flashinfer/attention/prefill.cuh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -797,8 +797,8 @@ __device__ __forceinline__ void logits_mask_multi_item_scoring(
797797
const uint32_t kv_len, const uint32_t window_left, const uint32_t chunk_end,
798798
const uint_fastdiv group_size, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8],
799799
// new arguments for compact description of mask
800-
const uint32_t prefix_len, uint16_t* token_pos_in_items) {
801-
const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z;
800+
const uint32_t prefix_len, uint16_t* token_pos_in_items, const uint32_t lane_idx = threadIdx.x,
801+
const uint32_t kv_head_idx = blockIdx.z) {
802802
constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q;
803803
constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV;
804804
using DTypeQKAccum = typename KTraits::DTypeQKAccum;
@@ -2285,23 +2285,26 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice(
22852285
if (iter >= mask_iteration || iter < window_iteration) {
22862286
logits_mask<KTraits>(
22872287
params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base,
2288-
chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>()) * NUM_MMA_KV * 16,
2288+
chunk_start +
2289+
(iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
22892290
qo_len, kv_len, chunk_end, group_size, s_frag);
22902291
}
22912292
} else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) {
22922293
if (iter + 1 >= num_iterations_prefix) {
22932294
logits_mask_multi_item_scoring<KTraits>(
22942295
params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base,
2295-
chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>()) * NUM_MMA_KV * 16,
2296+
chunk_start +
2297+
(iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
22962298
qo_len, kv_len, window_left, chunk_end, group_size, s_frag,
22972299
__ldg(maybe_prefix_len_ptr + request_idx),
2298-
maybe_token_pos_in_items_ptr + request_idx * token_pos_in_items_len);
2300+
maybe_token_pos_in_items_ptr + request_idx * token_pos_in_items_len, tid.x,
2301+
kv_head_idx);
22992302
} else {
23002303
if (iter >= mask_iteration || iter < window_iteration) {
23012304
logits_mask<KTraits>(
23022305
params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base,
23032306
chunk_start +
2304-
(iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>()) * NUM_MMA_KV * 16,
2307+
(iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
23052308
qo_len, kv_len, chunk_end, group_size, s_frag);
23062309
}
23072310
}

0 commit comments

Comments
 (0)