Skip to content

Commit

Permalink
Enable CP
Browse files Browse the repository at this point in the history
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: 5d4f276bcff9ff53c2caadd161161a9b7a33142a
Pull Request resolved: #433
  • Loading branch information
fegin committed Aug 7, 2024
1 parent b069f70 commit 66fea42
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 31 deletions.
1 change: 1 addition & 0 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def estimate_memory(job_config: JobConfig):

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def __init__(self):
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--experimental.context_parallel_degree",
type=int,
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
self.model_args.max_seq_len,
self.model_args.rope_theta,
)

Expand Down
22 changes: 15 additions & 7 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
@dataclass
class ParallelDims:
dp: int
cp: int
tp: int
pp: int
world_size: int
Expand All @@ -35,34 +36,41 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
dp, cp, tp, pp = self.dp, self.cp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
self.dp = dp = self.world_size // (cp * tp * pp)
assert dp >= 1, dp
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert dp * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
f"!= WORLD_SIZE({self.world_size})"
)
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
):
if d > 1:
dims.append(d)
names.append(name)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)
world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
return world_mesh

@property
def dp_enabled(self):
return self.dp > 1

@property
def cp_enabled(self):
return self.cp > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand Down
36 changes: 33 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard

try:
from torch.distributed._tensor.experimental.attention import enable_context_parallel
except ImportError:
print("The PyTorch version does not include the experimental CP APIs.")
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
Expand Down Expand Up @@ -472,6 +477,22 @@ def apply_compile(model: nn.Module, job_config: JobConfig):
return model


def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply context parallelism to the model. This is an experimental feature.
"""
if parallel_dims.pp_enabled:
raise NotImplementedError("CP + PP is not supported yet.")
cp_mesh = world_mesh["cp"]
callers = []
for layer_id, transformer_block in model.layers.items():
callers.append(transformer_block.attention)
enable_context_parallel(seq_dim=2, callers=callers, device_mesh=cp_mesh)
logger.info("Applied CP to the model")

return model


def apply_fsdp(
model: nn.Module,
world_mesh: DeviceMesh,
Expand All @@ -482,8 +503,14 @@ def apply_fsdp(
Apply data parallelism to the model. FSDP2 is used here.
"""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
if parallel_dims.cp_enabled:
# Temporary solution to enable FSDP + CP
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp", "cp"]._flatten()
else:
dp_mesh = world_mesh["cp"]
else:
dp_mesh = world_mesh["dp"]

mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
Expand Down Expand Up @@ -570,7 +597,10 @@ def parallelize_llama(
if job_config.training.compile:
model = apply_compile(model, job_config)

if parallel_dims.dp_enabled:
if parallel_dims.cp_enabled:
model = apply_cp(model, world_mesh, parallel_dims, job_config)

if parallel_dims.dp_enabled or parallel_dims.cp_enabled:
if parallel_dims.dp_type == "fsdp":
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
else:
Expand Down
75 changes: 55 additions & 20 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@

from dataclasses import dataclass, field
from datetime import timedelta
from functools import partial
from io import BytesIO
from timeit import default_timer as timer
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List, Optional

import numpy as np

import torch
import torch.nn.functional as F
from torch.distributed import destroy_process_group
from torch.distributed._tensor.experimental.attention import context_parallel_buffers
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel
from torch.distributed.device_mesh import DeviceMesh

from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
Expand Down Expand Up @@ -139,18 +142,40 @@ def zero_grad(self):
return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
def get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
cp_mesh: Optional[DeviceMesh],
):
if cp_mesh is not None:
context_parallel_ctx = partial(context_parallel_buffers, mesh=cp_mesh)
else:
context_parallel_ctx = partial(context_parallel_buffers, mesh=None)

@contextlib.contextmanager
def context():
def context(
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_restore_funcs: List[Optional[Callable]],
):
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(loss_parallel())

if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

yield
buffers = stack.enter_context(
context_parallel_ctx(
buffers=cp_buffers,
seq_dims=cp_seq_dims,
restore_funcs=cp_restore_funcs,
)
)

yield buffers

return context

Expand All @@ -173,6 +198,7 @@ def main(job_config: JobConfig):
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -215,6 +241,7 @@ def main(job_config: JobConfig):
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
world_mesh["cp"] if parallel_dims.cp_enabled else None,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
Expand Down Expand Up @@ -376,18 +403,28 @@ def loss_fn(pred, labels):
data_load_start = timer()
batch = next(data_iterator)
input_ids, labels = batch
ntokens_since_last_log += labels.numel()
ntokens_since_last_log += labels.numel() // parallel_dims.cp
data_loading_times.append(timer() - data_load_start)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()

if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
with train_context(
cp_buffers=[input_ids, labels, model.freqs_cis],
cp_seq_dims=[1, 1, 0],
cp_restore_funcs=[
None,
None,
lambda buf, m=model: setattr(m, "freqs_cis", buf),
],
) as cp_buffers:
input_ids = cp_buffers[0].cuda()
labels = cp_buffers[1].cuda()
model.freqs_cis = cp_buffers[2]

if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -396,15 +433,13 @@ def loss_fn(pred, labels):
else:
pp_schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
with train_context():
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
Expand Down

0 comments on commit 66fea42

Please sign in to comment.