Skip to content

Commit 0f46a78

Browse files
authored
[Model] [Quantization] Support quantization for Gemma3n (#21974)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent e1a7fe4 commit 0f46a78

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

vllm/model_executor/models/gemma3n.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from vllm.model_executor.sampling_metadata import SamplingMetadata
4747
from vllm.sequence import IntermediateTensors
4848

49+
from .interfaces import SupportsQuant
4950
from .utils import (AutoWeightsLoader, extract_layer_index,
5051
is_pp_missing_parameter, make_layers, maybe_prefix)
5152

@@ -68,6 +69,7 @@ def __init__(
6869
altup_num_inputs: int,
6970
altup_coef_clip: float,
7071
altup_active_idx: int,
72+
quant_config: QuantizationConfig,
7173
prefix: str,
7274
):
7375
super().__init__()
@@ -80,20 +82,23 @@ def __init__(
8082
altup_num_inputs,
8183
altup_num_inputs,
8284
bias=False,
85+
quant_config=quant_config,
8386
prefix=f"{prefix}.correction_coefs",
8487
return_bias=False,
8588
)
8689
self.prediction_coefs = ReplicatedLinear(
8790
altup_num_inputs,
8891
altup_num_inputs**2,
8992
bias=False,
93+
quant_config=quant_config,
9094
prefix=f"{prefix}.prediction_coefs",
9195
return_bias=False,
9296
)
9397
self.modality_router = ReplicatedLinear(
9498
hidden_size,
9599
altup_num_inputs,
96100
bias=False,
101+
quant_config=quant_config,
97102
prefix=f"{prefix}.modality_router",
98103
return_bias=False,
99104
)
@@ -400,6 +405,7 @@ def __init__(
400405
altup_num_inputs=config.altup_num_inputs,
401406
altup_coef_clip=config.altup_coef_clip,
402407
altup_active_idx=config.altup_active_idx,
408+
quant_config=quant_config,
403409
prefix=f"{prefix}.altup",
404410
)
405411
self.self_attn = Gemma3nAttention(
@@ -527,7 +533,7 @@ def forward(
527533

528534

529535
@support_torch_compile
530-
class Gemma3nTextModel(nn.Module):
536+
class Gemma3nTextModel(nn.Module, SupportsQuant):
531537

532538
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
533539
super().__init__()
@@ -540,6 +546,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
540546
self.embed_tokens = VocabParallelEmbedding(
541547
config.vocab_size,
542548
config.hidden_size,
549+
quant_config=quant_config,
543550
prefix=f"{prefix}.embed_tokens",
544551
)
545552
self.embed_scale = torch.tensor(
@@ -549,6 +556,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
549556
self.embed_tokens_per_layer = VocabParallelEmbedding(
550557
config.vocab_size_per_layer_input,
551558
config.num_hidden_layers * config.hidden_size_per_layer_input,
559+
quant_config=quant_config,
552560
prefix=f"{prefix}.per_layer_embed_tokens",
553561
)
554562
self.embed_scale_per_layer = torch.tensor(
@@ -582,7 +590,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
582590
gather_output=True,
583591
return_bias=False,
584592
quant_config=quant_config,
585-
prefix=f"{prefix}.{idx-1}.altup_projections",
593+
prefix=f"{prefix}.altup_projections.{idx-1}",
586594
) for idx in range(1, self.config.altup_num_inputs)
587595
])
588596
self.altup_unembed_projections = nn.ModuleList([
@@ -593,7 +601,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
593601
gather_output=True,
594602
return_bias=False,
595603
quant_config=quant_config,
596-
prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
604+
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
597605
) for idx in range(1, self.config.altup_num_inputs)
598606
])
599607

@@ -774,7 +782,7 @@ def forward(
774782
**kwargs)
775783

776784

777-
class Gemma3nForConditionalGeneration(nn.Module):
785+
class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
778786
packed_modules_mapping = {
779787
"qkv_proj": [
780788
"q_proj",

0 commit comments

Comments
 (0)