Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import importlib
import time
from datetime import timedelta
from typing import Any, Iterable, Optional
from typing import Any, Iterable

import torch
from torch.distributed.elastic.multiprocessing.errors import record
Expand All @@ -17,15 +17,16 @@
from torchtitan.components.metrics import build_metrics_processor
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.config import ConfigManager, JobConfig
from torchtitan.config import JobConfig
from torchtitan.distributed import utils as dist_utils
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.logging import logger
from torchtitan.tools.profiling import (
maybe_enable_memory_snapshot,
maybe_enable_profiling,
)
from torchtitan.train import main

from .engine import ForgeEngine

Expand Down Expand Up @@ -350,19 +351,4 @@ def close(self) -> None:


if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[Trainer] = None

try:
trainer = Trainer(config)
trainer.train()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed.")
main(Trainer)
47 changes: 9 additions & 38 deletions torchtitan/experiments/torchcomms/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional

import torch

from torchtitan.config import ConfigManager
from torchtitan.distributed import ParallelDims
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer
from torchtitan.train import main, Trainer

from .parallel_dims import TorchCommsParallelDims

Expand All @@ -32,35 +25,13 @@ def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
world_size=world_size,
)

def close(self) -> None:
# Call finalize on all comms after training and before destroying process group.
if hasattr(self, "parallel_dims"):
for comm in self.parallel_dims.comms:
comm.finalize()
super().close()

if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[TorchCommsTrainer] = None

try:
trainer = TorchCommsTrainer(config)

if config.checkpoint.create_seed_checkpoint:
assert (
int(os.environ["WORLD_SIZE"]) == 1
), "Must create seed checkpoint using a single device, to disable sharding."
assert (
config.checkpoint.enable
), "Must enable checkpointing when creating a seed checkpoint."
trainer.checkpointer.save(curr_step=0, last_step=True)
logger.info("Created seed checkpoint")
else:
trainer.train()
# Call finalize on all comms after training and before destroying process group.
for comm in trainer.parallel_dims.comms:
comm.finalize()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")
if __name__ == "__main__":
main(TorchCommsTrainer)
35 changes: 3 additions & 32 deletions torchtitan/models/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional

import torch

from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import utils as dist_utils

from torchtitan.models.flux.infra.parallelize import parallelize_encoders
Expand All @@ -20,8 +17,7 @@
pack_latents,
preprocess_data,
)
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer
from torchtitan.train import main, Trainer


class FluxTrainer(Trainer):
Expand Down Expand Up @@ -175,29 +171,4 @@ def forward_backward_step(


if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[FluxTrainer] = None

try:
trainer = FluxTrainer(config)
if config.checkpoint.create_seed_checkpoint:
assert (
int(os.environ["WORLD_SIZE"]) == 1
), "Must create seed checkpoint using a single device, to disable sharding."
assert (
config.checkpoint.enable
), "Must enable checkpointing when creating a seed checkpoint."
trainer.checkpointer.save(curr_step=0, last_step=True)
logger.info("Created seed checkpoint")
else:
trainer.train()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed.")
main(FluxTrainer)
72 changes: 61 additions & 11 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional
from typing import Any, Generator, Iterable

import torch

Expand Down Expand Up @@ -410,26 +410,67 @@ def batch_generator(

yield input_dict, labels

def forward_backward_step(
def post_dataloading_process(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
"""
Post-processing hook after data loading and before model forward pass.

This method processes the raw data from the dataloader and prepares it for
the model's forward pass. It separates the main input tensor from auxiliary
inputs and constructs additional keyword arguments (e.g., attention masks).

This method can be overridden in subclasses to customize data processing
for different training strategies (e.g., converting tensors to DTensors,
applying custom transformations, etc.).

Args:
input_dict: Dictionary containing tensors from the dataloader. Must
contain an "input" key with the main input tensor. May contain
additional keys for auxiliary inputs (e.g., position ids).
labels: Target labels for the batch.

Returns:
A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (unchanged from input parameter).
- extra_inputs: Dict of auxiliary input tensors (all keys except
"input" from input_dict). These are passed to the model forward
but are NOT forwarded across pipeline parallel stages.
- extra_kwargs: Dict of additional keyword arguments for model forward.
These ARE forwarded across pipeline parallel stages. Contains
attention_masks if flex attention is enabled.

Note:
The distinction between extra_inputs and extra_kwargs is important for
pipeline parallelism: extra_kwargs are forwarded to all pipeline stages,
while extra_inputs are only available to the first stage.
"""
inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
# dict as extra_inputs are not forwarded to other stages in PP, but
# extra_kwargs are.
extra_kwargs = {}
extra_kwargs: dict[str, Any] = {}

if getattr(self.model_args, "use_flex_attn", False):
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
)

return inputs, labels, extra_inputs, extra_kwargs

def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
input_dict, labels
)
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
optional_context_parallel_ctx = (
Expand Down Expand Up @@ -662,14 +703,19 @@ def close(self) -> None:
self.metrics_processor.close()


if __name__ == "__main__":
def main(trainer_class: type[Trainer]) -> None:
"""Main entry point for training with a specified trainer class.

Args:
trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer)
"""
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[Trainer] = None
trainer: Trainer | None = None

try:
trainer = Trainer(config)
trainer = trainer_class(config)

if config.checkpoint.create_seed_checkpoint:
assert (
Expand All @@ -690,3 +736,7 @@ def close(self) -> None:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")


if __name__ == "__main__":
main(Trainer)
Loading