Skip to content

Commit 139926b

Browse files
fegingithubsgi
authored andcommitted
Refactor attention and make attention mask an argument to the model (pytorch#1776)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#1797 * __->__ pytorch#1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is pytorch#1723. pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix pytorch#1723 with pytorch/pytorch#164111 and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in pytorch#1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ```
1 parent 1b88f57 commit 139926b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchtitan/distributed/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def set_determinism(
120120
# reproducibility, since the autotune results may not be deterministic.
121121
from torch.nn.attention.flex_attention import flex_attention
122122

123-
from torchtitan.models.attention import FlexAttention
123+
from torchtitan.models.attention import FlexAttentionWrapper
124124

125-
FlexAttention.flex_attn = torch.compile(flex_attention)
125+
FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention)
126126

127127
if not world_mesh:
128128
if seed is not None:

0 commit comments

Comments
 (0)