diff --git a/estimation.py b/estimation.py index f5527f74..de6d56d6 100644 --- a/estimation.py +++ b/estimation.py @@ -70,6 +70,7 @@ def estimate_memory(job_config: JobConfig): parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, + cp=job_config.experimental.context_parallel_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 33070120..3b963d1c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -323,6 +323,12 @@ def __init__(self): action="store_true", help="Enable CompiledAutograd to compile the backward.", ) + self.parser.add_argument( + "--experimental.context_parallel_degree", + type=int, + default=1, + help="Context parallelism degree. 1 means disabled.", + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 49cda624..8ba64e69 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -410,7 +410,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.dim // self.model_args.n_heads, # Need to compute until at least the max token limit for generation # (use 2x max sequence length to be safe) - self.model_args.max_seq_len * 2, + self.model_args.max_seq_len, self.model_args.rope_theta, ) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 2fdba316..0bcd1484 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -24,6 +24,7 @@ @dataclass class ParallelDims: dp: int + cp: int tp: int pp: int world_size: int @@ -35,34 +36,41 @@ def __post_init__(self): self._validate() def _validate(self): - dp, tp, pp = self.dp, self.tp, self.pp + dp, cp, tp, pp = self.dp, self.cp, self.tp, self.pp if dp == -1: - self.dp = dp = self.world_size // (tp * pp) + self.dp = dp = self.world_size // (cp * tp * pp) assert dp >= 1, dp + assert cp >= 1, cp assert tp >= 1, tp assert pp >= 1, pp - assert ( - dp * tp * pp == self.world_size - ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + assert dp * cp * tp * pp == self.world_size, ( + f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) " + f"!= WORLD_SIZE({self.world_size})" + ) assert self.dp_type in ("fsdp", "ddp") def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True + [self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True ): if d > 1: dims.append(d) names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) - return init_device_mesh(device_type, dims, mesh_dim_names=names) + world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + return world_mesh @property def dp_enabled(self): return self.dp > 1 + @property + def cp_enabled(self): + return self.cp > 1 + @property def tp_enabled(self): return self.tp > 1 diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index e3c6fc80..21b0f6bb 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -19,6 +19,11 @@ from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard + +try: + from torch.distributed._tensor.experimental.attention import enable_context_parallel +except ImportError: + print("The PyTorch version does not include the experimental CP APIs.") from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, ) @@ -472,6 +477,22 @@ def apply_compile(model: nn.Module, job_config: JobConfig): return model +def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig): + """ + Apply context parallelism to the model. This is an experimental feature. + """ + if parallel_dims.pp_enabled: + raise NotImplementedError("CP + PP is not supported yet.") + cp_mesh = world_mesh["cp"] + callers = [] + for layer_id, transformer_block in model.layers.items(): + callers.append(transformer_block.attention) + enable_context_parallel(seq_dim=2, callers=callers, device_mesh=cp_mesh) + logger.info("Applied CP to the model") + + return model + + def apply_fsdp( model: nn.Module, world_mesh: DeviceMesh, @@ -482,8 +503,14 @@ def apply_fsdp( Apply data parallelism to the model. FSDP2 is used here. """ - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + if parallel_dims.cp_enabled: + # Temporary solution to enable FSDP + CP + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp", "cp"]._flatten() + else: + dp_mesh = world_mesh["cp"] + else: + dp_mesh = world_mesh["dp"] mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], @@ -570,7 +597,10 @@ def parallelize_llama( if job_config.training.compile: model = apply_compile(model, job_config) - if parallel_dims.dp_enabled: + if parallel_dims.cp_enabled: + model = apply_cp(model, world_mesh, parallel_dims, job_config) + + if parallel_dims.dp_enabled or parallel_dims.cp_enabled: if parallel_dims.dp_type == "fsdp": model = apply_fsdp(model, world_mesh, parallel_dims, job_config) else: diff --git a/train.py b/train.py index 5a637f46..11c2ec22 100644 --- a/train.py +++ b/train.py @@ -11,18 +11,21 @@ from dataclasses import dataclass, field from datetime import timedelta +from functools import partial from io import BytesIO from timeit import default_timer as timer -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List, Optional import numpy as np import torch import torch.nn.functional as F from torch.distributed import destroy_process_group +from torch.distributed._tensor.experimental.attention import context_parallel_buffers from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel +from torch.distributed.device_mesh import DeviceMesh from torchtitan.checkpoint import CheckpointManager from torchtitan.config_manager import JobConfig @@ -139,18 +142,40 @@ def zero_grad(self): return OptimizersContainer([_build_optimizer(model) for model in model_parts]) -def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): +def get_train_context( + enable_loss_parallel: bool, + enable_compiled_autograd: bool, + cp_mesh: Optional[DeviceMesh], +): + if cp_mesh is not None: + context_parallel_ctx = partial(context_parallel_buffers, mesh=cp_mesh) + else: + context_parallel_ctx = partial(context_parallel_buffers, mesh=None) + @contextlib.contextmanager - def context(): + def context( + cp_buffers: List[torch.Tensor], + cp_seq_dims: List[int], + cp_restore_funcs: List[Optional[Callable]], + ): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(loss_parallel()) + if enable_compiled_autograd: stack.enter_context( torch._dynamo.utils.maybe_enable_compiled_autograd(True) ) - yield + buffers = stack.enter_context( + context_parallel_ctx( + buffers=cp_buffers, + seq_dims=cp_seq_dims, + restore_funcs=cp_restore_funcs, + ) + ) + + yield buffers return context @@ -173,6 +198,7 @@ def main(job_config: JobConfig): world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, + cp=job_config.experimental.context_parallel_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, @@ -215,6 +241,7 @@ def main(job_config: JobConfig): train_context = get_train_context( parallel_dims.loss_parallel_enabled, job_config.experimental.enable_compiled_autograd, + world_mesh["cp"] if parallel_dims.cp_enabled else None, ) # loss fn can be shared by pipeline-parallel or non-pp execution @@ -376,18 +403,28 @@ def loss_fn(pred, labels): data_load_start = timer() batch = next(data_iterator) input_ids, labels = batch - ntokens_since_last_log += labels.numel() + ntokens_since_last_log += labels.numel() // parallel_dims.cp data_loading_times.append(timer() - data_load_start) - input_ids = input_ids.cuda() - labels = labels.cuda() optimizers.zero_grad() - if parallel_dims.pp_enabled: - # pipeline parallel forward / backward inside step() call - is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + with train_context( + cp_buffers=[input_ids, labels, model.freqs_cis], + cp_seq_dims=[1, 1, 0], + cp_restore_funcs=[ + None, + None, + lambda buf, m=model: setattr(m, "freqs_cis", buf), + ], + ) as cp_buffers: + input_ids = cp_buffers[0].cuda() + labels = cp_buffers[1].cuda() + model.freqs_cis = cp_buffers[2] + + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with train_context(): if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) elif is_last_stage: @@ -396,15 +433,13 @@ def loss_fn(pred, labels): else: pp_schedule.step() - # accumulate losses across pipeline microbatches - loss = ( - torch.mean(torch.stack(losses)) - if is_last_stage - else torch.Tensor([-1.0]) - ) - else: - # Non-PP forward / backward - with train_context(): + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) + ) + else: + # Non-PP forward / backward pred = model(input_ids) loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size)