diff --git a/src/transformers/kernels/falcon_mamba/__init__.py b/src/transformers/kernels/falcon_mamba/__init__.py new file mode 100644 index 00000000000000..da88e3394f6533 --- /dev/null +++ b/src/transformers/kernels/falcon_mamba/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .selective_scan_with_ln_interface import mamba_inner_fn diff --git a/src/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py b/src/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py new file mode 100644 index 00000000000000..4a74986a81a13f --- /dev/null +++ b/src/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py @@ -0,0 +1,525 @@ +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch.cuda.amp import custom_bwd, custom_fwd + + +try: + import causal_conv1d_cuda +except ImportError: + causal_conv1d_cuda = None + +import mamba_ssm +import selective_scan_cuda + + +# For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127 +if hasattr(mamba_ssm.ops.triton, "layernorm"): + from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd +else: + from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd + + +class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False + ): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + dout, + x, + out, + None, + ctx.delta_softplus, + False, # option to recompute out_z, not used here + ) + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + return ( + du, + ddelta, + dA, + dB, + dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None, + ) + + +def rms_norm_forward( + x, + weight, + bias, + eps=1e-6, + is_rms_norm=True, +): + # x (b l) d + if x.stride(-1) != 1: + x = x.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0] + # y (b l) d + return y + + +def selective_scan_fn( + u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False +): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + + +def selective_scan_ref( + u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False +): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum("bdn,dn->bd", x, C) + else: + if C.dim() == 3: + y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) + else: + y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +class MambaInnerFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + out_proj_weight, + out_proj_bias, + A, + B=None, + C=None, + D=None, + delta_bias=None, + B_proj_bias=None, + C_proj_bias=None, + delta_softplus=True, + checkpoint_lvl=1, + b_rms_weight=None, + c_rms_weight=None, + dt_rms_weight=None, + b_c_dt_rms_eps=1e-6, + ): + """ + xz: (batch, dim, seqlen) + """ + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." + assert checkpoint_lvl in [0, 1] + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + if torch.is_autocast_enabled(): + x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_bias = ( + out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None + ) + if xz.stride(-1) != 1: + xz = xz.contiguous() + conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") + x, z = xz.chunk(2, dim=1) + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) + ctx.is_variable_B = B is None + ctx.is_variable_C = C is None + ctx.B_proj_bias_is_None = B_proj_bias is None + ctx.C_proj_bias_is_None = C_proj_bias is None + if B is None: # variable B + B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if B.stride(-1) != 1: + B = B.contiguous() + if C is None: # variable C + C = x_dbl[:, -d_state:] # (bl dstate) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if C.stride(-1) != 1: + C = C.contiguous() + if D is not None: + D = D.contiguous() + + if b_rms_weight is not None: + B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() + B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps) + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + if c_rms_weight is not None: + C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() + C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps) + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + if dt_rms_weight is not None: + delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() + delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps) + delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() + + out, scan_intermediates, out_z = selective_scan_cuda.fwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + ) + ctx.delta_softplus = delta_softplus + ctx.out_proj_bias_is_None = out_proj_bias is None + ctx.checkpoint_lvl = checkpoint_lvl + ctx.b_rms_weight = b_rms_weight + ctx.c_rms_weight = c_rms_weight + ctx.dt_rms_weight = dt_rms_weight + ctx.b_c_dt_rms_eps = b_c_dt_rms_eps + if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass + conv1d_out, delta = None, None + ctx.save_for_backward( + xz, + conv1d_weight, + conv1d_bias, + x_dbl, + x_proj_weight, + delta_proj_weight, + out_proj_weight, + conv1d_out, + delta, + A, + B, + C, + D, + delta_bias, + scan_intermediates, + b_rms_weight, + c_rms_weight, + dt_rms_weight, + out, + ) + return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) + + @staticmethod + @custom_bwd + def backward(ctx, dout): + # dout: (batch, seqlen, dim) + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." + ( + xz, + conv1d_weight, + conv1d_bias, + x_dbl, + x_proj_weight, + delta_proj_weight, + out_proj_weight, + conv1d_out, + delta, + A, + B, + C, + D, + delta_bias, + scan_intermediates, + b_rms_weight, + c_rms_weight, + dt_rms_weight, + out, + ) = ctx.saved_tensors + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + if dout.stride(-1) != 1: + dout = dout.contiguous() + if ctx.checkpoint_lvl == 1: + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) + if dt_rms_weight is not None: + delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() + delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps) + delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() + if b_rms_weight is not None: + # Recompute & RMSNorm B + B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() + B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps) + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + if c_rms_weight is not None: + # Recompute & RMSNorm C + C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() + C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps) + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + dxz = torch.empty_like(xz) # (batch, dim, seqlen) + dx, dz = dxz.chunk(2, dim=1) + dout = rearrange(dout, "b l e -> e (b l)") + dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) + dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( + conv1d_out, + delta, + A, + B, + C, + D, + z, + delta_bias, + dout_y, + scan_intermediates, + out, + dz, + ctx.delta_softplus, + True, # option to recompute out_z + ) + dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) + dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None + dD = dD if D is not None else None + dx_dbl = torch.empty_like(x_dbl) + dB_proj_bias = None + if ctx.is_variable_B: + if not A.is_complex(): + dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None + dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d) + dB = None + dC_proj_bias = None + if ctx.is_variable_C: + if not A.is_complex(): + dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None + dx_dbl[:, -d_state:] = dC # (bl d) + dC = None + ddelta = rearrange(ddelta, "b d l -> d (b l)") + ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) + dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) + dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") + dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) + dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) + dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True + ) + dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None + dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") + return ( + dxz, + dconv1d_weight, + dconv1d_bias, + dx_proj_weight, + ddelta_proj_weight, + dout_proj_weight, + dout_proj_bias, + dA, + dB, + dC, + dD, + ddelta_bias if delta_bias is not None else None, + # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps + dB_proj_bias, + dC_proj_bias, + None, + None, + None, + None, + None, + None, + ) + + +def mamba_inner_fn( + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + out_proj_weight, + out_proj_bias, + A, + B=None, + C=None, + D=None, + delta_bias=None, + B_proj_bias=None, + C_proj_bias=None, + delta_softplus=True, + checkpoint_lvl=1, + b_rms_weight=None, + c_rms_weight=None, + dt_rms_weight=None, + b_c_dt_rms_eps=1e-6, +): + return MambaInnerFn.apply( + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + out_proj_weight, + out_proj_bias, + A, + B, + C, + D, + delta_bias, + B_proj_bias, + C_proj_bias, + delta_softplus, + checkpoint_lvl, + b_rms_weight, + c_rms_weight, + dt_rms_weight, + b_c_dt_rms_eps, + ) diff --git a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py index 71df37a980cd0a..cabba738a479e1 100644 --- a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py @@ -23,7 +23,6 @@ logger = logging.get_logger(__name__) -# Copied from transformers.models.mamba.configuration_mamba.MambaConfig with mamba->falcon_mamba,Mamba->FalconMamba,MAMBA->FALCON_MAMBA,state-spaces/falcon_mamba-2.8b->tiiuae/falcon-mamba-7b,use_falcon_mambapy->use_mambapy class FalconMambaConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA @@ -82,8 +81,8 @@ class FalconMambaConfig(PretrainedConfig): Whether or not the cache should be used. use_mambapy (`bool`, *optional*, defaults to `False`): Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not avaiable. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. - - + mixer_rms_eps (`float`, *optional*, defaults to 1e-06): + The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. Example: ```python @@ -127,6 +126,7 @@ def __init__( rescale_prenorm_residual=False, use_cache=True, use_mambapy=False, + mixer_rms_eps=1e-6, **kwargs, ): self.vocab_size = vocab_size @@ -154,5 +154,6 @@ def __init__( self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache self.use_mambapy = use_mambapy + self.mixer_rms_eps = mixer_rms_eps super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 07374fe1dfd7b5..f682f75f222ef6 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 state-spaces/falcon_mamba org and HuggingFace Inc. team. +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,8 +45,10 @@ pscan = None if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update + + from ...kernels.falcon_mamba import mamba_inner_fn else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None @@ -131,6 +133,15 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here + self.register_buffer( + "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False + ) + self.register_buffer( + "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False + ) + self.rms_eps = config.mixer_rms_eps + if not is_fast_path_available: if self.use_mambapy: if is_mambapy_available(): @@ -175,6 +186,10 @@ def cuda_kernels_forward( self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, + b_rms_weight=self.b_c_rms, + c_rms_weight=self.b_c_rms, + dt_rms_weight=self.dt_rms, + b_c_dt_rms_eps=self.rms_eps, ) else: @@ -214,9 +229,9 @@ def cuda_kernels_forward( ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) - B = rms_forward(B) - C = rms_forward(C) - time_step = rms_forward(time_step) + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module # at the price of a small overhead. @@ -315,9 +330,9 @@ def slow_forward( ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) - B = rms_forward(B) - C = rms_forward(C) - time_step = rms_forward(time_step) + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] discrete_time_step = nn.functional.softplus(discrete_time_step).transpose( diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index d75014f370d29f..b94f235a1a61e3 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -524,3 +524,32 @@ def test_batched_generation(self): out = tok.batch_decode(out, skip_special_tokens=True) self.assertListEqual(out, EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_training_kernel(self): + model_id = "tiiuae/falcon-mamba-7b" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16) + tokenizer.pad_token_id = tokenizer.eos_token_id + + text = "Hello today" + + inputs = tokenizer(text, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + logits = torch.argmax(model(**inputs).logits, dim=-1) + + out_no_training = tokenizer.batch_decode(logits) + + model.train() + lm_logits = model(**inputs).logits + next_token = torch.argmax(lm_logits, dim=-1) + + out_training = tokenizer.batch_decode(next_token) + + # Just verify backward works + loss = (1 - lm_logits).mean() + loss.backward() + + self.assertEqual(out_training, out_no_training) diff --git a/utils/check_inits.py b/utils/check_inits.py index 19c23279b9b8d8..a9a841f3bbffb6 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -332,6 +332,7 @@ def get_transformers_submodules() -> List[str]: "modeling_attn_mask_utils", "safetensors_conversion", "modeling_gguf_pytorch_utils", + "kernels.falcon_mamba", ]