Skip to content

Commit f1dfb05

Browse files
committed
Respect pass in block size.
1 parent fabe89b commit f1dfb05

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

vllm/platforms/tpu.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9393
from vllm.config import CompilationLevel
9494

9595
cache_config = vllm_config.cache_config
96-
# For v0, the default block size is 16.
97-
if cache_config and cache_config.block_size is None:
98-
cache_config.block_size = cast(BlockSize, 16)
96+
assert cache_config is not None
97+
9998
compilation_config = vllm_config.compilation_config
10099

101100
# TPU only supports DYNAMO_ONCE compilation level
@@ -118,8 +117,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
118117
if envs.VLLM_USE_V1:
119118
from vllm.v1.attention.backends.pallas import (
120119
PallasAttentionBackend)
121-
cache_config.block_size = PallasAttentionBackend.get_page_size(
122-
vllm_config) # type: ignore[assignment]
120+
# For v1, the default block size is calculated from vllm_config.
121+
cache_config.block_size = (
122+
cache_config.block_size
123+
or PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment]
124+
)
125+
123126
min_page_size = PallasAttentionBackend.get_min_page_size(
124127
vllm_config)
125128
if min_page_size > cache_config.block_size:
@@ -130,7 +133,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
130133
min_page_size,
131134
)
132135
cache_config.block_size = min_page_size # type: ignore[assignment]
133-
136+
else:
137+
# For v0, the default block size is 16.
138+
cache_config.block_size = (
139+
cache_config.block_size or cast(BlockSize, 16)
140+
)
134141
parallel_config = vllm_config.parallel_config
135142
scheduler_config = vllm_config.scheduler_config
136143
if parallel_config.worker_cls == "auto":

0 commit comments

Comments
 (0)