Skip to content

Commit 84be1b0

Browse files
committed
Add unit test for block_size calculation.
1 parent f1dfb05 commit 84be1b0

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

tests/platforms/test_tpu.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for the Pallas MOE implementation.
3+
4+
Run `pytest tests/platforms/test_tpu.py`.
5+
"""
6+
from unittest.mock import MagicMock, patch
7+
8+
import pytest
9+
10+
import vllm.config
11+
from vllm.platforms.tpu import TpuPlatform
12+
13+
14+
@pytest.mark.parametrize(
15+
"use_v1,initial_block_size,expected_block_size",
16+
[
17+
(True, 32, 32), # Case 1: v1: block_size set, should remain unchanged
18+
(
19+
True, None, 128
20+
), # Case 2: v1: block_size None, should be set to get_page_size (128)
21+
(False, None, 16), # Case 3: v0: block_size None, should be set to 16
22+
(False, 32, 32), # Case 4: v0: block_size set, should remain unchanged
23+
])
24+
@patch(
25+
"vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_page_size",
26+
return_value=128)
27+
@patch(
28+
"vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_min_page_size",
29+
return_value=8)
30+
def test_tpu_platform_update_vllm_config_block_size_respect_passin_block_size(
31+
mock_get_min_page_size, mock_get_page_size, use_v1, initial_block_size,
32+
expected_block_size) -> None:
33+
"""Test TPU platform updates VLLM config with block size."""
34+
# arrange
35+
mock_cached_config = MagicMock()
36+
mock_cached_config.block_size = initial_block_size
37+
38+
mock_model_config = MagicMock()
39+
mock_model_config.dtype = "float16"
40+
41+
mock_vllm_config = MagicMock()
42+
mock_vllm_config.cache_config = mock_cached_config
43+
mock_vllm_config.compilation_config = MagicMock()
44+
mock_vllm_config.compilation_config.level = (
45+
vllm.config.CompilationLevel.DYNAMO_ONCE)
46+
mock_vllm_config.compilation_config.backend = "openxla"
47+
mock_vllm_config.model_config = mock_model_config
48+
mock_vllm_config.speculative_config = None
49+
mock_vllm_config.parallel_config = MagicMock()
50+
mock_vllm_config.parallel_config.worker_cls = (
51+
"vllm.v1.worker.tpu_worker.TPUWorker")
52+
mock_vllm_config.scheduler_config = MagicMock()
53+
54+
# act
55+
with patch("vllm.envs.VLLM_USE_V1", use_v1):
56+
TpuPlatform.check_and_update_config(mock_vllm_config)
57+
58+
# assert
59+
assert mock_cached_config.block_size == expected_block_size
60+
if use_v1:
61+
mock_get_min_page_size.assert_called()
62+
if initial_block_size is None:
63+
mock_get_page_size.assert_called()
64+
else:
65+
mock_get_page_size.assert_not_called()

vllm/platforms/tpu.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
117117
if envs.VLLM_USE_V1:
118118
from vllm.v1.attention.backends.pallas import (
119119
PallasAttentionBackend)
120+
120121
# For v1, the default block size is calculated from vllm_config.
121122
cache_config.block_size = (
122123
cache_config.block_size
123-
or PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment]
124+
or PallasAttentionBackend.get_page_size(
125+
vllm_config) # type: ignore[assignment]
124126
)
125127

126128
min_page_size = PallasAttentionBackend.get_min_page_size(
@@ -135,9 +137,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
135137
cache_config.block_size = min_page_size # type: ignore[assignment]
136138
else:
137139
# For v0, the default block size is 16.
138-
cache_config.block_size = (
139-
cache_config.block_size or cast(BlockSize, 16)
140-
)
140+
cache_config.block_size = (cache_config.block_size
141+
or cast(BlockSize, 16))
141142
parallel_config = vllm_config.parallel_config
142143
scheduler_config = vllm_config.scheduler_config
143144
if parallel_config.worker_cls == "auto":

0 commit comments

Comments
 (0)