2727from torchtitan .distributed import ParallelDims
2828from torchtitan .distributed .activation_checkpoint import apply_ac
2929from torchtitan .distributed .tensor_parallel import maybe_enable_async_tp
30+ from torchtitan .protocols .model import AttentionMasksType
3031from torchtitan .tools .logging import logger
3132
3233
@@ -67,10 +68,6 @@ def parallelize_llama(
6768 ({ parallel_dims .tp } ) and 2 * CP degree ({ parallel_dims .cp } ).
6869 """
6970
70- use_flex_attn = getattr (model .model_args , "use_flex_attn" , False )
71- if job_config .parallelism .context_parallel_degree > 1 and use_flex_attn :
72- raise NotImplementedError ("CP support for FlexAttention is still in progress." )
73-
7471 if parallel_dims .tp_enabled :
7572 enable_float8_linear = "float8" in job_config .model .converters
7673 float8_is_rowwise = job_config .quantize .linear .float8 .recipe_name in (
@@ -91,6 +88,11 @@ def parallelize_llama(
9188 )
9289 maybe_enable_async_tp (job_config , world_mesh ["tp" ])
9390
91+ use_flex_attn = getattr (model .model_args , "use_flex_attn" , False )
92+ if parallel_dims .cp_enabled :
93+ logger .info ("Applied Context Parallel to the model" )
94+ apply_cp (model , world_mesh ["cp" ], use_flex_attn )
95+
9496 model_compile_enabled = (
9597 job_config .compile .enable and "model" in job_config .compile .components
9698 )
@@ -131,9 +133,6 @@ def parallelize_llama(
131133 else :
132134 logger .info ("Applied FSDP to the model" )
133135
134- if parallel_dims .cp_enabled :
135- logger .info ("Applied Context Parallel to the model" )
136-
137136 if job_config .training .enable_cpu_offload :
138137 logger .info ("Applied CPU Offloading to the model" )
139138 elif parallel_dims .dp_replicate_enabled :
@@ -328,3 +327,91 @@ def apply_ddp(
328327 replicate (model , device_mesh = dp_mesh , bucket_cap_mb = 100 )
329328
330329 logger .info ("Applied DDP to the model" )
330+
331+
332+ def apply_cp (
333+ model : nn .Module ,
334+ cp_mesh : DeviceMesh ,
335+ use_flex_attn : bool ,
336+ ) -> None :
337+ """
338+ Apply context parallelism to the model.
339+ """
340+ from torch .distributed .tensor .experimental ._attention import (
341+ _ContextParallel ,
342+ _enable_context_parallel_dispatcher ,
343+ )
344+
345+ # Apply context parallelism to every transformer block
346+ # TODO: make seq_sim configurable once the implementation doesn't assume 2
347+ # internally.
348+ if use_flex_attn :
349+ cp_plan = _ContextParallel (
350+ seq_dim = 2 , attention_type = _ContextParallel .AttentionType .FLEX
351+ )
352+ else :
353+ # This is currently required as DTensor dispatcher is not enabled to
354+ # dispatch SDPA to CP implementation. We don't disable the CP
355+ # dispatching in TorchTitan as it is not needed. But there is a
356+ # corresponding API, _disable_context_parallel_dispatcher to do
357+ # that if users have this use case.
358+ _enable_context_parallel_dispatcher ()
359+ cp_plan = _ContextParallel (
360+ seq_dim = 2 , attention_type = _ContextParallel .AttentionType .SDPA
361+ )
362+
363+ for transformer_block in model .layers .values ():
364+ module = transformer_block .attention .inner_attention
365+ if use_flex_attn :
366+ module = module ._flex_attention_kernel
367+
368+ parallelize_module (
369+ module = module ,
370+ device_mesh = cp_mesh ,
371+ parallelize_plan = cp_plan ,
372+ )
373+
374+
375+ def cp_shard (
376+ cp_mesh : DeviceMesh ,
377+ inputs : torch .Tensor ,
378+ labels : torch .Tensor ,
379+ attention_masks : AttentionMasksType ,
380+ order_sensitive_buffers : dict [str , torch .Tensor ],
381+ order_sensitive_buffers_seq_dims : dict [str , int ],
382+ ):
383+ from torch .distributed .tensor .experimental ._attention import _context_parallel_shard
384+ from torch .nn .attention .flex_attention import BlockMask
385+
386+ load_balancer = None
387+ inputs , labels = _context_parallel_shard (
388+ mesh = cp_mesh ,
389+ buffers = (inputs , labels ),
390+ seq_dims = (1 , 1 ),
391+ load_balancer = load_balancer ,
392+ )
393+
394+ masks = (
395+ [attention_masks ]
396+ if isinstance (attention_masks , BlockMask )
397+ else list (attention_masks .values ())
398+ )
399+ masks = _context_parallel_shard (
400+ mesh = cp_mesh ,
401+ buffers = masks ,
402+ seq_dims = (2 ,) * len (masks ),
403+ load_balancer = load_balancer ,
404+ )
405+ attention_masks = (
406+ masks [0 ]
407+ if isinstance (attention_masks , BlockMask )
408+ else {k : v for k , v in zip (attention_masks .keys (), masks )}
409+ )
410+
411+ order_sensitive_buffers = _context_parallel_shard (
412+ mesh = cp_mesh ,
413+ buffers = order_sensitive_buffers ,
414+ seq_dims = order_sensitive_buffers_seq_dims ,
415+ load_balancer = load_balancer ,
416+ )
417+ return inputs , labels , attention_masks , order_sensitive_buffers
0 commit comments