Skip to content

Commit

Permalink
This PR adds experimental flags and functions to enable context paral…
Browse files Browse the repository at this point in the history
…lelism. We currently support only FSDP + CP and CP only. CP + TP is being tested.

ghstack-source-id: d57fcdae2fdc2481722471d8d4efbb4f416fe396
Pull Request resolved: #433
  • Loading branch information
fegin committed Jun 27, 2024
1 parent 681a94d commit 3a4ab5d
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 36 deletions.
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def __init__(self):
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--experimental.context_parallel_degree",
type=int,
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
19 changes: 14 additions & 5 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
@dataclass
class ParallelDims:
dp: int
cp: int
tp: int
pp: int
world_size: int
Expand All @@ -35,22 +36,26 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
dp, cp, tp, pp = self.dp, self.cp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
self.dp = dp = self.world_size // (cp * tp * pp)
assert dp >= 1, dp
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
dp * cp * tp * pp == self.world_size
), (
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
f"!= WORLD_SIZE({self.world_size})"
)
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
):
if d > 1:
dims.append(d)
Expand All @@ -63,6 +68,10 @@ def build_mesh(self, device_type):
def dp_enabled(self):
return self.dp > 1

@property
def cp_enabled(self):
return self.cp > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand Down
60 changes: 59 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
from typing import Dict, Tuple

import torch
import torch.nn.functional as F

from torch.distributed._composable.replicate import replicate
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
try:
from torch.distributed._tensor.experimental.attention import enable_context_parallel
except ImportError:
print("The PyTorch version does not include the experimental CP APIs.")
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -450,12 +456,61 @@ def apply_compile(model, job_config: JobConfig):
return model


def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply context parallelism to the model. This is an experimental feature.
"""
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
raise NotImplementedError("CP + TP or CP + PP are not supported yet.")
cp_mesh = world_mesh["cp"]
# If data parallelism is not enabled, we have to enable FSDP2 for
# gradient reduction.
mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
fsdp_config = {"mesh": cp_mesh, "mp_policy": mp_policy}
callers = []
for layer_id, transformer_block in model.layers.items():
if not parallel_dims.dp_enabled:
reshard_after_forward = (
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled
)
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
callers.append(transformer_block.attention)

enable_context_parallel(seq_dim=2, callers=callers, device_mesh=cp_mesh)

if not parallel_dims.dp_enabled:
model = fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)
logger.info("Applied CP to the model")

return model


def apply_fsdp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply data parallelism to the model. FSDP2 is used here.
"""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
if parallel_dims.cp_enabled:
# Manually create another device mesh for now as we don't support
# submesh flattening/reshape yet.
dp_mesh = init_device_mesh(
world_mesh.device_type,
(parallel_dims.dp * parallel_dims.cp,),
mesh_dim_names=["dp"],
)
else:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh

assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

mp_policy = MixedPrecisionPolicy(
Expand Down Expand Up @@ -521,6 +576,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if job_config.training.compile:
model = apply_compile(model, job_config)

if parallel_dims.cp_enabled:
model = apply_cp(model, world_mesh, parallel_dims, job_config)

if parallel_dims.dp_enabled:
if parallel_dims.dp_type == "fsdp":
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
Expand Down
82 changes: 52 additions & 30 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from dataclasses import dataclass, field
from datetime import timedelta
from functools import partial
from io import BytesIO
from timeit import default_timer as timer
from typing import Any, Dict, List
Expand All @@ -20,6 +21,7 @@
import torch
import torch.nn.functional as F
from torch.distributed import destroy_process_group
from torch.distributed._tensor.experimental.attention import context_parallel_buffers
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel
Expand Down Expand Up @@ -167,6 +169,7 @@ def main(job_config: JobConfig):
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -211,6 +214,20 @@ def main(job_config: JobConfig):
job_config.experimental.enable_compiled_autograd
)

if parallel_dims.cp_enabled:
cp_mesh = world_mesh["cp"]
context_parallel_ctx = partial(
context_parallel_buffers,
cp_rank=cp_mesh.get_local_rank(),
cp_world_size=cp_mesh.size(),
)
else:
context_parallel_ctx = partial(
context_parallel_buffers,
cp_rank=0,
cp_world_size=1,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
Expand Down Expand Up @@ -369,38 +386,43 @@ def loss_fn(pred, labels):
ntokens_since_last_log += labels.numel()
data_loading_times.append(timer() - data_load_start)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()

if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
pp_schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
with context_parallel_ctx(
buffers=[input_ids, labels, model.freqs_cis],
seq_dims=[1, 1, 0],
keep_orig_buffers=[False, False, True],
):
input_ids = input_ids.cuda()
labels = labels.cuda()
if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
pp_schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

# clip gradients
for model in model_parts:
Expand Down

0 comments on commit 3a4ab5d

Please sign in to comment.