From 485e8e5699989f97410924c46ad64d6fbc3d0762 Mon Sep 17 00:00:00 2001 From: DearPlanet Date: Thu, 27 Mar 2025 15:11:56 +0800 Subject: [PATCH 1/2] fix(backend): fix impl of `swap_blocks` and `copy_blocks` in flashinfer backend Signed-off-by: DearPlanet --- vllm/attention/backends/flashinfer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 0556c191ddea..0c65edfce248 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -83,14 +83,19 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + # Directly swap an entire KV Cache block, instead of splitting into K and V at the first dim + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + # K and V are seperated in the second dim, not the first dim + key_caches = [kv_cache[:, 0] for kv_cache in kv_caches] + value_caches = [kv_cache[:, 1] for kv_cache in kv_caches] + + ops.copy_blocks(key_caches, value_caches, src_to_dists) @staticmethod def get_supported_head_sizes() -> List[int]: From 19377d2fbf879a62804a8d6eee562dd094fef714 Mon Sep 17 00:00:00 2001 From: DearPlanet Date: Thu, 27 Mar 2025 15:25:03 +0800 Subject: [PATCH 2/2] fix: typo Signed-off-by: DearPlanet --- vllm/attention/backends/flashinfer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 0c65edfce248..163497983f10 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -36,7 +36,6 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.layer import Attention -from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig, get_current_vllm_config from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -83,7 +82,8 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - # Directly swap an entire KV Cache block, instead of splitting into K and V at the first dim + # Directly swap an entire KV Cache block + # Instead of splitting into K and V at the first dim ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod @@ -91,10 +91,10 @@ def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - # K and V are seperated in the second dim, not the first dim + # K and V are separated in the second dim, not the first dim key_caches = [kv_cache[:, 0] for kv_cache in kv_caches] value_caches = [kv_cache[:, 1] for kv_cache in kv_caches] - + ops.copy_blocks(key_caches, value_caches, src_to_dists) @staticmethod @@ -642,7 +642,7 @@ def prepare(self): # [0, 3, 6, 8] self.paged_kv_indices: List[int] = [] # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. + # first request's page indices in the paged_kv_indices list. self.paged_kv_indptr: List[int] = [0] # paged_kv_last_page_len is the length of the last page of each request self.paged_kv_last_page_len: List[int] = []