From 988718e99987bb95958cb0a79646800aab5c3d7c Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Mon, 6 May 2024 17:01:43 +0300 Subject: [PATCH] Jamba official hf (#14) * remove JambaConfig and use official one from transformers * changes in Jamba modeling file to align with official HF format --- vllm/model_executor/models/jamba.py | 213 ++++++++------------ vllm/transformers_utils/configs/__init__.py | 2 + 2 files changed, 89 insertions(+), 126 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 18b5a209ca01f..6133b00ac6812 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -4,15 +4,14 @@ from typing import Iterable, List, Optional, Tuple import torch -from transformers import JambaConfig -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn -from torch.nn.parameter import Parameter +from vllm.model_executor.layers.activation import SiluAndMul from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention + +from transformers import JambaConfig +from torch.nn.parameter import Parameter from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -33,6 +32,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import SamplerOutput +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -43,7 +45,6 @@ class MambaCacheParams: ssm_state: torch.Tensor = torch.Tensor() - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): """ @@ -124,28 +125,10 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): input_is_parallel=True, ) self.activation = config.hidden_act - self.apply_inner_layernorms = config.mamba_inner_layernorms - - if self.apply_inner_layernorms: - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=config.rms_norm_eps) - self.B_layernorm = RMSNorm(self.ssm_state_size, - eps=config.rms_norm_eps) - self.C_layernorm = RMSNorm(self.ssm_state_size, - eps=config.rms_norm_eps) - else: - self.dt_layernorm = None - self.B_layernorm = None - self.C_layernorm = None - - def _apply_layernorms(self, dt, B, C): - if self.dt_layernorm is not None: - dt = self.dt_layernorm.forward(dt.contiguous()) - if self.B_layernorm is not None: - B = self.B_layernorm.forward(B.contiguous()) - if self.C_layernorm is not None: - C = self.C_layernorm.forward(C.contiguous()) - return dt, B, C + + self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) def mamba_forward(self, hidden_states: torch.Tensor, @@ -189,7 +172,9 @@ def mamba_forward(self, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1, ) - time_step, B, C = self._apply_layernorms(time_step, B, C) + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -275,6 +260,36 @@ def forward( return hidden_states +class JambaMLP(nn.Module): + def __init__( + self, + config: JambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + class JambaMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -285,33 +300,27 @@ class JambaMoE(nn.Module): """ def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, + self, + config: JambaConfig, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size + self.num_total_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // self.tp_size if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - if self.num_total_experts > 1: - # init expert router iff this layer has multiple experts - self.router = ReplicatedLinear( - self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - ) + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype) self.ws = nn.Parameter( torch.empty( @@ -366,14 +375,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (batch * sequence_length, n_experts) - if self.num_total_experts > 1: - router_logits, _ = self.router(hidden_states) - else: - router_logits = torch.ones( - [hidden_states.shape[0], 1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) + router_logits, _ = self.router(hidden_states) final_hidden_states = fused_moe( hidden_states, @@ -394,28 +396,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class JambaMambaDecoderLayer(nn.Module): - def __init__( - self, - config: JambaConfig, - actual_num_experts: int, - actual_num_experts_per_tok: int, - layer_idx: int, + self, config: JambaConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None ) -> None: super().__init__() self.layer_idx = layer_idx self.config = config self.mamba = JambaMambaMixer(config, layer_idx) - self.moe = JambaMoE( - num_experts=actual_num_experts, - top_k=actual_num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_moe_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -436,20 +429,15 @@ def forward( hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, ssm_state) # Fully Connected - hidden_states, residual = self.pre_moe_layernorm( - hidden_states, residual) - hidden_states = self.moe(hidden_states) + hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) return hidden_states, residual class JambaAttentionDecoderLayer(nn.Module): def __init__( - self, - config: JambaConfig, - actual_num_experts: int, - actual_num_experts_per_tok: int, - quant_config: Optional[QuantizationConfig] = None, + self, config: JambaConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -494,16 +482,11 @@ def __init__( sliding_window=self.sliding_window, ) - self.moe = JambaMoE( - num_experts=actual_num_experts, - top_k=actual_num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_moe_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def self_attention( self, @@ -542,12 +525,14 @@ def forward( attn_metadata=attn_metadata, ) # Fully Connected - hidden_states, residual = self.pre_moe_layernorm( - hidden_states, residual) - hidden_states = self.moe(hidden_states) + hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) return hidden_states, residual +ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer} + + class JambaModel(nn.Module): def __init__( @@ -570,40 +555,12 @@ def __init__( org_num_embeddings=config.vocab_size, ) - # init each model layer, decide if it's mamba/attention and - # has experts and pass it down - - module_list = [] + decoder_layers = [] for i in range(config.num_hidden_layers): - is_attn = ((i - self.config.attn_layer_offset) % - self.config.attn_layer_period == 0) - is_expert = ((i - self.config.expert_layer_offset) % - self.config.expert_layer_period == 0) - - actual_num_experts = config.num_experts if is_expert else 1 - actual_num_experts_per_tok = config.num_experts_per_tok \ - if is_expert else 1 - - if is_attn: - module_list.append( - JambaAttentionDecoderLayer( - config, - actual_num_experts=actual_num_experts, - actual_num_experts_per_tok=actual_num_experts_per_tok, - quant_config=quant_config - )) - else: - module_list.append( - JambaMambaDecoderLayer( - config, - actual_num_experts=actual_num_experts, - actual_num_experts_per_tok=actual_num_experts_per_tok, - layer_idx=i, - )) - - self.layers = nn.ModuleList(module_list) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] + decoder_layers.append(layer_class(config, layer_idx=i, quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -732,6 +689,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] expert_params_mapping = [ @@ -758,6 +717,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if 'experts' in name: + continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 78dc6207a0352..3bccc425cc826 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,8 @@ from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.mpt import MPTConfig +from vllm.transformers_utils.configs.jamba import JambaConfig + __all__ = [ "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig" ]