11# SPDX-License-Identifier: Apache-2.0
2- import enum
32import time
4- from dataclasses import dataclass
53from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , cast
64from unittest .mock import patch
75
1311import torch_xla .core .xla_model as xm
1412import torch_xla .runtime as xr
1513
16- from vllm .attention import AttentionMetadata
1714from vllm .attention .backends .abstract import AttentionType
1815from vllm .attention .layer import Attention
1916from vllm .config import VllmConfig
2219from vllm .model_executor .model_loader import get_model
2320from vllm .sampling_params import SamplingType
2421from vllm .utils import LayerBlockType , cdiv , is_pin_memory_available
25- from vllm .v1 .attention .backends .pallas import (PallasAttentionBackend ,
26- PallasMetadata ,
22+ from vllm .v1 .attention .backends .pallas import (NUM_KV_PAGES_PER_BLOCK ,
2723 NUM_QUERIES_PER_BLOCK ,
28- NUM_KV_PAGES_PER_BLOCK )
24+ PallasAttentionBackend ,
25+ PallasMetadata )
2926from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
3027 KVCacheSpec )
31- from vllm .v1 .outputs import LogprobsTensors , ModelRunnerOutput
28+ from vllm .v1 .outputs import ModelRunnerOutput
3229from vllm .v1 .utils import bind_kv_cache
3330from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
3431
4845INVALID_TOKEN_ID = - 1
4946
5047
51- class TPUModelRunner () :
48+ class TPUModelRunner :
5249
5350 def __init__ (
5451 self ,
@@ -80,8 +77,8 @@ def __init__(
8077 self .block_size = cache_config .block_size
8178 self .max_model_len = model_config .max_model_len
8279 self .max_num_blocks_per_req = cdiv (self .max_model_len , self .block_size )
83- self .max_num_tokens = scheduler_config .max_num_batched_tokens # 8192
84- self .max_num_reqs = scheduler_config .max_num_seqs # 16
80+ self .max_num_tokens = scheduler_config .max_num_batched_tokens
81+ self .max_num_reqs = scheduler_config .max_num_seqs
8582
8683 # Model-related.
8784 self .num_attn_layers = model_config .get_num_layers_by_block_type (
@@ -115,8 +112,8 @@ def __init__(
115112 # The pytorch tensor and numpy array share the same buffer.
116113 # Sometimes the numpy op is faster so we create both.
117114 self .input_ids_cpu = torch .zeros (self .max_num_tokens ,
118- dtype = torch .int32 ,
119- device = "cpu" )
115+ dtype = torch .int32 ,
116+ device = "cpu" )
120117 self .input_ids_np = self .input_ids_cpu .numpy ()
121118
122119 self .positions_cpu = torch .zeros (self .max_num_tokens ,
@@ -132,10 +129,12 @@ def __init__(
132129 # self.input_batch.block_table has a shape of [max_num_reqs, max_num_blocks_per_req].
133130 # To reduce the number of recompilation, we want the block_table.shape[0] to be num_tokens.
134131 # To make the block_table to be compatible with the paged attention kernel, we want the block_table[1] to be multiple of NUM_KV_PAGES_PER_BLOCK.
135- padded_max_num_blocks_per_req = _get_padded_number (self .max_num_blocks_per_req , NUM_KV_PAGES_PER_BLOCK )
136- self .block_table_cpu = torch .zeros ((self .max_num_tokens , padded_max_num_blocks_per_req ),
137- dtype = self .input_batch .block_table .get_cpu_tensor ().dtype ,
138- device = "cpu" )
132+ padded_max_num_blocks_per_req = _get_padded_number (
133+ self .max_num_blocks_per_req , NUM_KV_PAGES_PER_BLOCK )
134+ self .block_table_cpu = torch .zeros (
135+ (self .max_num_tokens , padded_max_num_blocks_per_req ),
136+ dtype = self .input_batch .block_table .get_cpu_tensor ().dtype ,
137+ device = "cpu" )
139138
140139 self .query_start_loc_cpu = torch .zeros (self .max_num_tokens + 1 ,
141140 dtype = torch .int32 ,
@@ -325,9 +324,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
325324 assert req_id is not None
326325 num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
327326 num_scheduled_tokens_per_req .append (num_tokens )
328- max_num_scheduled_tokens_all_reqs = max (max_num_scheduled_tokens_all_reqs ,
329- num_tokens )
330- num_scheduled_tokens_per_req = np .array (num_scheduled_tokens_per_req , dtype = np .int32 )
327+ max_num_scheduled_tokens_all_reqs = max (
328+ max_num_scheduled_tokens_all_reqs , num_tokens )
329+ num_scheduled_tokens_per_req = np .array (num_scheduled_tokens_per_req ,
330+ dtype = np .int32 )
331331 assert max_num_scheduled_tokens_all_reqs > 0
332332
333333 # Get request indices.
@@ -341,13 +341,13 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
341341 # For each scheduled token, what is its position in the corresponding req.
342342 arange = np .concatenate (
343343 [self .arange_np [:n ] for n in num_scheduled_tokens_per_req ])
344-
344+
345345 # Get positions.
346346 positions_np = self .positions_np [:total_num_scheduled_tokens ]
347347 np .add (self .input_batch .num_computed_tokens_cpu [req_indices ],
348348 arange ,
349349 out = positions_np )
350-
350+
351351 # Get token indices.
352352 # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
353353 # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -362,7 +362,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
362362 0 ,
363363 torch .from_numpy (token_indices ),
364364 out = self .input_ids_cpu [:total_num_scheduled_tokens ])
365-
365+
366366 # Calculate the slot mapping.
367367 # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
368368 # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
@@ -381,27 +381,40 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
381381 np .add (block_numbers * self .block_size ,
382382 block_offsets ,
383383 out = self .slot_mapping_np [:total_num_scheduled_tokens ])
384-
384+
385385 # Prepare the attention metadata.
386386 self .query_start_loc_np [0 ] = 0
387387 np .cumsum (num_scheduled_tokens_per_req ,
388388 out = self .query_start_loc_np [1 :num_reqs + 1 ])
389-
389+
390390 self .seq_lens_np [:num_reqs ] = (
391391 self .input_batch .num_computed_tokens_cpu [:num_reqs ] +
392392 num_scheduled_tokens_per_req )
393393
394394 # Do the padding and copy the tensors to the TPU.
395- padded_total_num_scheduled_tokens = _get_padded_number (total_num_scheduled_tokens , NUM_QUERIES_PER_BLOCK )
396- self .input_ids = self .input_ids_cpu [:padded_total_num_scheduled_tokens ].to (self .device )
397- self .position_ids = self .positions_cpu [:padded_total_num_scheduled_tokens ].to (self .device )
395+ padded_total_num_scheduled_tokens = _get_padded_number (
396+ total_num_scheduled_tokens , NUM_QUERIES_PER_BLOCK )
397+ self .input_ids = self .input_ids_cpu [:
398+ padded_total_num_scheduled_tokens ].to (
399+ self .device )
400+ self .position_ids = self .positions_cpu [:
401+ padded_total_num_scheduled_tokens ].to (
402+ self .device )
398403 self .slot_mapping_cpu [total_num_scheduled_tokens :] = _PAD_SLOT_ID
399- slot_mapping = self .slot_mapping_cpu [:padded_total_num_scheduled_tokens ].to (self .device )
400- padded_block_table = self .block_table_cpu [:padded_total_num_scheduled_tokens ]
401- padded_block_table [:num_reqs , :self .max_num_blocks_per_req ] = self .input_batch .block_table .get_cpu_tensor ()[:num_reqs ]
404+ slot_mapping = self .slot_mapping_cpu [:
405+ padded_total_num_scheduled_tokens ].to (
406+ self .device )
407+ padded_block_table = self .block_table_cpu [:
408+ padded_total_num_scheduled_tokens ]
409+ padded_block_table [:num_reqs , :self .
410+ max_num_blocks_per_req ] = self .input_batch .block_table .get_cpu_tensor (
411+ )[:num_reqs ]
402412 padded_block_table = padded_block_table .to (self .device )
403- query_start_loc = self .query_start_loc_cpu [:padded_total_num_scheduled_tokens + 1 ].to (self .device )
404- seq_lens = self .seq_lens_cpu [:padded_total_num_scheduled_tokens ].to (self .device )
413+ query_start_loc = self .query_start_loc_cpu [:
414+ padded_total_num_scheduled_tokens
415+ + 1 ].to (self .device )
416+ seq_lens = self .seq_lens_cpu [:padded_total_num_scheduled_tokens ].to (
417+ self .device )
405418
406419 attn_metadata = PallasMetadata (
407420 slot_mapping = slot_mapping ,
@@ -418,7 +431,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
418431 logits_indices = query_start_loc [1 :] - 1
419432 return attn_metadata , logits_indices
420433
421-
422434 @torch .no_grad ()
423435 def execute_model (
424436 self ,
@@ -432,7 +444,7 @@ def execute_model(
432444 total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
433445
434446 # Run the decoder
435- with set_forward_context (attn_metadata , self .vllm_config ):
447+ with set_forward_context (attn_metadata , self .vllm_config ):
436448 hidden_states = self .model (
437449 token_ids = self .input_ids ,
438450 position_ids = self .position_ids ,
@@ -461,14 +473,14 @@ def execute_model(
461473 if generator is not None :
462474 # This relies on cuda-specific torch-internal impl details
463475 generator .set_offset (generator .get_offset () - 4 )
464-
476+
465477 # num_reqs entries should be non-None
466478 assert all (
467479 req_id is not None for req_id in
468480 self .input_batch .req_ids [:num_reqs ]), "req_ids contains None"
469481 req_ids = cast (List [str ], self .input_batch .req_ids [:num_reqs ])
470482
471- prompt_logprobs_dict : Dict [ str , LogprobsTensors ] = {}
483+ prompt_logprobs_dict = {}
472484 for req_id in self .input_batch .req_ids [:num_reqs ]:
473485 prompt_logprobs_dict [req_id ] = None
474486
@@ -526,7 +538,6 @@ def load_model(self) -> None:
526538 xm .mark_step ()
527539 xm .wait_device_ops ()
528540 model = ModelWrapperV1 (model )
529- # self.model = model
530541 self .model = torch .compile (model ,
531542 backend = "openxla" ,
532543 fullgraph = True ,
@@ -546,12 +557,14 @@ def dummy_run(
546557 slot_mapping = torch .zeros (num_tokens ,
547558 dtype = torch .int64 ,
548559 device = self .device )
549- block_tables = torch .zeros (
550- (num_tokens , self .block_table_cpu .shape [1 ]),
551- dtype = torch .int32 ,
552- device = self .device )
560+ block_tables = torch .zeros ((num_tokens , self .block_table_cpu .shape [1 ]),
561+ dtype = torch .int32 ,
562+ device = self .device )
553563 query_lens = [1 ] * num_tokens
554- query_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens , dtype = torch .int32 ), dim = 0 , dtype = torch .int32 ).to (self .device )
564+ query_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens ,
565+ dtype = torch .int32 ),
566+ dim = 0 ,
567+ dtype = torch .int32 ).to (self .device )
555568 context_lens = torch .ones ((num_tokens , ),
556569 dtype = torch .int32 ,
557570 device = self .device )
@@ -581,7 +594,6 @@ def capture_model(self) -> None:
581594
582595 start = time .perf_counter ()
583596 num_tokens = 16
584- # The num_tokens_list below is how GPU precompiles.
585597 while True :
586598 self .dummy_run (self .kv_caches , num_tokens )
587599 logger .info (" -- num_tokens: %d" , num_tokens )
@@ -591,8 +603,7 @@ def capture_model(self) -> None:
591603 break
592604 num_tokens *= 2
593605 end = time .perf_counter ()
594- logger .info ("Compilation finished in in %.2f [secs]." ,
595- end - start )
606+ logger .info ("Compilation finished in in %.2f [secs]." , end - start )
596607
597608 def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
598609 """
@@ -682,7 +693,7 @@ def forward(
682693 position_ids ,
683694 kv_caches ,
684695 )
685-
696+
686697 return hidden_states
687698
688699 def compute_logits (
0 commit comments