1010from datetime import timedelta
1111
1212import torch
13+
14+ from typing import List , Optional , Set
15+ from functools import partial
16+
17+ from torch .distributed .device_mesh import DeviceMesh
1318from torch .distributed .elastic .multiprocessing .errors import record
19+ from torch .nn .attention import SDPBackend , sdpa_kernel
1420
1521from torchtitan import utils
1622from torchtitan .checkpoint import CheckpointManager , TrainState
2834)
2935from torchtitan .profiling import maybe_enable_memory_snapshot , maybe_enable_profiling
3036
37+ try :
38+ from torch .distributed .tensor .experimental import context_parallel
39+ except ImportError :
40+ print (
41+ f"PyTorch version { torch .__version__ } does not include the experimental "
42+ "Context Parallel API. Please update to a newer version."
43+ )
44+
45+
46+ def get_train_context (
47+ enable_loss_parallel : bool ,
48+ enable_compiled_autograd : bool ,
49+ cp_mesh : Optional [DeviceMesh ] = None ,
50+ ):
51+ if cp_mesh is not None :
52+ context_parallel_ctx = partial (context_parallel , mesh = cp_mesh )
3153
32- def get_train_context (enable_loss_parallel : bool , enable_compiled_autograd : bool ):
3354 @contextlib .contextmanager
34- def context ():
55+ def context (
56+ cp_buffers : List [torch .Tensor ],
57+ cp_seq_dims : List [int ],
58+ cp_no_restore_buffers : Set [torch .Tensor ],
59+ ):
3560 with contextlib .ExitStack () as stack :
3661 if enable_loss_parallel :
3762 stack .enter_context (torch .distributed .tensor .parallel .loss_parallel ())
63+
3864 if enable_compiled_autograd :
3965 stack .enter_context (
4066 torch ._dynamo .utils .maybe_enable_compiled_autograd (True )
4167 )
68+
69+ if cp_mesh is not None :
70+ # currently we only support these two SDP backends.
71+ # TODO (xilunwu): support cuDNN backend
72+ stack .enter_context (
73+ sdpa_kernel ([SDPBackend .FLASH_ATTENTION , SDPBackend .EFFICIENT_ATTENTION ])
74+ )
75+ stack .enter_context (
76+ context_parallel_ctx (
77+ buffers = cp_buffers ,
78+ buffer_seq_dims = cp_seq_dims ,
79+ no_restore_buffers = cp_no_restore_buffers ,
80+ )
81+ )
82+
4283 yield
4384
4485 return context
@@ -61,6 +102,7 @@ def main(job_config: JobConfig):
61102 parallel_dims = ParallelDims (
62103 dp_shard = job_config .training .data_parallel_shard_degree ,
63104 dp_replicate = job_config .training .data_parallel_replicate_degree ,
105+ cp = job_config .experimental .context_parallel_degree ,
64106 tp = job_config .training .tensor_parallel_degree ,
65107 pp = job_config .experimental .pipeline_parallel_degree ,
66108 world_size = world_size ,
@@ -226,6 +268,7 @@ def loss_fn(pred, labels):
226268 train_context = get_train_context (
227269 parallel_dims .loss_parallel_enabled ,
228270 job_config .experimental .enable_compiled_autograd ,
271+ world_mesh ["cp" ] if parallel_dims .cp_enabled else None ,
229272 )
230273
231274 # variables used to keep info for metrics logging
@@ -259,18 +302,24 @@ def loss_fn(pred, labels):
259302 data_load_start = time .perf_counter ()
260303 batch = next (data_iterator )
261304 input_ids , labels = batch
262- ntokens_since_last_log += labels .numel ()
305+ ntokens_since_last_log += labels .numel () // parallel_dims . cp
263306 data_loading_times .append (time .perf_counter () - data_load_start )
264307
265308 input_ids = input_ids .cuda ()
266309 labels = labels .cuda ()
267310 optimizers .zero_grad ()
268311
312+ training_context = train_context (
313+ cp_buffers = [input_ids , labels , model .freqs_cis ],
314+ cp_seq_dims = [1 , 1 , 0 ],
315+ cp_no_restore_buffers = {input_ids , labels },
316+ )
317+
269318 if parallel_dims .pp_enabled :
270319 # Pipeline Parallel forward / backward inside step() call
271320 is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
272321
273- with train_context () :
322+ with training_context :
274323 if pp_mesh .get_local_rank () == 0 :
275324 pp_schedule .step (input_ids )
276325 elif is_last_stage :
@@ -287,7 +336,7 @@ def loss_fn(pred, labels):
287336 )
288337 else :
289338 # Non-PP forward / backward
290- with train_context () :
339+ with training_context :
291340 pred = model (input_ids )
292341 loss = loss_fn (pred , labels )
293342 # pred.shape=(bs, seq_len, vocab_size)
0 commit comments