diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index caf8c633f7..f581628b97 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -68,14 +68,18 @@ def run_bench( q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) + o = wrapper_old.run(q, kv_data) ms_old = do_bench(lambda: wrapper_old.run(q, kv_data)) - if len(p_kv_lens) > 0: + if len(p_kv_lens) == 1: q_d = q[: d_q_indptr[-1]] kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) q_p = q[d_q_indptr[-1] :] k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(1) k_p, v_p = k_p.squeeze(1), v_p.squeeze(1) + kv_indices_d = torch.arange( + 0, d_kv_indptr[-1], device=device, dtype=torch.int32 + ) last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1 wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper( @@ -83,8 +87,8 @@ def run_bench( kv_layout=kv_layout, ) wrapper_pod.plan( - d_q_indptr.to(device), d_kv_indptr.to(device), + kv_indices_d.to(device), last_page_len=last_page_len_d, num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, @@ -93,6 +97,19 @@ def run_bench( q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) + o_p, o_d = wrapper_pod.run( + q_p, + k_p, + v_p, + q_d, + kv_data, + causal_p=causal, + ) + o_pod = torch.cat([o_d, o_p], dim=0) + # Verify output matches + torch.testing.assert_close( + o, o_pod, rtol=1e-3, atol=1e-3, msg="POD-Attention output mismatch!" + ) ms_pod = do_bench( lambda: wrapper_pod.run( q_p, @@ -106,7 +123,7 @@ def run_bench( ) print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") - if len(p_kv_lens) > 0: + if len(p_kv_lens) == 1: print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") total_bytes = ( q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() @@ -116,7 +133,7 @@ def run_bench( bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3) print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s") - if len(p_kv_lens) > 0: + if len(p_kv_lens) == 1: bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3) print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s") @@ -128,8 +145,8 @@ def run_bench( # Irregular sequence lengths for prefill and decode d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] - p_q_configs = [[17] * 8, [], [17] * 16, []] - p_kv_configs = [[10000] * 8, [], [8192] * 16, []] + p_q_configs = [[17] * 1, [10000], [17] * 1, []] + p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] # construct random length testcases for _ in range(1): diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index fb0d260dc3..e4cc53b0e4 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -797,8 +797,8 @@ __device__ __forceinline__ void logits_mask_multi_item_scoring( const uint32_t kv_len, const uint32_t window_left, const uint32_t chunk_end, const uint_fastdiv group_size, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], // new arguments for compact description of mask - const uint32_t prefix_len, uint16_t* token_pos_in_items) { - const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z; + const uint32_t prefix_len, uint16_t* token_pos_in_items, const uint32_t lane_idx = threadIdx.x, + const uint32_t kv_head_idx = blockIdx.z) { constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; using DTypeQKAccum = typename KTraits::DTypeQKAccum; @@ -2285,23 +2285,26 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( if (iter >= mask_iteration || iter < window_iteration) { logits_mask( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, - chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, + chunk_start + + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag); } } else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) { if (iter + 1 >= num_iterations_prefix) { logits_mask_multi_item_scoring( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, - chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, + chunk_start + + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, window_left, chunk_end, group_size, s_frag, __ldg(maybe_prefix_len_ptr + request_idx), - maybe_token_pos_in_items_ptr + request_idx * token_pos_in_items_len); + maybe_token_pos_in_items_ptr + request_idx * token_pos_in_items_len, tid.x, + kv_head_idx); } else { if (iter >= mask_iteration || iter < window_iteration) { logits_mask( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + - (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag); } }