Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ ENV PIP_CONSTRAINT=""
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
# Using varlen_mamba for variable length sequence support
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba"
# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
Expand Down
20 changes: 20 additions & 0 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@
from fast_llm.tensor import Initializer


class BaseSSMKwargs:
_kwargs_attributes = {
"cu_seqlens": "cu_seqlens",
"seq_idx": "seq_idx",
"ssm_position_ids": "ssm_position_ids",
}

_prefix = ""

def __init_subclass__(cls, prefix="", **kwargs):
super().__init_subclass__(**kwargs)
cls._prefix = prefix
for attr, value in BaseSSMKwargs._kwargs_attributes.items():
setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value)


class SSMKwargs(BaseSSMKwargs, prefix=""):
pass


class SSMDimNames:
# TODO: Use separate tensor space for different mixers so there is no risk of name conflict.
state = "ssm_state" # State dimension (N), aka head size / num channels
Expand Down
74 changes: 61 additions & 13 deletions fast_llm/layers/ssm/mamba2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import typing

Expand All @@ -6,17 +7,28 @@
from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace
from fast_llm.functional.config import ActivationType
from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear
from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames
from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames, SSMKwargs
from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias
from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs
from fast_llm.layers.transformer.transformer import Mixer
from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_
from fast_llm.utils import Assert, div, get_lr_scale

_mamba_varlen = False
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa

_mamba_available = True
sig = inspect.signature(selective_scan_fn)
if "position_indices" in sig.parameters:
_mamba_varlen = True
logging.warning("Using selective_scan_fn from varlen_mamba that supports packing")
else:
_mamba_varlen = False
logging.warning("Using selective_scan_fn from original mamba without packing support")
# for training with packing install https://github.com/jxiw/varlen_mamba
# see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md

except (ImportError, RuntimeError):
_mamba_available = False

Expand Down Expand Up @@ -143,8 +155,16 @@ def __init__(
)

def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Note, we are nto doing "read" sequence-tensor parallel trainign here, since inner_projection is gathered over all GPUS.
This is also desired, since the currently used mamba kernel does not support STP.
TODO: use correct kernel from Mamba2!
"""
assert _mamba_available
assert _causal_conv1d_available
cu_seqlens = kwargs[SSMKwargs.cu_seqlens]
seq_idx = kwargs[SSMKwargs.seq_idx]
position_indices = kwargs[SSMKwargs.ssm_position_ids]

# inner_projection : (batch/local_sequence, local_sequence/batch, hidden)
# -> (batch/sequence, sequence/batch, inner_projection)
Expand Down Expand Up @@ -174,9 +194,20 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
.repeat_interleave(self._group_heads, 1, output_size=self._local_heads)
.flatten(1, 2)
)
x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu")

if cu_seqlens is not None:
# from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152
x = _causal_conv1d_fn(
x=x.transpose(1, 2).contiguous().transpose(1, 2),
weight=self.conv1d_weight.squeeze(1),
bias=self.conv1d_bias,
seq_idx=seq_idx,
activation="silu",
)
else:
x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu")

if not self._config.repeat_kv_before_conv:
x = (
x.unflatten(1, (self._local_head_groups, self._config.state_size))
.repeat_interleave(self._group_heads, 1, output_size=self._local_heads)
Expand All @@ -203,17 +234,34 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
self._debug_log(c, "c", self._BC_DIMS, kwargs)
self._debug_log(dt, "dt", self._XZ_DIMS, kwargs)

y = selective_scan_fn(
x,
dt,
-torch.exp(self.A_log.float()),
b,
c,
self.D.float(),
z,
delta_bias=self.dt_proj_bias.float(),
delta_softplus=True,
)
if not _mamba_varlen:
Assert.eq(cu_seqlens, None, msg="This version of Mamba2 does not support cu_seqlens, install verlen mamba")
y = selective_scan_fn(
x,
dt,
-torch.exp(self.A_log.float()),
b,
c,
self.D.float(),
z,
delta_bias=self.dt_proj_bias.float(),
delta_softplus=True,
)
else:
position_indices = position_indices if cu_seqlens is not None else None

y = selective_scan_fn(
x,
dt,
-torch.exp(self.A_log.float()),
b,
c,
self.D.float(),
z,
delta_bias=self.dt_proj_bias.float(),
delta_softplus=True,
position_indices=position_indices,
)

if self._debug_level:
self._debug_log(y, "y", self._XZ_DIMS, kwargs)
Expand Down
68 changes: 68 additions & 0 deletions fast_llm/layers/ssm/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import typing

import torch

from fast_llm.engine.base_model.config import Preprocessor
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.layers.ssm.config import SSMKwargs
from fast_llm.layers.transformer.config import TransformerKwargs
from fast_llm.models.ssm.config import HybridSSMBaseModelConfig
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


class Mamba2Preprocessor(Preprocessor):
def __init__(self, config: HybridSSMBaseModelConfig, tensor_space: TensorSpace):
self._config = config
self._tensor_space = tensor_space
self._distributed_config = self._tensor_space.distributed_config
self._transformer_dim_names = config.transformer._transformer_dim_names

def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
"""
Simplified preprocessor that does not take into account micro-sequences.
"""
if TransformerKwargs.sequence_lengths not in kwargs:
return
sequence_lengths = kwargs[TransformerKwargs.sequence_lengths]
if TransformerKwargs.cu_seqlens_k in kwargs:
# already set this in the transformer preprocessor, so we can use it here
cu_seqlens_k = kwargs[TransformerKwargs.cu_seqlens_k]
cu_seqlens_q = kwargs[TransformerKwargs.cu_seqlens_q]
Assert.eq(
cu_seqlens_k.shape[0],
cu_seqlens_q.shape[0],
msg="cu_seqlens_k and cu_seqlens_q have different lengths, is micro_sequence_length being used? This is currently not supported for Mamba.",
)
Assert.all_equal(cu_seqlens_k, cu_seqlens_q)
cu_seqlens = cu_seqlens_k
else:
seqlens = torch.cat(sequence_lengths)
cu_seqlens = torch.cat(
(
torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device),
torch.cumsum(seqlens, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device),
)
)
kwargs[SSMKwargs.cu_seqlens] = cu_seqlens
# from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152
kwargs[SSMKwargs.seq_idx] = torch.cat(
[
torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device)
for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1])
],
dim=0,
).unsqueeze(0)

sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths)
sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size
sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size
position_ids = torch.stack(
[torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths]
).to(self._tensor_space.distributed.device, dtype=torch.int64)
position_ids = position_ids[
:, sequence_k - sequence_q : sequence_k
] # this is only needed if we do micro-sequences?
kwargs[SSMKwargs.ssm_position_ids] = position_ids.to(torch.int32)
5 changes: 5 additions & 0 deletions fast_llm/models/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ def get_trainer_class(cls) -> type["HybridSSMTrainer"]:

def _validate(self) -> None:
super()._validate()
Assert.eq(
self.batch.micro_sequence_length,
self.batch.sequence_length,
msg="Micro-sequences not supported for SSMs. at htis point",
)
if (name := self.model.base_model.distillation_model) is None:
Assert.empty(self.reference_models)
else:
Expand Down
2 changes: 2 additions & 0 deletions fast_llm/models/ssm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.layers.language_model.head import LanguageModelHead
from fast_llm.layers.ssm.llamba_block import SSMBlock
from fast_llm.layers.ssm.preprocessing import Mamba2Preprocessor
from fast_llm.layers.transformer.transformer import TransformerBlock
from fast_llm.models.gpt.config import GPTBatchConfig
from fast_llm.models.gpt.model import GPTBaseModel, GPTModel
Expand All @@ -30,6 +31,7 @@ def __init__(
distributed_config: DistributedConfig,
):
super().__init__(config, distributed_config)
self._preprocessors.append(Mamba2Preprocessor(config, self._tensor_space))

def get_output_layers(self) -> list[Layer]:
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ HUGGINGFACE =
# To install on cpu environment (ex. for IDE support):
# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation
SSM =
mamba_ssm[causal-conv1d]==2.2.4
mamba_ssm[causal-conv1d] @ git+https://github.com/jxiw/varlen_mamba.git@varlen_mamba
cartesia_pytorch>=0.0.2

# GENERATION =
Expand Down
Loading