Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/v1/tpu/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
]

TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]

# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
Expand All @@ -32,12 +33,14 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
model: str,
max_tokens: int,
tensor_parallel_size: int,
max_num_seqs: int,
) -> None:
prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
Expand All @@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens=1024,
max_model_len=8196,
max_model_len=8192,
gpu_memory_utilization=0.7,
max_num_seqs=16,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
Expand Down
14 changes: 14 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"Using bfloat16 instead.", vllm_config.model_config.dtype)
vllm_config.model_config.dtype = torch.bfloat16

if envs.VLLM_USE_V1:
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > vllm_config.cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
vllm_config.cache_config.block_size,
min_page_size,
)
vllm_config.cache_config.block_size = min_page_size

parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
Expand Down
15 changes: 15 additions & 0 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv

logger = init_logger(__name__)

Expand Down Expand Up @@ -50,6 +52,19 @@ def swap_blocks(
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")

# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (1024 * 1024 // 2 //
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the 2 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we simply make sure that the size is smaller than half of SMEM capacity

vllm_config.scheduler_config.max_num_seqs // 4)
min_page_size = cdiv(vllm_config.model_config.max_model_len,
max_num_page_per_req)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size


@dataclass
class PallasMetadata:
Expand Down