| 
43 | 43 | from vllm.logger import init_logger  | 
44 | 44 | from vllm.model_executor.layers.layernorm import RMSNorm  | 
45 | 45 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,  | 
46 |  | -                                               QKVCrossParallelLinear,  | 
47 | 46 |                                                QKVParallelLinear,  | 
48 | 47 |                                                RowParallelLinear)  | 
49 | 48 | from vllm.model_executor.layers.logits_processor import LogitsProcessor  | 
@@ -814,11 +813,20 @@ def __init__(  | 
814 | 813 |         self.q_local_size = self.num_local_heads * self.head_dim  | 
815 | 814 |         self.kv_local_size = self.num_local_key_value_heads * self.head_dim  | 
816 | 815 | 
 
  | 
817 |  | -        self.qkv_proj = QKVCrossParallelLinear(  | 
 | 816 | +        # TODO(Isotr0py): Use QKVCrossParallelLinear when it supports  | 
 | 817 | +        # quantization  | 
 | 818 | +        self.q_proj = ColumnParallelLinear(  | 
 | 819 | +            input_size=self.hidden_size,  | 
 | 820 | +            output_size=self.num_heads * self.head_dim,  | 
 | 821 | +            bias=False,  | 
 | 822 | +            quant_config=quant_config,  | 
 | 823 | +            prefix=f"{prefix}.q_proj",  | 
 | 824 | +        )  | 
 | 825 | +        self.kv_proj = QKVParallelLinear(  | 
818 | 826 |             self.hidden_size,  | 
819 | 827 |             self.head_dim,  | 
820 |  | -            self.num_heads,  | 
821 |  | -            self.num_key_value_heads,  | 
 | 828 | +            total_num_heads=0,  | 
 | 829 | +            total_num_kv_heads=self.num_key_value_heads,  | 
822 | 830 |             bias=False,  | 
823 | 831 |             quant_config=quant_config,  | 
824 | 832 |             prefix=f"{prefix}.qkv_proj",  | 
@@ -854,11 +862,15 @@ def forward(  | 
854 | 862 |         kv_range_for_decode: Optional[List[Tuple[int, int]]],  | 
855 | 863 |         cross_attention_states: Optional[torch.Tensor],  | 
856 | 864 |     ) -> torch.Tensor:  | 
857 |  | -        q, k, v = self.qkv_proj(hidden_states, cross_attention_states)  | 
 | 865 | +        q, _ = self.q_proj(hidden_states)  | 
858 | 866 |         if cross_attention_states is not None:  | 
 | 867 | +            kv, _ = self.kv_proj(cross_attention_states)  | 
 | 868 | +            k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)  | 
859 | 869 |             k = k.view(-1, self.num_local_key_value_heads, self.head_dim)  | 
860 | 870 |             v = v.view(-1, self.num_local_key_value_heads, self.head_dim)  | 
861 | 871 |             k = self.k_norm(k)  | 
 | 872 | +        else:  | 
 | 873 | +            k = v = None  | 
862 | 874 | 
 
  | 
863 | 875 |         q = q.view(-1, self.num_local_heads, self.head_dim)  | 
864 | 876 |         q = self.q_norm(q)  | 
@@ -1149,8 +1161,13 @@ def forward(  | 
1149 | 1161 | class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,  | 
1150 | 1162 |                                      SupportsV0Only):  | 
1151 | 1163 |     packed_modules_mapping = {  | 
1152 |  | -        "qkv_proj": ["q_proj", "k_proj", "v_proj"],  | 
1153 |  | -        "gate_up_proj": ["gate_proj", "up_proj"]  | 
 | 1164 | +        "self_attn.qkv_proj": [  | 
 | 1165 | +            "self_attn.q_proj",  | 
 | 1166 | +            "self_attn.k_proj",  | 
 | 1167 | +            "self_attn.v_proj",  | 
 | 1168 | +        ],  | 
 | 1169 | +        "cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],  | 
 | 1170 | +        "gate_up_proj": ["gate_proj", "up_proj"],  | 
1154 | 1171 |     }  | 
1155 | 1172 | 
 
  | 
1156 | 1173 |     def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):  | 
@@ -1420,9 +1437,11 @@ def load_weights(self, weights: Iterable[Tuple[str,  | 
1420 | 1437 |                                                    torch.Tensor]]) -> Set[str]:  | 
1421 | 1438 |         stacked_params_mapping = [  | 
1422 | 1439 |             # (param_name, shard_name, shard_id)  | 
1423 |  | -            (".qkv_proj", ".q_proj", "q"),  | 
1424 |  | -            (".qkv_proj", ".k_proj", "k"),  | 
1425 |  | -            (".qkv_proj", ".v_proj", "v"),  | 
 | 1440 | +            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),  | 
 | 1441 | +            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),  | 
 | 1442 | +            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),  | 
 | 1443 | +            (".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),  | 
 | 1444 | +            (".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),  | 
1426 | 1445 |             (".gate_up_proj", ".gate_proj", 0),  | 
1427 | 1446 |             (".gate_up_proj", ".up_proj", 1),  | 
1428 | 1447 |         ]  | 
 | 
0 commit comments