|
22 | 22 | QuantizationConfig) |
23 | 23 | from vllm.model_executor.layers.sampler import Sampler, SamplerOutput |
24 | 24 | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
25 | | - VocabParallelEmbedding) |
| 25 | + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) |
26 | 26 | from vllm.model_executor.model_loader.weight_utils import ( |
27 | 27 | composed_weight_loader, default_weight_loader, sharded_weight_loader) |
28 | 28 | from vllm.model_executor.models.interfaces import (HasInnerState, |
@@ -59,7 +59,7 @@ def __init__(self, config: MambaConfig, layer_idx): |
59 | 59 | self.conv_kernel_size = config.conv_kernel |
60 | 60 | self.intermediate_size = config.intermediate_size |
61 | 61 | self.time_step_rank = int(config.time_step_rank) |
62 | | - |
| 62 | + self.is_falcon_mamba = config.model_type == "falcon_mamba" |
63 | 63 | self.conv1d = ColumnParallelLinear( |
64 | 64 | input_size=self.conv_kernel_size, |
65 | 65 | output_size=self.intermediate_size, |
@@ -109,6 +109,13 @@ def __init__(self, config: MambaConfig, layer_idx): |
109 | 109 | input_is_parallel=True, |
110 | 110 | ) |
111 | 111 | self.activation = config.hidden_act |
| 112 | + if self.is_falcon_mamba: |
| 113 | + self.dt_layernorm = RMSNorm(self.time_step_rank, |
| 114 | + eps=config.mixer_rms_eps) |
| 115 | + self.b_layernorm = RMSNorm(self.ssm_state_size, |
| 116 | + eps=config.mixer_rms_eps) |
| 117 | + self.c_layernorm = RMSNorm(self.ssm_state_size, |
| 118 | + eps=config.mixer_rms_eps) |
112 | 119 |
|
113 | 120 | def forward(self, hidden_states: torch.Tensor, |
114 | 121 | attn_metadata: AttentionMetadata, |
@@ -158,8 +165,12 @@ def forward(self, hidden_states: torch.Tensor, |
158 | 165 | [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], |
159 | 166 | dim=-1, |
160 | 167 | ) |
161 | | - |
162 | | - # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. |
| 168 | + # Note that Jamba and FalconMamba normalizes B, C, and time_step here |
| 169 | + # but Mamba doesn't. |
| 170 | + if self.is_falcon_mamba: |
| 171 | + time_step = self.dt_layernorm(time_step.contiguous()) |
| 172 | + B = self.b_layernorm(B.contiguous()) |
| 173 | + C = self.c_layernorm(C.contiguous()) |
163 | 174 |
|
164 | 175 | discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) |
165 | 176 | # 3.c perform the recurrence y ← SSM(A, B, C)(x) |
@@ -213,11 +224,9 @@ def __init__(self, |
213 | 224 | super().__init__() |
214 | 225 | self.layer_idx = layer_idx |
215 | 226 | self.config = config |
| 227 | + self.is_falcon_mamba = config.model_type == "falcon_mamba" |
216 | 228 | self.mixer = MambaMixer(config, layer_idx) |
217 | | - |
218 | 229 | self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
219 | | - self.pre_ff_layernorm = RMSNorm(config.hidden_size, |
220 | | - eps=config.layer_norm_epsilon) |
221 | 230 |
|
222 | 231 | def forward( |
223 | 232 | self, |
@@ -319,8 +328,18 @@ def __init__( |
319 | 328 | self.unpadded_vocab_size = config.vocab_size |
320 | 329 | if lora_config: |
321 | 330 | self.unpadded_vocab_size += lora_config.lora_extra_vocab_size |
322 | | - |
323 | | - self.lm_head = self.backbone.embeddings |
| 331 | + if config.tie_word_embeddings: |
| 332 | + self.lm_head = self.backbone.embeddings |
| 333 | + else: |
| 334 | + self.lm_head = ParallelLMHead( |
| 335 | + self.unpadded_vocab_size, |
| 336 | + config.hidden_size, |
| 337 | + org_num_embeddings=config.vocab_size, |
| 338 | + padding_size=DEFAULT_VOCAB_PADDING_SIZE |
| 339 | + # We need bigger padding if using lora for kernel |
| 340 | + # compatibility |
| 341 | + if not lora_config else lora_config.lora_vocab_padding_size, |
| 342 | + ) |
324 | 343 |
|
325 | 344 | # Used to track and store by the Mamba cache between steps. |
326 | 345 | self.mamba_cache: Optional[MambaCacheManager] = None |
@@ -398,7 +417,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
398 | 417 | for name, loaded_weight in weights: |
399 | 418 | if "A_log" in name: |
400 | 419 | name = name.replace("A_log", "A") |
401 | | - |
402 | 420 | # Skip loading extra bias for GPTQ models. |
403 | 421 | if name.endswith(".bias") and name not in params_dict: |
404 | 422 | continue |
|
0 commit comments