Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
cfee9c9
HF Trainer: ALST/Ulysses sequence parallelism integration via HF Acce…
sfc-gh-sbekman Oct 23, 2025
6e28ca8
make it work + tests
sfc-gh-sbekman Oct 28, 2025
86a09b9
cleanup
sfc-gh-sbekman Oct 28, 2025
bb902f9
Merge branch 'main' into alst-integration
stas00 Oct 28, 2025
c0e8e0d
undo
sfc-gh-sbekman Oct 28, 2025
101eaff
normalize
sfc-gh-sbekman Nov 5, 2025
d8770d5
always return cp_size
sfc-gh-sbekman Nov 5, 2025
4f416a4
cleanup
sfc-gh-sbekman Nov 5, 2025
ce5e392
extract code into _deepspeed_cp_compute_loss
sfc-gh-sbekman Nov 5, 2025
3ceaa94
fix
sfc-gh-sbekman Nov 5, 2025
607e166
Merge branch 'main' into alst-integration
stas00 Nov 5, 2025
211b6df
ALST/Ulysses sequence parallelism docs
kashif Nov 9, 2025
34b208c
typo
kashif Nov 10, 2025
816cc96
add link to UlyssesSPDataLoaderAdapter
kashif Nov 10, 2025
b3cbfb1
Merge pull request #3 from kashif/alst-doc
stas00 Nov 10, 2025
674db46
Merge remote-tracking branch 'origin/main' into alst-integration
sfc-gh-sbekman Nov 17, 2025
b12249a
adapt to renaming to SP
sfc-gh-sbekman Nov 17, 2025
4be7619
improve
sfc-gh-sbekman Nov 17, 2025
21ec5e5
fix
sfc-gh-sbekman Nov 17, 2025
bc32a16
Update docs/source/en/deepspeed.md
stas00 Nov 17, 2025
a850a3a
Merge branch 'main' into alst-integration
stas00 Nov 18, 2025
0127933
address comments
sfc-gh-sbekman Nov 18, 2025
a50c89c
Merge branch 'alst-integration' of https://github.com/stas00/transfor…
sfc-gh-sbekman Nov 18, 2025
5e29dd9
address comments
sfc-gh-sbekman Nov 18, 2025
6ce745d
Update src/transformers/trainer.py
stas00 Nov 18, 2025
59972a3
address comments
sfc-gh-sbekman Nov 18, 2025
f554277
address comments
sfc-gh-sbekman Nov 18, 2025
0eef76f
Update src/transformers/trainer.py
stas00 Nov 18, 2025
c277586
Merge branch 'main' into alst-integration
stas00 Nov 18, 2025
8201150
Update src/transformers/trainer.py
stas00 Nov 18, 2025
854fd51
style
sfc-gh-sbekman Nov 18, 2025
76ee3ad
Update docs/source/en/deepspeed.md
stas00 Nov 19, 2025
083ca01
Update docs/source/en/deepspeed.md
stas00 Nov 19, 2025
6929fb2
Account for Sequence Parallelism (SP) dataloader adapter effect
kashif Nov 19, 2025
6ae2bb0
Update src/transformers/trainer.py
stas00 Nov 19, 2025
407f34a
Update docs/source/en/deepspeed.md
stas00 Nov 19, 2025
363909b
Update docs/source/en/deepspeed.md
stas00 Nov 19, 2025
8f62f14
Merge branch 'main' into alst-integration
stas00 Nov 19, 2025
6c5b00c
Merge pull request #4 from kashif/sp_len
stas00 Nov 19, 2025
49c5ed7
model_accepts_loss_kwargs to False
kashif Nov 19, 2025
4cafb9b
better comment
kashif Nov 19, 2025
7c1abd5
Merge pull request #5 from kashif/loss_kwargs
stas00 Nov 19, 2025
58e4e13
Apply suggestion from @kashif
kashif Nov 19, 2025
d8d53c2
Apply suggestion from @kashif
kashif Nov 19, 2025
a05eb52
Apply suggestions from code review
kashif Nov 19, 2025
ad61079
Merge branch 'main' into alst-integration
kashif Nov 20, 2025
3fd097d
Apply suggestion from @kashif
kashif Nov 20, 2025
e3d8eda
Apply suggestion from @kashif
kashif Nov 20, 2025
ef59f3e
Apply suggestion from @kashif
kashif Nov 20, 2025
59487a8
Update src/transformers/trainer.py
kashif Nov 20, 2025
2444728
Update src/transformers/training_args.py
kashif Nov 20, 2025
4f33c2f
Merge branch 'main' into alst-integration
kashif Nov 21, 2025
7d09b28
Apply suggestion from @kashif
kashif Nov 21, 2025
2e52913
Apply suggestion from @kashif
kashif Nov 21, 2025
7a5c45e
Merge branch 'main' into alst-integration
SunMarc Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions docs/source/en/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,108 @@ The example ZeRO-3 and ZeRO-Infinity config below sets most of the parameter val
}
```

### Sequence Parallelism

DeepSpeed's ALST/Ulysses sequence parallelism enables training with very long sequences by splitting the sequence across multiple GPUs. This is particularly useful for training large language models with very long sequence lengths.

Arctic Long Sequence Training (ALST) uses a combination of sharding inputs along the sequence dimension and attention head parallelism. With this approach, you can train models with sequence lengths up to 500K tokens on a single H100 GPU, 3.7M on a single node, or 15M tokens on just four nodes with Llama-8B. The implementation described here enables one component of the full ALST system. For additional optimizations like TiledMLP and activation checkpoint offloading, refer to the [DeepSpeed ALST tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).

> [!TIP]
> For more detailed information about sequence parallelism, see the Accelerate [Sequence Parallelism](https://huggingface.co/docs/accelerate/concept_guides/sequence_parallelism) guide.

To enable ALST/Ulysses sequence parallelism with [`Trainer`], configure `parallelism_config` in [`TrainingArguments`]. Sequence parallelism is configured via Accelerate's `ParallelismConfig` and requires an Accelerate version higher than 1.12.0.

```py
from accelerate.utils import ParallelismConfig, DeepSpeedSequenceParallelConfig

# Example: 4 GPUs with sp_size=4, dp_replicate_size=1 (no data parallelism)
# Ensure total_size = dp_replicate_size * dp_shard_size * sp_size = 1 * 1 * 4 = 4 GPUs
parallelism_config = ParallelismConfig(
sp_backend="deepspeed",
sp_size=4, # Number of GPUs to split sequence across
dp_replicate_size=1, # Explicit: no data parallelism
sp_handler=DeepSpeedSequenceParallelConfig(
sp_seq_length_is_variable=True,
sp_attn_implementation="sdpa",
),
)

training_args = TrainingArguments(
...,
deepspeed="path/to/deepspeed_config.json",
parallelism_config=parallelism_config,
)
```

You can also configure sequence parallelism using an Accelerate config file.

```yaml
distributed_type: DEEPSPEED
deepspeed_config:
deepspeed_config_file: path/to/ds_config.json
machine_rank: 0
num_machines: 1
num_processes: 4 # Total number of processes
parallelism_config:
parallelism_config_sp_size: 4 # Sequence parallel size
parallelism_config_dp_replicate_size: 1 # Must be: dp_replicate_size * dp_shard_size * sp_size = num_processes
parallelism_config_sp_backend: deepspeed
parallelism_config_sp_seq_length_is_variable: true
parallelism_config_sp_attn_implementation: sdpa
```

Important configuration parameters include the following.

* `sp_backend` must be set to `"deepspeed"` to use ALST/Ulysses sequence parallelism.
* `sp_size` is the degree of sequence parallelism. For example, `sp_size=4` means 4 GPUs will process a single sequence in parallel. You need at least 2 GPUs to enable sequence parallelism. **Data feeding**: Each rank receives a unique data stream from the DataLoader (like DP). **Batch size calculation**: The effective `dp_world_size = world_size / sp_size`. So with 4 GPUs and `sp_size=4`, each of the 4 ranks gets different samples from the DataLoader, but `dp_world_size=1` for total batch size calculations
* `sp_seq_length_is_variable` determines how sequence lengths are handled. When set to `True` (recommended), the implementation adapts to varying sequence lengths between batches. When `False`, all sequences must be padded to a fixed length specified by `sp_seq_length`.
* `sp_attn_implementation` specifies the attention implementation to use. Supported values are `"sdpa"`, `"flash_attention_2"`, or `"flash_attention_3"`. Flash Attention is recommended for best performance, especially with multiple samples in a batch, because SDPA may incorrectly attend across sample boundaries.

> [!WARNING]
> Sequence parallelism requires your model to use one of the supported attention implementations (`sdpa`, `flash_attention_2`, or `flash_attention_3`). The `eager` attention implementation is not supported because it doesn't properly handle `position_ids`.

When using sequence parallelism, ensure your sequences are properly padded. Use `pad_to_multiple_of` in your data collator to ensure sequences are divisible by `sp_size`. For example, with `sp_size=4`, set `pad_to_multiple_of=4` or higher.

```py
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=4, # Ensure sequences are divisible by sp_size
)
```

When using `sp_size` with multiple GPUs, you **must** explicitly set `dp_replicate_size` or `dp_shard_size` to ensure `total_size = dp_replicate_size * dp_shard_size * sp_size` equals your total number of GPUs. For example, with 8 GPUs and `sp_size=4`, you must set `dp_replicate_size=2` (since 2 × 1 × 4 = 8):

```py
parallelism_config = ParallelismConfig(
sp_backend="deepspeed",
sp_size=4,
dp_replicate_size=2,
sp_handler=DeepSpeedSequenceParallelConfig(
sp_seq_length_is_variable=True,
sp_attn_implementation="flash_attention_2",
),
)
```

[`Trainer`] automatically handles the special requirements for sequence parallelism including:

* Adapting the data loader via DeepSpeed's [`UlyssesSPDataLoaderAdapter`](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/sequence_parallel/ulysses_sp.py) to shard sequences across GPUs. **Important**: Unlike Tensor Parallelism (where all ranks must receive identical data), each rank with SP receives a unique data stream from the DataLoader (similar to DP). The adapter handles distributing sequence chunks across SP ranks internally, so your DataLoader should continue feeding different samples to each rank.
* Generating `position_ids` when not provided
* Creating `shift_labels` for causal language modeling
* Aggregating loss across sequence parallel ranks with proper masking for `-100` labels

You can launch training with sequence parallelism using the `accelerate launch` command.

```bash
accelerate launch --config_file alst_config.yaml your_training_script.py \
--output_dir output_dir \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1
```

## Training features

DeepSpeed supports many training features that can be configured in the config file. This section describes some of the most important features.
Expand Down
19 changes: 15 additions & 4 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import gc
import importlib
import inspect
import json
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -2005,14 +2006,12 @@ def get_env(self):
paths = [self.repo_root_dir_str, self.src_dir_str]
if "/examples" in self.test_file_dir_str:
paths.append(self.examples_dir_str)
else:
paths.append(self.tests_dir_str)
paths.append(env.get("PYTHONPATH", ""))

env["PYTHONPATH"] = ":".join(paths)
return env

def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False):
Copy link
Contributor Author

@stas00 stas00 Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a really old version. In the latest incarnation it always return a Path object. But to keep BC, I added a new flag here instead. The tests are less clunkier then.

The latest version is here: https://github.com/stas00/ml-engineering/blob/master/testing/testing_utils.py

If wanted you could switch to the latest version instead and adapt tests to simplify.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's much better for it to always return a pathlib.Path object but you'd need to tweak a few tests which use this API.

"""
Args:
tmp_dir (`string`, *optional*):
Expand All @@ -2032,6 +2031,8 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
after (`bool`, *optional*):
If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
intact at the end of the test.
return_pathlib_obj (`bool`, *optional*):
If `True` will return a pathlib.Path object

Returns:
tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
Expand Down Expand Up @@ -2078,7 +2079,7 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
# register for deletion
self.teardown_tmp_dirs.append(tmp_dir)

return tmp_dir
return Path(tmp_dir).resolve() if return_pathlib_obj else tmp_dir

def python_one_liner_max_rss(self, one_liner_str):
"""
Expand Down Expand Up @@ -4076,3 +4077,13 @@ def use_one_line_repr(obj):
cache[(id(obj), indent, mode, prefix)] = output

return output


def write_file(file, content):
with open(file, "w") as f:
f.write(content)


def read_json_file(file):
with open(file, "r") as fh:
return json.load(fh)
137 changes: 121 additions & 16 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,11 @@ def __init__(
k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
)

# Override for Sequence Parallelism: SP computes its own good_tokens count, so skip num_items_in_batch calculation
pc = getattr(self.accelerator, "parallelism_config", None)
if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
self.model_accepts_loss_kwargs = False

self.neftune_noise_alpha = args.neftune_noise_alpha

self.compute_metrics = compute_metrics
Expand Down Expand Up @@ -2163,6 +2168,22 @@ def train(
ignore_keys_for_eval=ignore_keys_for_eval,
)

def get_sp_size(self) -> int:
"""Get the sequence parallel size"""
if getattr(self.accelerator, "parallelism_config", None) is None:
return 1
else:
pc = self.accelerator.parallelism_config
return pc.sp_size

def get_cp_size(self) -> int:
"""Get the context parallel size"""
if getattr(self.accelerator, "parallelism_config", None) is None:
return 1
else:
pc = self.accelerator.parallelism_config
return pc.cp_size

def get_tp_size(self) -> int:
"""Get the tensor parallel size from either the model or DeepSpeed config."""

Expand All @@ -2180,8 +2201,19 @@ def get_tp_size(self) -> int:
def get_total_train_batch_size(self, args) -> int:
"""Calculates total batch size (micro_batch * grad_accum * dp_world_size).

Note: Only considers DP and TP (dp_world_size = world_size // tp_size)."""
dp_world_size = args.world_size // self.get_tp_size()
Accounts for all parallelism dimensions: TP, CP, and SP.

Formula: dp_world_size = world_size // (tp_size * cp_size * sp_size)

Where:
- TP (Tensor Parallelism): Model layers split across GPUs
- CP (Context Parallelism): Sequences split using Ring Attention (FSDP2)
- SP (Sequence Parallelism): Sequences split using ALST/Ulysses (DeepSpeed)

All dimensions are separate and multiplicative: world_size = dp_size * tp_size * cp_size * sp_size
"""

dp_world_size = args.world_size // self.get_tp_size() // self.get_cp_size() // self.get_sp_size()
return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size

def _inner_training_loop(
Expand Down Expand Up @@ -2305,6 +2337,11 @@ def _inner_training_loop(
else:
self.optimizer = self.accelerator.prepare(self.optimizer)

# since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared
pc = getattr(self.accelerator, "parallelism_config", None)
if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model)

if self.is_fsdp_enabled:
self.model = self.model_wrapped = model

Expand Down Expand Up @@ -3639,23 +3676,30 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.
getattr(self.accelerator, "parallelism_config", None) is not None
and self.accelerator.parallelism_config.cp_enabled
):
if hasattr(model, "config"):
if model.config._attn_implementation != "sdpa":
raise ValueError(
f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
)
if self.accelerator.parallelism_config.cp_backend == "torch":
if hasattr(model, "config"):
if model.config._attn_implementation != "sdpa":
raise ValueError(
f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
)

if "shift_labels" not in inputs:
logger.warning_once("Shift labels not found in the inputs, shifting manually")
if "labels" in inputs:
_ignore_index = -100
labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
inputs["shift_labels"] = labels[:, 1:].contiguous()

# note: we don't do anything for accelerator.parallelism_config.sp_backend == "deepspeed" since:
# - accelerator.parallelism_config performs the `model.config._attn_implementation` checks already and it supports more than `dspa`
# - UlyssesSPDataLoaderAdapter called from Accelerate performs the `shift_label` creation - must not interfere
# - position_ids generation should be done by HF Trainer if it wasn't done by the user

if "position_ids" not in inputs:
logger.warning_once("Position IDs not found in the inputs, generating manually")
inputs["position_ids"] = torch.arange(
inputs["input_ids"].size(1), device=inputs["input_ids"].device
).expand(inputs["input_ids"].size(0), -1)
if "shift_labels" not in inputs:
logger.warning_once("Shift labels not found in the inputs, shifting manually")
if "labels" in inputs:
_ignore_index = -100
labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
inputs["shift_labels"] = labels[:, 1:].contiguous()

buffers = []
buffer_seq_dims = []
Expand Down Expand Up @@ -3824,6 +3868,10 @@ def compute_loss(
Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation.
"""
pc = getattr(self.accelerator, "parallelism_config", None)
if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
return self._deepspeed_sp_compute_loss(model, inputs, return_outputs, pc)

if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels")
else:
Expand Down Expand Up @@ -3877,6 +3925,55 @@ def compute_loss(

return (loss, outputs) if return_outputs else loss

def _deepspeed_sp_compute_loss(self, model, inputs, return_outputs, pc):
"""
How the loss is computed by Trainer under sequence parallelism with sp_backend=="deepspeed" and sp_size>1.
Performs weighted loss aggregation across SP ranks, accounting for varying numbers of valid tokens per rank
(e.g., when some ranks receive only padding or prompt tokens that are masked with -100).

Args:
model (`nn.Module`):
The model to compute the loss for.
inputs (`dict[str, Union[torch.Tensor, Any]]`):
The input data for the model. Must include "shift_labels" key.
return_outputs (`bool`, *optional*, defaults to `False`):
Whether to return the model outputs along with the loss.
pc (`accelerate.parallelism_config.ParallelismConfig`):
self.accelerator.parallelism_config object (not None)

Returns:
The loss of the model along with its output if return_outputs was set to True
"""

unwrapped_model = self.accelerator.unwrap_model(model)

outputs = model(**inputs)
shift_labels = inputs["shift_labels"]
loss = unwrapped_model.loss_function(
logits=outputs.logits,
labels=None,
shift_labels=shift_labels,
vocab_size=unwrapped_model.config.vocab_size,
)

sp_group = self.accelerator.torch_device_mesh["sp"].get_group()
sp_world_size = pc.sp_size
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
# special dealing with SFT that has prompt tokens that aren't used in loss computation
good_tokens = (shift_labels != -100).view(-1).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
# Skip ranks with zero valid tokens
total_loss = sum(
losses_per_rank[rank] * good_tokens_per_rank[rank]
for rank in range(sp_world_size)
if good_tokens_per_rank[rank] > 0
)
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / max(total_good_tokens, 1)

return (loss, outputs) if return_outputs else loss

def is_local_process_zero(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
Expand Down Expand Up @@ -3917,7 +4014,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
Path(os.path.join(output_dir, "user_content.pt")).touch()
# We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank
elif getattr(self.accelerator, "parallelism_config", None) is not None:
if self.accelerator.should_save_model:
# DeepSpeed SP already handles checkpoint saving below, so skip manual save in that case
pc = getattr(self.accelerator, "parallelism_config")
if self.accelerator.should_save_model and not (pc.sp_enabled and pc.sp_backend == "deepspeed"):
self._save(output_dir)
# If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained`
elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1:
Expand Down Expand Up @@ -4986,9 +5085,10 @@ def create_accelerator_and_postprocess(self):

# We defer compatibility checks to accelerator
if self.args.parallelism_config is not None:
if not is_accelerate_available("1.10.1"):
min_accelerate_version = "1.12.0"
if not is_accelerate_available(min_accelerate_version):
raise ImportError(
"ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature."
f"ParallelismConfig requires accelerate>={min_accelerate_version}). Please upgrade accelerate to use this feature."
)
args["parallelism_config"] = self.args.parallelism_config

Expand Down Expand Up @@ -5182,6 +5282,11 @@ def set_initial_training_values(
epoch_based = max_steps < 0
len_dataloader = len(dataloader) if has_length(dataloader) else None

# Account for Sequence Parallelism (SP) dataloader adapter's effect
sp_size = self.get_sp_size()
if sp_size > 1 and len_dataloader is not None:
len_dataloader = len_dataloader * sp_size

# Case 2: We have a dataloader length and can extrapolate
if len_dataloader is not None:
num_update_steps_per_epoch = max(
Expand Down
Loading