4646from vllm .model_executor .sampling_metadata import SamplingMetadata
4747from vllm .sequence import IntermediateTensors
4848
49+ from .interfaces import SupportsQuant
4950from .utils import (AutoWeightsLoader , extract_layer_index ,
5051 is_pp_missing_parameter , make_layers , maybe_prefix )
5152
@@ -68,6 +69,7 @@ def __init__(
6869 altup_num_inputs : int ,
6970 altup_coef_clip : float ,
7071 altup_active_idx : int ,
72+ quant_config : QuantizationConfig ,
7173 prefix : str ,
7274 ):
7375 super ().__init__ ()
@@ -80,20 +82,23 @@ def __init__(
8082 altup_num_inputs ,
8183 altup_num_inputs ,
8284 bias = False ,
85+ quant_config = quant_config ,
8386 prefix = f"{ prefix } .correction_coefs" ,
8487 return_bias = False ,
8588 )
8689 self .prediction_coefs = ReplicatedLinear (
8790 altup_num_inputs ,
8891 altup_num_inputs ** 2 ,
8992 bias = False ,
93+ quant_config = quant_config ,
9094 prefix = f"{ prefix } .prediction_coefs" ,
9195 return_bias = False ,
9296 )
9397 self .modality_router = ReplicatedLinear (
9498 hidden_size ,
9599 altup_num_inputs ,
96100 bias = False ,
101+ quant_config = quant_config ,
97102 prefix = f"{ prefix } .modality_router" ,
98103 return_bias = False ,
99104 )
@@ -400,6 +405,7 @@ def __init__(
400405 altup_num_inputs = config .altup_num_inputs ,
401406 altup_coef_clip = config .altup_coef_clip ,
402407 altup_active_idx = config .altup_active_idx ,
408+ quant_config = quant_config ,
403409 prefix = f"{ prefix } .altup" ,
404410 )
405411 self .self_attn = Gemma3nAttention (
@@ -527,7 +533,7 @@ def forward(
527533
528534
529535@support_torch_compile
530- class Gemma3nTextModel (nn .Module ):
536+ class Gemma3nTextModel (nn .Module , SupportsQuant ):
531537
532538 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
533539 super ().__init__ ()
@@ -540,6 +546,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
540546 self .embed_tokens = VocabParallelEmbedding (
541547 config .vocab_size ,
542548 config .hidden_size ,
549+ quant_config = quant_config ,
543550 prefix = f"{ prefix } .embed_tokens" ,
544551 )
545552 self .embed_scale = torch .tensor (
@@ -549,6 +556,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
549556 self .embed_tokens_per_layer = VocabParallelEmbedding (
550557 config .vocab_size_per_layer_input ,
551558 config .num_hidden_layers * config .hidden_size_per_layer_input ,
559+ quant_config = quant_config ,
552560 prefix = f"{ prefix } .per_layer_embed_tokens" ,
553561 )
554562 self .embed_scale_per_layer = torch .tensor (
@@ -582,7 +590,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
582590 gather_output = True ,
583591 return_bias = False ,
584592 quant_config = quant_config ,
585- prefix = f"{ prefix } .{ idx - 1 } .altup_projections " ,
593+ prefix = f"{ prefix } .altup_projections. { idx - 1 } " ,
586594 ) for idx in range (1 , self .config .altup_num_inputs )
587595 ])
588596 self .altup_unembed_projections = nn .ModuleList ([
@@ -593,7 +601,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
593601 gather_output = True ,
594602 return_bias = False ,
595603 quant_config = quant_config ,
596- prefix = f"{ prefix } .{ idx - 1 } .altup_unembed_projections " ,
604+ prefix = f"{ prefix } .altup_unembed_projections. { idx - 1 } " ,
597605 ) for idx in range (1 , self .config .altup_num_inputs )
598606 ])
599607
@@ -774,7 +782,7 @@ def forward(
774782 ** kwargs )
775783
776784
777- class Gemma3nForConditionalGeneration (nn .Module ):
785+ class Gemma3nForConditionalGeneration (nn .Module , SupportsQuant ):
778786 packed_modules_mapping = {
779787 "qkv_proj" : [
780788 "q_proj" ,
0 commit comments