Skip to content

Commit 94150da

Browse files
committed
[RFC][WIP][CP] Enable FlexAttention CP for llama3
This PR uses the latest CP APIs to enable FlexAttention + CP for llama3. This PR removes the usage of context_paralle() context manager and use `_context_parallel_shard()` to shard the input data. ghstack-source-id: 673d743 Pull-Request: #1857
1 parent 8099cbb commit 94150da

File tree

7 files changed

+243
-54
lines changed

7 files changed

+243
-54
lines changed

torchtitan/distributed/utils.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
2121
from torchtitan.distributed.parallel_dims import ParallelDims
22+
from torchtitan.protocols.model import AttentionMasksType
2223
from torchtitan.tools.logging import logger
2324
from torchtitan.tools.utils import device_module, device_type
2425

@@ -200,9 +201,6 @@ def context(cp_context: Generator[None, None, None] | None = None):
200201
if enable_loss_parallel:
201202
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
202203

203-
if cp_context:
204-
stack.enter_context(cp_context)
205-
206204
yield
207205

208206
return context
@@ -443,3 +441,65 @@ def _clip_grad_norm_with_ep(
443441
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)
444442

445443
return total_norm
444+
445+
446+
def cp_shard(
447+
cp_mesh: DeviceMesh,
448+
inputs: torch.Tensor,
449+
labels: torch.Tensor,
450+
attention_masks: AttentionMasksType | None,
451+
order_sensitive_buffers: dict[str, torch.Tensor],
452+
order_sensitive_buffers_seq_dims: dict[str, int],
453+
):
454+
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
455+
from torch.distributed.tensor.experimental._load_balancer import (
456+
_HeadTailLoadBalancer,
457+
_PTRRLoadBalancer,
458+
)
459+
from torch.nn.attention.flex_attention import BlockMask
460+
461+
seq_len = inputs.size(1)
462+
cp_world_size = cp_mesh.size(0)
463+
if isinstance(attention_masks, BlockMask):
464+
load_balancer = _PTRRLoadBalancer(attention_masks, cp_world_size)
465+
else:
466+
# For multiple BlockMasks or SDPA, we use the _HeadTailLoadBalancer.
467+
load_balancer = _HeadTailLoadBalancer(
468+
seq_len, cp_world_size, cp_mesh.device_type
469+
)
470+
471+
inputs, labels = _context_parallel_shard(
472+
mesh=cp_mesh,
473+
buffers=(inputs, labels),
474+
seq_dims=(1, 1),
475+
load_balancer=load_balancer,
476+
)
477+
478+
order_sensitive_buffers = _context_parallel_shard(
479+
mesh=cp_mesh,
480+
buffers=order_sensitive_buffers,
481+
seq_dims=order_sensitive_buffers_seq_dims,
482+
load_balancer=load_balancer,
483+
)
484+
485+
if attention_masks is None:
486+
return inputs, labels, None, order_sensitive_buffers
487+
488+
masks = (
489+
[attention_masks]
490+
if isinstance(attention_masks, BlockMask)
491+
else list(attention_masks.values())
492+
)
493+
masks = _context_parallel_shard(
494+
mesh=cp_mesh,
495+
buffers=masks,
496+
seq_dims=(2,) * len(masks),
497+
load_balancer=load_balancer,
498+
)
499+
attention_masks = (
500+
masks[0]
501+
if isinstance(attention_masks, BlockMask)
502+
else {k: v for k, v in zip(attention_masks.keys(), masks)}
503+
)
504+
505+
return inputs, labels, attention_masks, order_sensitive_buffers

torchtitan/models/attention.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@
3232
]
3333

3434

35+
class FlexAttentionKernel(torch.nn.Module):
36+
"""Wrapper to enable FlexCP"""
37+
38+
_compiled_flex_attn: ClassVar[Callable] = torch.compile(
39+
flex_attention, mode="max-autotune-no-cudagraphs"
40+
)
41+
42+
def forward(self, *args, **kwargs):
43+
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
44+
# be multiple compiled flex_attention instances, which can be slow.
45+
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
46+
# as the first argument, which will cause an error.
47+
# `FlexAttentionKernel._compiled_flex_attn` is correct.
48+
return FlexAttentionKernel._compiled_flex_attn(*args, **kwargs)
49+
50+
3551
class FlexAttentionWrapper(torch.nn.Module):
3652
"""Wrapper around `flex_attention` to make it torch.compile and CP compatible.
3753
@@ -45,9 +61,11 @@ class FlexAttentionWrapper(torch.nn.Module):
4561
block_mask as a keyword argument to be compatible with _ContextParallel.
4662
"""
4763

48-
_compiled_flex_attn: ClassVar[Callable] = torch.compile(
49-
flex_attention, mode="max-autotune-no-cudagraphs"
50-
)
64+
def __init__(self) -> None:
65+
super().__init__()
66+
# TODO: remove this wrapper once FlexAttentionWrapper.forward() has the
67+
# same signature as flex_attention() and is compatible with _ContextParallel.
68+
self._flex_attention_kernel = FlexAttentionKernel()
5169

5270
def forward(
5371
self,
@@ -59,15 +77,10 @@ def forward(
5977
scale: float | None = None,
6078
return_lse: bool = False,
6179
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
62-
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
63-
# be multiple compiled flex_attention instances, which can be slow.
64-
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
65-
# as the first argument, which will cause an error.
66-
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
67-
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation
68-
# to convert `lse` to be DTensor.
80+
# Used `return_lse` instead of `return_aux` because of easier TP module notation
81+
# to convert `lse` to be DTensor.
6982

70-
return FlexAttentionWrapper._compiled_flex_attn(
83+
return self._flex_attention_kernel(
7184
q,
7285
k,
7386
v,

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchtitan.distributed import ParallelDims
2828
from torchtitan.distributed.activation_checkpoint import apply_ac
2929
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
30+
from torchtitan.protocols.model import AttentionMasksType
3031
from torchtitan.tools.logging import logger
3132

3233

@@ -67,10 +68,6 @@ def parallelize_llama(
6768
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6869
"""
6970

70-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
71-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
72-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
73-
7471
if parallel_dims.tp_enabled:
7572
enable_float8_linear = "float8" in job_config.model.converters
7673
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
@@ -91,6 +88,11 @@ def parallelize_llama(
9188
)
9289
maybe_enable_async_tp(job_config, world_mesh["tp"])
9390

91+
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
92+
if parallel_dims.cp_enabled:
93+
logger.info("Applied Context Parallel to the model")
94+
apply_cp(model, world_mesh["cp"], use_flex_attn)
95+
9496
model_compile_enabled = (
9597
job_config.compile.enable and "model" in job_config.compile.components
9698
)
@@ -131,9 +133,6 @@ def parallelize_llama(
131133
else:
132134
logger.info("Applied FSDP to the model")
133135

134-
if parallel_dims.cp_enabled:
135-
logger.info("Applied Context Parallel to the model")
136-
137136
if job_config.training.enable_cpu_offload:
138137
logger.info("Applied CPU Offloading to the model")
139138
elif parallel_dims.dp_replicate_enabled:
@@ -328,3 +327,91 @@ def apply_ddp(
328327
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
329328

330329
logger.info("Applied DDP to the model")
330+
331+
332+
def apply_cp(
333+
model: nn.Module,
334+
cp_mesh: DeviceMesh,
335+
use_flex_attn: bool,
336+
) -> None:
337+
"""
338+
Apply context parallelism to the model.
339+
"""
340+
from torch.distributed.tensor.experimental._attention import (
341+
_ContextParallel,
342+
_enable_context_parallel_dispatcher,
343+
)
344+
345+
# Apply context parallelism to every transformer block
346+
# TODO: make seq_sim configurable once the implementation doesn't assume 2
347+
# internally.
348+
if use_flex_attn:
349+
cp_plan = _ContextParallel(
350+
seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX
351+
)
352+
else:
353+
# This is currently required as DTensor dispatcher is not enabled to
354+
# dispatch SDPA to CP implementation. We don't disable the CP
355+
# dispatching in TorchTitan as it is not needed. But there is a
356+
# corresponding API, _disable_context_parallel_dispatcher to do
357+
# that if users have this use case.
358+
_enable_context_parallel_dispatcher()
359+
cp_plan = _ContextParallel(
360+
seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA
361+
)
362+
363+
for transformer_block in model.layers.values():
364+
module = transformer_block.attention.inner_attention
365+
if use_flex_attn:
366+
module = module._flex_attention_kernel
367+
368+
parallelize_module(
369+
module=module,
370+
device_mesh=cp_mesh,
371+
parallelize_plan=cp_plan,
372+
)
373+
374+
375+
def cp_shard(
376+
cp_mesh: DeviceMesh,
377+
inputs: torch.Tensor,
378+
labels: torch.Tensor,
379+
attention_masks: AttentionMasksType,
380+
order_sensitive_buffers: dict[str, torch.Tensor],
381+
order_sensitive_buffers_seq_dims: dict[str, int],
382+
):
383+
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
384+
from torch.nn.attention.flex_attention import BlockMask
385+
386+
load_balancer = None
387+
inputs, labels = _context_parallel_shard(
388+
mesh=cp_mesh,
389+
buffers=(inputs, labels),
390+
seq_dims=(1, 1),
391+
load_balancer=load_balancer,
392+
)
393+
394+
masks = (
395+
[attention_masks]
396+
if isinstance(attention_masks, BlockMask)
397+
else list(attention_masks.values())
398+
)
399+
masks = _context_parallel_shard(
400+
mesh=cp_mesh,
401+
buffers=masks,
402+
seq_dims=(2,) * len(masks),
403+
load_balancer=load_balancer,
404+
)
405+
attention_masks = (
406+
masks[0]
407+
if isinstance(attention_masks, BlockMask)
408+
else {k: v for k, v in zip(attention_masks.keys(), masks)}
409+
)
410+
411+
order_sensitive_buffers = _context_parallel_shard(
412+
mesh=cp_mesh,
413+
buffers=order_sensitive_buffers,
414+
seq_dims=order_sensitive_buffers_seq_dims,
415+
load_balancer=load_balancer,
416+
)
417+
return inputs, labels, attention_masks, order_sensitive_buffers

torchtitan/models/llama3/model/args.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5555
)
5656
self.max_seq_len = seq_len
5757

58-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
59-
raise NotImplementedError(
60-
"CP support for FlexAttention is still in progress."
61-
)
62-
6358
def get_nparams_and_flops(
6459
self, model: nn.Module, seq_len: int
6560
) -> tuple[int, float]:

torchtitan/models/llama3/model/model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
9292
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
9393
for the purpose of broadcasting the frequency tensor during element-wise operations.
9494
95-
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
96-
and the first seqlen elements will be sliced, but dim must match x.
95+
The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim).
9796
9897
Args:
9998
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
@@ -104,10 +103,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
104103
"""
105104
ndim = x.ndim
106105
assert ndim > 1
106+
batch_size = x.shape[0]
107107
seqlen = x.shape[1]
108-
freqs_cis = freqs_cis[0:seqlen]
109-
assert freqs_cis.shape == (seqlen, x.shape[-1])
110-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
108+
assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1])
109+
shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
111110
return freqs_cis.view(*shape)
112111

113112

@@ -474,9 +473,18 @@ def get_attention_masks(
474473
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
475474
)
476475

476+
def get_order_sensitive_buffers(
477+
self,
478+
batch_size: int,
479+
seq_len: int,
480+
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
481+
freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1)
482+
return ({"freqs_cis": freqs_cis}, {"freqs_cis": 1})
483+
477484
def forward(
478485
self,
479486
tokens: torch.Tensor,
487+
freqs_cis: torch.Tensor,
480488
attention_masks: AttentionMasksType | None = None,
481489
):
482490
"""
@@ -496,7 +504,7 @@ def forward(
496504
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
497505

498506
for layer in self.layers.values():
499-
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
507+
h = layer(h, freqs_cis, attention_masks=attention_masks)
500508

501509
h = self.norm(h) if self.norm else h
502510
output = self.output(h) if self.output else h

torchtitan/protocols/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,10 @@ def get_attention_masks(
7070
raise NotImplementedError(
7171
"This model does not support attention masking/Flex Attention."
7272
)
73+
74+
def get_order_sensitive_buffers(
75+
self,
76+
batch_size: int,
77+
seq_len: int,
78+
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
79+
return ({}, {})

0 commit comments

Comments
 (0)