Skip to content

Commit 7e0ea69

Browse files
stas00sfc-gh-sbekmankashifSunMarc
authored
HF Trainer: ALST/Ulysses sequence parallelism integration via HF Accelerate (#41832)
* HF Trainer: ALST/Ulysses sequence parallelism integration via HF Accelerate Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * make it work + tests Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * cleanup Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * undo Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * normalize Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * always return cp_size Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * cleanup Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * extract code into _deepspeed_cp_compute_loss Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * fix Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * ALST/Ulysses sequence parallelism docs * typo * add link to UlyssesSPDataLoaderAdapter * adapt to renaming to SP Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * improve Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * fix Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * Update docs/source/en/deepspeed.md Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * address comments Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * address comments Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * Update src/transformers/trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * address comments Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * address comments Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * Update src/transformers/trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update src/transformers/trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * style Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> * Update docs/source/en/deepspeed.md Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update docs/source/en/deepspeed.md Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Account for Sequence Parallelism (SP) dataloader adapter effect * Update src/transformers/trainer.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update docs/source/en/deepspeed.md Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * Update docs/source/en/deepspeed.md Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * model_accepts_loss_kwargs to False * better comment * Apply suggestion from @kashif * Apply suggestion from @kashif * Apply suggestions from code review * Apply suggestion from @kashif * Apply suggestion from @kashif * Apply suggestion from @kashif * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Apply suggestion from @kashif * Apply suggestion from @kashif --------- Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> Co-authored-by: Stas Bekman <stas.bekman@snowflake.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent afdc40d commit 7e0ea69

File tree

5 files changed

+453
-21
lines changed

5 files changed

+453
-21
lines changed

docs/source/en/deepspeed.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,108 @@ The example ZeRO-3 and ZeRO-Infinity config below sets most of the parameter val
368368
}
369369
```
370370
371+
### Sequence Parallelism
372+
373+
DeepSpeed's ALST/Ulysses sequence parallelism enables training with very long sequences by splitting the sequence across multiple GPUs. This is particularly useful for training large language models with very long sequence lengths.
374+
375+
Arctic Long Sequence Training (ALST) uses a combination of sharding inputs along the sequence dimension and attention head parallelism. With this approach, you can train models with sequence lengths up to 500K tokens on a single H100 GPU, 3.7M on a single node, or 15M tokens on just four nodes with Llama-8B. The implementation described here enables one component of the full ALST system. For additional optimizations like TiledMLP and activation checkpoint offloading, refer to the [DeepSpeed ALST tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).
376+
377+
> [!TIP]
378+
> For more detailed information about sequence parallelism, see the Accelerate [Sequence Parallelism](https://huggingface.co/docs/accelerate/concept_guides/sequence_parallelism) guide.
379+
380+
To enable ALST/Ulysses sequence parallelism with [`Trainer`], configure `parallelism_config` in [`TrainingArguments`]. Sequence parallelism is configured via Accelerate's `ParallelismConfig` and requires an Accelerate version higher than 1.12.0.
381+
382+
```py
383+
from accelerate.utils import ParallelismConfig, DeepSpeedSequenceParallelConfig
384+
385+
# Example: 4 GPUs with sp_size=4, dp_replicate_size=1 (no data parallelism)
386+
# Ensure total_size = dp_replicate_size * dp_shard_size * sp_size = 1 * 1 * 4 = 4 GPUs
387+
parallelism_config = ParallelismConfig(
388+
sp_backend="deepspeed",
389+
sp_size=4, # Number of GPUs to split sequence across
390+
dp_replicate_size=1, # Explicit: no data parallelism
391+
sp_handler=DeepSpeedSequenceParallelConfig(
392+
sp_seq_length_is_variable=True,
393+
sp_attn_implementation="sdpa",
394+
),
395+
)
396+
397+
training_args = TrainingArguments(
398+
...,
399+
deepspeed="path/to/deepspeed_config.json",
400+
parallelism_config=parallelism_config,
401+
)
402+
```
403+
404+
You can also configure sequence parallelism using an Accelerate config file.
405+
406+
```yaml
407+
distributed_type: DEEPSPEED
408+
deepspeed_config:
409+
deepspeed_config_file: path/to/ds_config.json
410+
machine_rank: 0
411+
num_machines: 1
412+
num_processes: 4 # Total number of processes
413+
parallelism_config:
414+
parallelism_config_sp_size: 4 # Sequence parallel size
415+
parallelism_config_dp_replicate_size: 1 # Must be: dp_replicate_size * dp_shard_size * sp_size = num_processes
416+
parallelism_config_sp_backend: deepspeed
417+
parallelism_config_sp_seq_length_is_variable: true
418+
parallelism_config_sp_attn_implementation: sdpa
419+
```
420+
421+
Important configuration parameters include the following.
422+
423+
* `sp_backend` must be set to `"deepspeed"` to use ALST/Ulysses sequence parallelism.
424+
* `sp_size` is the degree of sequence parallelism. For example, `sp_size=4` means 4 GPUs will process a single sequence in parallel. You need at least 2 GPUs to enable sequence parallelism. **Data feeding**: Each rank receives a unique data stream from the DataLoader (like DP). **Batch size calculation**: The effective `dp_world_size = world_size / sp_size`. So with 4 GPUs and `sp_size=4`, each of the 4 ranks gets different samples from the DataLoader, but `dp_world_size=1` for total batch size calculations
425+
* `sp_seq_length_is_variable` determines how sequence lengths are handled. When set to `True` (recommended), the implementation adapts to varying sequence lengths between batches. When `False`, all sequences must be padded to a fixed length specified by `sp_seq_length`.
426+
* `sp_attn_implementation` specifies the attention implementation to use. Supported values are `"sdpa"`, `"flash_attention_2"`, or `"flash_attention_3"`. Flash Attention is recommended for best performance, especially with multiple samples in a batch, because SDPA may incorrectly attend across sample boundaries.
427+
428+
> [!WARNING]
429+
> Sequence parallelism requires your model to use one of the supported attention implementations (`sdpa`, `flash_attention_2`, or `flash_attention_3`). The `eager` attention implementation is not supported because it doesn't properly handle `position_ids`.
430+
431+
When using sequence parallelism, ensure your sequences are properly padded. Use `pad_to_multiple_of` in your data collator to ensure sequences are divisible by `sp_size`. For example, with `sp_size=4`, set `pad_to_multiple_of=4` or higher.
432+
433+
```py
434+
from transformers import DataCollatorForLanguageModeling
435+
436+
data_collator = DataCollatorForLanguageModeling(
437+
tokenizer=tokenizer,
438+
mlm=False,
439+
pad_to_multiple_of=4, # Ensure sequences are divisible by sp_size
440+
)
441+
```
442+
443+
When using `sp_size` with multiple GPUs, you **must** explicitly set `dp_replicate_size` or `dp_shard_size` to ensure `total_size = dp_replicate_size * dp_shard_size * sp_size` equals your total number of GPUs. For example, with 8 GPUs and `sp_size=4`, you must set `dp_replicate_size=2` (since 2 × 1 × 4 = 8):
444+
445+
```py
446+
parallelism_config = ParallelismConfig(
447+
sp_backend="deepspeed",
448+
sp_size=4,
449+
dp_replicate_size=2,
450+
sp_handler=DeepSpeedSequenceParallelConfig(
451+
sp_seq_length_is_variable=True,
452+
sp_attn_implementation="flash_attention_2",
453+
),
454+
)
455+
```
456+
457+
[`Trainer`] automatically handles the special requirements for sequence parallelism including:
458+
459+
* Adapting the data loader via DeepSpeed's [`UlyssesSPDataLoaderAdapter`](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/sequence_parallel/ulysses_sp.py) to shard sequences across GPUs. **Important**: Unlike Tensor Parallelism (where all ranks must receive identical data), each rank with SP receives a unique data stream from the DataLoader (similar to DP). The adapter handles distributing sequence chunks across SP ranks internally, so your DataLoader should continue feeding different samples to each rank.
460+
* Generating `position_ids` when not provided
461+
* Creating `shift_labels` for causal language modeling
462+
* Aggregating loss across sequence parallel ranks with proper masking for `-100` labels
463+
464+
You can launch training with sequence parallelism using the `accelerate launch` command.
465+
466+
```bash
467+
accelerate launch --config_file alst_config.yaml your_training_script.py \
468+
--output_dir output_dir \
469+
--per_device_train_batch_size 1 \
470+
--gradient_accumulation_steps 1
471+
```
472+
371473
## Training features
372474
373475
DeepSpeed supports many training features that can be configured in the config file. This section describes some of the most important features.

src/transformers/testing_utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import gc
2222
import importlib
2323
import inspect
24+
import json
2425
import logging
2526
import multiprocessing
2627
import os
@@ -2005,14 +2006,12 @@ def get_env(self):
20052006
paths = [self.repo_root_dir_str, self.src_dir_str]
20062007
if "/examples" in self.test_file_dir_str:
20072008
paths.append(self.examples_dir_str)
2008-
else:
2009-
paths.append(self.tests_dir_str)
20102009
paths.append(env.get("PYTHONPATH", ""))
20112010

20122011
env["PYTHONPATH"] = ":".join(paths)
20132012
return env
20142013

2015-
def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
2014+
def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False):
20162015
"""
20172016
Args:
20182017
tmp_dir (`string`, *optional*):
@@ -2032,6 +2031,8 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
20322031
after (`bool`, *optional*):
20332032
If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
20342033
intact at the end of the test.
2034+
return_pathlib_obj (`bool`, *optional*):
2035+
If `True` will return a pathlib.Path object
20352036
20362037
Returns:
20372038
tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
@@ -2078,7 +2079,7 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
20782079
# register for deletion
20792080
self.teardown_tmp_dirs.append(tmp_dir)
20802081

2081-
return tmp_dir
2082+
return Path(tmp_dir).resolve() if return_pathlib_obj else tmp_dir
20822083

20832084
def python_one_liner_max_rss(self, one_liner_str):
20842085
"""
@@ -4076,3 +4077,13 @@ def use_one_line_repr(obj):
40764077
cache[(id(obj), indent, mode, prefix)] = output
40774078

40784079
return output
4080+
4081+
4082+
def write_file(file, content):
4083+
with open(file, "w") as f:
4084+
f.write(content)
4085+
4086+
4087+
def read_json_file(file):
4088+
with open(file, "r") as fh:
4089+
return json.load(fh)

src/transformers/trainer.py

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,11 @@ def __init__(
603603
k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
604604
)
605605

606+
# Override for Sequence Parallelism: SP computes its own good_tokens count, so skip num_items_in_batch calculation
607+
pc = getattr(self.accelerator, "parallelism_config", None)
608+
if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
609+
self.model_accepts_loss_kwargs = False
610+
606611
self.neftune_noise_alpha = args.neftune_noise_alpha
607612

608613
self.compute_metrics = compute_metrics
@@ -2163,6 +2168,22 @@ def train(
21632168
ignore_keys_for_eval=ignore_keys_for_eval,
21642169
)
21652170

2171+
def get_sp_size(self) -> int:
2172+
"""Get the sequence parallel size"""
2173+
if getattr(self.accelerator, "parallelism_config", None) is None:
2174+
return 1
2175+
else:
2176+
pc = self.accelerator.parallelism_config
2177+
return pc.sp_size
2178+
2179+
def get_cp_size(self) -> int:
2180+
"""Get the context parallel size"""
2181+
if getattr(self.accelerator, "parallelism_config", None) is None:
2182+
return 1
2183+
else:
2184+
pc = self.accelerator.parallelism_config
2185+
return pc.cp_size
2186+
21662187
def get_tp_size(self) -> int:
21672188
"""Get the tensor parallel size from either the model or DeepSpeed config."""
21682189

@@ -2180,8 +2201,19 @@ def get_tp_size(self) -> int:
21802201
def get_total_train_batch_size(self, args) -> int:
21812202
"""Calculates total batch size (micro_batch * grad_accum * dp_world_size).
21822203
2183-
Note: Only considers DP and TP (dp_world_size = world_size // tp_size)."""
2184-
dp_world_size = args.world_size // self.get_tp_size()
2204+
Accounts for all parallelism dimensions: TP, CP, and SP.
2205+
2206+
Formula: dp_world_size = world_size // (tp_size * cp_size * sp_size)
2207+
2208+
Where:
2209+
- TP (Tensor Parallelism): Model layers split across GPUs
2210+
- CP (Context Parallelism): Sequences split using Ring Attention (FSDP2)
2211+
- SP (Sequence Parallelism): Sequences split using ALST/Ulysses (DeepSpeed)
2212+
2213+
All dimensions are separate and multiplicative: world_size = dp_size * tp_size * cp_size * sp_size
2214+
"""
2215+
2216+
dp_world_size = args.world_size // self.get_tp_size() // self.get_cp_size() // self.get_sp_size()
21852217
return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size
21862218

21872219
def _inner_training_loop(
@@ -2305,6 +2337,11 @@ def _inner_training_loop(
23052337
else:
23062338
self.optimizer = self.accelerator.prepare(self.optimizer)
23072339

2340+
# since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared
2341+
pc = getattr(self.accelerator, "parallelism_config", None)
2342+
if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
2343+
train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model)
2344+
23082345
if self.is_fsdp_enabled:
23092346
self.model = self.model_wrapped = model
23102347

@@ -3639,23 +3676,30 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.
36393676
getattr(self.accelerator, "parallelism_config", None) is not None
36403677
and self.accelerator.parallelism_config.cp_enabled
36413678
):
3642-
if hasattr(model, "config"):
3643-
if model.config._attn_implementation != "sdpa":
3644-
raise ValueError(
3645-
f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
3646-
)
3679+
if self.accelerator.parallelism_config.cp_backend == "torch":
3680+
if hasattr(model, "config"):
3681+
if model.config._attn_implementation != "sdpa":
3682+
raise ValueError(
3683+
f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
3684+
)
3685+
3686+
if "shift_labels" not in inputs:
3687+
logger.warning_once("Shift labels not found in the inputs, shifting manually")
3688+
if "labels" in inputs:
3689+
_ignore_index = -100
3690+
labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
3691+
inputs["shift_labels"] = labels[:, 1:].contiguous()
3692+
3693+
# note: we don't do anything for accelerator.parallelism_config.sp_backend == "deepspeed" since:
3694+
# - accelerator.parallelism_config performs the `model.config._attn_implementation` checks already and it supports more than `dspa`
3695+
# - UlyssesSPDataLoaderAdapter called from Accelerate performs the `shift_label` creation - must not interfere
3696+
# - position_ids generation should be done by HF Trainer if it wasn't done by the user
36473697

36483698
if "position_ids" not in inputs:
36493699
logger.warning_once("Position IDs not found in the inputs, generating manually")
36503700
inputs["position_ids"] = torch.arange(
36513701
inputs["input_ids"].size(1), device=inputs["input_ids"].device
36523702
).expand(inputs["input_ids"].size(0), -1)
3653-
if "shift_labels" not in inputs:
3654-
logger.warning_once("Shift labels not found in the inputs, shifting manually")
3655-
if "labels" in inputs:
3656-
_ignore_index = -100
3657-
labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
3658-
inputs["shift_labels"] = labels[:, 1:].contiguous()
36593703

36603704
buffers = []
36613705
buffer_seq_dims = []
@@ -3824,6 +3868,10 @@ def compute_loss(
38243868
Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
38253869
make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation.
38263870
"""
3871+
pc = getattr(self.accelerator, "parallelism_config", None)
3872+
if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
3873+
return self._deepspeed_sp_compute_loss(model, inputs, return_outputs, pc)
3874+
38273875
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
38283876
labels = inputs.pop("labels")
38293877
else:
@@ -3877,6 +3925,55 @@ def compute_loss(
38773925

38783926
return (loss, outputs) if return_outputs else loss
38793927

3928+
def _deepspeed_sp_compute_loss(self, model, inputs, return_outputs, pc):
3929+
"""
3930+
How the loss is computed by Trainer under sequence parallelism with sp_backend=="deepspeed" and sp_size>1.
3931+
Performs weighted loss aggregation across SP ranks, accounting for varying numbers of valid tokens per rank
3932+
(e.g., when some ranks receive only padding or prompt tokens that are masked with -100).
3933+
3934+
Args:
3935+
model (`nn.Module`):
3936+
The model to compute the loss for.
3937+
inputs (`dict[str, Union[torch.Tensor, Any]]`):
3938+
The input data for the model. Must include "shift_labels" key.
3939+
return_outputs (`bool`, *optional*, defaults to `False`):
3940+
Whether to return the model outputs along with the loss.
3941+
pc (`accelerate.parallelism_config.ParallelismConfig`):
3942+
self.accelerator.parallelism_config object (not None)
3943+
3944+
Returns:
3945+
The loss of the model along with its output if return_outputs was set to True
3946+
"""
3947+
3948+
unwrapped_model = self.accelerator.unwrap_model(model)
3949+
3950+
outputs = model(**inputs)
3951+
shift_labels = inputs["shift_labels"]
3952+
loss = unwrapped_model.loss_function(
3953+
logits=outputs.logits,
3954+
labels=None,
3955+
shift_labels=shift_labels,
3956+
vocab_size=unwrapped_model.config.vocab_size,
3957+
)
3958+
3959+
sp_group = self.accelerator.torch_device_mesh["sp"].get_group()
3960+
sp_world_size = pc.sp_size
3961+
# differentiable weighted per-shard-loss aggregation across ranks
3962+
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
3963+
# special dealing with SFT that has prompt tokens that aren't used in loss computation
3964+
good_tokens = (shift_labels != -100).view(-1).sum()
3965+
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
3966+
# Skip ranks with zero valid tokens
3967+
total_loss = sum(
3968+
losses_per_rank[rank] * good_tokens_per_rank[rank]
3969+
for rank in range(sp_world_size)
3970+
if good_tokens_per_rank[rank] > 0
3971+
)
3972+
total_good_tokens = sum(good_tokens_per_rank)
3973+
loss = total_loss / max(total_good_tokens, 1)
3974+
3975+
return (loss, outputs) if return_outputs else loss
3976+
38803977
def is_local_process_zero(self) -> bool:
38813978
"""
38823979
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
@@ -3917,7 +4014,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
39174014
Path(os.path.join(output_dir, "user_content.pt")).touch()
39184015
# We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank
39194016
elif getattr(self.accelerator, "parallelism_config", None) is not None:
3920-
if self.accelerator.should_save_model:
4017+
# DeepSpeed SP already handles checkpoint saving below, so skip manual save in that case
4018+
pc = getattr(self.accelerator, "parallelism_config")
4019+
if self.accelerator.should_save_model and not (pc.sp_enabled and pc.sp_backend == "deepspeed"):
39214020
self._save(output_dir)
39224021
# If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained`
39234022
elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1:
@@ -4986,9 +5085,10 @@ def create_accelerator_and_postprocess(self):
49865085

49875086
# We defer compatibility checks to accelerator
49885087
if self.args.parallelism_config is not None:
4989-
if not is_accelerate_available("1.10.1"):
5088+
min_accelerate_version = "1.12.0"
5089+
if not is_accelerate_available(min_accelerate_version):
49905090
raise ImportError(
4991-
"ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature."
5091+
f"ParallelismConfig requires accelerate>={min_accelerate_version}). Please upgrade accelerate to use this feature."
49925092
)
49935093
args["parallelism_config"] = self.args.parallelism_config
49945094

@@ -5182,6 +5282,11 @@ def set_initial_training_values(
51825282
epoch_based = max_steps < 0
51835283
len_dataloader = len(dataloader) if has_length(dataloader) else None
51845284

5285+
# Account for Sequence Parallelism (SP) dataloader adapter's effect
5286+
sp_size = self.get_sp_size()
5287+
if sp_size > 1 and len_dataloader is not None:
5288+
len_dataloader = len_dataloader * sp_size
5289+
51855290
# Case 2: We have a dataloader length and can extrapolate
51865291
if len_dataloader is not None:
51875292
num_update_steps_per_epoch = max(

0 commit comments

Comments
 (0)