@@ -64,20 +64,32 @@ class NemotronHMLP(nn.Module):
6464 def __init__ (
6565 self ,
6666 config : NemotronHConfig ,
67+ layer_idx : int ,
6768 quant_config : Optional [QuantizationConfig ] = None ,
6869 bias : bool = False ,
6970 prefix : str = "" ,
7071 ) -> None :
7172 super ().__init__ ()
73+
74+ hybrid_override_pattern = config .hybrid_override_pattern
75+ mlp_index = hybrid_override_pattern [:layer_idx + 1 ].count ("-" ) - 1
76+ if isinstance (config .intermediate_size , list ):
77+ if len (config .intermediate_size ) == 1 :
78+ intermediate_size = config .intermediate_size [0 ]
79+ else :
80+ intermediate_size = config .intermediate_size [mlp_index ]
81+ else :
82+ intermediate_size = config .intermediate_size
83+
7284 self .up_proj = ColumnParallelLinear (
7385 input_size = config .hidden_size ,
74- output_size = config . intermediate_size ,
86+ output_size = intermediate_size ,
7587 bias = bias ,
7688 quant_config = quant_config ,
7789 prefix = f"{ prefix } .up_proj" ,
7890 )
7991 self .down_proj = RowParallelLinear (
80- input_size = config . intermediate_size ,
92+ input_size = intermediate_size ,
8193 output_size = config .hidden_size ,
8294 bias = bias ,
8395 quant_config = quant_config ,
@@ -110,6 +122,7 @@ def __init__(
110122 quant_config = quant_config ,
111123 bias = config .mlp_bias ,
112124 prefix = f"{ prefix } .mixer" ,
125+ layer_idx = layer_idx ,
113126 )
114127
115128 self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -146,7 +159,7 @@ def __init__(
146159 hidden_size = config .hidden_size ,
147160 ssm_state_size = config .ssm_state_size ,
148161 conv_kernel_size = config .conv_kernel ,
149- intermediate_size = config .expand * config .hidden_size ,
162+ intermediate_size = config .mamba_num_heads * config .mamba_head_dim ,
150163 use_conv_bias = config .use_conv_bias ,
151164 use_bias = config .use_bias ,
152165 n_groups = config .n_groups ,
@@ -205,7 +218,10 @@ def __init__(
205218 # the KV heads across multiple tensor parallel GPUs.
206219 assert tp_size % self .total_num_kv_heads == 0
207220 self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
208- self .head_dim = config .hidden_size // self .total_num_heads
221+ if hasattr (config , "head_dim" ) and config .head_dim is not None :
222+ self .head_dim = config .head_dim
223+ else :
224+ self .head_dim = config .hidden_size // self .total_num_heads
209225 self .q_size = self .num_heads * self .head_dim
210226 self .kv_size = self .num_kv_heads * self .head_dim
211227 self .scaling = self .head_dim ** - 0.5
@@ -481,7 +497,7 @@ def get_mamba_state_shape_from_config(
481497 """
482498 parallel_config = vllm_config .parallel_config
483499 hf_config = vllm_config .model_config .hf_config
484- intermediate_size = hf_config .expand * hf_config .hidden_size
500+ intermediate_size = hf_config .mamba_num_heads * hf_config .mamba_head_dim
485501
486502 return MambaStateShapeCalculator .mamba2_state_shape (
487503 intermediate_size = intermediate_size ,
0 commit comments