diff --git a/docs/source/en/deepspeed.md b/docs/source/en/deepspeed.md index 7971854011ee..bee280fdc328 100644 --- a/docs/source/en/deepspeed.md +++ b/docs/source/en/deepspeed.md @@ -368,6 +368,108 @@ The example ZeRO-3 and ZeRO-Infinity config below sets most of the parameter val } ``` +### Sequence Parallelism + +DeepSpeed's ALST/Ulysses sequence parallelism enables training with very long sequences by splitting the sequence across multiple GPUs. This is particularly useful for training large language models with very long sequence lengths. + +Arctic Long Sequence Training (ALST) uses a combination of sharding inputs along the sequence dimension and attention head parallelism. With this approach, you can train models with sequence lengths up to 500K tokens on a single H100 GPU, 3.7M on a single node, or 15M tokens on just four nodes with Llama-8B. The implementation described here enables one component of the full ALST system. For additional optimizations like TiledMLP and activation checkpoint offloading, refer to the [DeepSpeed ALST tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/). + +> [!TIP] +> For more detailed information about sequence parallelism, see the Accelerate [Sequence Parallelism](https://huggingface.co/docs/accelerate/concept_guides/sequence_parallelism) guide. + +To enable ALST/Ulysses sequence parallelism with [`Trainer`], configure `parallelism_config` in [`TrainingArguments`]. Sequence parallelism is configured via Accelerate's `ParallelismConfig` and requires an Accelerate version higher than 1.12.0. + +```py +from accelerate.utils import ParallelismConfig, DeepSpeedSequenceParallelConfig + +# Example: 4 GPUs with sp_size=4, dp_replicate_size=1 (no data parallelism) +# Ensure total_size = dp_replicate_size * dp_shard_size * sp_size = 1 * 1 * 4 = 4 GPUs +parallelism_config = ParallelismConfig( + sp_backend="deepspeed", + sp_size=4, # Number of GPUs to split sequence across + dp_replicate_size=1, # Explicit: no data parallelism + sp_handler=DeepSpeedSequenceParallelConfig( + sp_seq_length_is_variable=True, + sp_attn_implementation="sdpa", + ), +) + +training_args = TrainingArguments( + ..., + deepspeed="path/to/deepspeed_config.json", + parallelism_config=parallelism_config, +) +``` + +You can also configure sequence parallelism using an Accelerate config file. + +```yaml +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: path/to/ds_config.json +machine_rank: 0 +num_machines: 1 +num_processes: 4 # Total number of processes +parallelism_config: + parallelism_config_sp_size: 4 # Sequence parallel size + parallelism_config_dp_replicate_size: 1 # Must be: dp_replicate_size * dp_shard_size * sp_size = num_processes + parallelism_config_sp_backend: deepspeed + parallelism_config_sp_seq_length_is_variable: true + parallelism_config_sp_attn_implementation: sdpa +``` + +Important configuration parameters include the following. + +* `sp_backend` must be set to `"deepspeed"` to use ALST/Ulysses sequence parallelism. +* `sp_size` is the degree of sequence parallelism. For example, `sp_size=4` means 4 GPUs will process a single sequence in parallel. You need at least 2 GPUs to enable sequence parallelism. **Data feeding**: Each rank receives a unique data stream from the DataLoader (like DP). **Batch size calculation**: The effective `dp_world_size = world_size / sp_size`. So with 4 GPUs and `sp_size=4`, each of the 4 ranks gets different samples from the DataLoader, but `dp_world_size=1` for total batch size calculations +* `sp_seq_length_is_variable` determines how sequence lengths are handled. When set to `True` (recommended), the implementation adapts to varying sequence lengths between batches. When `False`, all sequences must be padded to a fixed length specified by `sp_seq_length`. +* `sp_attn_implementation` specifies the attention implementation to use. Supported values are `"sdpa"`, `"flash_attention_2"`, or `"flash_attention_3"`. Flash Attention is recommended for best performance, especially with multiple samples in a batch, because SDPA may incorrectly attend across sample boundaries. + +> [!WARNING] +> Sequence parallelism requires your model to use one of the supported attention implementations (`sdpa`, `flash_attention_2`, or `flash_attention_3`). The `eager` attention implementation is not supported because it doesn't properly handle `position_ids`. + +When using sequence parallelism, ensure your sequences are properly padded. Use `pad_to_multiple_of` in your data collator to ensure sequences are divisible by `sp_size`. For example, with `sp_size=4`, set `pad_to_multiple_of=4` or higher. + +```py +from transformers import DataCollatorForLanguageModeling + +data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=4, # Ensure sequences are divisible by sp_size +) +``` + +When using `sp_size` with multiple GPUs, you **must** explicitly set `dp_replicate_size` or `dp_shard_size` to ensure `total_size = dp_replicate_size * dp_shard_size * sp_size` equals your total number of GPUs. For example, with 8 GPUs and `sp_size=4`, you must set `dp_replicate_size=2` (since 2 × 1 × 4 = 8): + +```py +parallelism_config = ParallelismConfig( + sp_backend="deepspeed", + sp_size=4, + dp_replicate_size=2, + sp_handler=DeepSpeedSequenceParallelConfig( + sp_seq_length_is_variable=True, + sp_attn_implementation="flash_attention_2", + ), +) +``` + +[`Trainer`] automatically handles the special requirements for sequence parallelism including: + +* Adapting the data loader via DeepSpeed's [`UlyssesSPDataLoaderAdapter`](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/sequence_parallel/ulysses_sp.py) to shard sequences across GPUs. **Important**: Unlike Tensor Parallelism (where all ranks must receive identical data), each rank with SP receives a unique data stream from the DataLoader (similar to DP). The adapter handles distributing sequence chunks across SP ranks internally, so your DataLoader should continue feeding different samples to each rank. +* Generating `position_ids` when not provided +* Creating `shift_labels` for causal language modeling +* Aggregating loss across sequence parallel ranks with proper masking for `-100` labels + +You can launch training with sequence parallelism using the `accelerate launch` command. + +```bash +accelerate launch --config_file alst_config.yaml your_training_script.py \ +--output_dir output_dir \ +--per_device_train_batch_size 1 \ +--gradient_accumulation_steps 1 +``` + ## Training features DeepSpeed supports many training features that can be configured in the config file. This section describes some of the most important features. diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 81f4de255f9c..9e5263eed1d2 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -21,6 +21,7 @@ import gc import importlib import inspect +import json import logging import multiprocessing import os @@ -2005,14 +2006,12 @@ def get_env(self): paths = [self.repo_root_dir_str, self.src_dir_str] if "/examples" in self.test_file_dir_str: paths.append(self.examples_dir_str) - else: - paths.append(self.tests_dir_str) paths.append(env.get("PYTHONPATH", "")) env["PYTHONPATH"] = ":".join(paths) return env - def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): + def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False): """ Args: tmp_dir (`string`, *optional*): @@ -2032,6 +2031,8 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): after (`bool`, *optional*): If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents intact at the end of the test. + return_pathlib_obj (`bool`, *optional*): + If `True` will return a pathlib.Path object Returns: tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir @@ -2078,7 +2079,7 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): # register for deletion self.teardown_tmp_dirs.append(tmp_dir) - return tmp_dir + return Path(tmp_dir).resolve() if return_pathlib_obj else tmp_dir def python_one_liner_max_rss(self, one_liner_str): """ @@ -4076,3 +4077,13 @@ def use_one_line_repr(obj): cache[(id(obj), indent, mode, prefix)] = output return output + + +def write_file(file, content): + with open(file, "w") as f: + f.write(content) + + +def read_json_file(file): + with open(file, "r") as fh: + return json.load(fh) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bccb5a881732..c97c2bb7846b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -603,6 +603,11 @@ def __init__( k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values() ) + # Override for Sequence Parallelism: SP computes its own good_tokens count, so skip num_items_in_batch calculation + pc = getattr(self.accelerator, "parallelism_config", None) + if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled: + self.model_accepts_loss_kwargs = False + self.neftune_noise_alpha = args.neftune_noise_alpha self.compute_metrics = compute_metrics @@ -2163,6 +2168,22 @@ def train( ignore_keys_for_eval=ignore_keys_for_eval, ) + def get_sp_size(self) -> int: + """Get the sequence parallel size""" + if getattr(self.accelerator, "parallelism_config", None) is None: + return 1 + else: + pc = self.accelerator.parallelism_config + return pc.sp_size + + def get_cp_size(self) -> int: + """Get the context parallel size""" + if getattr(self.accelerator, "parallelism_config", None) is None: + return 1 + else: + pc = self.accelerator.parallelism_config + return pc.cp_size + def get_tp_size(self) -> int: """Get the tensor parallel size from either the model or DeepSpeed config.""" @@ -2180,8 +2201,19 @@ def get_tp_size(self) -> int: def get_total_train_batch_size(self, args) -> int: """Calculates total batch size (micro_batch * grad_accum * dp_world_size). - Note: Only considers DP and TP (dp_world_size = world_size // tp_size).""" - dp_world_size = args.world_size // self.get_tp_size() + Accounts for all parallelism dimensions: TP, CP, and SP. + + Formula: dp_world_size = world_size // (tp_size * cp_size * sp_size) + + Where: + - TP (Tensor Parallelism): Model layers split across GPUs + - CP (Context Parallelism): Sequences split using Ring Attention (FSDP2) + - SP (Sequence Parallelism): Sequences split using ALST/Ulysses (DeepSpeed) + + All dimensions are separate and multiplicative: world_size = dp_size * tp_size * cp_size * sp_size + """ + + dp_world_size = args.world_size // self.get_tp_size() // self.get_cp_size() // self.get_sp_size() return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size def _inner_training_loop( @@ -2305,6 +2337,11 @@ def _inner_training_loop( else: self.optimizer = self.accelerator.prepare(self.optimizer) + # since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared + pc = getattr(self.accelerator, "parallelism_config", None) + if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled: + train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model) + if self.is_fsdp_enabled: self.model = self.model_wrapped = model @@ -3639,23 +3676,30 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch. getattr(self.accelerator, "parallelism_config", None) is not None and self.accelerator.parallelism_config.cp_enabled ): - if hasattr(model, "config"): - if model.config._attn_implementation != "sdpa": - raise ValueError( - f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}." - ) + if self.accelerator.parallelism_config.cp_backend == "torch": + if hasattr(model, "config"): + if model.config._attn_implementation != "sdpa": + raise ValueError( + f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}." + ) + + if "shift_labels" not in inputs: + logger.warning_once("Shift labels not found in the inputs, shifting manually") + if "labels" in inputs: + _ignore_index = -100 + labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index) + inputs["shift_labels"] = labels[:, 1:].contiguous() + + # note: we don't do anything for accelerator.parallelism_config.sp_backend == "deepspeed" since: + # - accelerator.parallelism_config performs the `model.config._attn_implementation` checks already and it supports more than `dspa` + # - UlyssesSPDataLoaderAdapter called from Accelerate performs the `shift_label` creation - must not interfere + # - position_ids generation should be done by HF Trainer if it wasn't done by the user if "position_ids" not in inputs: logger.warning_once("Position IDs not found in the inputs, generating manually") inputs["position_ids"] = torch.arange( inputs["input_ids"].size(1), device=inputs["input_ids"].device ).expand(inputs["input_ids"].size(0), -1) - if "shift_labels" not in inputs: - logger.warning_once("Shift labels not found in the inputs, shifting manually") - if "labels" in inputs: - _ignore_index = -100 - labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index) - inputs["shift_labels"] = labels[:, 1:].contiguous() buffers = [] buffer_seq_dims = [] @@ -3824,6 +3868,10 @@ def compute_loss( Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss, make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation. """ + pc = getattr(self.accelerator, "parallelism_config", None) + if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled: + return self._deepspeed_sp_compute_loss(model, inputs, return_outputs, pc) + if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: labels = inputs.pop("labels") else: @@ -3877,6 +3925,55 @@ def compute_loss( return (loss, outputs) if return_outputs else loss + def _deepspeed_sp_compute_loss(self, model, inputs, return_outputs, pc): + """ + How the loss is computed by Trainer under sequence parallelism with sp_backend=="deepspeed" and sp_size>1. + Performs weighted loss aggregation across SP ranks, accounting for varying numbers of valid tokens per rank + (e.g., when some ranks receive only padding or prompt tokens that are masked with -100). + + Args: + model (`nn.Module`): + The model to compute the loss for. + inputs (`dict[str, Union[torch.Tensor, Any]]`): + The input data for the model. Must include "shift_labels" key. + return_outputs (`bool`, *optional*, defaults to `False`): + Whether to return the model outputs along with the loss. + pc (`accelerate.parallelism_config.ParallelismConfig`): + self.accelerator.parallelism_config object (not None) + + Returns: + The loss of the model along with its output if return_outputs was set to True + """ + + unwrapped_model = self.accelerator.unwrap_model(model) + + outputs = model(**inputs) + shift_labels = inputs["shift_labels"] + loss = unwrapped_model.loss_function( + logits=outputs.logits, + labels=None, + shift_labels=shift_labels, + vocab_size=unwrapped_model.config.vocab_size, + ) + + sp_group = self.accelerator.torch_device_mesh["sp"].get_group() + sp_world_size = pc.sp_size + # differentiable weighted per-shard-loss aggregation across ranks + losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) + # special dealing with SFT that has prompt tokens that aren't used in loss computation + good_tokens = (shift_labels != -100).view(-1).sum() + good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) + # Skip ranks with zero valid tokens + total_loss = sum( + losses_per_rank[rank] * good_tokens_per_rank[rank] + for rank in range(sp_world_size) + if good_tokens_per_rank[rank] > 0 + ) + total_good_tokens = sum(good_tokens_per_rank) + loss = total_loss / max(total_good_tokens, 1) + + return (loss, outputs) if return_outputs else loss + def is_local_process_zero(self) -> bool: """ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several @@ -3917,7 +4014,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa Path(os.path.join(output_dir, "user_content.pt")).touch() # We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank elif getattr(self.accelerator, "parallelism_config", None) is not None: - if self.accelerator.should_save_model: + # DeepSpeed SP already handles checkpoint saving below, so skip manual save in that case + pc = getattr(self.accelerator, "parallelism_config") + if self.accelerator.should_save_model and not (pc.sp_enabled and pc.sp_backend == "deepspeed"): self._save(output_dir) # If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained` elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1: @@ -4986,9 +5085,10 @@ def create_accelerator_and_postprocess(self): # We defer compatibility checks to accelerator if self.args.parallelism_config is not None: - if not is_accelerate_available("1.10.1"): + min_accelerate_version = "1.12.0" + if not is_accelerate_available(min_accelerate_version): raise ImportError( - "ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature." + f"ParallelismConfig requires accelerate>={min_accelerate_version}). Please upgrade accelerate to use this feature." ) args["parallelism_config"] = self.args.parallelism_config @@ -5182,6 +5282,11 @@ def set_initial_training_values( epoch_based = max_steps < 0 len_dataloader = len(dataloader) if has_length(dataloader) else None + # Account for Sequence Parallelism (SP) dataloader adapter's effect + sp_size = self.get_sp_size() + if sp_size > 1 and len_dataloader is not None: + len_dataloader = len_dataloader * sp_size + # Case 2: We have a dataloader length and can extrapolate if len_dataloader is not None: num_update_steps_per_epoch = max( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 68ff158316be..8c7f1a5573c3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1139,7 +1139,7 @@ class TrainingArguments: ) parallelism_config: Optional[ParallelismConfig] = field( default=None, - metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.10.1`")}, + metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.12.0`")}, ) deepspeed: Optional[Union[dict, str]] = field( default=None, diff --git a/tests/deepspeed/test_alst_ulysses_sp.py b/tests/deepspeed/test_alst_ulysses_sp.py new file mode 100644 index 000000000000..2f24325f3ab4 --- /dev/null +++ b/tests/deepspeed/test_alst_ulysses_sp.py @@ -0,0 +1,214 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys + +from transformers import is_torch_available +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + read_json_file, + require_accelerate, + require_torch_multi_accelerator, + slow, + write_file, +) + + +if is_torch_available(): + import torch + + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + HfArgumentParser, + Trainer, + TrainingArguments, + ) + + +class TestTrainerALSTUlyssesSP(TestCasePlus): + """Test Trainer with ALST/Ulysses sequence parallelism enabled via accelerate's ParallelismConfig.""" + + @require_torch_multi_accelerator + @require_accelerate + @slow + def test_sp_equivalence(self): + """Test that ALST/Ulysses sequence parallelism produces the same losses as without it.""" + + # shared setup + world_size = 2 + script_path = __file__ # self.test_file_dir} / "test_alst_ulysses_sp.py" + ds_config_path = self.test_file_dir / "ds_config_zero2.json" + + # step 1. Run with SP enabled (sp_size=world_size) + sp_yes_output_dir = self.get_auto_remove_tmp_dir(return_pathlib_obj=True) + sp_yes_accelerate_config_path = sp_yes_output_dir / "context_parallel_config.yaml" + sp_yes_losses_path = sp_yes_output_dir / "sp_yes_losses.json" + write_file( + sp_yes_accelerate_config_path, + f""" +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: {ds_config_path} +machine_rank: 0 +num_machines: 1 +num_processes: {world_size} +parallelism_config: + parallelism_config_sp_size: {world_size} + parallelism_config_sp_backend: deepspeed + parallelism_config_sp_seq_length_is_variable: true + parallelism_config_sp_attn_implementation: sdpa + """, + ) + + cmd_sp = f""" + accelerate launch + --config_file {sp_yes_accelerate_config_path} + {script_path} + --output_dir {sp_yes_output_dir} + --report_to none + --max_steps 10 + --per_device_train_batch_size 1 + --gradient_accumulation_steps 1 + --logging_steps 1 + --remove_unused_columns False + --seed 42 + --loss_output_file {sp_yes_losses_path} + """.split() + + execute_subprocess_async(cmd_sp, env=self.get_env()) + + # step 2. Run without SP enabled (sp_size=world_size) + sp_no_output_dir = self.get_auto_remove_tmp_dir(return_pathlib_obj=True) + sp_no_accelerate_config_path = sp_no_output_dir / "context_parallel_config.yaml" + sp_no_losses_path = sp_no_output_dir / "sp_yes_losses.json" + write_file( + sp_no_accelerate_config_path, + f""" +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: {ds_config_path} +machine_rank: 0 +num_machines: 1 +num_processes: {world_size} + """, + ) + + cmd_sp = f""" + accelerate launch + --config_file {sp_no_accelerate_config_path} + {script_path} + --output_dir {sp_no_output_dir} + --report_to none + --max_steps 10 + --per_device_train_batch_size 1 + --gradient_accumulation_steps 1 + --logging_steps 1 + --remove_unused_columns False + --seed 42 + --loss_output_file {sp_no_losses_path} + """.split() + + execute_subprocess_async(cmd_sp, env=self.get_env()) + + # Compare losses - should be very close since SP just splits sequence computation + sp_yes_losses = read_json_file(sp_yes_losses_path) + sp_no_losses = read_json_file(sp_no_losses_path) + + assert len(sp_yes_losses) == len(sp_no_losses), ( + f"Different number of losses: SP has {len(sp_yes_losses)}, no-SP has {len(sp_no_losses)}" + ) + + # ALST/UlyssesSP should produce very similar results (small numerical differences expected) + # The differences come from: + # - Different gradient reduction patterns in distributed training + # - BF16 mixed precision accumulated differences + sp_yes_losses_tensor = torch.tensor(sp_yes_losses) + sp_no_losses_tensor = torch.tensor(sp_no_losses) + torch.testing.assert_close( + sp_yes_losses_tensor, + sp_no_losses_tensor, + atol=2e-2, + rtol=2e-5, + msg=f"SP-enabled losses {sp_yes_losses} do not match SP-disabled losses {sp_no_losses}", + ) + + +if __name__ == "__main__": + model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + + # Parse custom arguments (not TrainingArguments parameters) + loss_output_file = None + + if "--loss_output_file" in sys.argv: + idx = sys.argv.index("--loss_output_file") + loss_output_file = sys.argv[idx + 1] + sys.argv.pop(idx) + sys.argv.pop(idx) + + parser = HfArgumentParser((TrainingArguments,)) + training_args = parser.parse_args_into_dataclasses()[0] + + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + model_name, + attn_implementation="sdpa", # SP requires SDPA or FA + ) + # fix the outdated testing model config + model.generation_config.pad_token_id = 1 + + # Create simple dataset: just tokenize some text + texts = [ + "The quick brown fox jumps over the lazy dog. " * 10, + "Hello world, this is a test sentence for training. " * 10, + ] * 4 # 8 samples total + + def tokenize_function(examples): + return tokenizer(examples, max_length=128, truncation=True, padding="max_length") + + train_dataset = [tokenize_function(text) for text in texts] + + # Use standard DataCollatorForLanguageModeling for causal LM + # pad_to_multiple_of=4 ensures sequences are divisible by sp_size * 2 (for sp_size=2) + # Trainer will automatically generate position_ids and shift_labels as needed + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, # Causal language modeling + pad_to_multiple_of=4, + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=data_collator, + ) + + # Train for a few steps + trainer.train() + + # Verify training completed + assert trainer.state.global_step > 0, "Training should have completed at least one step" + + # Save losses to file if requested (for equivalence testing) + if loss_output_file and training_args.process_index == 0: + losses = [log["loss"] for log in trainer.state.log_history if "loss" in log] + with open(loss_output_file, "w") as f: + json.dump(losses, f)