Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,13 @@ def next_power_of_2(n) -> int:
return 1 << (n - 1).bit_length()


def prev_power_of_2(n: int) -> int:
"""The previous power of 2 (inclusive)"""
if n <= 0:
return 0
return 1 << (n.bit_length() - 1)


def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y

Expand Down Expand Up @@ -2985,4 +2992,4 @@ def has_deep_ep() -> bool:
def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available."""

return _has_module("deep_gemm")
return _has_module("deep_gemm")
6 changes: 6 additions & 0 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache


def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
66 changes: 50 additions & 16 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available)
is_pin_memory_available, prev_power_of_2)
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
PallasMetadata,
get_page_size_bytes)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
Expand All @@ -56,8 +57,6 @@
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8


#########################################################
Expand Down Expand Up @@ -139,7 +138,11 @@ def __init__(
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
model_dtype = self.dtype
if isinstance(model_dtype, str):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else:
self.kv_cache_dtype = model_dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
Expand Down Expand Up @@ -192,6 +195,14 @@ def __init__(
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size

self._num_slices_per_kv_cache_update_block = \
_get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
kv_cache_dtype=self.kv_cache_dtype,
))

# Lazy initialization
self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
Expand Down Expand Up @@ -719,7 +730,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
num_kv_update_slices = slot_mapping_metadata.shape[0]
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size)
self.block_size, self._num_slices_per_kv_cache_update_block)
slot_mapping_metadata = np.pad(
slot_mapping_metadata,
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
Expand Down Expand Up @@ -750,8 +761,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
num_kv_update_slices=torch.tensor([num_kv_update_slices],
dtype=torch.int32,
device=self.device),
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
num_slices_per_kv_cache_update_block=self.
_num_slices_per_kv_cache_update_block,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
Expand Down Expand Up @@ -1178,7 +1189,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
num_tokens, self.max_num_reqs, self.block_size)
num_tokens, self.max_num_reqs, self.block_size,
self._num_slices_per_kv_cache_update_block)
num_kv_update_slices = torch.tensor([padded_num_slices],
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros((3, padded_num_slices),
Expand All @@ -1201,8 +1213,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
num_slices_per_kv_cache_update_block=self.
_num_slices_per_kv_cache_update_block,
)

if self.is_multimodal_model:
Expand Down Expand Up @@ -1807,19 +1819,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return paddings[index]


def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
page_size: int) -> int:
def _get_padded_num_kv_cache_update_slices(
num_tokens: int, max_num_reqs: int, page_size: int,
num_slices_per_kv_cache_update_block: int) -> int:
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
padded_num_slices = (
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
padded_num_slices + num_slices_per_kv_cache_update_block - 1
) // num_slices_per_kv_cache_update_block * \
num_slices_per_kv_cache_update_block
return padded_num_slices


def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
"""Find the optimum number of slices to copy per Pallas program instance.

Increasing the number of slices copied in one instance of the kernel program
will increase HBM bandwidth utilization via more in-flight DMAs.

However, it will also use more VMEM, and experimentally, we observed
performance regression at 128 slices on v6e, likely due to running
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# Conservative VMEM usage limit: 32 MiB
vmem_limit = 32 * 1024 * 1024
num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
if num_slices_per_block > 64:
num_slices_per_block = 64
return num_slices_per_block


def replace_set_lora(model):

def _tpu_set_lora(
Expand Down