2424import numpy as np
2525import numpy .typing as npt
2626import torch
27- import torch .distributed
2827import torch .nn as nn
2928from vllm .attention import AttentionType
3029from vllm .attention .layer import Attention
3635from vllm .model_executor .layers .fused_moe import FusedMoE
3736from vllm .model_executor .model_loader import get_model
3837from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
39- from vllm .platforms import current_platform
4038from vllm .sampling_params import SamplingType
4139from vllm .sequence import IntermediateTensors
42- from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
43- LayerBlockType , cdiv , is_pin_memory_available )
40+ from vllm .utils import DeviceMemoryProfiler , LayerBlockType , cdiv
4441from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
4542from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
4643 KVCacheSpec )
5047
5148from vllm_ascend .attention .attention_v1 import (AscendAttentionBackend ,
5249 AscendMetadata )
50+ from vllm_ascend .platform import NPUPlatform
5351
5452if TYPE_CHECKING :
5553 from vllm .v1 .core .sched .output import SchedulerOutput
6058class NPUModelRunner :
6159
6260 def __init__ (self , vllm_config : VllmConfig , device : torch .device ):
63-
6461 self .vllm_config = vllm_config
6562 self .model_config = vllm_config .model_config
66- self .cache_config = vllm_config .cache_config
6763 self .lora_config = vllm_config .lora_config
68- self .load_config = vllm_config .load_config
69- self .parallel_config = vllm_config .parallel_config
7064 self .scheduler_config = vllm_config .scheduler_config
71- self .speculative_config = vllm_config .speculative_config
72- self .prompt_adapter_config = vllm_config .prompt_adapter_config
73- self .observability_config = vllm_config .observability_config
74-
75- model_config = self .model_config
76- cache_config = self .cache_config
77- scheduler_config = self .scheduler_config
78- parallel_config = self .parallel_config
79-
8065 self .device = device
81- self .pin_memory = is_pin_memory_available ()
82- self .dtype = self .model_config .dtype
83-
84- if cache_config .cache_dtype == "auto" :
85- self .kv_cache_dtype = self .dtype
86- else :
87- self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [
88- cache_config .cache_dtype ]
89-
90- self .is_multimodal_model = model_config .is_multimodal_model
91- self .sliding_window = model_config .get_sliding_window ()
92- self .block_size = cache_config .block_size
93- self .max_model_len = model_config .max_model_len
94- self .max_num_blocks_per_req = cdiv (self .max_model_len , self .block_size )
95- self .max_num_tokens = scheduler_config .max_num_batched_tokens
96- self .max_num_reqs = scheduler_config .max_num_seqs
66+ self .is_multimodal_model = self .model_config .is_multimodal_model
67+ self .block_size = vllm_config .cache_config .block_size
68+ self .max_num_blocks_per_req = cdiv (self .model_config .max_model_len ,
69+ self .block_size )
70+ self .max_num_tokens = self .scheduler_config .max_num_batched_tokens
71+ self .max_num_reqs = self .scheduler_config .max_num_seqs
9772
9873 # Model-related.
99- self .num_attn_layers = model_config .get_num_layers_by_block_type (
100- parallel_config , LayerBlockType .attention )
101- self .num_query_heads = model_config .get_num_attention_heads (
102- parallel_config )
103- self .num_kv_heads = model_config .get_num_kv_heads (parallel_config )
104- self .head_size = model_config .get_head_size ()
105- self .hidden_size = model_config .get_hidden_size ()
74+ self .num_attn_layers = self .model_config .get_num_layers_by_block_type (
75+ vllm_config .parallel_config , LayerBlockType .attention )
76+ self .hidden_size = self .model_config .get_hidden_size ()
10677
10778 # Multi-modal data support
10879 self .input_registry = INPUT_REGISTRY
10980 self .mm_registry = MULTIMODAL_REGISTRY
110- self .uses_mrope = model_config .uses_mrope
81+ self .uses_mrope = self . model_config .uses_mrope
11182
112- encoder_compute_budget , encoder_cache_size = compute_encoder_budget (
113- model_config = model_config ,
114- scheduler_config = scheduler_config ,
83+ self . max_num_encoder_input_tokens , self . encoder_cache_size = compute_encoder_budget (
84+ model_config = self . model_config ,
85+ scheduler_config = self . scheduler_config ,
11586 mm_registry = self .mm_registry )
116- self .max_num_encoder_input_tokens = encoder_compute_budget
117- self .encoder_cache_size = encoder_cache_size
11887
11988 # Lazy initialization
12089 # self.model: nn.Module # Set after load_model
12190 self .kv_caches : List [torch .Tensor ] = []
12291 # req_id -> (input_id -> encoder_output)
12392 self .encoder_cache : Dict [str , Dict [int , torch .Tensor ]] = {}
12493
125- # Set up speculative decoding.
126- self .use_spec_decode = False
127-
12894 # Request states.
12995 self .requests : Dict [str , CachedRequestState ] = {}
13096 # Persistent batch.
13197 self .input_batch = InputBatch (
13298 max_num_reqs = self .max_num_reqs ,
133- max_model_len = self .max_model_len ,
99+ max_model_len = self .model_config . max_model_len ,
134100 max_num_blocks_per_req = self .max_num_blocks_per_req ,
135101 device = self .device ,
136- pin_memory = self . pin_memory ,
137- vocab_size = model_config .get_vocab_size (),
102+ pin_memory = True ,
103+ vocab_size = self . model_config .get_vocab_size (),
138104 )
139105
140106 self .input_ids = torch .zeros (self .max_num_tokens ,
@@ -165,46 +131,41 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
165131 (3 , self .max_num_tokens + 1 ),
166132 dtype = torch .int64 ,
167133 device = "cpu" ,
168- pin_memory = self . pin_memory )
134+ pin_memory = True )
169135
170136 self .inputs_embeds = torch .zeros (
171137 (self .max_num_tokens , self .hidden_size ),
172- dtype = self .dtype ,
138+ dtype = self .model_config . dtype ,
173139 device = self .device )
174140
175141 # OPTIMIZATION: Cache the tensors rather than creating them every step.
176142 self .arange_np : npt .NDArray [np .int32 ] = np .arange (max (
177- self .max_num_reqs + 1 , self .max_model_len , self .max_num_tokens ),
143+ self .max_num_reqs + 1 , self .model_config .max_model_len ,
144+ self .max_num_tokens ),
178145 dtype = np .int32 )
179146 # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
180147 # a faster version of creating a new tensor every time. Thus, we should
181148 # not make any assumptions about the values in these tensors.
182149 self .input_ids_cpu = torch .zeros (self .max_num_tokens ,
183150 dtype = torch .int32 ,
184151 device = "cpu" ,
185- pin_memory = self .pin_memory )
186- self .input_ids_np = self .input_ids_cpu .numpy ()
152+ pin_memory = True )
187153 self .positions_cpu = torch .zeros (self .max_num_tokens ,
188154 dtype = torch .int64 ,
189155 device = "cpu" ,
190- pin_memory = self . pin_memory )
156+ pin_memory = True )
191157 self .positions_np = self .positions_cpu .numpy ()
192158
193159 self .slot_mapping_cpu = torch .zeros (self .max_num_tokens ,
194160 dtype = torch .int32 ,
195161 device = "cpu" ,
196- pin_memory = self . pin_memory )
162+ pin_memory = True )
197163 self .slot_mapping_np = self .slot_mapping_cpu .numpy ()
198164
199- self .query_start_loc_cpu = torch .zeros (self .max_num_reqs + 1 ,
200- dtype = torch .int32 ,
201- device = "cpu" ,
202- pin_memory = self .pin_memory )
203- self .query_start_loc_np = self .query_start_loc_cpu .numpy ()
204165 self .seq_lens_cpu = torch .zeros (self .max_num_reqs ,
205166 dtype = torch .int32 ,
206167 device = "cpu" ,
207- pin_memory = self . pin_memory )
168+ pin_memory = True )
208169 self .seq_lens_np = self .seq_lens_cpu .numpy ()
209170
210171 self .input_positions_cpu = torch .arange (0 ,
@@ -220,7 +181,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
220181 # Therefore, an environment variable is added here to dynamically set
221182 # the size of the pre-constructed mask matrix based on requirements.
222183 mask_len = os .getenv ("PAGED_ATTENTION_MASK_LEN" , 10000 )
223- self .attn_mask_len = min (self .max_model_len , int (mask_len ))
184+ self .attn_mask_len = min (self .model_config .max_model_len ,
185+ int (mask_len ))
224186 self .attn_mask_npu = torch .full (
225187 (self .attn_mask_len , self .attn_mask_len ),
226188 NPU_PAGED_ATTENTION_MASK_VALUE ,
@@ -384,8 +346,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
384346 def get_model (self ) -> nn .Module :
385347 return self .model
386348
387- def make_attention_mask (self , seq_lens , query_lens ,
388- position ) -> torch .Tensor :
349+ def _make_attention_mask (self , seq_lens , query_lens ,
350+ position ) -> torch .Tensor :
389351 max_seq_len = max (seq_lens , default = 0 )
390352 if max_seq_len <= self .attn_mask_len :
391353 return torch .index_select (self .attn_mask_npu ,
@@ -475,9 +437,9 @@ def _process_reqs(
475437 slot_mapping = self .slot_mapping_cpu [:total_num_scheduled_tokens ].to (
476438 self .device , non_blocking = True )
477439
478- attn_mask = self .make_attention_mask (seq_lens = seq_lens ,
479- query_lens = num_scheduled_tokens ,
480- position = positions )
440+ attn_mask = self ._make_attention_mask (seq_lens = seq_lens ,
441+ query_lens = num_scheduled_tokens ,
442+ position = positions )
481443
482444 attn_metadata = AscendMetadata (
483445 seq_lens = query_lens ,
@@ -653,22 +615,19 @@ def _profile_multimodal(self) -> None:
653615 self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
654616
655617 @torch .inference_mode ()
656- def _dummy_run (
657- self ,
658- num_tokens : int ,
659- ) -> torch .Tensor :
618+ def _dummy_run (self ) -> torch .Tensor :
660619 model = self .model
661620 if self .is_multimodal_model :
662621 input_ids = None
663- inputs_embeds = self .inputs_embeds [:num_tokens ]
622+ inputs_embeds = self .inputs_embeds [:self . max_num_tokens ]
664623 else :
665- input_ids = self .input_ids [:num_tokens ]
624+ input_ids = self .input_ids [:self . max_num_tokens ]
666625 inputs_embeds = None
667626
668627 if self .uses_mrope :
669- positions = self .mrope_positions [:, :num_tokens ]
628+ positions = self .mrope_positions [:, :self . max_num_tokens ]
670629 else :
671- positions = self .input_positions_cpu [:num_tokens ]
630+ positions = self .input_positions_cpu [:self . max_num_tokens ]
672631
673632 if get_pp_group ().is_first_rank :
674633 intermediate_tensors = None
@@ -680,7 +639,7 @@ def _dummy_run(
680639 dtype = self .model_config .dtype ,
681640 device = self .device ))
682641 intermediate_tensors = IntermediateTensors ({
683- k : v [:num_tokens ]
642+ k : v [:self . max_num_tokens ]
684643 for k , v in self .intermediate_tensors .items ()
685644 })
686645
@@ -719,15 +678,15 @@ def profile_run(self) -> None:
719678 ]
720679
721680 # Trigger compilation for general shape.
722- hidden_states = self ._dummy_run (self . max_num_tokens )
681+ hidden_states = self ._dummy_run ()
723682
724683 if get_pp_group ().is_last_rank :
725684 hidden_states = hidden_states [logit_indices ]
726685 logits = self .model .compute_logits (hidden_states , None )
727686 else :
728687 logits = None
729688
730- current_platform .synchronize ()
689+ NPUPlatform .synchronize ()
731690 del hidden_states , logits , dummy_kv_caches
732691 self .encoder_cache .clear ()
733692 gc .collect ()
@@ -739,10 +698,8 @@ def load_model(self) -> None:
739698 self .model = get_model (vllm_config = self .vllm_config )
740699 if self .lora_config :
741700 raise ValueError ("LoRA model is not supported on NPU now." )
742-
743- self .model_memory_usage = m .consumed_memory
744701 logger .info ("Loading model weights took %.4f GB" ,
745- self . model_memory_usage / float (2 ** 30 ))
702+ m . consumed_memory / float (2 ** 30 ))
746703
747704 def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
748705 """
0 commit comments