Skip to content

Commit aedd0a7

Browse files
committed
enable Context Parallel
ghstack-source-id: 5a8900c Pull Request resolved: #592
1 parent 36fba84 commit aedd0a7

File tree

15 files changed

+186
-63
lines changed

15 files changed

+186
-63
lines changed

estimation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
1515
from torch.testing._internal.distributed.fake_pg import FakeStore
1616

17+
from torchtitan import utils
1718
from torchtitan.config_manager import JobConfig
1819
from torchtitan.datasets import build_tokenizer
1920
from torchtitan.float8 import Float8Handler
2021
from torchtitan.logging import init_logger, logger
2122
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
2223
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
2324
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
24-
from train import get_train_context
2525

2626

2727
def estimate_memory(job_config: JobConfig):
@@ -66,6 +66,7 @@ def estimate_memory(job_config: JobConfig):
6666
parallel_dims = ParallelDims(
6767
dp_shard=job_config.training.data_parallel_shard_degree,
6868
dp_replicate=job_config.training.data_parallel_replicate_degree,
69+
cp=job_config.experimental.context_parallel_degree,
6970
tp=job_config.training.tensor_parallel_degree,
7071
pp=job_config.experimental.pipeline_parallel_degree,
7172
world_size=world_size,
@@ -94,7 +95,7 @@ def estimate_memory(job_config: JobConfig):
9495
tokenizer_type = model_name_to_tokenizer[model_name]
9596
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
9697

97-
train_context = get_train_context(
98+
train_context = utils.get_train_context(
9899
parallel_dims.loss_parallel_enabled,
99100
job_config.experimental.enable_compiled_autograd,
100101
)

test_runner.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,41 @@ def build_test_list():
306306
"hsdp+tp",
307307
ngpu=8,
308308
),
309+
OverrideDefinitions(
310+
[
311+
[
312+
"--training.data_parallel_shard_degree=2",
313+
"--experimental.context_parallel_degree=2",
314+
]
315+
],
316+
"FSDP+CP",
317+
"fsdp+cp",
318+
ngpu=4,
319+
),
320+
OverrideDefinitions(
321+
[
322+
[
323+
"--training.data_parallel_shard_degree=2",
324+
"--training.data_parallel_replicate_degree=2",
325+
"--experimental.context_parallel_degree=2",
326+
]
327+
],
328+
"HSDP+CP",
329+
"hsdp+cp",
330+
ngpu=8,
331+
),
332+
OverrideDefinitions(
333+
[
334+
[
335+
"--training.data_parallel_shard_degree=2",
336+
"--training.tensor_parallel_degree=2",
337+
"--experimental.context_parallel_degree=2",
338+
]
339+
],
340+
"FSDP+TP+CP",
341+
"fsdp+tp+cp",
342+
ngpu=8,
343+
),
309344
OverrideDefinitions(
310345
[
311346
[

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ def __init__(self):
325325
action="store_true",
326326
help="Enable CompiledAutograd to compile the backward.",
327327
)
328+
self.parser.add_argument(
329+
"--experimental.context_parallel_degree",
330+
type=int,
331+
default=1,
332+
help="Context parallelism degree. 1 means disabled.",
333+
)
328334
self.parser.add_argument(
329335
"--training.mixed_precision_param",
330336
type=str,

torchtitan/models/llama/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,9 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
415415
return precompute_freqs_cis(
416416
self.model_args.dim // self.model_args.n_heads,
417417
# Need to compute until at least the max token limit for generation
418-
# (use 2x max sequence length to be safe)
419-
self.model_args.max_seq_len * 2,
418+
# TODO: explain in docs/composability.md why we removed the 2x
419+
# relaxing in our CP enablement PR
420+
self.model_args.max_seq_len,
420421
self.model_args.rope_theta,
421422
)
422423

torchtitan/parallelisms/parallel_dims.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class ParallelDims:
1616
dp_replicate: int
1717
dp_shard: int
18+
cp: int
1819
tp: int
1920
pp: int
2021
world_size: int
@@ -24,36 +25,38 @@ def __post_init__(self):
2425
self._validate()
2526

2627
def _validate(self):
27-
dp_replicate, dp_shard, tp, pp = (
28+
dp_replicate, dp_shard, cp, tp, pp = (
2829
self.dp_replicate,
2930
self.dp_shard,
31+
self.cp,
3032
self.tp,
3133
self.pp,
3234
)
33-
for d in (dp_replicate, tp, pp):
35+
for d in (dp_replicate, cp, tp, pp):
3436
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
3537
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
3638

3739
dp = dp_replicate * dp_shard
3840
if dp < 0:
39-
dp = self.world_size // (tp * pp)
41+
dp = self.world_size // (cp * tp * pp)
4042
self.dp_shard = dp_shard = dp // dp_replicate
4143

4244
assert dp_replicate >= 1
4345
assert dp_shard >= 1
46+
assert cp >= 1, cp
4447
assert tp >= 1, tp
4548
assert pp >= 1, pp
46-
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
49+
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
4750
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
48-
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
51+
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
4952
)
5053

5154
def build_mesh(self, device_type):
5255
dims = []
5356
names = []
5457
for d, name in zip(
55-
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
56-
["pp", "dp_replicate", "dp_shard", "tp"],
58+
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
59+
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
5760
):
5861
if d > 1:
5962
dims.append(d)
@@ -71,6 +74,13 @@ def build_mesh(self, device_type):
7174
# initialized
7275
if self.dp_replicate > 1 and self.dp_shard > 1:
7376
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
77+
78+
if self.cp > 1:
79+
if self.dp_replicate > 1 and self.dp_shard > 1:
80+
mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")
81+
else:
82+
mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
83+
7484
return mesh
7585

7686
@property
@@ -85,6 +95,10 @@ def dp_replicate_enabled(self):
8595
def dp_shard_enabled(self):
8696
return self.dp_shard > 1
8797

98+
@property
99+
def cp_enabled(self):
100+
return self.cp > 1
101+
88102
@property
89103
def tp_enabled(self):
90104
return self.tp > 1
@@ -98,5 +112,5 @@ def loss_parallel_enabled(self):
98112
return self.tp > 1 and self.enable_loss_parallel
99113

100114
@cached_property
101-
def model_parallel_size(self):
102-
return self.tp * self.pp
115+
def non_data_parallel_size(self):
116+
return self.cp * self.tp * self.pp

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.nn as nn
14+
1415
from torch.distributed import DeviceMesh
1516
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1617
from torch.distributed._composable.replicate import replicate
@@ -72,36 +73,44 @@ def parallelize_llama(
7273
)
7374
apply_compile(model)
7475

75-
if parallel_dims.dp_enabled:
76-
if parallel_dims.dp_shard_enabled:
77-
if parallel_dims.dp_replicate_enabled:
78-
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
79-
else:
80-
dp_mesh = world_mesh["dp"]
81-
82-
apply_fsdp(
83-
model,
84-
dp_mesh,
85-
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
86-
reduce_dtype=TORCH_DTYPE_MAP[
87-
job_config.training.mixed_precision_reduce
88-
],
89-
tp_enabled=parallel_dims.tp_enabled,
90-
pp_enabled=parallel_dims.pp_enabled,
91-
)
92-
if parallel_dims.dp_replicate_enabled:
93-
logger.info("Applied HSDP to the model")
94-
else:
95-
logger.info("Applied FSDP to the model")
76+
if (
77+
parallel_dims.dp_shard_enabled
78+
): # apply FSDP or HSDP, potentially with Context Parallel
79+
dp_mesh_dim_names = (
80+
("dp_replicate", "dp_shard")
81+
if parallel_dims.dp_replicate_enabled
82+
else ("dp",)
83+
)
84+
dp_mesh = (
85+
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
86+
if parallel_dims.cp_enabled
87+
else world_mesh[dp_mesh_dim_names]
88+
)
89+
apply_fsdp(
90+
model,
91+
dp_mesh,
92+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
93+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
94+
tp_enabled=parallel_dims.tp_enabled,
95+
pp_enabled=parallel_dims.pp_enabled,
96+
)
97+
98+
if parallel_dims.dp_replicate_enabled:
99+
logger.info("Applied HSDP to the model")
96100
else:
97-
if world_mesh.ndim > 1:
98-
raise RuntimeError("DDP has not supported > 1D parallelism")
99-
apply_ddp(
100-
model,
101-
world_mesh,
102-
enable_compile=job_config.training.compile,
103-
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
104-
)
101+
logger.info("Applied FSDP to the model")
102+
103+
if parallel_dims.cp_enabled:
104+
logger.info("Applied Context Parallel to the model")
105+
elif parallel_dims.dp_replicate_enabled:
106+
if world_mesh.ndim > 1:
107+
raise RuntimeError("DDP has not supported > 1D parallelism")
108+
apply_ddp(
109+
model,
110+
world_mesh,
111+
enable_compile=job_config.training.compile,
112+
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
113+
)
105114

106115

107116
def apply_tp(

torchtitan/utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
78
import gc
89
import os
910
import subprocess
1011
from dataclasses import dataclass
1112
from datetime import timedelta
12-
from typing import Optional, Union
13+
from typing import Generator, List, Optional, Set, Union
1314

1415
import torch
1516
import torch.distributed._functional_collectives as funcol
@@ -101,6 +102,57 @@ def run(self, step_count):
101102
SKIP_CLEANUP = "3"
102103

103104

105+
def create_context_parallel_ctx(
106+
cp_mesh: DeviceMesh,
107+
cp_buffers: List[torch.Tensor],
108+
cp_seq_dims: List[int],
109+
cp_no_restore_buffers: Set[torch.Tensor],
110+
):
111+
try:
112+
from torch.distributed.tensor.experimental import context_parallel
113+
except ImportError:
114+
print(
115+
f"PyTorch version {torch.__version__} does not include the experimental "
116+
"Context Parallel API. Please update to a newer version."
117+
)
118+
119+
return context_parallel(
120+
cp_mesh,
121+
buffers=cp_buffers,
122+
buffer_seq_dims=cp_seq_dims,
123+
no_restore_buffers=cp_no_restore_buffers,
124+
)
125+
126+
127+
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
128+
@contextlib.contextmanager
129+
def context(cp_context: Optional[Generator[None, None, None]] = None):
130+
with contextlib.ExitStack() as stack:
131+
if enable_loss_parallel:
132+
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
133+
134+
if enable_compiled_autograd:
135+
stack.enter_context(
136+
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
137+
)
138+
139+
if cp_context is not None:
140+
from torch.nn.attention import sdpa_kernel, SDPBackend
141+
142+
# currently we only support these two SDP backends.
143+
# TODO (xilunwu): support cuDNN backend
144+
stack.enter_context(
145+
sdpa_kernel(
146+
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
147+
)
148+
)
149+
stack.enter_context(cp_context)
150+
151+
yield
152+
153+
return context
154+
155+
104156
def init_distributed(job_config):
105157
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
106158
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055

0 commit comments

Comments
 (0)