4444 default_weight_loader ,
4545 maybe_remap_kv_scale_name ,
4646)
47+ from vllm .model_executor .sampling_metadata import SamplingMetadata
4748from vllm .sequence import IntermediateTensors
4849
4950from ...attention .layers .encoder_only_attention import EncoderOnlyAttention
@@ -70,20 +71,61 @@ def __init__(
7071 prefix : str = "" ,
7172 ) -> None :
7273 super ().__init__ ()
73- self .gate_up_proj = MergedColumnParallelLinear (
74- hidden_size ,
75- [intermediate_size ] * 2 ,
76- bias = False ,
77- quant_config = quant_config ,
78- prefix = f"{ prefix } .gate_up_proj" ,
79- )
80- self .down_proj = RowParallelLinear (
81- intermediate_size ,
82- hidden_size ,
83- bias = False ,
84- quant_config = quant_config ,
85- prefix = f"{ prefix } .down_proj" ,
86- )
74+
75+ # Detect GGUF quantization
76+ is_gguf_quantized = False
77+ if quant_config is not None :
78+ quant_config_type = type (quant_config ).__name__ .lower ()
79+ if "gguf" in quant_config_type or (
80+ hasattr (quant_config , "quant_method" )
81+ and "gguf" in str (quant_config .quant_method ).lower ()
82+ ):
83+ is_gguf_quantized = True
84+
85+ # Import ColumnParallelLinear for GGUF compatibility
86+ from vllm .model_executor .layers .linear import ColumnParallelLinear
87+
88+ if is_gguf_quantized :
89+ # Use separate linear layers for GGUF compatibility
90+ # (no merged layers)
91+ self .gate_proj = ColumnParallelLinear (
92+ hidden_size ,
93+ intermediate_size ,
94+ bias = False ,
95+ quant_config = quant_config , # Enable GGUF quantization
96+ prefix = f"{ prefix } .gate_proj" ,
97+ )
98+ self .up_proj = ColumnParallelLinear (
99+ hidden_size ,
100+ intermediate_size ,
101+ bias = False ,
102+ quant_config = quant_config , # Enable GGUF quantization
103+ prefix = f"{ prefix } .up_proj" ,
104+ )
105+ self .down_proj = RowParallelLinear (
106+ intermediate_size ,
107+ hidden_size ,
108+ bias = False ,
109+ quant_config = quant_config , # Enable GGUF quantization
110+ prefix = f"{ prefix } .down_proj" ,
111+ )
112+ self .gate_up_proj = None # Not used for GGUF
113+ else :
114+ # Use quantized linear layers for non-GGUF models
115+ self .gate_up_proj = MergedColumnParallelLinear (
116+ hidden_size ,
117+ [intermediate_size ] * 2 ,
118+ bias = False ,
119+ quant_config = quant_config ,
120+ prefix = f"{ prefix } .gate_up_proj" ,
121+ )
122+ self .down_proj = RowParallelLinear (
123+ intermediate_size ,
124+ hidden_size ,
125+ bias = False ,
126+ quant_config = quant_config ,
127+ prefix = f"{ prefix } .down_proj" ,
128+ )
87129 if hidden_activation != "gelu_pytorch_tanh" :
88130 raise ValueError (
89131 "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
@@ -93,7 +135,15 @@ def __init__(
93135 self .act_fn = GeluAndMul (approximate = "tanh" )
94136
95137 def forward (self , x : torch .Tensor ) -> torch .Tensor :
96- gate_up , _ = self .gate_up_proj (x )
138+ if hasattr (self , "gate_proj" ) and self .gate_proj is not None :
139+ # GGUF mode: use separate gate_proj and up_proj
140+ gate , _ = self .gate_proj (x )
141+ up , _ = self .up_proj (x )
142+ gate_up = torch .cat ([gate , up ], dim = - 1 )
143+ else :
144+ # Non-GGUF mode: use merged gate_up_proj
145+ gate_up , _ = self .gate_up_proj (x )
146+
97147 x = self .act_fn (gate_up )
98148 x , _ = self .down_proj (x )
99149 return x
@@ -135,22 +185,85 @@ def __init__(
135185 self .kv_size = self .num_kv_heads * self .head_dim
136186 self .scaling = config .query_pre_attn_scalar ** - 0.5
137187
138- self .qkv_proj = QKVParallelLinear (
139- hidden_size ,
140- self .head_dim ,
141- self .total_num_heads ,
142- self .total_num_kv_heads ,
143- bias = config .attention_bias ,
144- quant_config = quant_config ,
145- prefix = f"{ prefix } .qkv_proj" ,
146- )
147- self .o_proj = RowParallelLinear (
148- self .total_num_heads * self .head_dim ,
149- hidden_size ,
150- bias = config .attention_bias ,
151- quant_config = quant_config ,
152- prefix = f"{ prefix } .o_proj" ,
153- )
188+ # GGUF quantization requires separate Q/K/V layers instead of fused QKV
189+ is_gguf_quantized = False
190+
191+ # Check if we're using GGUF quantization by looking at the
192+ # quant_config type
193+ if quant_config is not None :
194+ quant_config_type = str (type (quant_config ))
195+ # GGUF quantization configs typically have 'gguf' in their type name
196+ if (
197+ "gguf" in quant_config_type .lower ()
198+ or "GGUF" in quant_config_type
199+ or hasattr (quant_config , "quant_method" )
200+ and quant_config .quant_method == "gguf"
201+ ):
202+ is_gguf_quantized = True
203+
204+ # Store GGUF detection result for use in load_weights
205+ self .is_gguf_quantized = is_gguf_quantized
206+
207+ if is_gguf_quantized :
208+ # Create separate Q/K/V linear layers for GGUF compatibility
209+ # Pass quant_config to enable GGUF quantization
210+ # (keeps weights compressed)
211+ from vllm .model_executor .layers .linear import ColumnParallelLinear
212+
213+ self .q_proj = ColumnParallelLinear (
214+ hidden_size ,
215+ self .total_num_heads * self .head_dim ,
216+ bias = config .attention_bias ,
217+ quant_config = quant_config , # Enable GGUF quantization
218+ prefix = f"{ prefix } .q_proj" ,
219+ )
220+ self .k_proj = ColumnParallelLinear (
221+ hidden_size ,
222+ self .total_num_kv_heads * self .head_dim ,
223+ bias = config .attention_bias ,
224+ quant_config = quant_config , # Enable GGUF quantization
225+ prefix = f"{ prefix } .k_proj" ,
226+ )
227+ self .v_proj = ColumnParallelLinear (
228+ hidden_size ,
229+ self .total_num_kv_heads * self .head_dim ,
230+ bias = config .attention_bias ,
231+ quant_config = quant_config , # Enable GGUF quantization
232+ prefix = f"{ prefix } .v_proj" ,
233+ )
234+ self .qkv_proj = None # Not used for GGUF
235+
236+ # Also create separate o_proj for GGUF compatibility
237+ from vllm .model_executor .layers .linear import RowParallelLinear
238+
239+ self .o_proj = RowParallelLinear (
240+ self .total_num_heads * self .head_dim ,
241+ hidden_size ,
242+ bias = config .attention_bias ,
243+ quant_config = quant_config , # Enable GGUF quantization
244+ prefix = f"{ prefix } .o_proj" ,
245+ )
246+ else :
247+ # Use fused QKV for non-GGUF models
248+ self .qkv_proj = QKVParallelLinear (
249+ hidden_size ,
250+ self .head_dim ,
251+ self .total_num_heads ,
252+ self .total_num_kv_heads ,
253+ bias = config .attention_bias ,
254+ quant_config = quant_config ,
255+ prefix = f"{ prefix } .qkv_proj" ,
256+ )
257+ # Create o_proj for non-GGUF models too
258+ from vllm .model_executor .layers .linear import RowParallelLinear
259+
260+ self .o_proj = RowParallelLinear (
261+ self .total_num_heads * self .head_dim ,
262+ hidden_size ,
263+ bias = config .attention_bias ,
264+ quant_config = quant_config ,
265+ prefix = f"{ prefix } .o_proj" ,
266+ )
154267
155268 self .q_norm = GemmaRMSNorm (self .head_dim , eps = config .rms_norm_eps )
156269 self .k_norm = GemmaRMSNorm (self .head_dim , eps = config .rms_norm_eps )
@@ -207,8 +320,16 @@ def forward(
207320 hidden_states : torch .Tensor ,
208321 ** kwargs ,
209322 ) -> torch .Tensor :
210- qkv , _ = self .qkv_proj (hidden_states )
211- q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
323+ # Handle both fused QKV and separate Q/K/V projections
324+ if self .qkv_proj is not None :
325+ # Fused QKV projection (non-GGUF models)
326+ qkv , _ = self .qkv_proj (hidden_states )
327+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
328+ else :
329+ # Separate Q/K/V projections (GGUF models)
330+ q , _ = self .q_proj (hidden_states )
331+ k , _ = self .k_proj (hidden_states )
332+ v , _ = self .v_proj (hidden_states )
212333
213334 q = q .unflatten (- 1 , (self .num_heads , self .head_dim ))
214335 q = self .q_norm (q )
@@ -369,6 +490,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
369490 self .config = config
370491 self .quant_config = quant_config
371492
493+ # Detect GGUF quantization from model config
494+ self .is_gguf_quantized = vllm_config .model_config .quantization == "gguf"
495+
372496 self .embed_tokens = VocabParallelEmbedding (
373497 config .vocab_size ,
374498 config .hidden_size ,
@@ -439,9 +563,48 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
439563 ("gate_up_proj" , "gate_proj" , 0 ),
440564 ("gate_up_proj" , "up_proj" , 1 ),
441565 ]
566+ # Check if any attention layer has GGUF quantization
567+ has_gguf_attention = False
568+ for module in self .modules ():
569+ if hasattr (module , "is_gguf_quantized" ) and module .is_gguf_quantized :
570+ has_gguf_attention = True
571+ break
572+
573+ if not has_gguf_attention :
574+ # Use normal stacked mapping for non-GGUF models
575+ stacked_params_mapping .extend (
576+ [
577+ ("qkv_proj" , "q_proj" , "q" ),
578+ ("qkv_proj" , "k_proj" , "k" ),
579+ ("qkv_proj" , "v_proj" , "v" ),
580+ ]
581+ )
582+
583+ # Include gate_up_proj mapping only for non-GGUF models
584+ if not has_gguf_attention :
585+ stacked_params_mapping .extend (
586+ [
587+ ("gate_up_proj" , "gate_proj" , 0 ),
588+ ("gate_up_proj" , "up_proj" , 1 ),
589+ ]
590+ )
442591 params_dict = dict (self .named_parameters ())
443592 loaded_params : set [str ] = set ()
444593 for name , loaded_weight in weights :
594+ # Apply GGUF-specific RMSNorm weight correction for Gemma3
595+ # This must happen BEFORE any transformations (transpose, etc.)
596+ # GemmaRMSNorm computes: output = x * (1 + weight)
597+ # GGUF stores full weight values (for standard x * weight)
598+ # but vLLM's GemmaRMSNorm expects (weight - 1) since it adds 1
599+ # during the forward pass.
600+ if (
601+ self .quant_config is not None
602+ and self .quant_config .get_name () == "gguf"
603+ and "norm" in name
604+ and len (loaded_weight .shape ) == 1
605+ ):
606+ loaded_weight = loaded_weight - 1.0
607+
445608 if self .quant_config is not None and (
446609 scale_name := self .quant_config .get_cache_scale (name )
447610 ):
@@ -478,20 +641,78 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
478641 if is_pp_missing_parameter (name , self ):
479642 continue
480643 param = params_dict [name ]
644+
645+ # Fix shape mismatch for GGUF models - transpose if needed
646+ if (
647+ has_gguf_attention
648+ and "weight" in name
649+ and ("self_attn" in name or "mlp" in name )
650+ and param .shape != loaded_weight .shape
651+ and param .shape == loaded_weight .T .shape
652+ ):
653+ loaded_weight = loaded_weight .T
654+ # Transposed weight to match model parameter shape
655+
481656 weight_loader = param .weight_loader
482657 weight_loader (param , loaded_weight , shard_id )
483658 break
484659 else :
485660 # Skip loading extra bias for GPTQ models.
486661 if name .endswith (".bias" ) and name not in params_dict :
487662 continue
663+ # Skip GGUF qweight_type metadata for layers that don't have it
664+ # (e.g., embedding layers). These are handled by GGUF
665+ # quantization layers.
666+ if name .endswith (".qweight_type" ) and name not in params_dict :
667+ continue
668+ # Skip GGUF qweight parameters that don't exist
669+ # Gemma3's GGUF layers use regular ColumnParallelLinear
670+ # with 'weight' instead of 'qweight'
671+ if name .endswith (".qweight" ) and name not in params_dict :
672+ # Try to load as regular weight instead
673+ name = name .replace (".qweight" , ".weight" )
674+ if name not in params_dict :
675+ continue
488676 # Remapping the name of FP8 kv-scale.
489677 name = maybe_remap_kv_scale_name (name , params_dict )
490678 if name is None :
491679 continue
492680 if is_pp_missing_parameter (name , self ):
493681 continue
494682 param = params_dict [name ]
683+
684+ # Skip shape checking for GGUF uninitialized parameters
685+ # GGUF quantized layers use UninitializedParameter
686+ # which has no shape
687+ from torch .nn .parameter import UninitializedParameter
688+
689+ is_uninitialized = isinstance (param , UninitializedParameter )
690+
691+ # Fix shape mismatch for GGUF models - transpose if needed
692+ if (
693+ has_gguf_attention
694+ and "self_attn" in name
695+ and "weight" in name
696+ and not is_uninitialized
697+ and param .shape != loaded_weight .shape
698+ and param .shape == loaded_weight .T .shape
699+ ):
700+ loaded_weight = loaded_weight .T
701+ # Transposed weight to match model parameter shape
702+
703+ # Fix shape mismatch for GGUF models - transpose if needed
704+ # (for non-stacked parameters)
705+ if (
706+ has_gguf_attention
707+ and "weight" in name
708+ and ("self_attn" in name or "mlp" in name )
709+ and not is_uninitialized
710+ and param .shape != loaded_weight .shape
711+ and param .shape == loaded_weight .T .shape
712+ ):
713+ loaded_weight = loaded_weight .T
714+ # Transposed weight to match model parameter shape
715+
495716 weight_loader = getattr (param , "weight_loader" , default_weight_loader )
496717 weight_loader (param , loaded_weight )
497718 loaded_params .add (name )
@@ -519,6 +740,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
519740 del lora_config # Unused.
520741 super ().__init__ ()
521742 self .config = config
743+ # Store model config for quantization access
744+ self .model_config = vllm_config .model_config
522745 # currently all existing Gemma models have `tie_word_embeddings` enabled
523746 assert config .tie_word_embeddings
524747 self .quant_config = quant_config
@@ -551,8 +774,11 @@ def forward(
551774 def compute_logits (
552775 self ,
553776 hidden_states : torch .Tensor ,
777+ sampling_metadata : SamplingMetadata ,
554778 ) -> Optional [torch .Tensor ]:
555- logits = self .logits_processor (self .model .embed_tokens , hidden_states )
779+ logits = self .logits_processor (
780+ self .model .embed_tokens , hidden_states , sampling_metadata
781+ )
556782 return logits
557783
558784 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
0 commit comments