diff --git a/estimation.py b/estimation.py index f58907c6f7..9edeff6d2b 100644 --- a/estimation.py +++ b/estimation.py @@ -14,6 +14,7 @@ from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.testing._internal.distributed.fake_pg import FakeStore +from torchtitan import utils from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer from torchtitan.float8 import Float8Handler @@ -21,7 +22,6 @@ from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import models_parallelize_fns, ParallelDims -from train import get_train_context def estimate_memory(job_config: JobConfig): @@ -66,6 +66,7 @@ def estimate_memory(job_config: JobConfig): parallel_dims = ParallelDims( dp_shard=job_config.training.data_parallel_shard_degree, dp_replicate=job_config.training.data_parallel_replicate_degree, + cp=job_config.experimental.context_parallel_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, @@ -94,7 +95,7 @@ def estimate_memory(job_config: JobConfig): tokenizer_type = model_name_to_tokenizer[model_name] tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) - train_context = get_train_context( + train_context = utils.get_train_context( parallel_dims.loss_parallel_enabled, job_config.experimental.enable_compiled_autograd, ) diff --git a/test_runner.py b/test_runner.py index 6bed5ae341..61031d742f 100755 --- a/test_runner.py +++ b/test_runner.py @@ -306,6 +306,41 @@ def build_test_list(): "hsdp+tp", ngpu=8, ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_shard_degree=2", + "--experimental.context_parallel_degree=2", + ] + ], + "FSDP+CP", + "fsdp+cp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_shard_degree=2", + "--training.data_parallel_replicate_degree=2", + "--experimental.context_parallel_degree=2", + ] + ], + "HSDP+CP", + "hsdp+cp", + ngpu=8, + ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_shard_degree=2", + "--training.tensor_parallel_degree=2", + "--experimental.context_parallel_degree=2", + ] + ], + "FSDP+TP+CP", + "fsdp+tp+cp", + ngpu=8, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 88e51f0270..bd88fe1870 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -325,6 +325,12 @@ def __init__(self): action="store_true", help="Enable CompiledAutograd to compile the backward.", ) + self.parser.add_argument( + "--experimental.context_parallel_degree", + type=int, + default=1, + help="Context parallelism degree. 1 means disabled.", + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 7f102a8012..01aa21b51e 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -415,8 +415,9 @@ def _precompute_freqs_cis(self) -> torch.Tensor: return precompute_freqs_cis( self.model_args.dim // self.model_args.n_heads, # Need to compute until at least the max token limit for generation - # (use 2x max sequence length to be safe) - self.model_args.max_seq_len * 2, + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, self.model_args.rope_theta, ) diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index da9d4240bc..aa7e7b3745 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -15,6 +15,7 @@ class ParallelDims: dp_replicate: int dp_shard: int + cp: int tp: int pp: int world_size: int @@ -24,36 +25,38 @@ def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, tp, pp = ( + dp_replicate, dp_shard, cp, tp, pp = ( self.dp_replicate, self.dp_shard, + self.cp, self.tp, self.pp, ) - for d in (dp_replicate, tp, pp): + for d in (dp_replicate, cp, tp, pp): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." dp = dp_replicate * dp_shard if dp < 0: - dp = self.world_size // (tp * pp) + dp = self.world_size // (cp * tp * pp) self.dp_shard = dp_shard = dp // dp_replicate assert dp_replicate >= 1 assert dp_shard >= 1 + assert cp >= 1, cp assert tp >= 1, tp assert pp >= 1, pp - assert dp_replicate * dp_shard * tp * pp == self.world_size, ( + assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, ( f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " - f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" ) def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.tp], - ["pp", "dp_replicate", "dp_shard", "tp"], + [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], + ["pp", "dp_replicate", "dp_shard", "cp", "tp"], ): if d > 1: dims.append(d) @@ -71,6 +74,13 @@ def build_mesh(self, device_type): # initialized if self.dp_replicate > 1 and self.dp_shard > 1: mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") + + if self.cp > 1: + if self.dp_replicate > 1 and self.dp_shard > 1: + mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp") + else: + mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + return mesh @property @@ -85,6 +95,10 @@ def dp_replicate_enabled(self): def dp_shard_enabled(self): return self.dp_shard > 1 + @property + def cp_enabled(self): + return self.cp > 1 + @property def tp_enabled(self): return self.tp > 1 @@ -98,5 +112,5 @@ def loss_parallel_enabled(self): return self.tp > 1 and self.enable_loss_parallel @cached_property - def model_parallel_size(self): - return self.tp * self.pp + def non_data_parallel_size(self): + return self.cp * self.tp * self.pp diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fc26703db0..f5b20ebca4 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn + from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._composable.replicate import replicate @@ -72,36 +73,51 @@ def parallelize_llama( ) apply_compile(model) - if parallel_dims.dp_enabled: - if parallel_dims.dp_shard_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mesh = world_mesh["dp_replicate", "dp_shard"] - else: - dp_mesh = world_mesh["dp"] - - apply_fsdp( - model, - dp_mesh, - param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[ - job_config.training.mixed_precision_reduce - ], - tp_enabled=parallel_dims.tp_enabled, - pp_enabled=parallel_dims.pp_enabled, - ) - if parallel_dims.dp_replicate_enabled: - logger.info("Applied HSDP to the model") - else: - logger.info("Applied FSDP to the model") + if ( + parallel_dims.dp_shard_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + + # TODO: instead of flattening the mesh twice, we could've done in a batter way: + # dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] + # However, this leads to an error in `DeviceMesh.__get_item__` which I believe is + # a bug in DeviceMesh. We should fix it and then use the above line. + dp_mesh_dim_names = ( + ("dp_replicate", "dp_shard") + if parallel_dims.dp_replicate_enabled + else ("dp",) + ) + # note that mesh can only be flattened from the finest-grained mesh dimensions + dp_mesh = ( + world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") + if parallel_dims.cp_enabled + else world_mesh[dp_mesh_dim_names] + ) + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + tp_enabled=parallel_dims.tp_enabled, + pp_enabled=parallel_dims.pp_enabled, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") else: - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism") - apply_ddp( - model, - world_mesh, - enable_compile=job_config.training.compile, - enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, - ) + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) def apply_tp( diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 7c562b47bf..933a363552 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import gc import os import subprocess from dataclasses import dataclass from datetime import timedelta -from typing import Optional, Union +from typing import Generator, List, Optional, Set, Union import torch import torch.distributed._functional_collectives as funcol @@ -101,6 +102,57 @@ def run(self, step_count): SKIP_CLEANUP = "3" +def create_context_parallel_ctx( + cp_mesh: DeviceMesh, + cp_buffers: List[torch.Tensor], + cp_seq_dims: List[int], + cp_no_restore_buffers: Set[torch.Tensor], +): + try: + from torch.distributed.tensor.experimental import context_parallel + except ImportError: + print( + f"PyTorch version {torch.__version__} does not include the experimental " + "Context Parallel API. Please update to a newer version." + ) + + return context_parallel( + cp_mesh, + buffers=cp_buffers, + buffer_seq_dims=cp_seq_dims, + no_restore_buffers=cp_no_restore_buffers, + ) + + +def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): + @contextlib.contextmanager + def context(cp_context: Optional[Generator[None, None, None]] = None): + with contextlib.ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) + + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + + if cp_context is not None: + from torch.nn.attention import sdpa_kernel, SDPBackend + + # currently we only support these two SDP backends. + # TODO (xilunwu): support cuDNN backend + stack.enter_context( + sdpa_kernel( + [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + ) + ) + stack.enter_context(cp_context) + + yield + + return context + + def init_distributed(job_config): # FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) # to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 diff --git a/train.py b/train.py index 3e8994a34d..8991673a4e 100644 --- a/train.py +++ b/train.py @@ -4,12 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib import os import time from datetime import timedelta import torch + from torch.distributed.elastic.multiprocessing.errors import record from torchtitan import utils @@ -29,21 +29,6 @@ from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling -def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): - @contextlib.contextmanager - def context(): - with contextlib.ExitStack() as stack: - if enable_loss_parallel: - stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) - if enable_compiled_autograd: - stack.enter_context( - torch._dynamo.utils.maybe_enable_compiled_autograd(True) - ) - yield - - return context - - # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record def main(job_config: JobConfig): @@ -70,6 +55,7 @@ def main(job_config: JobConfig): parallel_dims = ParallelDims( dp_shard=job_config.training.data_parallel_shard_degree, dp_replicate=job_config.training.data_parallel_replicate_degree, + cp=job_config.experimental.context_parallel_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, @@ -232,7 +218,7 @@ def loss_fn(pred, labels): data_iterator = iter(data_loader) - train_context = get_train_context( + train_context = utils.get_train_context( parallel_dims.loss_parallel_enabled, job_config.experimental.enable_compiled_autograd, ) @@ -275,11 +261,23 @@ def loss_fn(pred, labels): labels = labels.cuda() optimizers.zero_grad() + # apply context parallelism if cp is enabled + optional_context_parallel_ctx = ( + utils.create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels, model.freqs_cis], + cp_seq_dims=[1, 1, 0], + cp_no_restore_buffers={input_ids, labels}, + ) + if parallel_dims.cp_enabled + else None + ) + if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with train_context(): + with train_context(optional_context_parallel_ctx): if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) elif is_last_stage: @@ -296,7 +294,7 @@ def loss_fn(pred, labels): ) else: # Non-PP forward / backward - with train_context(): + with train_context(optional_context_parallel_ctx): pred = model(input_ids) loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size) @@ -348,7 +346,7 @@ def loss_fn(pred, labels): # tokens per second, abbr. as wps by convention wps = ntokens_since_last_log / ( - time_delta * parallel_dims.model_parallel_size + time_delta * parallel_dims.non_data_parallel_size ) # model FLOPS utilization # For its definition and calculation, please refer to the PaLM paper: diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index bb3cd35371..da3bc45e8c 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -42,6 +42,7 @@ compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [experimental] +context_parallel_degree = 1 pipeline_parallel_degree = 1 enable_async_tensor_parallel = false diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 3230b208bf..b238ebada1 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -38,6 +38,7 @@ compile = false dataset = "c4" [experimental] +context_parallel_degree = 1 pipeline_parallel_degree = 1 [checkpoint] diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index e7c920c65d..2764a57ef7 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -38,6 +38,7 @@ compile = false dataset = "c4" [experimental] +context_parallel_degree = 1 pipeline_parallel_degree = 1 [checkpoint] diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 5ffaaeca7f..e64d8aa8d2 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -37,6 +37,7 @@ compile = false dataset = "c4" [experimental] +context_parallel_degree = 1 pipeline_parallel_degree = 1 [checkpoint] diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index c7723ef319..e52beb3663 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -38,6 +38,7 @@ compile = true dataset = "c4" [experimental] +context_parallel_degree = 1 pipeline_parallel_degree = 1 enable_async_tensor_parallel = true diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index fb6d5f50ba..2d55a36d39 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -38,6 +38,7 @@ compile = false dataset = "c4" [experimental] +context_parallel_degree = 1 pipeline_parallel_degree = 1 [checkpoint] diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index e0c5bd03eb..3001ec7487 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -38,6 +38,7 @@ compile = false dataset = "c4" [experimental] +context_parallel_degree = 1 pipeline_parallel_degree = 1 [checkpoint]