1010from datetime import timedelta
1111
1212import torch
13+
14+ from typing import List , Optional , Set
15+ from functools import partial
16+
17+ from torch .distributed .device_mesh import DeviceMesh
1318from torch .distributed .elastic .multiprocessing .errors import record
19+ from torch .nn .attention import SDPBackend , sdpa_kernel
1420
1521from torchtitan import utils
1622from torchtitan .checkpoint import CheckpointManager , TrainState
2834)
2935from torchtitan .profiling import maybe_enable_memory_snapshot , maybe_enable_profiling
3036
37+ try :
38+ from torch .distributed .tensor .experimental import context_parallel
39+ except ImportError :
40+ print (
41+ f"PyTorch version { torch .__version__ } does not include the experimental "
42+ "Context Parallel API. Please update to a newer version."
43+ )
44+
45+
46+ def get_train_context (
47+ enable_loss_parallel : bool ,
48+ enable_compiled_autograd : bool ,
49+ cp_mesh : Optional [DeviceMesh ] = None ,
50+ ):
51+ if cp_mesh is not None :
52+ context_parallel_ctx = partial (context_parallel , mesh = cp_mesh )
3153
32- def get_train_context (enable_loss_parallel : bool , enable_compiled_autograd : bool ):
3354 @contextlib .contextmanager
34- def context ():
55+ def context (
56+ cp_buffers : List [torch .Tensor ],
57+ cp_seq_dims : List [int ],
58+ cp_no_restore_buffers : Set [torch .Tensor ],
59+ ):
3560 with contextlib .ExitStack () as stack :
3661 if enable_loss_parallel :
3762 stack .enter_context (torch .distributed .tensor .parallel .loss_parallel ())
63+
3864 if enable_compiled_autograd :
3965 stack .enter_context (
4066 torch ._dynamo .utils .maybe_enable_compiled_autograd (True )
4167 )
68+
69+ if cp_mesh is not None :
70+ # currently we only support these two SDP backends.
71+ # TODO (xilunwu): support cuDNN backend
72+ stack .enter_context (
73+ sdpa_kernel ([SDPBackend .FLASH_ATTENTION , SDPBackend .EFFICIENT_ATTENTION ])
74+ )
75+ stack .enter_context (
76+ context_parallel_ctx (
77+ buffers = cp_buffers ,
78+ buffer_seq_dims = cp_seq_dims ,
79+ no_restore_buffers = cp_no_restore_buffers ,
80+ )
81+ )
82+
4283 yield
4384
4485 return context
@@ -70,6 +111,7 @@ def main(job_config: JobConfig):
70111 parallel_dims = ParallelDims (
71112 dp_shard = job_config .training .data_parallel_shard_degree ,
72113 dp_replicate = job_config .training .data_parallel_replicate_degree ,
114+ cp = job_config .experimental .context_parallel_degree ,
73115 tp = job_config .training .tensor_parallel_degree ,
74116 pp = job_config .experimental .pipeline_parallel_degree ,
75117 world_size = world_size ,
@@ -235,6 +277,7 @@ def loss_fn(pred, labels):
235277 train_context = get_train_context (
236278 parallel_dims .loss_parallel_enabled ,
237279 job_config .experimental .enable_compiled_autograd ,
280+ world_mesh ["cp" ] if parallel_dims .cp_enabled else None ,
238281 )
239282
240283 # variables used to keep info for metrics logging
@@ -268,18 +311,24 @@ def loss_fn(pred, labels):
268311 data_load_start = time .perf_counter ()
269312 batch = next (data_iterator )
270313 input_ids , labels = batch
271- ntokens_since_last_log += labels .numel ()
314+ ntokens_since_last_log += labels .numel () // parallel_dims . cp
272315 data_loading_times .append (time .perf_counter () - data_load_start )
273316
274317 input_ids = input_ids .cuda ()
275318 labels = labels .cuda ()
276319 optimizers .zero_grad ()
277320
321+ training_context = train_context (
322+ cp_buffers = [input_ids , labels , model .freqs_cis ],
323+ cp_seq_dims = [1 , 1 , 0 ],
324+ cp_no_restore_buffers = {input_ids , labels },
325+ )
326+
278327 if parallel_dims .pp_enabled :
279328 # Pipeline Parallel forward / backward inside step() call
280329 is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
281330
282- with train_context () :
331+ with training_context :
283332 if pp_mesh .get_local_rank () == 0 :
284333 pp_schedule .step (input_ids )
285334 elif is_last_stage :
@@ -296,7 +345,7 @@ def loss_fn(pred, labels):
296345 )
297346 else :
298347 # Non-PP forward / backward
299- with train_context () :
348+ with training_context :
300349 pred = model (input_ids )
301350 loss = loss_fn (pred , labels )
302351 # pred.shape=(bs, seq_len, vocab_size)
0 commit comments