Skip to content

Commit f6b9729

Browse files
[Model] FalconMamba Support (#9325)
1 parent 496e991 commit f6b9729

File tree

5 files changed

+35
-12
lines changed

5 files changed

+35
-12
lines changed

docs/source/models/supported_models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ Text Generation
8787
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
8888
-
8989
- ✅︎
90+
* - :code:`FalconMambaForCausalLM`
91+
- FalconMamba
92+
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
93+
- ✅︎
94+
-
9095
* - :code:`GemmaForCausalLM`
9196
- Gemma
9297
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.

tests/models/decoder_only/language/test_mamba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ...utils import check_outputs_equal
1212

13-
MODELS = ["state-spaces/mamba-130m-hf"]
13+
MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]
1414

1515

1616
# Use lower-level interfaces to create this greedy generator, as mamba will

vllm/model_executor/layers/layernorm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def __init__(
2727
self.variance_epsilon = eps
2828
self.variance_size_override = (None if var_hidden_size == hidden_size
2929
else var_hidden_size)
30-
3130
self.weight = nn.Parameter(torch.ones(hidden_size))
3231

3332
def forward_native(

vllm/model_executor/models/mamba.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
QuantizationConfig)
2323
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
2424
from vllm.model_executor.layers.vocab_parallel_embedding import (
25-
VocabParallelEmbedding)
25+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
2626
from vllm.model_executor.model_loader.weight_utils import (
2727
composed_weight_loader, default_weight_loader, sharded_weight_loader)
2828
from vllm.model_executor.models.interfaces import (HasInnerState,
@@ -59,7 +59,7 @@ def __init__(self, config: MambaConfig, layer_idx):
5959
self.conv_kernel_size = config.conv_kernel
6060
self.intermediate_size = config.intermediate_size
6161
self.time_step_rank = int(config.time_step_rank)
62-
62+
self.is_falcon_mamba = config.model_type == "falcon_mamba"
6363
self.conv1d = ColumnParallelLinear(
6464
input_size=self.conv_kernel_size,
6565
output_size=self.intermediate_size,
@@ -109,6 +109,13 @@ def __init__(self, config: MambaConfig, layer_idx):
109109
input_is_parallel=True,
110110
)
111111
self.activation = config.hidden_act
112+
if self.is_falcon_mamba:
113+
self.dt_layernorm = RMSNorm(self.time_step_rank,
114+
eps=config.mixer_rms_eps)
115+
self.b_layernorm = RMSNorm(self.ssm_state_size,
116+
eps=config.mixer_rms_eps)
117+
self.c_layernorm = RMSNorm(self.ssm_state_size,
118+
eps=config.mixer_rms_eps)
112119

113120
def forward(self, hidden_states: torch.Tensor,
114121
attn_metadata: AttentionMetadata,
@@ -158,8 +165,12 @@ def forward(self, hidden_states: torch.Tensor,
158165
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
159166
dim=-1,
160167
)
161-
162-
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
168+
# Note that Jamba and FalconMamba normalizes B, C, and time_step here
169+
# but Mamba doesn't.
170+
if self.is_falcon_mamba:
171+
time_step = self.dt_layernorm(time_step.contiguous())
172+
B = self.b_layernorm(B.contiguous())
173+
C = self.c_layernorm(C.contiguous())
163174

164175
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
165176
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
@@ -213,11 +224,9 @@ def __init__(self,
213224
super().__init__()
214225
self.layer_idx = layer_idx
215226
self.config = config
227+
self.is_falcon_mamba = config.model_type == "falcon_mamba"
216228
self.mixer = MambaMixer(config, layer_idx)
217-
218229
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
219-
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
220-
eps=config.layer_norm_epsilon)
221230

222231
def forward(
223232
self,
@@ -319,8 +328,18 @@ def __init__(
319328
self.unpadded_vocab_size = config.vocab_size
320329
if lora_config:
321330
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
322-
323-
self.lm_head = self.backbone.embeddings
331+
if config.tie_word_embeddings:
332+
self.lm_head = self.backbone.embeddings
333+
else:
334+
self.lm_head = ParallelLMHead(
335+
self.unpadded_vocab_size,
336+
config.hidden_size,
337+
org_num_embeddings=config.vocab_size,
338+
padding_size=DEFAULT_VOCAB_PADDING_SIZE
339+
# We need bigger padding if using lora for kernel
340+
# compatibility
341+
if not lora_config else lora_config.lora_vocab_padding_size,
342+
)
324343

325344
# Used to track and store by the Mamba cache between steps.
326345
self.mamba_cache: Optional[MambaCacheManager] = None
@@ -398,7 +417,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
398417
for name, loaded_weight in weights:
399418
if "A_log" in name:
400419
name = name.replace("A_log", "A")
401-
402420
# Skip loading extra bias for GPTQ models.
403421
if name.endswith(".bias") and name not in params_dict:
404422
continue

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
# For decapoda-research/llama-*
5454
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
5555
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
56+
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
5657
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
5758
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
5859
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),

0 commit comments

Comments
 (0)