Skip to content

Commit d59132b

Browse files
committed
[RFC][WIP][CP] Enable FlexAttention CP for llama3
This PR uses the latest CP APIs to enable FlexAttention + CP for llama3. This PR removes the usage of context_paralle() context manager and use `_context_parallel_shard()` to shard the input data. ghstack-source-id: 5d04d61 Pull-Request: #1857
1 parent eaa4393 commit d59132b

File tree

5 files changed

+207
-51
lines changed

5 files changed

+207
-51
lines changed

torchtitan/distributed/utils.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
2121
from torchtitan.distributed.parallel_dims import ParallelDims
22+
from torchtitan.protocols.model import AttentionMasksType
2223
from torchtitan.tools.logging import logger
2324
from torchtitan.tools.utils import device_module, device_type
2425

@@ -200,9 +201,6 @@ def context(cp_context: Generator[None, None, None] | None = None):
200201
if enable_loss_parallel:
201202
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
202203

203-
if cp_context:
204-
stack.enter_context(cp_context)
205-
206204
yield
207205

208206
return context
@@ -443,3 +441,53 @@ def _clip_grad_norm_with_ep(
443441
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)
444442

445443
return total_norm
444+
445+
446+
def cp_shard(
447+
cp_mesh: DeviceMesh,
448+
inputs: torch.Tensor,
449+
labels: torch.Tensor,
450+
attention_masks: AttentionMasksType | None,
451+
order_sensitive_buffers: dict[str, torch.Tensor],
452+
order_sensitive_buffers_seq_dims: dict[str, int],
453+
):
454+
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
455+
from torch.distributed.tensor.experimental._load_balancer import _PTRRLoadBalancer
456+
from torch.nn.attention.flex_attention import BlockMask
457+
458+
load_balancer = _PTRRLoadBalancer(attention_masks, cp_mesh.size(0))
459+
inputs, labels = _context_parallel_shard(
460+
mesh=cp_mesh,
461+
buffers=(inputs, labels),
462+
seq_dims=(1, 1),
463+
load_balancer=load_balancer,
464+
)
465+
466+
order_sensitive_buffers = _context_parallel_shard(
467+
mesh=cp_mesh,
468+
buffers=order_sensitive_buffers,
469+
seq_dims=order_sensitive_buffers_seq_dims,
470+
load_balancer=load_balancer,
471+
)
472+
473+
if attention_masks is None:
474+
return inputs, labels, None, order_sensitive_buffers
475+
476+
masks = (
477+
[attention_masks]
478+
if isinstance(attention_masks, BlockMask)
479+
else list(attention_masks.values())
480+
)
481+
masks = _context_parallel_shard(
482+
mesh=cp_mesh,
483+
buffers=masks,
484+
seq_dims=(2,) * len(masks),
485+
load_balancer=load_balancer,
486+
)
487+
attention_masks = (
488+
masks[0]
489+
if isinstance(attention_masks, BlockMask)
490+
else {k: v for k, v in zip(attention_masks.keys(), masks)}
491+
)
492+
493+
return inputs, labels, attention_masks, order_sensitive_buffers

torchtitan/models/attention.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@
3232
]
3333

3434

35+
class FlexAttentionKernel(torch.nn.Module):
36+
"""Wrapper to enable FlexCP"""
37+
38+
_compiled_flex_attn: ClassVar[Callable] = torch.compile(
39+
flex_attention, mode="max-autotune-no-cudagraphs"
40+
)
41+
42+
def forward(self, *args, **kwargs):
43+
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
44+
# be multiple compiled flex_attention instances, which can be slow.
45+
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
46+
# as the first argument, which will cause an error.
47+
# `FlexAttentionKernel._compiled_flex_attn` is correct.
48+
return FlexAttentionKernel._compiled_flex_attn(*args, **kwargs)
49+
50+
3551
class FlexAttentionWrapper(torch.nn.Module):
3652
"""Wrapper around `flex_attention` to make it torch.compile and CP compatible.
3753
@@ -45,9 +61,11 @@ class FlexAttentionWrapper(torch.nn.Module):
4561
block_mask as a keyword argument to be compatible with _ContextParallel.
4662
"""
4763

48-
_compiled_flex_attn: ClassVar[Callable] = torch.compile(
49-
flex_attention, mode="max-autotune-no-cudagraphs"
50-
)
64+
def __init__(self) -> None:
65+
super().__init__()
66+
# TODO: remove this wrapper once FlexAttentionWrapper.forward() has the
67+
# same signature as flex_attention() and is compatible with _ContextParallel.
68+
self._flex_attention_kernel = FlexAttentionKernel()
5169

5270
def forward(
5371
self,
@@ -59,15 +77,10 @@ def forward(
5977
scale: float | None = None,
6078
return_lse: bool = False,
6179
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
62-
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
63-
# be multiple compiled flex_attention instances, which can be slow.
64-
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
65-
# as the first argument, which will cause an error.
66-
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
67-
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation
68-
# to convert `lse` to be DTensor.
80+
# Used `return_lse` instead of `return_aux` because of easier TP module notation
81+
# to convert `lse` to be DTensor.
6982

70-
return FlexAttentionWrapper._compiled_flex_attn(
83+
return self._flex_attention_kernel(
7184
q,
7285
k,
7386
v,

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchtitan.distributed import ParallelDims
2828
from torchtitan.distributed.activation_checkpoint import apply_ac
2929
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
30+
from torchtitan.protocols.model import AttentionMasksType
3031
from torchtitan.tools.logging import logger
3132

3233

@@ -67,10 +68,6 @@ def parallelize_llama(
6768
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6869
"""
6970

70-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
71-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
72-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
73-
7471
if parallel_dims.tp_enabled:
7572
enable_float8_linear = "float8" in job_config.model.converters
7673
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
@@ -91,6 +88,11 @@ def parallelize_llama(
9188
)
9289
maybe_enable_async_tp(job_config, world_mesh["tp"])
9390

91+
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
92+
if parallel_dims.cp_enabled:
93+
logger.info("Applied Context Parallel to the model")
94+
apply_cp(model, world_mesh["cp"], use_flex_attn)
95+
9496
model_compile_enabled = (
9597
job_config.compile.enable and "model" in job_config.compile.components
9698
)
@@ -131,9 +133,6 @@ def parallelize_llama(
131133
else:
132134
logger.info("Applied FSDP to the model")
133135

134-
if parallel_dims.cp_enabled:
135-
logger.info("Applied Context Parallel to the model")
136-
137136
if job_config.training.enable_cpu_offload:
138137
logger.info("Applied CPU Offloading to the model")
139138
elif parallel_dims.dp_replicate_enabled:
@@ -328,3 +327,91 @@ def apply_ddp(
328327
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
329328

330329
logger.info("Applied DDP to the model")
330+
331+
332+
def apply_cp(
333+
model: nn.Module,
334+
cp_mesh: DeviceMesh,
335+
use_flex_attn: bool,
336+
) -> None:
337+
"""
338+
Apply context parallelism to the model.
339+
"""
340+
from torch.distributed.tensor.experimental._attention import (
341+
_ContextParallel,
342+
_enable_context_parallel_dispatcher,
343+
)
344+
345+
# Apply context parallelism to every transformer block
346+
# TODO: make seq_sim configurable once the implementation doesn't assume 2
347+
# internally.
348+
if use_flex_attn:
349+
cp_plan = _ContextParallel(
350+
seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX
351+
)
352+
else:
353+
# This is currently required as DTensor dispatcher is not enabled to
354+
# dispatch SDPA to CP implementation. We don't disable the CP
355+
# dispatching in TorchTitan as it is not needed. But there is a
356+
# corresponding API, _disable_context_parallel_dispatcher to do
357+
# that if users have this use case.
358+
_enable_context_parallel_dispatcher()
359+
cp_plan = _ContextParallel(
360+
seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA
361+
)
362+
363+
for transformer_block in model.layers.values():
364+
module = transformer_block.attention.inner_attention
365+
if use_flex_attn:
366+
module = module._flex_attention_kernel
367+
368+
parallelize_module(
369+
module=module,
370+
device_mesh=cp_mesh,
371+
parallelize_plan=cp_plan,
372+
)
373+
374+
375+
def cp_shard(
376+
cp_mesh: DeviceMesh,
377+
inputs: torch.Tensor,
378+
labels: torch.Tensor,
379+
attention_masks: AttentionMasksType,
380+
order_sensitive_buffers: dict[str, torch.Tensor],
381+
order_sensitive_buffers_seq_dims: dict[str, int],
382+
):
383+
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
384+
from torch.nn.attention.flex_attention import BlockMask
385+
386+
load_balancer = None
387+
inputs, labels = _context_parallel_shard(
388+
mesh=cp_mesh,
389+
buffers=(inputs, labels),
390+
seq_dims=(1, 1),
391+
load_balancer=load_balancer,
392+
)
393+
394+
masks = (
395+
[attention_masks]
396+
if isinstance(attention_masks, BlockMask)
397+
else list(attention_masks.values())
398+
)
399+
masks = _context_parallel_shard(
400+
mesh=cp_mesh,
401+
buffers=masks,
402+
seq_dims=(2,) * len(masks),
403+
load_balancer=load_balancer,
404+
)
405+
attention_masks = (
406+
masks[0]
407+
if isinstance(attention_masks, BlockMask)
408+
else {k: v for k, v in zip(attention_masks.keys(), masks)}
409+
)
410+
411+
order_sensitive_buffers = _context_parallel_shard(
412+
mesh=cp_mesh,
413+
buffers=order_sensitive_buffers,
414+
seq_dims=order_sensitive_buffers_seq_dims,
415+
load_balancer=load_balancer,
416+
)
417+
return inputs, labels, attention_masks, order_sensitive_buffers

torchtitan/models/llama3/model/args.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5555
)
5656
self.max_seq_len = seq_len
5757

58-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
59-
raise NotImplementedError(
60-
"CP support for FlexAttention is still in progress."
61-
)
62-
6358
def get_nparams_and_flops(
6459
self, model: nn.Module, seq_len: int
6560
) -> tuple[int, float]:

torchtitan/train.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -409,12 +409,10 @@ def batch_generator(
409409

410410
yield input_dict, labels
411411

412-
def forward_backward_step(
412+
def post_dataloader_step(
413413
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
414-
) -> torch.Tensor:
415-
model_parts = self.model_parts
416-
parallel_dims = self.parallel_dims
417-
414+
) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any], dict[str, Any],]:
415+
"""Post processing of the batch and label after being loaded from the dataloader."""
418416
inputs = input_dict["input"]
419417
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
420418
# For arguments, like attention_masks, we have to put them in a separate
@@ -423,38 +421,53 @@ def forward_backward_step(
423421
extra_kwargs = {}
424422

425423
if getattr(self.model_args, "use_flex_attn", False):
426-
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
424+
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
427425
input_batch=inputs,
428426
tokenizer=self.tokenizer,
429427
extra_inputs=extra_inputs,
430428
)
429+
else:
430+
extra_kwargs["attention_masks"] = None
431431

432432
# Get the order sensitive buffers
433-
order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers(
433+
order_sensitive_buffers = self.model_parts[0].get_order_sensitive_buffers(
434434
inputs.size(0), inputs.size(1)
435435
)
436-
extra_args.update(order_sensitive_buffers[0])
437-
438-
# apply context parallelism if cp is enabled
439-
# ensure CP handles the separate freqs_cis buffer for each pp stage
440-
optional_context_parallel_ctx = (
441-
dist_utils.create_context_parallel_ctx(
442-
cp_mesh=parallel_dims.world_mesh["cp"],
443-
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
444-
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
445-
cp_no_restore_buffers={inputs, labels},
446-
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
447-
)
448-
if parallel_dims.cp_enabled
436+
cp_mesh = (
437+
self.parallel_dims.world_mesh["cp"]
438+
if self.parallel_dims.cp_enabled
449439
else None
450440
)
441+
if cp_mesh:
442+
(
443+
inputs,
444+
labels,
445+
extra_kwargs["attention_masks"],
446+
*order_sensitive_buffers,
447+
) = dist_utils.cp_shard(
448+
cp_mesh,
449+
inputs,
450+
labels,
451+
extra_kwargs["attention_masks"],
452+
*order_sensitive_buffers,
453+
)
454+
extra_kwargs.update(order_sensitive_buffers[0])
455+
return inputs, labels, extra_inputs, extra_kwargs
456+
457+
def forward_backward_step(
458+
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
459+
) -> torch.Tensor:
460+
model_parts = self.model_parts
461+
parallel_dims = self.parallel_dims
462+
463+
inputs, labels, extra_inputs, extra_kwargs = self.post_dataloader_step(
464+
input_dict, labels
465+
)
451466

452467
if parallel_dims.pp_enabled:
453468
# Pipeline Parallel forward / backward inside step() call
454-
with self.train_context(optional_context_parallel_ctx):
455-
targets, losses = (
456-
(labels, []) if self.pp_has_last_stage else (None, None)
457-
)
469+
targets, losses = (labels, []) if self.pp_has_last_stage else (None, None)
470+
with self.train_context():
458471
if self.pp_has_first_stage:
459472
self.pp_schedule.step(
460473
inputs,
@@ -484,7 +497,7 @@ def forward_backward_step(
484497
)
485498
else:
486499
# Non-PP forward / backward
487-
with self.train_context(optional_context_parallel_ctx):
500+
with self.train_context():
488501
assert len(model_parts) == 1
489502
with self.maybe_enable_amp:
490503
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)

0 commit comments

Comments
 (0)