3939from vllm .model_executor .layers .quantization import QuantizationConfig
4040from vllm .model_executor .layers .vocab_parallel_embedding import (
4141 ParallelLMHead , VocabParallelEmbedding )
42- from vllm .model_executor .model_loader .weight_utils import default_weight_loader
4342from vllm .model_executor .sampling_metadata import SamplingMetadata
4443from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
4544from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
5554from .interfaces import (SupportsLoRA , SupportsMultiModal , SupportsPP ,
5655 SupportsQuant )
5756from .utils import (AutoWeightsLoader , PPMissingLayer , WeightsMapper ,
58- flatten_bn , is_pp_missing_parameter ,
59- make_empty_intermediate_tensors_factory , maybe_prefix )
57+ flatten_bn , make_empty_intermediate_tensors_factory ,
58+ maybe_prefix )
6059
6160logger = init_logger (__name__ )
6261
@@ -414,64 +413,63 @@ def __exit__(self, exc_type, exc_value, traceback):
414413 setattr (self .config , key , value )
415414
416415
417- class TransformersModel :
416+ class TransformersBase (nn .Module , SupportsQuant , SupportsLoRA , SupportsPP ):
417+ embedding_padding_modules = ["lm_head" ]
418+ embedding_modules = ["embed_tokens"
419+ ] # TODO transformers will have a util to get it
418420
419421 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
420422 super ().__init__ ()
421423 logger .info ("Using Transformers backend." )
422424
423- config : PretrainedConfig = vllm_config .model_config .hf_config
424- cache_config : CacheConfig = vllm_config .cache_config
425- device_config : DeviceConfig = vllm_config .device_config
426- model_config : ModelConfig = vllm_config .model_config
427- parallel_config : ParallelConfig = vllm_config .parallel_config
428- quant_config : QuantizationConfig = vllm_config .quant_config
429-
430- self .config = config
431- self .text_config = config .get_text_config ()
432- self .cache_config = cache_config
433- self .device_config = device_config
434- self .model_config = model_config
435- self .parallel_config = parallel_config
436- self .quant_config = quant_config
425+ self .config : PretrainedConfig = vllm_config .model_config .hf_config
426+ self .text_config : PretrainedConfig = self .config .get_text_config ()
427+ self .cache_config : CacheConfig = vllm_config .cache_config
428+ self .device_config : DeviceConfig = vllm_config .device_config
429+ self .model_config : ModelConfig = vllm_config .model_config
430+ self .parallel_config : ParallelConfig = vllm_config .parallel_config
431+ self .quant_config : QuantizationConfig = vllm_config .quant_config
437432
438433 self .pp_group = get_pp_group ()
439434 self .pp_size = self .pp_group .world_size
440435 self .pp_rank = self .pp_group .rank_in_group
441436 self .tp_size = get_tensor_model_parallel_world_size ()
442437
438+ # To be updated in child classes for use in `load_weights`
439+ self .skip_prefixes : Optional [list [str ]] = None
440+
443441 # vLLM handles interleaved sliding window attention by creating a new
444442 # interleaved_sliding_window attribute and deleting the sliding_window
445443 # attribute. This breaks the constructors in Transformers so we
446444 # temporarily add the attribute back to construct the model.
447445 config_override = nullcontext ()
448- if hasattr (config , "interleaved_sliding_window" ):
446+ if hasattr (self . config , "interleaved_sliding_window" ):
449447 config_override = ConfigOverride (
450- config , sliding_window = config .interleaved_sliding_window )
448+ self .config ,
449+ sliding_window = self .config .interleaved_sliding_window )
451450
452451 # Set correct attn and init on "meta" to delay allocating GPU tensors
453452 # TODO: @raushan, use the public `model.set_attn_implementation()`
454453 # method after v4.54.0 is released
455454 self .text_config ._attn_implementation = "vllm"
456455 with init_on_device_without_buffers ("meta" ), config_override :
457456 self .model : PreTrainedModel = AutoModel .from_config (
458- config ,
459- torch_dtype = model_config .dtype ,
460- trust_remote_code = model_config .trust_remote_code ,
457+ self . config ,
458+ torch_dtype = self . model_config .dtype ,
459+ trust_remote_code = self . model_config .trust_remote_code ,
461460 )
462461
463462 self .pipeline_parallel ()
464463 self .tensor_parallel ()
465464
466465 # Input embeddings
467- text_config = config .get_text_config ()
468466 if not isinstance (self .model .get_input_embeddings (), PPMissingLayer ):
469467 self .model .set_input_embeddings (
470468 VocabParallelEmbedding (
471- text_config .vocab_size ,
472- text_config .hidden_size ,
473- org_num_embeddings = text_config .vocab_size ,
474- quant_config = quant_config ,
469+ self . text_config .vocab_size ,
470+ self . text_config .hidden_size ,
471+ org_num_embeddings = self . text_config .vocab_size ,
472+ quant_config = self . quant_config ,
475473 ))
476474
477475 # Attention layers
@@ -481,8 +479,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
481479 self .init_parameters (self .model )
482480
483481 self .make_empty_intermediate_tensors = (
484- make_empty_intermediate_tensors_factory ([ "hidden_states" ],
485- text_config .hidden_size ))
482+ make_empty_intermediate_tensors_factory (
483+ [ "hidden_states" ], self . text_config .hidden_size ))
486484
487485 def pipeline_parallel (self ):
488486 """
@@ -654,78 +652,40 @@ def forward(
654652
655653 def load_weights (self , weights : Iterable [tuple [str ,
656654 torch .Tensor ]]) -> set [str ]:
657- params_dict = dict (self .named_parameters ())
658-
659- loaded_params = set [str ]()
660- for name , loaded_weight in weights :
661- # Use "model" instead of base_model_prefix because
662- # the base model attribute in vLLM is always `model`
663- if not name .startswith (prefix := "model." ):
664- name = prefix + name
665-
666- if is_pp_missing_parameter (name , self ):
667- continue
668- if name in params_dict :
669- param = params_dict [name ]
670- weight_loader = getattr (param , "weight_loader" ,
671- default_weight_loader )
672- weight_loader (param , loaded_weight )
673- loaded_params .add (name )
674- return loaded_params
655+ loader = AutoWeightsLoader (self , skip_prefixes = self .skip_prefixes )
656+ return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
675657
676658
677659@support_torch_compile
678- class TransformersForCausalLM (nn .Module , SupportsQuant , SupportsLoRA ,
679- SupportsPP ):
680- embedding_padding_modules = ["lm_head" ]
681- embedding_modules = ["embed_tokens"
682- ] # TODO transformers will have a util to get it
660+ class TransformersForCausalLM (TransformersBase ):
683661
684662 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
685- super ().__init__ ()
686- config : PretrainedConfig = vllm_config .model_config .hf_config
687- quant_config : QuantizationConfig = vllm_config .quant_config
688-
689- self .config = config
663+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
690664
691- self .transformers_model = TransformersModel (vllm_config = vllm_config ,
692- prefix = prefix )
693- self .model = self .transformers_model .model
665+ # Tell `TransformersBase.load_weights` to skip
666+ # `lm_head` if the model has tied word embeddings
667+ if self .text_config .tie_word_embeddings :
668+ self .skip_prefixes = ["lm_head." ]
694669
695670 if get_pp_group ().is_last_rank :
696- self .unpadded_vocab_size = config .vocab_size
671+ self .unpadded_vocab_size = self . text_config .vocab_size
697672 self .lm_head = ParallelLMHead (
698- config .vocab_size ,
699- config .hidden_size ,
700- quant_config = quant_config ,
673+ self . text_config .vocab_size ,
674+ self . text_config .hidden_size ,
675+ quant_config = self . quant_config ,
701676 prefix = maybe_prefix (prefix , "lm_head" ),
702677 )
703- if config .tie_word_embeddings :
678+ if self . text_config .tie_word_embeddings :
704679 self .lm_head = self .lm_head .tie_weights (
705680 self .model .get_input_embeddings ())
706681
707- logit_scale = getattr (config , "logit_scale" , 1.0 )
708- self .logits_processor = LogitsProcessor (self . unpadded_vocab_size ,
709- config .vocab_size ,
710- logit_scale )
682+ logit_scale = getattr (self . text_config , "logit_scale" , 1.0 )
683+ self .logits_processor = LogitsProcessor (
684+ self . unpadded_vocab_size , self . text_config .vocab_size ,
685+ logit_scale )
711686 else :
712687 self .lm_head = PPMissingLayer ()
713688
714- self .make_empty_intermediate_tensors = (
715- self .transformers_model .make_empty_intermediate_tensors )
716-
717- def forward (
718- self ,
719- input_ids : Optional [torch .Tensor ],
720- positions : torch .Tensor ,
721- intermediate_tensors : Optional [IntermediateTensors ] = None ,
722- inputs_embeds : Optional [torch .Tensor ] = None ,
723- ) -> Union [torch .Tensor , IntermediateTensors ]:
724- model_output = self .transformers_model .forward (input_ids , positions ,
725- intermediate_tensors ,
726- inputs_embeds )
727- return model_output
728-
729689 def compute_logits (
730690 self ,
731691 hidden_states : torch .Tensor ,
@@ -735,23 +695,12 @@ def compute_logits(
735695 sampling_metadata )
736696 return logits
737697
738- def load_weights (self , weights : Iterable [tuple [str ,
739- torch .Tensor ]]) -> set [str ]:
740- skip_prefixes = ["lm_head."
741- ] if self .config .tie_word_embeddings else None
742- loader = AutoWeightsLoader (self , skip_prefixes = skip_prefixes )
743- return loader .load_weights (weights )
744-
745698
746699@MULTIMODAL_REGISTRY .register_processor (
747700 MultiModalProcessor ,
748701 info = MultiModalProcessingInfo ,
749702 dummy_inputs = MultiModalDummyInputsBuilder )
750- class TransformersForMultimodalLM (nn .Module , SupportsQuant , SupportsLoRA ,
751- SupportsPP , SupportsMultiModal ):
752- embedding_padding_modules = ["lm_head" ]
753- embedding_modules = ["embed_tokens" ]
754-
703+ class TransformersForMultimodalLM (TransformersForCausalLM , SupportsMultiModal ):
755704 # Backwards compatibility for prev released models. State dicts back then
756705 # had different formats and cannot be loaded with `AutoModel` mapping as is
757706 hf_to_vllm_mapper = WeightsMapper (
@@ -776,40 +725,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
776725 })
777726
778727 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
779- super ().__init__ ()
780- config : PretrainedConfig = vllm_config .model_config .hf_config
781- quant_config : QuantizationConfig = vllm_config .quant_config
728+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
782729
783- self .config = config
784730 self .dtype = vllm_config .model_config .dtype
785731
786- self .transformers_model = TransformersModel (vllm_config = vllm_config ,
787- prefix = prefix )
788- self .model = self .transformers_model .model
789- text_config = config .get_text_config ()
790-
791- if get_pp_group ().is_last_rank :
792- self .unpadded_vocab_size = text_config .vocab_size
793- self .lm_head = ParallelLMHead (
794- text_config .vocab_size ,
795- text_config .hidden_size ,
796- quant_config = quant_config ,
797- prefix = maybe_prefix (prefix , "lm_head" ),
798- )
799- if text_config .tie_word_embeddings :
800- self .lm_head = self .lm_head .tie_weights (
801- self .model .get_input_embeddings ())
802-
803- logit_scale = getattr (config , "logit_scale" , 1.0 )
804- self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
805- text_config .vocab_size ,
806- logit_scale )
807- else :
808- self .lm_head = PPMissingLayer ()
809-
810- self .make_empty_intermediate_tensors = (
811- self .transformers_model .make_empty_intermediate_tensors )
812-
813732 def forward (
814733 self ,
815734 input_ids : Optional [torch .Tensor ],
@@ -828,30 +747,10 @@ def forward(
828747 input_ids , multimodal_embeds )
829748 input_ids = None
830749
831- model_output = self .transformers_model .forward (input_ids , positions ,
832- intermediate_tensors ,
833- inputs_embeds )
750+ model_output = super ().forward (input_ids , positions ,
751+ intermediate_tensors , inputs_embeds )
834752 return model_output
835753
836- def compute_logits (
837- self ,
838- hidden_states : torch .Tensor ,
839- sampling_metadata : SamplingMetadata ,
840- ) -> Optional [torch .Tensor ]:
841- logits = self .logits_processor (self .lm_head , hidden_states ,
842- sampling_metadata )
843- return logits
844-
845- def load_weights (self , weights : Iterable [tuple [str ,
846- torch .Tensor ]]) -> set [str ]:
847- loader = AutoWeightsLoader (
848- self ,
849- skip_prefixes = ([
850- "lm_head."
851- ] if self .config .get_text_config ().tie_word_embeddings else None ),
852- )
853- return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
854-
855754 def get_multimodal_embeddings (self , ** kwargs ):
856755 pixel_values = kwargs .pop ("pixel_values" , None )
857756 pixel_values = pixel_values if pixel_values is not None else kwargs .pop (
0 commit comments