diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index eb62d5a53c1a..08315a13853c 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -64,20 +64,32 @@ class NemotronHMLP(nn.Module): def __init__( self, config: NemotronHConfig, + layer_idx: int, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", ) -> None: super().__init__() + + hybrid_override_pattern = config.hybrid_override_pattern + mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1 + if isinstance(config.intermediate_size, list): + if len(config.intermediate_size) == 1: + intermediate_size = config.intermediate_size[0] + else: + intermediate_size = config.intermediate_size[mlp_index] + else: + intermediate_size = config.intermediate_size + self.up_proj = ColumnParallelLinear( input_size=config.hidden_size, - output_size=config.intermediate_size, + output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.up_proj", ) self.down_proj = RowParallelLinear( - input_size=config.intermediate_size, + input_size=intermediate_size, output_size=config.hidden_size, bias=bias, quant_config=quant_config, @@ -110,6 +122,7 @@ def __init__( quant_config=quant_config, bias=config.mlp_bias, prefix=f"{prefix}.mixer", + layer_idx=layer_idx, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -146,7 +159,7 @@ def __init__( hidden_size=config.hidden_size, ssm_state_size=config.ssm_state_size, conv_kernel_size=config.conv_kernel, - intermediate_size=config.expand * config.hidden_size, + intermediate_size=config.mamba_num_heads * config.mamba_head_dim, use_conv_bias=config.use_conv_bias, use_bias=config.use_bias, n_groups=config.n_groups, @@ -205,7 +218,10 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = config.hidden_size // self.total_num_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -481,7 +497,7 @@ def get_mamba_state_shape_from_config( """ parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config - intermediate_size = hf_config.expand * hf_config.hidden_size + intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index 457b3371e90d..027f2911543f 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -151,7 +151,7 @@ def __init__( num_hidden_layers=52, hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", num_attention_heads=32, - attention_head_dim=128, + head_dim=128, num_key_value_heads=8, # nemo: num_query_groups mlp_hidden_act="relu2", attention_bias=False, @@ -194,7 +194,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.hybrid_override_pattern = hybrid_override_pattern self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim + self.head_dim = head_dim self.sliding_window = sliding_window self.max_position_embeddings = max_position_embeddings self.attention_dropout = attention_dropout