Skip to content

Commit b2c55fb

Browse files
committed
Enable CP
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested. ghstack-source-id: d3d3fba Pull Request resolved: #433
1 parent 35969a5 commit b2c55fb

File tree

4 files changed

+130
-37
lines changed

4 files changed

+130
-37
lines changed

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,12 @@ def __init__(self):
323323
action="store_true",
324324
help="Enable CompiledAutograd to compile the backward.",
325325
)
326+
self.parser.add_argument(
327+
"--experimental.context_parallel_degree",
328+
type=int,
329+
default=1,
330+
help="Context parallelism degree. 1 means disabled.",
331+
)
326332
self.parser.add_argument(
327333
"--training.mixed_precision_param",
328334
type=str,

torchtitan/parallelisms/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
@dataclass
2525
class ParallelDims:
2626
dp: int
27+
cp: int
2728
tp: int
2829
pp: int
2930
world_size: int
@@ -35,22 +36,24 @@ def __post_init__(self):
3536
self._validate()
3637

3738
def _validate(self):
38-
dp, tp, pp = self.dp, self.tp, self.pp
39+
dp, cp, tp, pp = self.dp, self.cp, self.tp, self.pp
3940
if dp == -1:
40-
self.dp = dp = self.world_size // (tp * pp)
41+
self.dp = dp = self.world_size // (cp * tp * pp)
4142
assert dp >= 1, dp
43+
assert cp >= 1, cp
4244
assert tp >= 1, tp
4345
assert pp >= 1, pp
44-
assert (
45-
dp * tp * pp == self.world_size
46-
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
46+
assert dp * cp * tp * pp == self.world_size, (
47+
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
48+
f"!= WORLD_SIZE({self.world_size})"
49+
)
4750
assert self.dp_type in ("fsdp", "ddp")
4851

4952
def build_mesh(self, device_type):
5053
dims = []
5154
names = []
5255
for d, name in zip(
53-
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
56+
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
5457
):
5558
if d > 1:
5659
dims.append(d)
@@ -63,6 +66,10 @@ def build_mesh(self, device_type):
6366
def dp_enabled(self):
6467
return self.dp > 1
6568

69+
@property
70+
def cp_enabled(self):
71+
return self.cp > 1
72+
6673
@property
6774
def tp_enabled(self):
6875
return self.tp > 1

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@
1616

1717
from torch.distributed._composable.replicate import replicate
1818
from torch.distributed._tensor import Replicate, Shard
19+
20+
try:
21+
from torch.distributed._tensor.experimental.attention import enable_context_parallel
22+
except ImportError:
23+
print("The PyTorch version does not include the experimental CP APIs.")
1924
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
2025
checkpoint_wrapper as ptd_checkpoint_wrapper,
2126
CheckpointImpl,
2227
)
28+
from torch.distributed.device_mesh import init_device_mesh
2329
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
2430
from torch.distributed.tensor.parallel import (
2531
ColwiseParallel,
@@ -453,12 +459,61 @@ def apply_compile(model, job_config: JobConfig):
453459
return model
454460

455461

462+
def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig):
463+
"""
464+
Apply context parallelism to the model. This is an experimental feature.
465+
"""
466+
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
467+
raise NotImplementedError("CP + TP or CP + PP are not supported yet.")
468+
cp_mesh = world_mesh["cp"]
469+
# If data parallelism is not enabled, we have to enable FSDP2 for
470+
# gradient reduction.
471+
mp_policy = MixedPrecisionPolicy(
472+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
473+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
474+
)
475+
fsdp_config = {"mesh": cp_mesh, "mp_policy": mp_policy}
476+
callers = []
477+
for layer_id, transformer_block in model.layers.items():
478+
if not parallel_dims.dp_enabled:
479+
reshard_after_forward = (
480+
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled
481+
)
482+
fully_shard(
483+
transformer_block,
484+
**fsdp_config,
485+
reshard_after_forward=reshard_after_forward,
486+
)
487+
model.layers[layer_id] = transformer_block
488+
callers.append(transformer_block.attention)
489+
490+
enable_context_parallel(seq_dim=2, callers=callers, device_mesh=cp_mesh)
491+
492+
if not parallel_dims.dp_enabled:
493+
model = fully_shard(
494+
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
495+
)
496+
logger.info("Applied CP to the model")
497+
498+
return model
499+
500+
456501
def apply_fsdp(model, world_mesh, parallel_dims, job_config: JobConfig):
457502
"""
458503
Apply data parallelism to the model. FSDP2 is used here.
459504
"""
460505

461-
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
506+
if parallel_dims.cp_enabled:
507+
# Manually create another device mesh for now as we don't support
508+
# submesh flattening/reshape yet.
509+
dp_mesh = init_device_mesh(
510+
world_mesh.device_type,
511+
(parallel_dims.dp * parallel_dims.cp,),
512+
mesh_dim_names=["dp"],
513+
)
514+
else:
515+
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
516+
462517
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
463518

464519
mp_policy = MixedPrecisionPolicy(
@@ -526,6 +581,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
526581
if job_config.training.compile:
527582
model = apply_compile(model, job_config)
528583

584+
if parallel_dims.cp_enabled:
585+
model = apply_cp(model, world_mesh, parallel_dims, job_config)
586+
529587
if parallel_dims.dp_enabled:
530588
if parallel_dims.dp_type == "fsdp":
531589
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)

train.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from dataclasses import dataclass, field
1313
from datetime import timedelta
14+
from functools import partial
1415
from io import BytesIO
1516
from timeit import default_timer as timer
1617
from typing import Any, Dict, List
@@ -20,6 +21,7 @@
2021
import torch
2122
import torch.nn.functional as F
2223
from torch.distributed import destroy_process_group
24+
from torch.distributed._tensor.experimental.attention import context_parallel_buffers
2325
from torch.distributed.checkpoint.stateful import Stateful
2426
from torch.distributed.elastic.multiprocessing.errors import record
2527
from torch.distributed.tensor.parallel import loss_parallel
@@ -169,6 +171,7 @@ def main(job_config: JobConfig):
169171
world_size = int(os.environ["WORLD_SIZE"])
170172
parallel_dims = ParallelDims(
171173
dp=job_config.training.data_parallel_degree,
174+
cp=job_config.experimental.context_parallel_degree,
172175
tp=job_config.training.tensor_parallel_degree,
173176
pp=job_config.experimental.pipeline_parallel_degree,
174177
world_size=world_size,
@@ -213,6 +216,20 @@ def main(job_config: JobConfig):
213216
job_config.experimental.enable_compiled_autograd,
214217
)
215218

219+
if parallel_dims.cp_enabled:
220+
cp_mesh = world_mesh["cp"]
221+
context_parallel_ctx = partial(
222+
context_parallel_buffers,
223+
cp_rank=cp_mesh.get_local_rank(),
224+
cp_world_size=cp_mesh.size(),
225+
)
226+
else:
227+
context_parallel_ctx = partial(
228+
context_parallel_buffers,
229+
cp_rank=0,
230+
cp_world_size=1,
231+
)
232+
216233
# loss fn can be shared by pipeline-parallel or non-pp execution
217234
def loss_fn(pred, labels):
218235
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
@@ -371,38 +388,43 @@ def loss_fn(pred, labels):
371388
ntokens_since_last_log += labels.numel()
372389
data_loading_times.append(timer() - data_load_start)
373390

374-
input_ids = input_ids.cuda()
375-
labels = labels.cuda()
376391
optimizers.zero_grad()
377392

378-
if parallel_dims.pp_enabled:
379-
# pipeline parallel forward / backward inside step() call
380-
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
381-
382-
with train_context():
383-
if pp_mesh.get_local_rank() == 0:
384-
pp_schedule.step(input_ids)
385-
elif is_last_stage:
386-
losses = []
387-
pp_schedule.step(target=labels, losses=losses)
388-
else:
389-
pp_schedule.step()
390-
391-
# accumulate losses across pipeline microbatches
392-
loss = (
393-
torch.mean(torch.stack(losses))
394-
if is_last_stage
395-
else torch.Tensor([-1.0])
396-
)
397-
else:
398-
# Non-PP forward / backward
399-
with train_context():
400-
pred = model(input_ids)
401-
loss = loss_fn(pred, labels)
402-
# pred.shape=(bs, seq_len, vocab_size)
403-
# need to free to before bwd to avoid peaking memory
404-
del pred
405-
loss.backward()
393+
with context_parallel_ctx(
394+
buffers=[input_ids, labels, model.freqs_cis],
395+
seq_dims=[1, 1, 0],
396+
keep_orig_buffers=[False, False, True],
397+
):
398+
input_ids = input_ids.cuda()
399+
labels = labels.cuda()
400+
if parallel_dims.pp_enabled:
401+
# pipeline parallel forward / backward inside step() call
402+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
403+
404+
with train_context():
405+
if pp_mesh.get_local_rank() == 0:
406+
pp_schedule.step(input_ids)
407+
elif is_last_stage:
408+
losses = []
409+
pp_schedule.step(target=labels, losses=losses)
410+
else:
411+
pp_schedule.step()
412+
413+
# accumulate losses across pipeline microbatches
414+
loss = (
415+
torch.mean(torch.stack(losses))
416+
if is_last_stage
417+
else torch.Tensor([-1.0])
418+
)
419+
else:
420+
# Non-PP forward / backward
421+
with train_context():
422+
pred = model(input_ids)
423+
loss = loss_fn(pred, labels)
424+
# pred.shape=(bs, seq_len, vocab_size)
425+
# need to free to before bwd to avoid peaking memory
426+
del pred
427+
loss.backward()
406428

407429
# clip gradients
408430
for model in model_parts:

0 commit comments

Comments
 (0)