|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Tests for the Pallas MOE implementation. |
| 3 | +
|
| 4 | +Run `pytest tests/platforms/test_tpu.py`. |
| 5 | +""" |
| 6 | +from unittest.mock import MagicMock, patch |
| 7 | + |
| 8 | +import pytest |
| 9 | + |
| 10 | +import vllm.config |
| 11 | +from vllm.platforms.tpu import TpuPlatform |
| 12 | + |
| 13 | + |
| 14 | +@pytest.mark.parametrize( |
| 15 | + "use_v1,initial_block_size,expected_block_size", |
| 16 | + [ |
| 17 | + (True, 32, 32), # Case 1: v1: block_size set, should remain unchanged |
| 18 | + ( |
| 19 | + True, None, 128 |
| 20 | + ), # Case 2: v1: block_size None, should be set to get_page_size (128) |
| 21 | + (False, None, 16), # Case 3: v0: block_size None, should be set to 16 |
| 22 | + (False, 32, 32), # Case 4: v0: block_size set, should remain unchanged |
| 23 | + ]) |
| 24 | +@patch( |
| 25 | + "vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_page_size", |
| 26 | + return_value=128) |
| 27 | +@patch( |
| 28 | + "vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_min_page_size", |
| 29 | + return_value=8) |
| 30 | +def test_tpu_platform_update_vllm_config_block_size_respect_passin_block_size( |
| 31 | + mock_get_min_page_size, mock_get_page_size, use_v1, initial_block_size, |
| 32 | + expected_block_size) -> None: |
| 33 | + """Test TPU platform updates VLLM config with block size.""" |
| 34 | + # arrange |
| 35 | + mock_cached_config = MagicMock() |
| 36 | + mock_cached_config.block_size = initial_block_size |
| 37 | + |
| 38 | + mock_model_config = MagicMock() |
| 39 | + mock_model_config.dtype = "float16" |
| 40 | + |
| 41 | + mock_vllm_config = MagicMock() |
| 42 | + mock_vllm_config.cache_config = mock_cached_config |
| 43 | + mock_vllm_config.compilation_config = MagicMock() |
| 44 | + mock_vllm_config.compilation_config.level = ( |
| 45 | + vllm.config.CompilationLevel.DYNAMO_ONCE) |
| 46 | + mock_vllm_config.compilation_config.backend = "openxla" |
| 47 | + mock_vllm_config.model_config = mock_model_config |
| 48 | + mock_vllm_config.speculative_config = None |
| 49 | + mock_vllm_config.parallel_config = MagicMock() |
| 50 | + mock_vllm_config.parallel_config.worker_cls = ( |
| 51 | + "vllm.v1.worker.tpu_worker.TPUWorker") |
| 52 | + mock_vllm_config.scheduler_config = MagicMock() |
| 53 | + |
| 54 | + # act |
| 55 | + with patch("vllm.envs.VLLM_USE_V1", use_v1): |
| 56 | + TpuPlatform.check_and_update_config(mock_vllm_config) |
| 57 | + |
| 58 | + # assert |
| 59 | + assert mock_cached_config.block_size == expected_block_size |
| 60 | + if use_v1: |
| 61 | + mock_get_min_page_size.assert_called() |
| 62 | + if initial_block_size is None: |
| 63 | + mock_get_page_size.assert_called() |
| 64 | + else: |
| 65 | + mock_get_page_size.assert_not_called() |
0 commit comments