Skip to content

Commit

Permalink
Added mamba.py backend (#30139)
Browse files Browse the repository at this point in the history
* Update README.md

* tests: forward ok

* backward test done

* done testing

* removed check. scripts

* Update README.md

* added use_mambapy arg

* fixed typo in warning

* protected imports w/ mambapy package

* delete pscan.py + raise rather than assert

* Update import_utils.py

* fix whitespaces and unused import

* trailing whitespace + import block unformatted

* Update modeling_mamba.py

* transpose before pscan

* shape comment

* ran make style

* use_mambapy=False by default

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* ran make fix-copies

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
2 people authored and itazap committed Jul 25, 2024
1 parent 4352e10 commit 6648bed
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/mamba/configuration_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class MambaConfig(PretrainedConfig):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
use_mambapy (`bool`, *optional*, defaults to `False`):
Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
Example:
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
use_mambapy=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -149,5 +152,6 @@ def __init__(
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.use_mambapy = use_mambapy

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
58 changes: 42 additions & 16 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@
add_start_docstrings_to_model_forward,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
from .configuration_mamba import MambaConfig


logger = logging.get_logger(__name__)

if is_mambapy_available():
from mambapy.pscan import pscan
else:
pscan = None

if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
Expand Down Expand Up @@ -87,6 +92,8 @@ def __init__(self, config: MambaConfig, layer_idx: int):
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]

self.use_mambapy = config.use_mambapy

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
# selective projection used to make dt, B and C input dependant
Expand All @@ -105,11 +112,23 @@ def __init__(self, config: MambaConfig, layer_idx: int):
self.use_bias = config.use_bias

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
if self.use_mambapy:
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
else:
raise ImportError(
"use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
)
else:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

def cuda_kernels_forward(
self,
Expand Down Expand Up @@ -257,17 +276,24 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()

# 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))
if self.use_mambapy and self.training and cache_params is None:
hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size]

if cache_params is not None:
cache_params.update_ssm_state(self.layer_idx, ssm_state)
scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
scan_output = scan_output + hidden_states * self.D[None, :, None]
scan_output = scan_output * self.act(gate)
else:
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))

if cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,12 @@ def is_causal_conv1d_available():
return False


def is_mambapy_available():
if is_torch_available():
return _is_package_available("mambapy")
return False


def is_torch_mps_available():
if is_torch_available():
import torch
Expand Down

0 comments on commit 6648bed

Please sign in to comment.