Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions benchmarks/bench_mixed_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,27 @@ def run_bench(
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
o = wrapper_old.run(q, kv_data)
ms_old = do_bench(lambda: wrapper_old.run(q, kv_data))

if len(p_kv_lens) > 0:
if len(p_kv_lens) == 1:
q_d = q[: d_q_indptr[-1]]
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
q_p = q[d_q_indptr[-1] :]
k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(1)
k_p, v_p = k_p.squeeze(1), v_p.squeeze(1)
kv_indices_d = torch.arange(
0, d_kv_indptr[-1], device=device, dtype=torch.int32
)

last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout=kv_layout,
)
wrapper_pod.plan(
d_q_indptr.to(device),
d_kv_indptr.to(device),
kv_indices_d.to(device),
last_page_len=last_page_len_d,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
Expand All @@ -93,6 +97,19 @@ def run_bench(
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
o_p, o_d = wrapper_pod.run(
q_p,
k_p,
v_p,
q_d,
kv_data,
causal_p=causal,
)
o_pod = torch.cat([o_d, o_p], dim=0)
# Verify output matches
torch.testing.assert_close(
o, o_pod, rtol=1e-3, atol=1e-3, msg="POD-Attention output mismatch!"
)
ms_pod = do_bench(
lambda: wrapper_pod.run(
q_p,
Expand All @@ -106,7 +123,7 @@ def run_bench(
)

print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
if len(p_kv_lens) > 0:
if len(p_kv_lens) == 1:
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
total_bytes = (
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
Expand All @@ -116,7 +133,7 @@ def run_bench(
bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3)

print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s")
if len(p_kv_lens) > 0:
if len(p_kv_lens) == 1:
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")

Expand All @@ -128,8 +145,8 @@ def run_bench(
# Irregular sequence lengths for prefill and decode
d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256]
d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256]
p_q_configs = [[17] * 8, [], [17] * 16, []]
p_kv_configs = [[10000] * 8, [], [8192] * 16, []]
p_q_configs = [[17] * 1, [10000], [17] * 1, []]
p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []]

# construct random length testcases
for _ in range(1):
Expand Down
15 changes: 9 additions & 6 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,8 @@ __device__ __forceinline__ void logits_mask_multi_item_scoring(
const uint32_t kv_len, const uint32_t window_left, const uint32_t chunk_end,
const uint_fastdiv group_size, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8],
// new arguments for compact description of mask
const uint32_t prefix_len, uint16_t* token_pos_in_items) {
const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z;
const uint32_t prefix_len, uint16_t* token_pos_in_items, const uint32_t lane_idx = threadIdx.x,
const uint32_t kv_head_idx = blockIdx.z) {
constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q;
constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV;
using DTypeQKAccum = typename KTraits::DTypeQKAccum;
Expand Down Expand Up @@ -2285,23 +2285,26 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice(
if (iter >= mask_iteration || iter < window_iteration) {
logits_mask<KTraits>(
params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base,
chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>()) * NUM_MMA_KV * 16,
chunk_start +
(iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
qo_len, kv_len, chunk_end, group_size, s_frag);
}
} else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) {
if (iter + 1 >= num_iterations_prefix) {
logits_mask_multi_item_scoring<KTraits>(
params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base,
chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>()) * NUM_MMA_KV * 16,
chunk_start +
(iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
qo_len, kv_len, window_left, chunk_end, group_size, s_frag,
__ldg(maybe_prefix_len_ptr + request_idx),
maybe_token_pos_in_items_ptr + request_idx * token_pos_in_items_len);
maybe_token_pos_in_items_ptr + request_idx * token_pos_in_items_len, tid.x,
kv_head_idx);
} else {
if (iter >= mask_iteration || iter < window_iteration) {
logits_mask<KTraits>(
params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base,
chunk_start +
(iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>()) * NUM_MMA_KV * 16,
(iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
qo_len, kv_len, chunk_end, group_size, s_frag);
}
}
Expand Down