diff --git a/docs/debugging.md b/docs/debugging.md index f7758cbde5..4deb20bbac 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -70,7 +70,7 @@ When debugging issues with multi-dimensional parallelism (combinations of FSDP, Set consistent random seeds across all parallelism dimensions: ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.seed 42 +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.seed 42 ``` **Seed behavior with parallelism:** @@ -84,7 +84,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr Enable deterministic algorithms to ensure bit-for-bit reproducibility across runs: ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.deterministic +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.deterministic ``` **What it does:** @@ -93,6 +93,19 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr - Sets deterministic workspace configuration for CuBLAS operations - **Note:** This will significantly reduce training performance but ensures exact reproducibility +Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation. + +### Activation Checkipointing Debugging ### + +The following debug configs are available for AC. + +`preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. + +`determinism_check` - A string specifying the determinism function + +`debug` - capture ac debug information. Will be slower. + +See https://docs.pytorch.org/docs/stable/checkpoint.html for details. ### Seed-Checkpoint-based Reproducibility diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 21322ba232..ec835b4166 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -25,7 +25,7 @@ RowwiseParallel, ) from torchtitan.components.metrics import build_device_memory_monitor -from torchtitan.config import ConfigManager +from torchtitan.config import ConfigManager, Debug as DebugConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools import utils @@ -133,7 +133,8 @@ def test_generate( # sequences would require https://github.com/pytorch/torchtitan/pull/686 apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) - dist_utils.set_determinism(world_mesh, device, seed, deterministic) + debug_config = DebugConfig(seed=seed, deterministic=deterministic) + dist_utils.set_determinism(world_mesh, device, debug_config) # materalize model model.to_empty(device=device_type) diff --git a/torchtitan/config/__init__.py b/torchtitan/config/__init__.py index ba2795a601..53f390149a 100644 --- a/torchtitan/config/__init__.py +++ b/torchtitan/config/__init__.py @@ -16,6 +16,7 @@ ActivationCheckpoint, Checkpoint, Comm, + Debug, FaultTolerance, Job, JobConfig, @@ -49,4 +50,5 @@ "Profiling", "Training", "Validation", + "Debug", ] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7fe6802374..95588d2c3b 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -263,15 +263,6 @@ class Training: many temporary files. """ - seed: int | None = None - """Choose the base RNG seed used for training""" - - deterministic: bool = False - """Use deterministic algorithms wherever possible, may be slower""" - - debug_moe_force_load_balance: bool = False - """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" - @dataclass class Parallelism: @@ -639,6 +630,26 @@ class ActivationCheckpoint: https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 """ + preserve_rng_state: bool = False + """ + If deterministic output compared to non-checkpointed passes is required, set + to true. Results in stashing and restoring the RNG state during each checkpoint, + may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html + for details. + """ + + determinism_check: str = "default" + """ + A string specifying the determinism function. See + https://docs.pytorch.org/docs/stable/checkpoint.html for details. + """ + + debug: bool = False + """ + Capture ac debug information. Will be slower. See + https://docs.pytorch.org/docs/stable/checkpoint.html for details. + """ + @dataclass class Compile: @@ -887,6 +898,21 @@ def __post_init__(self): ), "validation steps must be positive or -1" +@dataclass +class Debug: + seed: int | None = None + """Choose the base RNG seed used for training""" + + deterministic: bool = False + """Use deterministic algorithms wherever possible, may be slower""" + + deterministic_warn_only: bool = False + """Only warns about ops without deterministic implementations rather than erroring out """ + + moe_force_load_balance: bool = False + """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + + @dataclass class JobConfig: """ @@ -912,6 +938,7 @@ class JobConfig: fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) experimental: Experimental = field(default_factory=Experimental) validation: Validation = field(default_factory=Validation) + debug: Debug = field(default_factory=Debug) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 57809c45f9..8359f71730 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -38,7 +38,11 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ac_freq = int(ac_config.selective_ac_option) if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop + module, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + early_stop=ac_config.early_stop, + debug=ac_config.debug, ) else: return module @@ -125,8 +129,10 @@ def selective_checkpointing_context_fn(): return ptd_checkpoint_wrapper( module, context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, early_stop=ac_config.early_stop, + debug=ac_config.debug, ) @@ -141,7 +147,11 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: nn.Module: The module with full activation checkpointing applied. """ return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop + module, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + early_stop=ac_config.early_stop, + debug=ac_config.debug, ) @@ -157,7 +167,7 @@ def _apply_op_sac_to_transformer_block_with_flex( Args: module (nn.Module): The transformer block to apply SAC to. - ac_config (ACConfig): The activation checkpointing config. + ac_config (ACConfig): The Activation Checkpoint config. base_fqn (str, optional): The base fqn of the module. Defaults to None. model_compile_enabled (bool): Whether model compilation is enabled. Defaults to False. @@ -298,7 +308,6 @@ def apply_ac( Returns: None """ - if ac_config.mode == "memory_budget": assert model_compile_enabled, "Memory budget mode requires model to be compiled" if ac_config.visualize_memory_budget_pareto: diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 93a96a4439..1d2c623e1b 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -17,7 +17,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor -from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP +from torchtitan.config import Comm as CommConfig, Debug as DebugConfig, TORCH_DTYPE_MAP from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -83,8 +83,7 @@ def dist_mean( def set_determinism( world_mesh: DeviceMesh | None, device: torch.device, - seed: int | None = None, - deterministic: bool = False, + debug_config: DebugConfig, distinct_seed_mesh_dim: str = "pp", ) -> None: """ @@ -97,9 +96,12 @@ def set_determinism( Set Determinism flags for increased reproducibility with loss of performance. """ - if deterministic: + if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") torch.use_deterministic_algorithms(True) + torch.use_deterministic_algorithms( + True, warn_only=debug_config.deterministic_warn_only + ) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # env var for deterministic CuBLAS @@ -114,6 +116,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) + seed = debug_config.seed if not world_mesh: if seed is not None: torch.manual_seed(seed) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index d832c39696..2f1887b2d7 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -104,8 +104,7 @@ def __init__(self, job_config: ForgeJobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, ) self.train_spec = get_train_spec(job_config.model.name) diff --git a/torchtitan/experiments/forge/job_config.py b/torchtitan/experiments/forge/job_config.py index b1c014cc1d..efc6e6c074 100644 --- a/torchtitan/experiments/forge/job_config.py +++ b/torchtitan/experiments/forge/job_config.py @@ -12,6 +12,7 @@ Checkpoint, Comm, Compile, + Debug, Job, LRScheduler, MemoryEstimation, @@ -45,6 +46,7 @@ class ForgeJobConfig: # fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) # experimental: Experimental = field(default_factory=Experimental) # validation: Validation = field(default_factory=Validation) + debug: Debug = field(default_factory=Debug) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index fe0de3354d..85115fef2b 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -99,6 +99,7 @@ def __init__(self) -> None: SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, ] def forward( diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 3bac6e82f1..0328c52334 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -107,7 +107,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args._debug_force_load_balance = ( - job_config.training.debug_moe_force_load_balance + job_config.debug.moe_force_load_balance ) def get_nparams_and_flops( diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 69d5b7dfbf..3fa2389335 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -34,8 +34,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( self.parallel_dims.world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, distinct_seed_mesh_dim="dp_shard", ) diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index faeb60aadf..53043a1d02 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -82,7 +82,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args._debug_force_load_balance = ( - job_config.training.debug_moe_force_load_balance + job_config.debug.moe_force_load_balance ) def get_nparams_and_flops( diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 5fd98fdce7..5660ba9c9c 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -56,7 +56,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: self.max_seq_len = seq_len self.moe_args._debug_force_load_balance = ( - job_config.training.debug_moe_force_load_balance + job_config.debug.moe_force_load_balance ) def get_nparams_and_flops( diff --git a/torchtitan/train.py b/torchtitan/train.py index 2efd7931ed..7a7e3fb516 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,7 @@ from typing import Any, Generator, Iterable, Optional import torch + from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -118,8 +119,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name)