|
44 | 44 | default_weight_loader, |
45 | 45 | maybe_remap_kv_scale_name, |
46 | 46 | ) |
| 47 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
47 | 48 | from vllm.sequence import IntermediateTensors |
48 | 49 |
|
49 | 50 | from ...attention.layers.encoder_only_attention import EncoderOnlyAttention |
@@ -442,6 +443,20 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
442 | 443 | params_dict = dict(self.named_parameters()) |
443 | 444 | loaded_params: set[str] = set() |
444 | 445 | for name, loaded_weight in weights: |
| 446 | + # Apply GGUF-specific RMSNorm weight correction for Gemma3 |
| 447 | + # This must happen BEFORE any transformations (transpose, etc.) |
| 448 | + # GemmaRMSNorm computes: output = x * (1 + weight) |
| 449 | + # GGUF stores full weight values (for standard x * weight) |
| 450 | + # but vLLM's GemmaRMSNorm expects (weight - 1) since it adds 1 |
| 451 | + # during the forward pass. |
| 452 | + if ( |
| 453 | + self.quant_config is not None |
| 454 | + and self.quant_config.get_name() == "gguf" |
| 455 | + and "norm" in name |
| 456 | + and len(loaded_weight.shape) == 1 |
| 457 | + ): |
| 458 | + loaded_weight = loaded_weight - 1.0 |
| 459 | + |
445 | 460 | if self.quant_config is not None and ( |
446 | 461 | scale_name := self.quant_config.get_cache_scale(name) |
447 | 462 | ): |
@@ -485,6 +500,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
485 | 500 | # Skip loading extra bias for GPTQ models. |
486 | 501 | if name.endswith(".bias") and name not in params_dict: |
487 | 502 | continue |
| 503 | + # Skip GGUF qweight_type metadata for layers that don't have it |
| 504 | + # (e.g., embedding layers). These are handled by GGUF |
| 505 | + # quantization layers. |
| 506 | + if name.endswith(".qweight_type") and name not in params_dict: |
| 507 | + continue |
| 508 | + |
| 509 | + # Handle GGUF qweight for embedding and other non-merged layers |
| 510 | + # GGUF uses .qweight for quantized weights, but some layers |
| 511 | + # (like VocabParallelEmbedding) expect .weight |
| 512 | + if name.endswith(".qweight") and name not in params_dict: |
| 513 | + # Try to load as regular weight instead |
| 514 | + name = name.replace(".qweight", ".weight") |
| 515 | + if name not in params_dict: |
| 516 | + continue |
| 517 | + |
488 | 518 | # Remapping the name of FP8 kv-scale. |
489 | 519 | name = maybe_remap_kv_scale_name(name, params_dict) |
490 | 520 | if name is None: |
@@ -519,6 +549,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
519 | 549 | del lora_config # Unused. |
520 | 550 | super().__init__() |
521 | 551 | self.config = config |
| 552 | + # Store model config for quantization access |
| 553 | + self.model_config = vllm_config.model_config |
522 | 554 | # currently all existing Gemma models have `tie_word_embeddings` enabled |
523 | 555 | assert config.tie_word_embeddings |
524 | 556 | self.quant_config = quant_config |
@@ -551,8 +583,11 @@ def forward( |
551 | 583 | def compute_logits( |
552 | 584 | self, |
553 | 585 | hidden_states: torch.Tensor, |
| 586 | + sampling_metadata: SamplingMetadata, |
554 | 587 | ) -> Optional[torch.Tensor]: |
555 | | - logits = self.logits_processor(self.model.embed_tokens, hidden_states) |
| 588 | + logits = self.logits_processor( |
| 589 | + self.model.embed_tokens, hidden_states, sampling_metadata |
| 590 | + ) |
556 | 591 | return logits |
557 | 592 |
|
558 | 593 | def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
|
0 commit comments