1515import vllm .envs as envs
1616from vllm .attention .backends .abstract import AttentionType
1717from vllm .attention .layer import Attention
18+ from vllm .compilation .wrapper import TorchCompileWrapperWithCustomDispatcher
1819from vllm .config import VllmConfig
1920from vllm .forward_context import set_forward_context
2021from vllm .logger import init_logger
@@ -691,11 +692,10 @@ def execute_model(
691692 hidden_states = self .model (
692693 input_ids = input_ids ,
693694 positions = self .position_ids ,
694- kv_caches = self .kv_caches ,
695695 inputs_embeds = inputs_embeds ,
696696 )
697- selected_token_ids = self .model . sample_from_hidden (
698- hidden_states , tpu_sampling_metadata )
697+ selected_token_ids = self .sample_from_hidden (hidden_states ,
698+ tpu_sampling_metadata )
699699 # Remove padding on cpu and keep dynamic op outside of xla graph.
700700 selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
701701
@@ -795,17 +795,15 @@ def load_model(self) -> None:
795795 "get_tensor_model_parallel_rank" ,
796796 return_value = xm_tp_rank ):
797797 model = get_model (vllm_config = self .vllm_config )
798- model = model .eval ()
798+ # Sync all pending XLA execution during model initialization and weight
799+ # loading.
799800 xm .mark_step ()
800801 xm .wait_device_ops ()
801- model = ModelWrapperV1 (model )
802- self .model = torch .compile (model ,
803- backend = "openxla" ,
804- fullgraph = True ,
805- dynamic = False )
802+ self .model = model
803+ self .sampler = TPUSampler ()
806804
807805 @torch .no_grad ()
808- def _dummy_run (self , kv_caches , num_tokens : int ) -> None :
806+ def _dummy_run (self , num_tokens : int ) -> None :
809807 if self .is_multimodal_model :
810808 input_ids = None
811809 inputs_embeds = torch .zeros ((num_tokens , self .hidden_size ),
@@ -856,7 +854,6 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None:
856854 with set_forward_context (attn_metadata , self .vllm_config , 0 ):
857855 out = self .model (input_ids = input_ids ,
858856 positions = position_ids ,
859- kv_caches = kv_caches ,
860857 inputs_embeds = inputs_embeds )
861858 self ._hidden_states_dtype = out .dtype
862859
@@ -868,7 +865,7 @@ def capture_model(self) -> None:
868865 start = time .perf_counter ()
869866 for num_tokens in self .num_tokens_paddings :
870867 logger .info (" -- num_tokens: %d" , num_tokens )
871- self ._dummy_run (self . kv_caches , num_tokens )
868+ self ._dummy_run (num_tokens )
872869 xm .mark_step ()
873870 xm .wait_device_ops ()
874871 end = time .perf_counter ()
@@ -899,8 +896,7 @@ def capture_model(self) -> None:
899896 from_input_batch (self .input_batch , indices )
900897 logger .info (" -- num_tokens: %d, num_seqs: %d" , num_tokens ,
901898 num_reqs_to_sample )
902- out = self .model .sample_from_hidden (dummy_hidden ,
903- sampling_meta )
899+ out = self .sample_from_hidden (dummy_hidden , sampling_meta )
904900 out = out .cpu ()
905901 # Requests can't be more than tokens. But do compile for the
906902 # next bigger value in case num_tokens uses bucketed padding.
@@ -954,79 +950,48 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
954950 self .vllm_config .compilation_config .static_forward_context ,
955951 self .kv_caches )
956952
957-
958- class ModelWrapperV1 (nn .Module ):
959-
960- def __init__ (self , model : nn .Module ):
961- super ().__init__ ()
962- self .model = model
963- self .sampler = TPUSampler ()
964-
965- def sample (
966- self , logits : torch .Tensor ,
967- sampling_metadata : TPUSupportedSamplingMetadata ) -> SamplerOutput :
968- sampler_out = self .sampler (logits , sampling_metadata )
969- return sampler_out
970-
971- def forward (
972- self ,
973- input_ids : torch .Tensor ,
974- positions : torch .Tensor ,
975- kv_caches : list [torch .Tensor ],
976- inputs_embeds : Optional [torch .Tensor ] = None ,
977- ) -> torch .Tensor :
978- """Executes the forward pass of the model.
979-
980- Args:
981- input_ids: The input token IDs of shape [num_tokens].
982- positions: The input position IDs of shape [num_tokens].
983- kv_caches: The key and value caches. They can be None during the
984- memory profiling at initialization.
985- inputs_embeds: The input embeddings of shape [num_tokens,
986- hidden_size]. It is used for multimodal models.
987- """
988-
989- hidden_states = self .model (
990- input_ids = input_ids ,
991- positions = positions ,
992- inputs_embeds = inputs_embeds ,
993- )
994-
995- return hidden_states
953+ def reset_dynamo_cache (self ):
954+ if self .is_multimodal_model :
955+ assert hasattr (self .model , "language_model" )
956+ compiled_model = self .model .language_model .model
957+ else :
958+ compiled_model = self .model .model
959+ if isinstance (compiled_model , TorchCompileWrapperWithCustomDispatcher ):
960+ logger .info ("Clear dynamo cache and cached dynamo bytecode." )
961+ torch ._dynamo .eval_frame .remove_from_cache (
962+ compiled_model .original_code_object )
963+ compiled_model .compiled_codes .clear ()
996964
997965 def sample_from_hidden (
998966 self ,
999967 hidden_states : torch .Tensor ,
1000968 sampling_metadata : TPUSupportedSamplingMetadata ,
1001969 ) -> torch .Tensor :
1002970 """
1003- Sample with xla-friendly function. This function is to be traced
1004- separately from `forward` for lighter compilation overhead.
1005- """
971+ Sample with xla-friendly function. This function is to be traced
972+ separately for lighter compilation overhead.
973+ """
1006974 # Tensor `sample_hidden_states` is of fixed pre-compiled size.
1007975 sample_hidden_states = \
1008976 hidden_states [sampling_metadata .indices_do_sample ]
1009- logits = self .compute_logits (sample_hidden_states )
977+ # SamplingMetadata here for pruning output in LogitsProcessor, disabled.
978+ logits = self .model .compute_logits (sample_hidden_states , None )
979+
980+ def sample (
981+ logits : torch .Tensor ,
982+ sampling_metadata : TPUSupportedSamplingMetadata
983+ ) -> SamplerOutput :
984+ sampler_out = self .sampler (logits , sampling_metadata )
985+ return sampler_out
986+
1010987 # Optimized greedy sampling branch, tracing both paths in a single pass
1011988 # NOTE all_greedy is a scalar, this is just an optimized if/else.
1012- out_tokens = torch .where (sampling_metadata . all_greedy ,
1013- torch . argmax ( logits , dim = - 1 , keepdim = True ) ,
1014- self . sample (logits , sampling_metadata )\
1015- .sampled_token_ids )
989+ out_tokens = torch .where (
990+ sampling_metadata . all_greedy ,
991+ torch . argmax (logits , dim = - 1 , keepdim = True ),
992+ sample ( logits , sampling_metadata ) .sampled_token_ids )
1016993 return out_tokens
1017994
1018- def compute_logits (self ,
1019- hidden_states : torch .Tensor ) -> Optional [torch .Tensor ]:
1020- # SamplingMetadata here for pruning output in LogitsProcessor, disabled
1021- logits = self .model .compute_logits (hidden_states , None )
1022- return logits
1023-
1024- def get_multimodal_embeddings (self , * args , ** kwargs ):
1025- return self .model .get_multimodal_embeddings (* args , ** kwargs )
1026-
1027- def get_input_embeddings (self , * args , ** kwargs ):
1028- return self .model .get_input_embeddings (* args , ** kwargs )
1029-
1030995
1031996def _get_padded_number (n : int , multiple : int ) -> int :
1032997 return ((n + multiple - 1 ) // multiple ) * multiple
0 commit comments