diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index d44d2c790198..f94ab75f9a4f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -77,7 +77,7 @@ def forward_native( input_dtype = x.dtype x = x * nn.functional.silu(gate.to(torch.float32)) if not self.use_rms_norm: - return x + return x.to(input_dtype) if self.n_groups == 1: if self.tp_size > 1: @@ -117,9 +117,11 @@ def forward_cuda( x: torch.Tensor, gate: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - + input_dtype = x.dtype if not self.use_rms_norm: - return x * nn.functional.silu(gate.to(torch.float32)) + # Keep gate in float32 for numerical stability during silu + return x * nn.functional.silu(gate.to( + torch.float32)).to(input_dtype) if self.tp_size > 1 or self.n_groups != 1: return self.forward_native(x, gate) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 591a75ffdb73..1c0e3911fcce 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -453,7 +453,6 @@ def forward( attn_metadata = get_forward_context().attn_metadata mamba2_metadata = prepare_mamba2_metadata( chunk_size=self.config.mamba_chunk_size, - input_ids=input_ids, attn_metadata=attn_metadata, ) if get_pp_group().is_first_rank: