You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Refactor attention and make attention mask an argument to the model (pytorch#1776)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#1797
* __->__ pytorch#1776
**Status**
1. Change all models, including the experimental ones.
2. E2E loss verification.
3. We should add an unittest for attention. But since we don't have GPU
unittest, this can be done in a separate PR.
**Summary**
This PR aims to refactor how TorchTitan build the attention masks and
pass to model. Before this PR, init_attention_masks() is called in
Trainer but the masks are stored as a class variable of
FlexAttentionWrapper(). We chose this shortcut to support the case where
a single model requires multiple masks.
The previous design has several issues, one particular one is
pytorch#1723.
pytorch/pytorch#164111 proves that we can let
PP split BlockMask, this PR performs the refactor to pass masks as an
argument of model.forward().
The new design:
1. Model needs to provide `get_attention_masks()` that accepts
`create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA,
then this API should return None as SDPA currently doesn't support
varlen. But once it does, we may have to return some tuple of int that
represents the mask.
Justification: attention logic is technically a part of the model, but
requires some information from trainer/dataloader. So it's model
author's responsibility to provide some API that let trainer calls to
get the masks.
2. `get_attention_masks()` will be called from the trainer and the
resulting masks are passed to the model.forward().
Justification: this will allow us to fixpytorch#1723 with
pytorch/pytorch#164111 and this PR.
3. Now SDPA and FlexAttention are wrapped in two different classes.
~~Note: we still have two very very thin op wrappers that are used for
CP. I keep these two for the CP education purpose. But this certainly
can be confusion for Titan's users. I'm opnn to merge them to
AttentionOp.~~
See the discussion in pytorch#1723.
**Verification**
*llama3*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml"
```
*llama3 flex*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn"
```
*llama4*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint
```
*llama4 irope*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint
```
*deepseek*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml"
```
*deepseek flex*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn"
```
0 commit comments