Skip to content

Commit 2cc2069

Browse files
authored
[TPU][Bugfix] fix kv cache padding (#20048)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 9f0608f commit 2cc2069

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

vllm/v1/attention/backends/pallas.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,7 @@ def get_kv_cache_shape(
4848
) -> tuple[int, ...]:
4949
padded_head_size = cdiv(
5050
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
51-
num_blocks = num_blocks * head_size // padded_head_size
52-
if padded_head_size != head_size:
53-
logger.warning_once(
54-
"head size is padded to %d, and num_blocks is adjusted to %d"
55-
" accordingly", padded_head_size, num_blocks)
56-
head_size = padded_head_size
57-
return (num_blocks, block_size, num_kv_heads * 2, head_size)
51+
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
5852

5953
@staticmethod
6054
def swap_blocks(

vllm/v1/worker/tpu_worker.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from vllm.logger import init_logger
1919
from vllm.lora.request import LoRARequest
2020
from vllm.model_executor import set_random_seed
21-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
21+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
22+
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
2223
from vllm.v1.core.sched.output import SchedulerOutput
2324
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
2425
KVCacheSpec)
@@ -221,7 +222,17 @@ def determine_available_memory(self) -> int:
221222
usable_memory_size = int(total_memory_size *
222223
self.cache_config.gpu_memory_utilization)
223224
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
224-
225+
head_size = self.model_config.get_head_size()
226+
if head_size > 0:
227+
padded_head_size = cdiv(
228+
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
229+
if padded_head_size != head_size:
230+
logger.warning_once("head size is padded to %d",
231+
padded_head_size)
232+
# We adjust the usable memory size for the KV cache to prevent OOM
233+
# errors, even after padding the head_size.
234+
tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size //
235+
padded_head_size)
225236
return int(tpu_kv_cache_bytes)
226237

227238
def execute_model(

0 commit comments

Comments
 (0)