Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
cb86f45
Test all models
jlamypoirier Jun 5, 2025
f8850e4
Parametrized dependencies
jlamypoirier Jun 6, 2025
478ac05
fixes
jlamypoirier Jun 6, 2025
d3b18a1
stuff
jlamypoirier Jun 9, 2025
8c64f03
fix
jlamypoirier Jun 9, 2025
c0f648c
fixes
jlamypoirier Jun 10, 2025
e92c311
stuff
jlamypoirier Jun 11, 2025
b877fb2
stuff
jlamypoirier Jun 11, 2025
907aef0
attempt
jlamypoirier Jun 11, 2025
1340903
attempt
jlamypoirier Jun 11, 2025
8aed0a3
Cleanup tests
jlamypoirier Jun 11, 2025
830a380
fixes
jlamypoirier Jun 11, 2025
13e1da5
fix
jlamypoirier Jun 12, 2025
aa0e821
Merge remote-tracking branch 'origin/main' into update_base_image
jlamypoirier Jun 12, 2025
45bb0ff
Merge remote-tracking branch 'origin/main' into update_base_image
jlamypoirier Jun 12, 2025
c467b63
Merge branch 'update_base_image' into test_all_models
jlamypoirier Jun 12, 2025
0dffe5c
fixes
jlamypoirier Jun 12, 2025
dcc5064
fixes
jlamypoirier Jun 12, 2025
9d415bc
fixes
jlamypoirier Jun 12, 2025
a6cce17
Merge branch 'update_base_image' into test_all_models
jlamypoirier Jun 12, 2025
68251c2
fixes
jlamypoirier Jun 12, 2025
68333ef
Merge remote-tracking branch 'origin/main' into test_all_models
jlamypoirier Jun 12, 2025
639d6c2
doc
jlamypoirier Jun 12, 2025
7465428
stuff
jlamypoirier Jun 12, 2025
ced34e0
stuff
jlamypoirier Jun 12, 2025
b328f07
stuff
jlamypoirier Jun 12, 2025
7ed804b
stuff
jlamypoirier Jun 12, 2025
890ad75
Merge branch 'improve_testing' into test_all_models
jlamypoirier Jun 12, 2025
6f00035
stuff
jlamypoirier Jun 12, 2025
e45ff6a
stuff
jlamypoirier Jun 12, 2025
8b16be2
Merge branch 'improve_testing' into test_all_models
jlamypoirier Jun 12, 2025
68db703
Merge branch 'main' into improve_testing
jlamypoirier Jun 12, 2025
e8615c2
Merge branch 'improve_testing' into test_all_models
jlamypoirier Jun 12, 2025
67d3c92
fix
jlamypoirier Jun 12, 2025
c2ae03d
fix
jlamypoirier Jun 13, 2025
31da2a8
misc
jlamypoirier Jun 13, 2025
c2ee8fe
stuff
jlamypoirier Jun 13, 2025
6c775e4
stuff
jlamypoirier Jun 13, 2025
4ba584b
Merge branch 'model_testing_configs' into test_all_models
jlamypoirier Jun 13, 2025
d41e0d5
misc
jlamypoirier Jun 13, 2025
59582c3
misc
jlamypoirier Jun 13, 2025
c0ca0b9
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 13, 2025
8ecf81e
fix
jlamypoirier Jun 13, 2025
2c009a8
Merge branch 'model_testing_configs' into test_all_models
jlamypoirier Jun 13, 2025
c5b29e2
Revert "misc"
jlamypoirier Jun 13, 2025
4071b70
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 13, 2025
bfa8d00
Merge branch 'model_testing_configs' into test_all_models
jlamypoirier Jun 13, 2025
edced8c
Cleanup tests
jlamypoirier Jun 13, 2025
9b904ad
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 13, 2025
58677d2
fix
jlamypoirier Jun 13, 2025
8d48d1f
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 13, 2025
c0b5e8e
Merge branch 'model_testing_configs' into cleanup_tests
jlamypoirier Jun 13, 2025
4171f27
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 13, 2025
e125fa9
move to directory
jlamypoirier Jun 13, 2025
d61445a
Merge remote-tracking branch 'origin/main' into improve_testing
jlamypoirier Jun 16, 2025
9c5883e
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 16, 2025
5a928f0
Merge branch 'model_testing_configs' into cleanup_tests
jlamypoirier Jun 16, 2025
7dc7f53
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 16, 2025
d164f25
fixes
jlamypoirier Jun 16, 2025
0889d2f
Merge branch 'main' into improve_testing
jlamypoirier Jun 16, 2025
006e1ff
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 16, 2025
8dc3abe
Merge remote-tracking branch 'origin/main' into model_testing_configs
jlamypoirier Jun 16, 2025
7a04c6a
Merge branch 'model_testing_configs' into cleanup_tests
jlamypoirier Jun 16, 2025
7eb4c5d
Merge remote-tracking branch 'origin/main' into cleanup_tests
jlamypoirier Jun 16, 2025
645eeb1
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 16, 2025
9179127
fix
jlamypoirier Jun 16, 2025
d97e4c1
fix
jlamypoirier Jun 16, 2025
c95e8eb
Fix dropless mlp
jlamypoirier Jun 17, 2025
c4a34f0
Merge branch 'main' into cleanup_tests
jlamypoirier Jun 17, 2025
bdf37ca
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 17, 2025
8667b9d
Merge branch 'test_all_models' into fix_dropless_mlp
jlamypoirier Jun 17, 2025
468ed7e
fix
jlamypoirier Jun 17, 2025
eb734bd
fix
jlamypoirier Jun 17, 2025
141ab00
Merge branch 'test_all_models' into fix_dropless_mlp
jlamypoirier Jun 17, 2025
4f74237
Merge branch 'main' into test_all_models
jlamypoirier Jun 19, 2025
58d4275
Merge branch 'test_all_models' into fix_dropless_mlp
jlamypoirier Jun 19, 2025
c338d44
fixes
jlamypoirier Jun 19, 2025
cc806ef
Merge branch 'main' into cleanup_tests
jlamypoirier Jun 19, 2025
9cba39b
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 19, 2025
33d5595
Merge branch 'test_all_models' into fix_dropless_mlp
jlamypoirier Jun 19, 2025
ef7ab29
Merge remote-tracking branch 'origin/main' into fix_dropless_mlp
jlamypoirier Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TritonConfig:
MAX_BLOCK_SIZE_BYTES = 65536


class MLPRecomputeLevel(str, enum.Enum):
class MLPRecomputeLevel(enum.StrEnum):
none = "none"
activation = "activation"
activation_and_input = "activation_and_input"
Expand Down
5 changes: 5 additions & 0 deletions fast_llm/functional/triton/sparse_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
@dataclasses.dataclass()
class SparseMap:
sparse_rows: torch.Tensor
# The end row for each expert, including padding. `expert_ends[i] = expert_begins[i] + padded_tokens_per_expert[i]`
expert_ends: torch.Tensor
# The end row for each expert, excluding padding. `expert_pad_begins[i] = expert_begins[i] + unpadded_tokens_per_expert[i]`
expert_pad_begins: torch.Tensor
# The number of rows un the dense tensor, i.e., the number of tokens.
num_rows_dense: int
# The number of sparse rows, including padding. `num_rows = expert_ends[-1]`
num_rows: int
# The number of sparse rows, excluding padding. `num_rows_unpadded = num_rows_dense * num_experts_per_token`
num_rows_unpadded: int
num_experts: int
num_experts_per_token: int
Expand Down
27 changes: 16 additions & 11 deletions fast_llm/functional/triton/sparse_linear.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os

import torch

from fast_llm.functional.triton import TritonConfig, tl, tl_constexpr, triton, triton_autotune, triton_jit
from fast_llm.functional.triton.sparse_copy import SparseMap
from fast_llm.utils import Assert, div

autotune_configs = [
autotune_configs = (
TritonConfig(
{"block_size_row": 128, "block_size_col": 256, "block_size_inner": 64, "group_size_row": 8},
num_stages=3,
Expand Down Expand Up @@ -45,7 +47,10 @@
num_stages=5,
num_warps=2,
),
]
)

if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"):
autotune_configs = (autotune_configs[2],)


@triton_autotune(
Expand Down Expand Up @@ -255,13 +260,13 @@ def output_sparse_matmul_kernel(
def output_sparse_matmul(
lhs: torch.Tensor,
rhs: torch.Tensor,
sparse_map: SparseMap | None,
sparse_map: SparseMap | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
) -> torch.Tensor:
"""
Output-sparse matrix multiplication with a sparse column dimension,
i.e., with a mapping row_index -> sparse_index (obtained from expert_ends).
Output-sparse matrix multiplication with a sparse column dimension
and a mapping row_index -> sparse_index (obtained from expert_ends).
Ex.: MLP layer 1 forward (Y = X x W1^T), MLP layer 2 input grad (gY = gZ x W2).
Formula: out[i, js] = sum_k(lhs[i, k] * rhs[k, jd]), where jd = js + col_sparse_dim * sparse_index[i]
sparse_index[i] = sum(expert_ends <= i)
Expand Down Expand Up @@ -381,13 +386,13 @@ def input_inner_sparse_matmul_kernel(
def input_inner_sparse_matmul(
lhs: torch.Tensor,
rhs: torch.Tensor,
sparse_map: SparseMap | None,
sparse_map: SparseMap | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
) -> torch.Tensor:
"""
Left-input-sparse matrix multiplication with a sparse inner dimension,
i.e., with a mapping row_index -> sparse_index (obtained from expert_ends).
Left-input-sparse matrix multiplication with a sparse inner dimension
and a mapping row_index -> sparse_index (obtained from expert_ends).
Ex.: MLP layer 2 forward (Z = Y x W2^T), MLP layer 1 input grad (gX = gY x W1).
Formula: out[i, j] = sum_ks(lhs[i, ks] * rhs[kd, j]), where kd = ks + inner_sparse_dim * sparse_index[i]
sparse_index[i] = sum(expert_ends <= i)
Expand Down Expand Up @@ -511,13 +516,13 @@ def input_row_sparse_matmul_kernel(
def input_row_sparse_matmul(
lhs: torch.Tensor,
rhs: torch.Tensor,
sparse_map: SparseMap | None,
sparse_map: SparseMap | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
) -> torch.Tensor:
"""
Left-input-sparse matrix multiplication with a sparse row dimension,
i.e., with a mapping inner_index -> sparse_index.
Left-input-sparse matrix multiplication with a sparse row dimension
and a mapping inner_index -> sparse_index.
Ex.: MLP layer 1 weight grad (gW1 = gY^T x X), MLP layer 2 weight grad (gW2^T = Y^T x gZ).
Formula: out[id, j] = sum_ks(lhs[is, ks] * rhs[ks, j]), where
sparse_begin[sparse_index[id]] <= ks < sparse_end[sparse_index[id]],
Expand Down
5 changes: 2 additions & 3 deletions fast_llm/models/ssm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

from fast_llm.engine.base_model.base_model import Layer
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.layers.language_model.embedding import LanguageModelEmbedding
from fast_llm.layers.language_model.head import LanguageModelHead
from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2
from fast_llm.layers.ssm.llamba_block import LlambaBlock
from fast_llm.layers.ssm.mamba_layer import MambaLayer
from fast_llm.layers.transformer.transformer import TransformerLayer
from fast_llm.models.gpt.model import GPTBaseModel
from fast_llm.models.gpt.model import GPTBaseModel, GPTModel
from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -135,7 +134,7 @@ def get_layers(self) -> list[Layer]:
return layers


class HybridSSMModel[ConfigType: HybridSSMModelConfig](FastLLMModel[ConfigType]):
class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]):
"""
A hybrid model that combines Transformer and SSM blocks.
"""
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ CORE =
safetensors>=0.5.3
# Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation
flash-attn==2.7.3
# Dropless MLP is broken with triton 3.2.0, 3.3.0 and 3.3.1. TODO: Remove once a working triton version is released.
triton==3.1.0


# Small packages required for some optional features and tools.
Expand Down Expand Up @@ -57,7 +59,7 @@ DEV =
pytest-xdist>=3.7.0
# Somehow needed for Megatron to work with base image 24.11
setuptools>=80.9.0
# dependency manager needs it.
# Dependency manager needs colorama to show colors.
colorama>=0.4.6

# Required for building the documentation
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def pytest_configure(config):
rendezvous_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id + 1,
)

# Skip slow autotune for tests. The default config has the highest block size, so this shouldn't hide any bug.
os.environ["FAST_LLM_SKIP_TRITON_AUTOTUNE"] = "TRUE"


@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(config, items: list[pytest.Function]):
Expand Down
Empty file added tests/functional/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,6 @@ def test_mlp_recomputation(gated, activation_type):
@pytest.mark.slow
@requires_cuda
def test_dropless_mlp():
# TODO: Fix dropless MOE
pytest.fail("Test fails, aborting to avoid breaking cuda", False)
num_experts = 4
experts_per_token = 4
tokens = 256
Expand Down Expand Up @@ -273,7 +271,7 @@ def test_dropless_mlp():
sparse_map = get_sparse_map(top_experts, num_experts)

for i, recompute_level in enumerate(MLPRecomputeLevel):
print(recompute_level.value) # noqa
print("recompute_level", recompute_level) # noqa
input_.grad = None
scores.grad = None
for param in params:
Expand Down
154 changes: 154 additions & 0 deletions tests/functional/test_sparse_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import dataclasses
import functools

import pytest
import torch

from fast_llm.functional.triton.sparse_copy import SparseMap
from fast_llm.functional.triton.sparse_linear import (
dense_matmul,
input_inner_sparse_matmul,
input_row_sparse_matmul,
output_sparse_matmul,
)
from fast_llm.utils import Assert
from tests.utils.utils import requires_cuda


@dataclasses.dataclass
class _SparseTestData:
dense_dim: int
sparse_dim: int
expert_ends: tuple[int, ...]
tokens_per_expert: tuple[int, ...]
std: float = 0.125

@functools.cached_property
def expert_begins(self) -> tuple[int, ...]:
return (0,) + self.expert_ends[:-1]

@functools.cached_property
def expert_pad_begins(self) -> tuple[int, ...]:
return tuple(
expert_begin + expert_tokens
for expert_begin, expert_tokens in zip(self.expert_begins, self.tokens_per_expert, strict=True)
)

@functools.cached_property
def token_dim(self) -> int:
return self.expert_ends[-1]

@property
def sparse_dim_expanded(self) -> int:
return self.sparse_dim * self.num_experts

@functools.cached_property
def num_experts(self) -> int:
return len(self.expert_begins)

@functools.cached_property
def sparse_map(self) -> SparseMap:
return SparseMap(
num_experts=self.num_experts,
expert_ends=torch.tensor(self.expert_ends, device="cuda"),
expert_pad_begins=torch.tensor(self.expert_pad_begins, device="cuda"),
num_rows=self.expert_ends[-1],
# Not needed
sparse_rows=None,
num_rows_dense=None,
num_rows_unpadded=None,
num_experts_per_token=None,
)

def normal(self, dim_0: int, dim_1: int) -> torch.Tensor:
return torch.normal(0, self.std, (dim_0, dim_1), device="cuda")


_SPARSE_TEST_DATAS = (
_SparseTestData(
dense_dim=384,
sparse_dim=256,
expert_ends=(128, 384, 512),
tokens_per_expert=(78, 256, 54),
),
_SparseTestData(
dense_dim=256,
sparse_dim=512,
expert_ends=(128, 256, 256, 384),
tokens_per_expert=(52, 125, 0, 97),
),
)


@requires_cuda
@pytest.mark.slow
@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS)
def test_dense_matmul(sparse_test_data):
lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim)
rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim)

output = dense_matmul(lhs, rhs)
output_ref = torch.matmul(lhs, rhs)
Assert.rms_close(output, output_ref, 1e-3)


@requires_cuda
@pytest.mark.slow
@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS)
def test_output_sparse_matmul(sparse_test_data):
lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim)
rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded)

# Randomly initialize the output to ensure padded values have no effect.
out = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim)
output = output_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map, out)

output_ref = torch.zeros_like(output)
for i in range(sparse_test_data.num_experts):
# Padded tokens are treated like regular ones.
output_ref[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]] = torch.matmul(
lhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]],
rhs[:, i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim],
)

Assert.rms_close(output, output_ref, 1e-3)


@requires_cuda
@pytest.mark.slow
@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS)
def test_input_inner_sparse_matmul(sparse_test_data):
lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim)
rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim)

output = input_inner_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map)

output_ref = torch.zeros_like(output)
for i in range(sparse_test_data.num_experts):
# Padded tokens are treated like regular ones.
output_ref[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]] = torch.matmul(
lhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]],
rhs[i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim],
)

Assert.rms_close(output, output_ref, 1e-3)


@requires_cuda
@pytest.mark.slow
@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS)
def test_input_row_sparse_matmul(sparse_test_data):
lhs = sparse_test_data.normal(sparse_test_data.sparse_dim, sparse_test_data.token_dim)
rhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim)

output = input_row_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map)

output_ref = torch.zeros_like(output)
for i in range(sparse_test_data.num_experts):
# Padded tokens are excluded from the sum.
output_ref[i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim] = torch.matmul(
lhs[:, sparse_test_data.expert_begins[i] : sparse_test_data.expert_pad_begins[i]],
rhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_pad_begins[i]],
)

Assert.rms_close(output, output_ref, 1e-3)
13 changes: 9 additions & 4 deletions tests/models/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_confi
+ [
"training.checkpoint.interval=1",
"training.evaluators.validation.interval=2",
"training.evaluators.validation.evaluators.iterations=1",
"training.evaluators.validation.evaluator.iterations=1",
],
)

Expand Down Expand Up @@ -63,7 +63,7 @@ def test_resume(run_test_script_for_all_models):
[
"training.checkpoint.interval=1",
"training.evaluators.validation.interval=2",
"training.evaluators.validation.evaluators.iterations=1",
"training.evaluators.validation.evaluator.iterations=1",
],
compare=f"test_checkpoint_and_eval",
prepare_fn=_prepare_resume_fn,
Expand All @@ -79,7 +79,7 @@ def test_resume_frozen(run_test_script_for_all_models):
[
"training.checkpoint.interval=1",
"training.evaluators.validation.interval=2",
"training.evaluators.validation.evaluators.iterations=1",
"training.evaluators.validation.evaluator.iterations=1",
"model.base_model.transformer.mlp_lr_scale=0.",
],
compare="test_checkpoint_and_eval",
Expand Down Expand Up @@ -442,7 +442,12 @@ def test_run_converted_model(model_testing_config, convert_paths):
)
errors = []
compare = CompareConfig()
model_as_hf = transformers.AutoModel.from_pretrained(
auto_model = (
transformers.AutoModel
if model_testing_config.name in ("diffusion_llama", "dream")
else transformers.AutoModelForCausalLM
)
model_as_hf = auto_model.from_pretrained(
convert_paths["huggingface_0"], trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code
).cuda()
for name, model in zip(
Expand Down
Loading