File tree Expand file tree Collapse file tree 2 files changed +14
-9
lines changed Expand file tree Collapse file tree 2 files changed +14
-9
lines changed Original file line number Diff line number Diff line change @@ -48,13 +48,7 @@ def get_kv_cache_shape(
4848 ) -> tuple [int , ...]:
4949 padded_head_size = cdiv (
5050 head_size , TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
51- num_blocks = num_blocks * head_size // padded_head_size
52- if padded_head_size != head_size :
53- logger .warning_once (
54- "head size is padded to %d, and num_blocks is adjusted to %d"
55- " accordingly" , padded_head_size , num_blocks )
56- head_size = padded_head_size
57- return (num_blocks , block_size , num_kv_heads * 2 , head_size )
51+ return (num_blocks , block_size , num_kv_heads * 2 , padded_head_size )
5852
5953 @staticmethod
6054 def swap_blocks (
Original file line number Diff line number Diff line change 1818from vllm .logger import init_logger
1919from vllm .lora .request import LoRARequest
2020from vllm .model_executor import set_random_seed
21- from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE
21+ from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , cdiv
22+ from vllm .v1 .attention .backends .pallas import TPU_HEAD_SIZE_ALIGNMENT
2223from vllm .v1 .core .sched .output import SchedulerOutput
2324from vllm .v1 .kv_cache_interface import (AttentionSpec , KVCacheConfig ,
2425 KVCacheSpec )
@@ -221,7 +222,17 @@ def determine_available_memory(self) -> int:
221222 usable_memory_size = int (total_memory_size *
222223 self .cache_config .gpu_memory_utilization )
223224 tpu_kv_cache_bytes = max (usable_memory_size - profiled , 0 )
224-
225+ head_size = self .model_config .get_head_size ()
226+ if head_size > 0 :
227+ padded_head_size = cdiv (
228+ head_size , TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
229+ if padded_head_size != head_size :
230+ logger .warning_once ("head size is padded to %d" ,
231+ padded_head_size )
232+ # We adjust the usable memory size for the KV cache to prevent OOM
233+ # errors, even after padding the head_size.
234+ tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size //
235+ padded_head_size )
225236 return int (tpu_kv_cache_bytes )
226237
227238 def execute_model (
You can’t perform that action at this time.
0 commit comments