Skip to content

Commit a4dd9ac

Browse files
fegingithubsgi
authored andcommitted
Refactor attention and make attention mask an argument to the model (pytorch#1776)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#1797 * __->__ pytorch#1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is pytorch#1723. pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix pytorch#1723 with pytorch/pytorch#164111 and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in pytorch#1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ```
1 parent a287903 commit a4dd9ac

File tree

15 files changed

+508
-276
lines changed

15 files changed

+508
-276
lines changed

torchtitan/distributed/utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def set_determinism(
112112
# reproducibility, since the autotune results may not be deterministic.
113113
from torch.nn.attention.flex_attention import flex_attention
114114

115-
from torchtitan.models.attention import FlexAttention
115+
from torchtitan.models.attention import FlexAttentionWrapper
116116

117-
FlexAttention.flex_attn = torch.compile(flex_attention)
117+
FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention)
118118

119119
if not world_mesh:
120120
if seed is not None:
@@ -209,14 +209,6 @@ def context(cp_context: Generator[None, None, None] | None = None):
209209
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
210210
)
211211

212-
if cp_context is not None:
213-
from torch.nn.attention import SDPBackend
214-
215-
from torchtitan.models.attention import ScaledDotProductAttention
216-
217-
if SDPBackend.MATH in ScaledDotProductAttention.backends:
218-
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
219-
220212
stack.enter_context(cp_context)
221213

222214
yield

torchtitan/experiments/forge/example_train.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,14 @@ def forward_backward_step(
157157
model_parts = self.model_parts
158158
parallel_dims = self.parallel_dims
159159

160-
# apply context parallelism if cp is enabled
161-
# ensure CP handles the separate freqs_cis buffer for each pp stage
162160
inputs = input_dict["input"]
163-
# Create the FlexAttention mask according to the input
161+
extra_args = {}
162+
164163
if getattr(self.model_args, "use_flex_attn", False):
165-
cp_mesh = (
166-
parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
164+
extra_args["attention_masks"] = model_parts[0].get_attention_masks(
165+
input_batch=inputs,
166+
tokenizer=self.tokenizer,
167167
)
168-
init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh)
169168

170169
optional_context_parallel_ctx = (
171170
dist_utils.create_context_parallel_ctx(
@@ -187,11 +186,18 @@ def forward_backward_step(
187186
)
188187
if self.pp_has_first_stage:
189188
self.pp_schedule.step(
190-
inputs, target=targets, losses=losses, input_batch=inputs
189+
inputs,
190+
**extra_args,
191+
target=targets,
192+
losses=losses,
193+
input_batch=inputs,
191194
)
192195
else:
193196
self.pp_schedule.step(
194-
target=targets, losses=losses, input_batch=inputs
197+
**extra_args,
198+
target=targets,
199+
losses=losses,
200+
input_batch=inputs,
195201
)
196202

197203
# accumulate losses across pipeline microbatches
@@ -209,7 +215,7 @@ def forward_backward_step(
209215
with self.train_context(optional_context_parallel_ctx):
210216
assert len(model_parts) == 1
211217
with self.maybe_enable_amp:
212-
pred = model_parts[0](inputs)
218+
pred = model_parts[0](inputs, **extra_args)
213219
loss = self.loss_fn(pred, labels)
214220
# need to free to before bwd to avoid peaking memory
215221
del pred

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def apply_non_moe_tp(
239239
layer_plan = {
240240
"attention_norm": SequenceParallel(),
241241
"attention": prepare_module_input(
242-
input_layouts=(Shard(1), None),
243-
desired_input_layouts=(Replicate(), None),
242+
input_layouts=(Shard(1), None, None),
243+
desired_input_layouts=(Replicate(), None, None),
244244
),
245245
"attention.wq": colwise_parallel(),
246246
"attention.wk": colwise_parallel(),

torchtitan/experiments/llama4/model/model.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,20 @@
99
import torch
1010
import torch.nn.functional as F
1111
from torch import nn
12-
13-
from torchtitan.models.attention import build_attention
12+
from torch.nn.attention.flex_attention import and_masks
13+
14+
from torchtitan.components.tokenizer import BaseTokenizer
15+
from torchtitan.models.attention import (
16+
create_attention_mask,
17+
FlexAttentionWrapper,
18+
get_causal_mask_mod,
19+
get_document_mask_mod,
20+
get_fixed_block_mask_mod,
21+
ScaledDotProductAttentionWrapper,
22+
)
1423
from torchtitan.models.moe import MoE
15-
from torchtitan.protocols import ModelProtocol
24+
from torchtitan.protocols.model import AttentionMasksType
25+
from torchtitan.protocols.train_spec import ModelProtocol
1626

1727
from .args import RoPEScalingArgs, TransformerModelArgs
1828

@@ -192,9 +202,11 @@ def __init__(
192202
# values of these two variables.
193203
self.use_rope = use_rope
194204

195-
self.sdpa = build_attention(
196-
model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size
197-
)
205+
self.use_flex_attn = model_args.use_flex_attn
206+
if self.use_flex_attn:
207+
self.inner_attention = FlexAttentionWrapper()
208+
else:
209+
self.inner_attention = ScaledDotProductAttentionWrapper()
198210

199211
def init_weights(self, init_std: float):
200212
for linear in (self.wq, self.wk, self.wv):
@@ -205,6 +217,7 @@ def forward(
205217
self,
206218
x: torch.Tensor,
207219
freqs_cis: torch.Tensor,
220+
attention_masks: AttentionMasksType | None,
208221
):
209222
"""
210223
Forward pass of the attention module.
@@ -239,7 +252,13 @@ def forward(
239252
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
240253
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
241254

242-
output = self.sdpa(xq, xk, xv)
255+
if self.use_flex_attn:
256+
assert isinstance(attention_masks, dict), attention_masks
257+
attention_mask = attention_masks["rope" if self.use_rope else "nope"]
258+
output = self.inner_attention(xq, xk, xv, block_mask=attention_mask)
259+
else:
260+
assert attention_masks is None
261+
output = self.inner_attention(xq, xk, xv)
243262

244263
output = output.transpose(
245264
1, 2
@@ -372,6 +391,7 @@ def forward(
372391
self,
373392
x: torch.Tensor,
374393
freqs_cis: torch.Tensor,
394+
attention_masks: AttentionMasksType | None,
375395
):
376396
"""
377397
Perform a forward pass through the TransformerBlock.
@@ -384,7 +404,7 @@ def forward(
384404
torch.Tensor: Output tensor after applying attention and feedforward layers.
385405
386406
"""
387-
h = x + self.attention(self.attention_norm(x), freqs_cis)
407+
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
388408
if self.moe_enabled:
389409
out = h + self.moe(self.ffn_norm(h))
390410
else:
@@ -485,9 +505,40 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
485505
self.model_args.rope_scaling_args,
486506
)
487507

508+
def get_attention_masks(
509+
self,
510+
input_batch: torch.Tensor,
511+
tokenizer: BaseTokenizer,
512+
extra_inputs: dict[str, torch.Tensor] | None = None,
513+
) -> AttentionMasksType:
514+
mask_mods = [get_causal_mask_mod()]
515+
match self.model_args.attn_mask_type:
516+
case "causal":
517+
B = 1
518+
case "block_causal":
519+
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
520+
B = input_batch.shape[0]
521+
case _:
522+
raise ValueError(
523+
f"Unknown attention mask type: {self.model_args.attn_mask_type}"
524+
)
525+
526+
rope_mask_mod = and_masks(
527+
*mask_mods,
528+
get_fixed_block_mask_mod(self.model_args.fixed_attn_block_size),
529+
)
530+
nope_mask_mod = and_masks(*mask_mods)
531+
532+
seqlen = input_batch.shape[1]
533+
return {
534+
"rope": create_attention_mask(rope_mask_mod, B, None, seqlen, seqlen),
535+
"nope": create_attention_mask(nope_mask_mod, B, None, seqlen, seqlen),
536+
}
537+
488538
def forward(
489539
self,
490540
tokens: torch.Tensor,
541+
attention_masks: AttentionMasksType | None = None,
491542
input_batch: torch.Tensor | None = None,
492543
):
493544
"""
@@ -511,7 +562,7 @@ def forward(
511562
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
512563

513564
for layer in self.layers.values():
514-
h = layer(h, self.freqs_cis)
565+
h = layer(h, self.freqs_cis, attention_masks)
515566

516567
h = self.norm(h) if self.norm else h
517568
output = self.output(h) if self.output else h

torchtitan/experiments/qwen3/model/model.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,23 @@
1010
import torch
1111
import torch.nn.functional as F
1212
from torch import nn
13-
14-
from torchtitan.models.attention import build_attention
13+
from torch.nn.attention.flex_attention import and_masks, BlockMask
14+
15+
from torchtitan.components.tokenizer import BaseTokenizer
16+
from torchtitan.models.attention import (
17+
create_attention_mask,
18+
FlexAttentionWrapper,
19+
get_causal_mask_mod,
20+
get_document_mask_mod,
21+
ScaledDotProductAttentionWrapper,
22+
)
1523
from torchtitan.models.moe import MoE
24+
from torchtitan.protocols.model import AttentionMasksType
1625
from torchtitan.protocols.train_spec import ModelProtocol
1726

1827
from .args import Qwen3ModelArgs
1928

29+
2030
# Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py
2131
def precompute_rope_cache(
2232
dim: int, max_seq_len: int, base: float = 1_000_000.0
@@ -133,6 +143,7 @@ def __init__(self, model_args: Qwen3ModelArgs):
133143
self.n_rep = self.n_heads // self.n_kv_heads
134144
self.head_dim = model_args.head_dim
135145
self.scaling = self.head_dim**-0.5
146+
self.use_flex_attn = getattr(model_args, "use_flex_attn", False)
136147

137148
# RMSNorm added here to the here to include the q-k norm
138149
# This is one of the main differences between Llama3 and Qwen3
@@ -155,7 +166,11 @@ def __init__(self, model_args: Qwen3ModelArgs):
155166
self.wo = nn.Linear(
156167
model_args.n_heads * self.head_dim, model_args.dim, bias=False
157168
)
158-
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)
169+
170+
if self.use_flex_attn:
171+
self.inner_attention = FlexAttentionWrapper()
172+
else:
173+
self.inner_attention = ScaledDotProductAttentionWrapper()
159174

160175
def init_weights(self, init_std: float):
161176
for linear in (self.wq, self.wk, self.wv):
@@ -170,6 +185,7 @@ def forward(
170185
self,
171186
x: torch.Tensor,
172187
rope_cache: torch.Tensor,
188+
attention_masks: AttentionMasksType | None,
173189
):
174190
"""
175191
Forward pass of the attention module.
@@ -210,7 +226,12 @@ def forward(
210226
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
211227
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
212228

213-
output = self.sdpa(xq, xk, xv, scale=self.scaling)
229+
if self.use_flex_attn:
230+
assert isinstance(attention_masks, BlockMask), attention_masks
231+
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
232+
else:
233+
assert attention_masks is None
234+
output = self.inner_attention(xq, xk, xv)
214235

215236
output = output.transpose(
216237
1, 2
@@ -308,6 +329,7 @@ def forward(
308329
self,
309330
x: torch.Tensor,
310331
rope_cache: torch.Tensor,
332+
attention_masks: AttentionMasksType | None,
311333
):
312334
"""
313335
Perform a forward pass through the TransformerBlock.
@@ -320,7 +342,7 @@ def forward(
320342
torch.Tensor: Output tensor after applying attention and feedforward layers.
321343
322344
"""
323-
x = x + self.attention(self.attention_norm(x), rope_cache)
345+
x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks)
324346

325347
if self.moe_enabled:
326348
x = x + self.moe(self.ffn_norm(x))
@@ -423,9 +445,31 @@ def _precompute_rope_cache(self) -> torch.Tensor:
423445
self.model_args.rope_theta,
424446
)
425447

448+
def get_attention_masks(
449+
self,
450+
input_batch: torch.Tensor,
451+
tokenizer: BaseTokenizer,
452+
extra_inputs: dict[str, torch.Tensor] | None = None,
453+
) -> AttentionMasksType:
454+
mask_mods = [get_causal_mask_mod()]
455+
match self.model_args.attn_mask_type:
456+
case "causal":
457+
B = 1
458+
case "block_causal":
459+
B = input_batch.shape[0]
460+
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
461+
case _:
462+
raise ValueError(
463+
f"Unknown attention mask type: {self.model_args.attn_mask_type}"
464+
)
465+
return create_attention_mask(
466+
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
467+
)
468+
426469
def forward(
427470
self,
428471
tokens: torch.Tensor,
472+
attention_masks: AttentionMasksType | None = None,
429473
input_batch: torch.Tensor | None = None,
430474
):
431475
"""
@@ -449,7 +493,7 @@ def forward(
449493
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
450494

451495
for layer in self.layers.values():
452-
h = layer(h, self.rope_cache)
496+
h = layer(h, self.rope_cache, attention_masks)
453497

454498
h = self.norm(h) if self.norm else h
455499
output = self.output(h) if self.output else h

torchtitan/experiments/vlm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _get_dict(obj) -> dict[str, Any]:
3535

3636
llama3_siglip2_configs = {
3737
"debugmodel": Llama3Siglip2ModelArgs(
38-
**_get_dict(llama3_configs["debugmodel"]),
38+
**_get_dict(llama3_configs["debugmodel_flex_attn"]),
3939
encoder=Siglip2ModelArgs(
4040
dim=128,
4141
ffn_dim=256,

0 commit comments

Comments
 (0)