From 99fd5cc57a29f521d3fc587649a1b2955cf01037 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sat, 19 Apr 2025 05:53:39 +0000 Subject: [PATCH 1/2] [TPU][V1] Implicitly adjust page size when there's SMEM OOM Signed-off-by: Chengji Yao --- tests/v1/tpu/test_basic.py | 4 +++- vllm/platforms/tpu.py | 14 ++++++++++++++ vllm/v1/attention/backends/pallas.py | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 8164952fe382..13e118561271 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -22,6 +22,7 @@ ] TENSOR_PARALLEL_SIZES = [1] +MAX_NUM_REQS = [16, 1924] # TODO: Enable when CI/CD will have a multi-tpu instance # TENSOR_PARALLEL_SIZES = [1, 4] @@ -32,6 +33,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) +@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS) def test_basic( vllm_runner: type[VllmRunner], monkeypatch: pytest.MonkeyPatch, @@ -51,7 +53,7 @@ def test_basic( # Note: max_num_batched_tokens == 1024 is needed here to # actually test chunked prompt max_num_batched_tokens=1024, - max_model_len=8196, + max_model_len=8192, gpu_memory_utilization=0.7, max_num_seqs=16, tensor_parallel_size=tensor_parallel_size) as vllm_model: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 83dd3e9c817a..b1e221e28b43 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -97,6 +97,20 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Using bfloat16 instead.", vllm_config.model_config.dtype) vllm_config.model_config.dtype = torch.bfloat16 + if envs.VLLM_USE_V1: + from vllm.v1.attention.backends.pallas import ( + PallasAttentionBackend) + min_page_size = PallasAttentionBackend.get_min_page_size( + vllm_config) + if min_page_size > vllm_config.cache_config.block_size: + logger.warning( + "Increase the page size from %s to %s to make sure there's" + "no SMEM OOM", + vllm_config.cache_config.block_size, + min_page_size, + ) + vllm_config.cache_config.block_size = min_page_size + parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 3e8149a24ebf..05b97172bc6c 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -10,7 +10,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.utils import cdiv logger = init_logger(__name__) @@ -50,6 +52,19 @@ def swap_blocks( ) -> None: raise RuntimeError("swap_blocks is not used for the TPU backend.") + # In recent TPU generations, up to v6e, the SMEM size is 1MB. The + # block_tables within the PallasMetadata constitute almost the entire SMEM + # requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here + # we simply make sure that the size is smaller than half of SMEM capacity. + @staticmethod + def get_min_page_size(vllm_config: VllmConfig) -> int: + max_num_page_per_req = (1024 * 1024 // 2 // + vllm_config.scheduler_config.max_num_seqs // 4) + min_page_size = cdiv(vllm_config.model_config.max_model_len, + max_num_page_per_req) + min_page_size = 1 << (min_page_size - 1).bit_length() + return min_page_size + @dataclass class PallasMetadata: From dae230238310b2eba70dac28d0cadfb905a9e133 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sat, 19 Apr 2025 05:57:05 +0000 Subject: [PATCH 2/2] fix test Signed-off-by: Chengji Yao --- tests/v1/tpu/test_basic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 13e118561271..a4571a554572 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -22,7 +22,7 @@ ] TENSOR_PARALLEL_SIZES = [1] -MAX_NUM_REQS = [16, 1924] +MAX_NUM_REQS = [16, 1024] # TODO: Enable when CI/CD will have a multi-tpu instance # TENSOR_PARALLEL_SIZES = [1, 4] @@ -40,6 +40,7 @@ def test_basic( model: str, max_tokens: int, tensor_parallel_size: int, + max_num_seqs: int, ) -> None: prompt = "The next numbers of the sequence " + ", ".join( str(i) for i in range(1024)) + " are:" @@ -55,7 +56,7 @@ def test_basic( max_num_batched_tokens=1024, max_model_len=8192, gpu_memory_utilization=0.7, - max_num_seqs=16, + max_num_seqs=max_num_seqs, tensor_parallel_size=tensor_parallel_size) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)