2020from transformers import PretrainedConfig
2121
2222from vllm .attention import Attention , AttentionMetadata
23- from vllm .config import CacheConfig , MultiModalConfig
23+ from vllm .config import CacheConfig , LoRAConfig , MultiModalConfig
2424from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
2525from vllm .inputs import (INPUT_REGISTRY , DecoderOnlyInputs , InputContext ,
2626 token_inputs )
3030from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
3131 MergedColumnParallelLinear ,
3232 QKVParallelLinear ,
33+ ReplicatedLinear ,
3334 RowParallelLinear )
3435from vllm .model_executor .layers .logits_processor import LogitsProcessor
3536from vllm .model_executor .layers .quantization import QuantizationConfig
3940from vllm .model_executor .layers .vocab_parallel_embedding import (
4041 ParallelLMHead , VocabParallelEmbedding )
4142from vllm .model_executor .model_loader .weight_utils import default_weight_loader
43+ from vllm .model_executor .models .module_mapping import MultiModelKeys
4244from vllm .model_executor .sampling_metadata import SamplingMetadata
4345from vllm .multimodal import MULTIMODAL_REGISTRY
4446from vllm .multimodal .base import MultiModalInputs
4547from vllm .multimodal .utils import cached_get_tokenizer
4648from vllm .sequence import IntermediateTensors , SequenceData
4749from vllm .utils import is_list_of
4850
49- from .interfaces import SupportsMultiModal , SupportsPP
51+ from .interfaces import SupportsLoRA , SupportsMultiModal , SupportsPP
5052from .utils import (flatten_bn , is_pp_missing_parameter ,
5153 make_empty_intermediate_tensors_factory , make_layers )
5254
@@ -122,8 +124,8 @@ def __init__(
122124 # Strided linear layer.
123125 assert self ._qkv_same_embed_dim , \
124126 'Visual Attention implementation only supports self-attention'
125- self .in_proj = nn . Linear (embed_dim , 3 * embed_dim )
126- self .out_proj = nn . Linear (embed_dim , embed_dim )
127+ self .in_proj = ReplicatedLinear (embed_dim , 3 * embed_dim )
128+ self .out_proj = ReplicatedLinear (embed_dim , embed_dim )
127129 self .norm_factor = math .sqrt (self .hidden_size_per_attention_head )
128130
129131 def forward (
@@ -133,7 +135,7 @@ def forward(
133135 ) -> torch .Tensor :
134136 # query/key/value: [sq, b, h]
135137 sq , b , _ = x .size ()
136- mixed_x_layer = self .in_proj (x )
138+ mixed_x_layer , _ = self .in_proj (x )
137139
138140 # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
139141 new_tensor_shape = mixed_x_layer .size ()[:- 1 ] + \
@@ -182,7 +184,7 @@ def forward(
182184 (self .hidden_size_per_partition ,)
183185 context_layer = context_layer .view (* new_context_layer_shape )
184186
185- output = self .out_proj (context_layer )
187+ output , _ = self .out_proj (context_layer )
186188
187189 return output
188190
@@ -860,18 +862,15 @@ def dummy_data_for_qwen(
860862 return seq_data , mm_data
861863
862864
863- @MULTIMODAL_REGISTRY .register_image_input_mapper (input_mapper_for_qwen )
864- @MULTIMODAL_REGISTRY .register_max_image_tokens (MAX_QWEN_IMG_TOKENS )
865- @INPUT_REGISTRY .register_dummy_data (dummy_data_for_qwen )
866- @INPUT_REGISTRY .register_input_processor (input_processor_for_qwen )
867- class QWenLMHeadModel (nn .Module , SupportsMultiModal , SupportsPP ):
865+ class QWenBaseModel (nn .Module , SupportsMultiModal , SupportsPP , SupportsLoRA ):
868866
869867 def __init__ (
870868 self ,
871869 config : PretrainedConfig ,
872870 multimodal_config : MultiModalConfig ,
873871 cache_config : Optional [CacheConfig ] = None ,
874872 quant_config : Optional [QuantizationConfig ] = None ,
873+ lora_config : Optional [LoRAConfig ] = None ,
875874 ):
876875 super ().__init__ ()
877876 self .config = config
@@ -990,3 +989,91 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
990989 weight_loader = getattr (param , "weight_loader" ,
991990 default_weight_loader )
992991 weight_loader (param , loaded_weight )
992+
993+
994+ class QWenLLM (QWenBaseModel ):
995+ packed_modules_mapping = {
996+ "c_attn" : ["c_attn" ],
997+ "gate_up_proj" : [
998+ "w2" ,
999+ "w1" ,
1000+ ],
1001+ }
1002+ # LoRA specific attributes
1003+ supported_lora_modules = [
1004+ "c_attn" ,
1005+ "gate_up_proj" ,
1006+ "c_proj" ,
1007+ ]
1008+
1009+ embedding_modules = {}
1010+ embedding_padding_modules = []
1011+
1012+
1013+ class QWenVL (QWenBaseModel ):
1014+ packed_modules_mapping = {
1015+ "c_attn" : ["c_attn" ],
1016+ "gate_up_proj" : [
1017+ "w2" ,
1018+ "w1" ,
1019+ ],
1020+ }
1021+ # LoRA specific attributes
1022+ supported_lora_modules = [
1023+ "c_attn" ,
1024+ "gate_up_proj" ,
1025+ "c_proj" ,
1026+ # visual module
1027+ "out_proj" ,
1028+ "in_proj" ,
1029+ "c_fc" ,
1030+ # resampler
1031+ "kv_proj" ,
1032+ ]
1033+
1034+ embedding_modules = {}
1035+ embedding_padding_modules = []
1036+
1037+ def get_mm_mapping (self ) -> MultiModelKeys :
1038+ """
1039+ Get the module prefix in multimodal models
1040+ """
1041+ return MultiModelKeys .from_string_field (
1042+ language_model = "transformer.h" ,
1043+ connector = "transformer.visual.attn_pool" ,
1044+ tower_model = "transformer.visual.transformer" )
1045+
1046+
1047+ @MULTIMODAL_REGISTRY .register_image_input_mapper (input_mapper_for_qwen )
1048+ @MULTIMODAL_REGISTRY .register_max_image_tokens (MAX_QWEN_IMG_TOKENS )
1049+ @INPUT_REGISTRY .register_dummy_data (dummy_data_for_qwen )
1050+ @INPUT_REGISTRY .register_input_processor (input_processor_for_qwen )
1051+ class QWenLMHeadModel (QWenBaseModel ):
1052+ """
1053+ QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
1054+ conducive to the current integration logic of LoRA in vLLM. Therefore, it
1055+ is necessary to separate them.
1056+ """
1057+ # Ensure that the LoRA support check passes when the class is not
1058+ # initialized, but set all these attributes to empty.
1059+ packed_modules_mapping = {}
1060+ supported_lora_modules = []
1061+ embedding_modules = {}
1062+ embedding_padding_modules = []
1063+
1064+ def __new__ (
1065+ cls ,
1066+ config : PretrainedConfig ,
1067+ multimodal_config : MultiModalConfig ,
1068+ cache_config : Optional [CacheConfig ] = None ,
1069+ quant_config : Optional [QuantizationConfig ] = None ,
1070+ lora_config : Optional [LoRAConfig ] = None ,
1071+ ):
1072+ # Initialize VL
1073+ if hasattr (config , "visual" ):
1074+ return QWenVL (config , multimodal_config , cache_config ,
1075+ quant_config , lora_config )
1076+ # Initialize LLM
1077+ else :
1078+ return QWenLLM (config , multimodal_config , cache_config ,
1079+ quant_config , lora_config )
0 commit comments