1010from datetime import timedelta
1111
1212import torch
13+
14+ from typing import List , Optional , Set
15+ from functools import partial
16+
17+ from torch .distributed .device_mesh import DeviceMesh
1318from torch .distributed .elastic .multiprocessing .errors import record
1419
1520from torchtitan import utils
2833)
2934from torchtitan .profiling import maybe_enable_memory_snapshot , maybe_enable_profiling
3035
36+ try :
37+ from torch .distributed .tensor .experimental import context_parallel
38+ except ImportError :
39+ print (
40+ f"PyTorch version { torch .__version__ } does not include the experimental "
41+ "Context Parallel API. Please update to a newer version."
42+ )
43+
44+
45+ def get_train_context (
46+ enable_loss_parallel : bool ,
47+ enable_compiled_autograd : bool ,
48+ cp_mesh : Optional [DeviceMesh ] = None ,
49+ ):
50+ if cp_mesh is not None :
51+ context_parallel_ctx = partial (context_parallel , mesh = cp_mesh )
3152
32- def get_train_context (enable_loss_parallel : bool , enable_compiled_autograd : bool ):
3353 @contextlib .contextmanager
34- def context ():
54+ def context (
55+ cp_buffers : List [torch .Tensor ],
56+ cp_seq_dims : List [int ],
57+ cp_no_restore_buffers : Set [torch .Tensor ],
58+ ):
3559 with contextlib .ExitStack () as stack :
3660 if enable_loss_parallel :
3761 stack .enter_context (torch .distributed .tensor .parallel .loss_parallel ())
62+
3863 if enable_compiled_autograd :
3964 stack .enter_context (
4065 torch ._dynamo .utils .maybe_enable_compiled_autograd (True )
4166 )
67+
68+ if cp_mesh is not None :
69+ stack .enter_context (
70+ context_parallel_ctx (
71+ buffers = cp_buffers ,
72+ buffer_seq_dims = cp_seq_dims ,
73+ no_restore_buffers = cp_no_restore_buffers ,
74+ )
75+ )
76+
4277 yield
4378
4479 return context
@@ -61,6 +96,7 @@ def main(job_config: JobConfig):
6196 parallel_dims = ParallelDims (
6297 dp_shard = job_config .training .data_parallel_shard_degree ,
6398 dp_replicate = job_config .training .data_parallel_replicate_degree ,
99+ cp = job_config .experimental .context_parallel_degree ,
64100 tp = job_config .training .tensor_parallel_degree ,
65101 pp = job_config .experimental .pipeline_parallel_degree ,
66102 world_size = world_size ,
@@ -226,6 +262,7 @@ def loss_fn(pred, labels):
226262 train_context = get_train_context (
227263 parallel_dims .loss_parallel_enabled ,
228264 job_config .experimental .enable_compiled_autograd ,
265+ world_mesh ["cp" ] if parallel_dims .cp_enabled else None ,
229266 )
230267
231268 # variables used to keep info for metrics logging
@@ -259,18 +296,24 @@ def loss_fn(pred, labels):
259296 data_load_start = time .perf_counter ()
260297 batch = next (data_iterator )
261298 input_ids , labels = batch
262- ntokens_since_last_log += labels .numel ()
299+ ntokens_since_last_log += labels .numel () // parallel_dims . cp
263300 data_loading_times .append (time .perf_counter () - data_load_start )
264301
265302 input_ids = input_ids .cuda ()
266303 labels = labels .cuda ()
267304 optimizers .zero_grad ()
268305
306+ training_context = train_context (
307+ cp_buffers = [input_ids , labels , model .freqs_cis ],
308+ cp_seq_dims = [1 , 1 , 0 ],
309+ cp_no_restore_buffers = {input_ids , labels },
310+ )
311+
269312 if parallel_dims .pp_enabled :
270313 # Pipeline Parallel forward / backward inside step() call
271314 is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
272315
273- with train_context () :
316+ with training_context :
274317 if pp_mesh .get_local_rank () == 0 :
275318 pp_schedule .step (input_ids )
276319 elif is_last_stage :
@@ -287,7 +330,7 @@ def loss_fn(pred, labels):
287330 )
288331 else :
289332 # Non-PP forward / backward
290- with train_context () :
333+ with training_context :
291334 pred = model (input_ids )
292335 loss = loss_fn (pred , labels )
293336 # pred.shape=(bs, seq_len, vocab_size)
0 commit comments