Skip to content

Commit 603a661

Browse files
authored
[Model] factoring out MambaMixer out of Jamba (#8993)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
1 parent fb2716d commit 603a661

File tree

3 files changed

+245
-374
lines changed

3 files changed

+245
-374
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn.parameter import Parameter
4+
5+
from vllm.attention.backends.abstract import AttentionMetadata
6+
from vllm.distributed.parallel_state import (
7+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
8+
from vllm.model_executor.custom_op import CustomOp
9+
from vllm.model_executor.layers.layernorm import RMSNorm
10+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
11+
MergedColumnParallelLinear,
12+
RowParallelLinear)
13+
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
14+
causal_conv1d_fn, causal_conv1d_update)
15+
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
16+
selective_scan_fn, selective_state_update)
17+
from vllm.model_executor.models.mamba_cache import MambaCacheParams
18+
from vllm.model_executor.utils import set_weight_attrs
19+
20+
21+
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
22+
@CustomOp.register("mamba_mixer")
23+
class MambaMixer(CustomOp):
24+
"""
25+
Compute ∆, A, B, C, and D the state space parameters and compute
26+
the `contextualized_states`. A, D are input independent
27+
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
28+
for why A isn't selective) ∆, B, C are input-dependent
29+
(this is a key difference between Mamba and the linear time
30+
invariant S4, and is why Mamba is called
31+
**selective** state spaces)
32+
"""
33+
34+
def __init__(self,
35+
hidden_size: int,
36+
ssm_state_size: int,
37+
conv_kernel_size: int,
38+
intermediate_size: int,
39+
time_step_rank: int,
40+
use_conv_bias: bool,
41+
use_bias: bool,
42+
use_rms_norm: bool,
43+
rms_norm_eps: float = 1e-5,
44+
activation="silu"):
45+
super().__init__()
46+
self.time_step_rank = time_step_rank
47+
self.ssm_state_size = ssm_state_size
48+
self.use_rms_norm = use_rms_norm
49+
self.activation = activation
50+
51+
self.conv1d = ColumnParallelLinear(
52+
input_size=conv_kernel_size,
53+
output_size=intermediate_size,
54+
bias=use_conv_bias,
55+
)
56+
# unsqueeze to fit conv1d weights shape into the linear weights shape.
57+
# Can't do this in `weight_loader` since it already exists in
58+
# `ColumnParallelLinear` and `set_weight_attrs`
59+
# doesn't allow to override it
60+
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
61+
62+
self.in_proj = MergedColumnParallelLinear(hidden_size,
63+
[intermediate_size] * 2,
64+
bias=use_bias)
65+
# selective projection used to make dt, B and C input dependent
66+
self.x_proj = RowParallelLinear(
67+
intermediate_size,
68+
time_step_rank + ssm_state_size * 2,
69+
bias=False,
70+
)
71+
# time step projection (discretization) -
72+
# In the forward we need to apply dt_proj without the bias,
73+
# as the bias is added in the selective scan kernel.
74+
self.dt_proj = ColumnParallelLinear(time_step_rank,
75+
intermediate_size,
76+
bias=True,
77+
skip_bias_add=True)
78+
79+
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
80+
tp_rank = get_tensor_model_parallel_rank()
81+
tp_size = get_tensor_model_parallel_world_size()
82+
param.data.copy_(
83+
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
84+
dim=0)[tp_rank])
85+
86+
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
87+
weight_loader(param, -torch.exp(loaded_weight.float()))
88+
89+
tp_size = get_tensor_model_parallel_world_size()
90+
self.A = nn.Parameter(
91+
torch.empty(
92+
intermediate_size // tp_size,
93+
ssm_state_size,
94+
dtype=torch.float32,
95+
))
96+
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))
97+
98+
set_weight_attrs(self.D, {"weight_loader": weight_loader})
99+
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
100+
101+
self.out_proj = RowParallelLinear(
102+
intermediate_size,
103+
hidden_size,
104+
bias=use_bias,
105+
input_is_parallel=True,
106+
)
107+
108+
self.dt_layernorm = RMSNorm(time_step_rank,
109+
eps=rms_norm_eps) if use_rms_norm else None
110+
111+
self.b_layernorm = RMSNorm(ssm_state_size,
112+
eps=rms_norm_eps) if use_rms_norm else None
113+
114+
self.c_layernorm = RMSNorm(ssm_state_size,
115+
eps=rms_norm_eps) if use_rms_norm else None
116+
117+
def forward_native(self, hidden_states: torch.Tensor,
118+
attn_metadata: AttentionMetadata,
119+
conv_state: torch.Tensor, ssm_state: torch.Tensor):
120+
pass
121+
122+
def forward_cuda(self, hidden_states: torch.Tensor,
123+
attn_metadata: AttentionMetadata,
124+
mamba_cache_params: MambaCacheParams):
125+
126+
# 1. Gated MLP's linear projection
127+
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
128+
hidden_states, gate = projected_states.chunk(2, dim=-2)
129+
130+
# 2. Convolution sequence transformation
131+
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
132+
self.conv1d.weight.size(2))
133+
134+
if attn_metadata.query_start_loc is not None \
135+
and attn_metadata.context_lens_tensor is not None:
136+
# |---------- N-1 iteration --------|
137+
# |---------------- N iteration ---------------------|
138+
# |- tokenA -|......................|-- newTokens ---|
139+
# |---------- context_len ----------|
140+
# |-------------------- seq_len ---------------------|
141+
# |-- query_len ---|
142+
hidden_states = causal_conv1d_fn(
143+
hidden_states,
144+
conv_weights,
145+
self.conv1d.bias,
146+
activation=self.activation,
147+
conv_states=mamba_cache_params.conv_state,
148+
has_initial_state=attn_metadata.context_lens_tensor > 0,
149+
cache_indices=mamba_cache_params.state_indices_tensor,
150+
query_start_loc=attn_metadata.query_start_loc)
151+
else:
152+
hidden_states = causal_conv1d_update(
153+
hidden_states.transpose(0, 1),
154+
mamba_cache_params.conv_state,
155+
conv_weights,
156+
self.conv1d.bias,
157+
self.activation,
158+
conv_state_indices=mamba_cache_params.state_indices_tensor)
159+
hidden_states = hidden_states.transpose(0, 1)
160+
161+
# 3. State Space Model sequence transformation
162+
# 3.a. input varying initialization of time_step, B and C
163+
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
164+
165+
time_step, B, C = torch.split(
166+
ssm_parameters,
167+
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
168+
dim=-1,
169+
)
170+
if self.use_rms_norm:
171+
assert self.dt_layernorm is not None
172+
assert self.b_layernorm is not None
173+
assert self.c_layernorm is not None
174+
time_step = self.dt_layernorm(time_step.contiguous())
175+
B = self.b_layernorm(B.contiguous())
176+
C = self.c_layernorm(C.contiguous())
177+
178+
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
179+
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
180+
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
181+
self.dt_proj, "bias") else None)
182+
183+
if attn_metadata.query_start_loc is not None \
184+
and attn_metadata.context_lens_tensor is not None:
185+
scan_outputs = selective_scan_fn(
186+
hidden_states,
187+
mamba_cache_params.ssm_state,
188+
discrete_time_step,
189+
self.A,
190+
B.transpose(-2, -1),
191+
C.transpose(-2, -1),
192+
self.D.float(),
193+
gate,
194+
time_proj_bias,
195+
delta_softplus=True,
196+
cache_indices=mamba_cache_params.state_indices_tensor,
197+
has_initial_state=attn_metadata.context_lens_tensor > 0,
198+
query_start_loc=attn_metadata.query_start_loc)
199+
else:
200+
scan_outputs = selective_state_update(
201+
mamba_cache_params.ssm_state,
202+
hidden_states.transpose(0, 1),
203+
discrete_time_step.transpose(0, 1),
204+
self.A,
205+
B,
206+
C,
207+
self.D,
208+
gate.transpose(0, 1),
209+
time_proj_bias,
210+
dt_softplus=True,
211+
state_batch_indices=mamba_cache_params.state_indices_tensor)
212+
scan_outputs = scan_outputs.transpose(0, 1)
213+
214+
# 4. Final linear projection
215+
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
216+
-1))[0]
217+
return contextualized_states

0 commit comments

Comments
 (0)