Skip to content

Commit 35552f0

Browse files
committed
[RFC] Refactor attention and make attention mask an argument to the model
**Status** The PR is not landable yet but server as a RFC. If people are okay with this design, this PR requires following changes and verifications: 1. Change all models, including the experimental ones. 2. E2E loss verification (this has been done for functional check, but loss verification is noot done yet). 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a seperate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is #1723. Now that pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix #1723 with pytorch/pytorch#164111 and this PR. 3. Provide a single AttentionOp instead of two. Justification: since the masking logic is moved outside, we don't need to do bookkeeping of masks in FlexAttentionWrapper. The logic is so simple that one AttentionOp makes things cleaner. Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certinaly can be confusion for Titan's users. I'm opn to merge them to AttentionOp. See the discussion in #1723. ghstack-source-id: 71f4e41 Pull-Request-resolved: #1776
1 parent eb13ba2 commit 35552f0

File tree

10 files changed

+362
-237
lines changed

10 files changed

+362
-237
lines changed

torchtitan/distributed/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def set_determinism(
106106
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
107107
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
108108

109+
# Ensure flex_attention is compiled without max-autotune. This is needed to ensure
110+
# reproducibility, since the autotune results may not be deterministic.
111+
from torch.nn.attention.flex_attention import flex_attention
112+
113+
from torchtitan.models.attention import FlexAttentionWrapper
114+
115+
FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention)
116+
109117
if not world_mesh:
110118
if seed is not None:
111119
torch.manual_seed(seed)

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def apply_non_moe_tp(
238238
layer_plan = {
239239
"attention_norm": SequenceParallel(),
240240
"attention": prepare_module_input(
241-
input_layouts=(Shard(1), None),
242-
desired_input_layouts=(Replicate(), None),
241+
input_layouts=(Shard(1), None, None),
242+
desired_input_layouts=(Replicate(), None, None),
243243
),
244244
"attention.wq": colwise_parallel(),
245245
"attention.wk": colwise_parallel(),

torchtitan/experiments/llama4/model/model.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,23 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
87
import torch
98
import torch.nn.functional as F
109
from torch import nn
11-
12-
from torchtitan.models.attention import build_attention
10+
from torch.nn.attention.flex_attention import and_masks
11+
12+
from torchtitan.components.tokenizer import BaseTokenizer
13+
from torchtitan.models.attention import (
14+
create_attention_mask,
15+
FlexAttentionWrapper,
16+
get_causal_mask_mod,
17+
get_document_mask_mod,
18+
get_fixed_block_mask_mod,
19+
ScaledDotProductAttentionWrapper,
20+
)
1321
from torchtitan.models.moe import MoE
14-
from torchtitan.protocols import ModelProtocol
22+
from torchtitan.protocols.model import AttentionMasksType
23+
from torchtitan.protocols.train_spec import ModelProtocol
1524

1625
from .args import TransformerModelArgs
1726

@@ -155,9 +164,11 @@ def __init__(
155164
# values of these two variables.
156165
self.use_rope = use_rope
157166

158-
self.sdpa = build_attention(
159-
model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size
160-
)
167+
self.use_flex_attn = model_args.use_flex_attn
168+
if self.use_flex_attn:
169+
self.inner_attention = FlexAttentionWrapper()
170+
else:
171+
self.inner_attention = ScaledDotProductAttentionWrapper()
161172

162173
def init_weights(self, init_std: float):
163174
for linear in (self.wq, self.wk, self.wv):
@@ -168,6 +179,7 @@ def forward(
168179
self,
169180
x: torch.Tensor,
170181
freqs_cis: torch.Tensor,
182+
attention_masks: AttentionMasksType | None,
171183
):
172184
"""
173185
Forward pass of the attention module.
@@ -202,7 +214,13 @@ def forward(
202214
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
203215
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
204216

205-
output = self.sdpa(xq, xk, xv)
217+
if self.use_flex_attn:
218+
assert isinstance(attention_masks, dict), attention_masks
219+
attention_mask = attention_masks["rope" if self.use_rope else "nope"]
220+
output = self.inner_attention(xq, xk, xv, block_mask=attention_mask)
221+
else:
222+
assert attention_masks is None
223+
output = self.inner_attention(xq, xk, xv)
206224

207225
output = output.transpose(
208226
1, 2
@@ -335,6 +353,7 @@ def forward(
335353
self,
336354
x: torch.Tensor,
337355
freqs_cis: torch.Tensor,
356+
attention_masks: AttentionMasksType | None,
338357
):
339358
"""
340359
Perform a forward pass through the TransformerBlock.
@@ -347,7 +366,7 @@ def forward(
347366
torch.Tensor: Output tensor after applying attention and feedforward layers.
348367
349368
"""
350-
h = x + self.attention(self.attention_norm(x), freqs_cis)
369+
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
351370
if self.moe_enabled:
352371
out = h + self.moe(self.ffn_norm(h))
353372
else:
@@ -447,9 +466,38 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
447466
self.model_args.rope_theta,
448467
)
449468

469+
def get_attention_masks(
470+
self,
471+
input_batch: torch.Tensor,
472+
tokenizer: BaseTokenizer,
473+
extra_inputs: dict[str, torch.Tensor] | None = None,
474+
) -> AttentionMasksType:
475+
mask_mods = [get_causal_mask_mod()]
476+
match self.model_args.attn_mask_type:
477+
case "causal":
478+
B = 1
479+
case "block_causal":
480+
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
481+
B = input_batch.shape[0]
482+
case _:
483+
raise ValueError(f"Unknown attention mask type: {self.attn_mask_type}")
484+
485+
rope_mask_mod = and_masks(
486+
*mask_mods,
487+
get_fixed_block_mask_mod(self.model_args.fixed_attn_block_size),
488+
)
489+
nope_mask_mod = and_masks(*mask_mods)
490+
491+
seqlen = input_batch.shape[1]
492+
return {
493+
"rope": create_attention_mask(rope_mask_mod, B, None, seqlen, seqlen),
494+
"nope": create_attention_mask(nope_mask_mod, B, None, seqlen, seqlen),
495+
}
496+
450497
def forward(
451498
self,
452499
tokens: torch.Tensor,
500+
attention_masks: AttentionMasksType | None = None,
453501
input_batch: torch.Tensor | None = None,
454502
):
455503
"""
@@ -473,7 +521,7 @@ def forward(
473521
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
474522

475523
for layer in self.layers.values():
476-
h = layer(h, self.freqs_cis)
524+
h = layer(h, self.freqs_cis, attention_masks)
477525

478526
h = self.norm(h) if self.norm else h
479527
output = self.output(h) if self.output else h

0 commit comments

Comments
 (0)