Skip to content

Commit

Permalink
Jamba official hf (vllm-project#14)
Browse files Browse the repository at this point in the history
* remove JambaConfig and use official one from transformers

* changes in Jamba modeling file to align with official HF format
  • Loading branch information
tomeras91 committed May 6, 2024
1 parent af7a4ac commit 988718e
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 126 deletions.
213 changes: 87 additions & 126 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from typing import Iterable, List, Optional, Tuple

import torch
from transformers import JambaConfig
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from torch import nn
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention

from transformers import JambaConfig
from torch.nn.parameter import Parameter
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand All @@ -33,6 +32,9 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand All @@ -43,7 +45,6 @@ class MambaCacheParams:
ssm_state: torch.Tensor = torch.Tensor()



# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class JambaMambaMixer(nn.Module):
"""
Expand Down Expand Up @@ -124,28 +125,10 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
input_is_parallel=True,
)
self.activation = config.hidden_act
self.apply_inner_layernorms = config.mamba_inner_layernorms

if self.apply_inner_layernorms:
self.dt_layernorm = RMSNorm(self.time_step_rank,
eps=config.rms_norm_eps)
self.B_layernorm = RMSNorm(self.ssm_state_size,
eps=config.rms_norm_eps)
self.C_layernorm = RMSNorm(self.ssm_state_size,
eps=config.rms_norm_eps)
else:
self.dt_layernorm = None
self.B_layernorm = None
self.C_layernorm = None

def _apply_layernorms(self, dt, B, C):
if self.dt_layernorm is not None:
dt = self.dt_layernorm.forward(dt.contiguous())
if self.B_layernorm is not None:
B = self.B_layernorm.forward(B.contiguous())
if self.C_layernorm is not None:
C = self.C_layernorm.forward(C.contiguous())
return dt, B, C

self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
self.b_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
self.c_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)

def mamba_forward(self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -189,7 +172,9 @@ def mamba_forward(self,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
time_step, B, C = self._apply_layernorms(time_step, B, C)
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())

discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
Expand Down Expand Up @@ -275,6 +260,36 @@ def forward(
return hidden_states


class JambaMLP(nn.Module):
def __init__(
self,
config: JambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class JambaMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Expand All @@ -285,33 +300,27 @@ class JambaMoE(nn.Module):
"""

def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
self,
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.num_total_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // self.tp_size

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype

if self.num_total_experts > 1:
# init expert router iff this layer has multiple experts
self.router = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
)
self.router = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype)

self.ws = nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -366,14 +375,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
if self.num_total_experts > 1:
router_logits, _ = self.router(hidden_states)
else:
router_logits = torch.ones(
[hidden_states.shape[0], 1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
router_logits, _ = self.router(hidden_states)

final_hidden_states = fused_moe(
hidden_states,
Expand All @@ -394,28 +396,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class JambaMambaDecoderLayer(nn.Module):

def __init__(
self,
config: JambaConfig,
actual_num_experts: int,
actual_num_experts_per_tok: int,
layer_idx: int,
self, config: JambaConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.mamba = JambaMambaMixer(config, layer_idx)
self.moe = JambaMoE(
num_experts=actual_num_experts,
top_k=actual_num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_moe_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config, quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
Expand All @@ -436,20 +429,15 @@ def forward(
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
ssm_state)
# Fully Connected
hidden_states, residual = self.pre_moe_layernorm(
hidden_states, residual)
hidden_states = self.moe(hidden_states)
hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual


class JambaAttentionDecoderLayer(nn.Module):

def __init__(
self,
config: JambaConfig,
actual_num_experts: int,
actual_num_experts_per_tok: int,
quant_config: Optional[QuantizationConfig] = None,
self, config: JambaConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -494,16 +482,11 @@ def __init__(
sliding_window=self.sliding_window,
)

self.moe = JambaMoE(
num_experts=actual_num_experts,
top_k=actual_num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_moe_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config, quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def self_attention(
self,
Expand Down Expand Up @@ -542,12 +525,14 @@ def forward(
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.pre_moe_layernorm(
hidden_states, residual)
hidden_states = self.moe(hidden_states)
hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual


ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}


class JambaModel(nn.Module):

def __init__(
Expand All @@ -570,40 +555,12 @@ def __init__(
org_num_embeddings=config.vocab_size,
)

# init each model layer, decide if it's mamba/attention and
# has experts and pass it down

module_list = []
decoder_layers = []
for i in range(config.num_hidden_layers):
is_attn = ((i - self.config.attn_layer_offset) %
self.config.attn_layer_period == 0)
is_expert = ((i - self.config.expert_layer_offset) %
self.config.expert_layer_period == 0)

actual_num_experts = config.num_experts if is_expert else 1
actual_num_experts_per_tok = config.num_experts_per_tok \
if is_expert else 1

if is_attn:
module_list.append(
JambaAttentionDecoderLayer(
config,
actual_num_experts=actual_num_experts,
actual_num_experts_per_tok=actual_num_experts_per_tok,
quant_config=quant_config
))
else:
module_list.append(
JambaMambaDecoderLayer(
config,
actual_num_experts=actual_num_experts,
actual_num_experts_per_tok=actual_num_experts_per_tok,
layer_idx=i,
))

self.layers = nn.ModuleList(module_list)
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
decoder_layers.append(layer_class(config, layer_idx=i, quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -732,6 +689,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

expert_params_mapping = [
Expand All @@ -758,6 +717,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]):
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if 'experts' in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.mpt import MPTConfig

from vllm.transformers_utils.configs.jamba import JambaConfig

__all__ = [
"ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig"
]

0 comments on commit 988718e

Please sign in to comment.