From 958cac9cc76d333f74ef107b1b19a2a1e8d74972 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 9 Jul 2024 13:15:04 -0700 Subject: [PATCH] Add support of DDP and experimental CompiledAutograd Summary: Address the comments in https://github.com/pytorch/torchtitan/pull/319 and resubmit the PR to fit the current code base. Test Plan: ``` CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600 --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000 ``` ghstack-source-id: a3131b7b0a835576992dd86d010f53866da4aa9d Pull Request resolved: https://github.com/pytorch/torchtitan/pull/432 --- estimation.py | 1 + test_runner.py | 13 +++++++++ torchtitan/config_manager.py | 11 ++++++++ torchtitan/parallelisms/__init__.py | 3 +++ torchtitan/parallelisms/parallelize_llama.py | 28 +++++++++++++++++--- train.py | 27 +++++++++++++++---- 6 files changed, 75 insertions(+), 8 deletions(-) diff --git a/estimation.py b/estimation.py index e82a7b71..c1b4f4a9 100644 --- a/estimation.py +++ b/estimation.py @@ -67,6 +67,7 @@ def estimate_memory(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.experimental.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") diff --git a/test_runner.py b/test_runner.py index cba63544..48482c6f 100755 --- a/test_runner.py +++ b/test_runner.py @@ -273,6 +273,19 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.tensor_parallel_degree 1", + "--training.data_parallel_degree 8", + "--experimental.data_parallel_type ddp", + "--experimental.enable_compiled_autograd", + ] + ], + "CompiledDDP", + "compiled_ddp", + ngpu=8, + ), ] return integration_tests_flavors diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 86bcffd8..7e7e427f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -312,6 +312,17 @@ def __init__(self): The default value will be the number of pipeline stages, if unspecified. """, ) + self.parser.add_argument( + "--experimental.data_parallel_type", + type=str, + default="fsdp", + help="Data parallelism type. TorchTitan currently supports FSDP and DDP.", + ) + self.parser.add_argument( + "--experimental.enable_compiled_autograd", + action="store_true", + help="Enable CompiledAutograd to compile the backward.", + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 7e1b21c7..2fdba316 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -28,8 +28,10 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool + dp_type: str def __post_init__(self): + self.dp_type = self.dp_type.lower() self._validate() def _validate(self): @@ -42,6 +44,7 @@ def _validate(self): assert ( dp * tp * pp == self.world_size ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + assert self.dp_type in ("fsdp", "ddp") def build_mesh(self, device_type): dims = [] diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 6cafa4ab..aa8dd50c 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -12,8 +12,9 @@ from typing import Dict, Tuple import torch - from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy + +from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -452,7 +453,7 @@ def apply_compile(model, job_config: JobConfig): return model -def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): +def apply_fsdp(model, world_mesh, parallel_dims, job_config: JobConfig): """ Apply data parallelism to the model. FSDP2 is used here. """ @@ -489,6 +490,24 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): return model +def apply_ddp(model, world_mesh, parallel_dims, job_config: JobConfig): + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism.") + + if job_config.training.compile: + if job_config.experimental.compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") + return model + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ Apply tensor parallelism, activation checkpointing, torch.compile, and data @@ -508,6 +527,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): model = apply_compile(model, job_config) if parallel_dims.dp_enabled: - model = apply_dp(model, world_mesh, parallel_dims, job_config) + if parallel_dims.dp_type == "fsdp": + model = apply_fsdp(model, world_mesh, parallel_dims, job_config) + else: + model = apply_ddp(model, world_mesh, parallel_dims, job_config) return model diff --git a/train.py b/train.py index 8e55c210..cb14c700 100644 --- a/train.py +++ b/train.py @@ -135,6 +135,22 @@ def zero_grad(self): return OptimizersContainer([_build_optimizer(model) for model in model_parts]) +def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): + @contextlib.contextmanager + def context(): + with contextlib.ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(loss_parallel()) + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + + yield + + return context + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record def main(job_config: JobConfig): @@ -157,6 +173,7 @@ def main(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.experimental.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) @@ -191,9 +208,9 @@ def main(job_config: JobConfig): dp_rank, ) - # loss_parallel enables dispatching to efficient loss operators - loss_parallel_ctx = ( - loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext + train_context = get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, ) # loss fn can be shared by pipeline-parallel or non-pp execution @@ -362,7 +379,7 @@ def loss_fn(pred, labels): # pipeline parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with loss_parallel_ctx(): + with train_context(): if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) elif is_last_stage: @@ -379,7 +396,7 @@ def loss_fn(pred, labels): ) else: # Non-PP forward / backward - with loss_parallel_ctx(): + with train_context(): pred = model(input_ids) loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size)