Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ jobs:
- run: |
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand All @@ -56,6 +58,8 @@ jobs:
- run: |
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
- name: Publish the documentation
run: mkdocs gh-deploy --force --dirty
18 changes: 15 additions & 3 deletions fast_llm/engine/checkpoint/safe_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,21 @@ def _check_parameters(self, errors: list[str]) -> None:
elif counter is not None and counter > 0:
errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"')
if self._distributed.world_group is not None:
counter_tensor = torch.tensor(
[counter or 0 for counter in counter_per_parameter.values()], dtype=torch.int64
).to(self._distributed.device)
counter_list = []
for parameter_name, counter in counter_per_parameter.items():
parameter_stage = self._model.get_parameter_stage(parameter_name)
parameter_meta = parameter_stage.get_parameter_meta(parameter_name)
if (
counter is None
or (not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0)
or parameter_stage.is_tied_weight_copy
):
# Ignore the counter from missing or duplicate tensors.
counter = 0
counter_list.append(counter)

counter_tensor = torch.tensor(counter_list, dtype=torch.int64).to(self._distributed.device)

add_ephemeral_timeout(self._distributed.world_group, self._timeout)
all_reduce(counter_tensor, group=self._distributed.world_group)
counter_per_parameter = {
Expand Down
5 changes: 2 additions & 3 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,8 @@ def _validate(self) -> None:
self.batch_data_rank = self.data_rank // self.sequence_data_parallel

self.tensor_rank = self.rank % self.tensor_parallel
if self.tensor_parallel == 1:
with self._set_implicit_default():
self.sequence_tensor_parallel = False
if self.tensor_parallel == 1 and self.sequence_tensor_parallel:
self.sequence_tensor_parallel = False

if self.reference_config is not None:
self.reference_config.validate()
Expand Down
5 changes: 2 additions & 3 deletions fast_llm/engine/multi_stage/fast_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def initialize_weights(self, timeout: float | None = None) -> None:
stage.initialize_weights()
for name, tied_parameter in self._tied_parameters.items():
if tied_parameter.group is not None:
broadcast(
self._stages[tied_parameter.main_stage].weight_shard, 0, tied_parameter.group, timeout=timeout
)
for fsdp in self._stages[tied_parameter.main_stage].fsdps:
broadcast(fsdp.weight_shard, 0, tied_parameter.group, timeout=timeout)
self._finalize_load(reset_optimizer=True)

def _finalize_load(self, reset_optimizer: bool = True) -> None:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fast_llm.engine.distributed.config import DistributedDim
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode
from fast_llm.functional.triton.pointwise import triton_add, triton_copy
from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill
from fast_llm.logging import log_distributed_tensor
from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta
from fast_llm.utils import Assert, clamp, padded_cumsum
Expand Down
2 changes: 0 additions & 2 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,13 @@ def import_state_tensor(
"""
Given a global parameter tensor, set the associated slice of a local parameter shard.
Return the size of the local slice.
TODO: Doesn't work
"""
fsdp_index = self._fsdp_index[parameter_name]
return self._fsdps[fsdp_index].import_state_tensor(parameter_name, shards[fsdp_index], tensor)

def _export_shard(
self, shards: tuple[torch.Tensor], data_type: DataType | None = None
) -> typing.Generator[tuple[str, torch.Tensor], None, None]:
# TODO: Doesn't work
for fsdp, shard in zip(self._fsdps, shards, strict=True):
yield from fsdp.export_shard(shard, self._distributed, data_type)

Expand Down
1 change: 1 addition & 0 deletions tests/test_ssms.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_load_from_llamba_checkpoint(distributed_config):
assert torch.allclose(logits, hf_logits, atol=1e-2)


# TODO: Speed up this test or bring it back as an integration test.
@pytest.mark.skip(reason="Too slow.")
@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed")
@pytest.mark.parametrize(
Expand Down