Skip to content

Commit 863d315

Browse files
authored
[V1][TPU] Pad the block_table.shape[1] so the ragged paged attention can handle correctly (#14597)
1 parent d374f04 commit 863d315

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from vllm.sampling_params import SamplingType
2424
from vllm.sequence import IntermediateTensors
2525
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
26-
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
26+
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
27+
PallasAttentionBackend,
2728
PallasMetadata)
2829
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
2930
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@@ -138,8 +139,10 @@ def __init__(
138139
device="cpu")
139140
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
140141

142+
padded_max_num_blocks_per_req = _get_padded_number(
143+
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
141144
self.block_table_cpu = torch.zeros(
142-
(self.max_num_tokens, self.max_num_blocks_per_req),
145+
(self.max_num_tokens, padded_max_num_blocks_per_req),
143146
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
144147
device="cpu")
145148

0 commit comments

Comments
 (0)