Skip to content

Commit 98e2424

Browse files
authored
Use MHA SDPA (#1)
1 parent c00aa57 commit 98e2424

File tree

8 files changed

+277
-63
lines changed

8 files changed

+277
-63
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
from torchtune.models.gemma2._attention_mask import get_sliding_attention_mask
11+
12+
13+
class TestGetSlidingAttentionMask:
14+
@pytest.fixture
15+
def basic_params(self):
16+
return {"bsz": 2, "seq_len": 4, "sliding_window_size": 2, "device": None}
17+
18+
def test_get_sliding_attention_mask(self, basic_params):
19+
"""Test that when mask is None, a causal mask is created and sliding window is applied."""
20+
bsz = 2
21+
seq_len = 4
22+
sliding_window_size = 2
23+
mask = get_sliding_attention_mask(
24+
mask=None,
25+
sliding_window_size=basic_params["sliding_window_size"],
26+
bsz=basic_params["bsz"],
27+
seq_len=basic_params["seq_len"],
28+
device=basic_params["device"],
29+
)
30+
31+
assert mask.shape == (
32+
basic_params["bsz"],
33+
basic_params["seq_len"],
34+
basic_params["seq_len"],
35+
)
36+
assert mask.dtype == torch.bool
37+
38+
# Check that the mask has the expected sliding window pattern
39+
# True positions can be attended to, False positions are masked
40+
expected_pattern = torch.tensor(
41+
[
42+
[True, False, False, False],
43+
[True, True, False, False],
44+
[False, True, True, False],
45+
[False, False, True, True],
46+
],
47+
dtype=torch.bool,
48+
)
49+
50+
# Check first batch element
51+
torch.testing.assert_close(mask[0], expected_pattern)
52+
# All batch elements should be identical
53+
torch.testing.assert_close(mask[0], mask[1])
54+
55+
def test_get_sliding_attention_mask_different_window_sizes(self):
56+
"""Test sliding window with different window sizes."""
57+
bsz, seq_len = 1, 5
58+
59+
# Test window size 1 (only current position)
60+
mask = get_sliding_attention_mask(
61+
mask=None,
62+
sliding_window_size=1,
63+
bsz=bsz,
64+
seq_len=seq_len,
65+
device=None,
66+
)
67+
68+
expected_window_1 = torch.tensor(
69+
[
70+
[True, False, False, False, False],
71+
[False, True, False, False, False],
72+
[False, False, True, False, False],
73+
[False, False, False, True, False],
74+
[False, False, False, False, True],
75+
],
76+
dtype=torch.bool,
77+
)
78+
79+
torch.testing.assert_close(mask[0], expected_window_1)
80+
81+
# Test window size 3
82+
mask = get_sliding_attention_mask(
83+
mask=None,
84+
sliding_window_size=3,
85+
bsz=bsz,
86+
seq_len=seq_len,
87+
device=None,
88+
)
89+
90+
expected_window_3 = torch.tensor(
91+
[
92+
[True, False, False, False, False],
93+
[True, True, False, False, False],
94+
[True, True, True, False, False],
95+
[False, True, True, True, False],
96+
[False, False, True, True, True],
97+
],
98+
dtype=torch.bool,
99+
)
100+
101+
torch.testing.assert_close(mask[0], expected_window_3)
102+
103+
def test_get_sliding_attention_mask_large_window(self):
104+
"""Test sliding window larger than sequence length."""
105+
bsz, seq_len = 1, 3
106+
sliding_window_size = 5 # Larger than seq_len
107+
108+
mask = get_sliding_attention_mask(
109+
mask=None,
110+
sliding_window_size=sliding_window_size,
111+
bsz=bsz,
112+
seq_len=seq_len,
113+
device=None,
114+
)
115+
116+
# Should behave like a regular causal mask when window is larger than seq_len
117+
expected_causal = torch.tensor(
118+
[
119+
[True, False, False],
120+
[True, True, False],
121+
[True, True, True],
122+
],
123+
dtype=torch.bool,
124+
)
125+
126+
torch.testing.assert_close(mask[0], expected_causal)

tests/torchtune/modules/test_attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_flex_attention(self, mock_sdpa, mock_flex):
122122
_attention_call = _sdpa_or_flex_attention()
123123
_ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal)
124124
mock_sdpa.assert_not_called()
125-
mock_flex.assert_called_with(q, k, v, block_mask=attn_mask)
125+
mock_flex.assert_called_with(q, k, v, block_mask=attn_mask, scale=None)
126126
# If mask is not a BlockMask, then we should call SDPA
127127
_attention_call = _sdpa_or_flex_attention()
128128
_ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from torchtune.modules.attention_utils import _MaskType
12+
13+
14+
def get_sliding_attention_mask(
15+
mask: Optional[_MaskType],
16+
sliding_window_size: int,
17+
bsz: int,
18+
seq_len: int,
19+
device: Optional[torch.device] = None,
20+
) -> _MaskType:
21+
"""
22+
Args:
23+
mask (Optional[_MaskType]): Mask to apply to the attention scores.
24+
sliding_window_size (int): Sliding window size to apply to the attention mask.
25+
bsz (int): Batch size. Argument is unused, but listed for consistency.
26+
seq_len (int): Sequence length.
27+
device (Optional[torch.device]): Device to use for the mask. Defaults to None.
28+
29+
Returns:
30+
A tensor mask that applies sliding window masking.
31+
32+
Raises:
33+
ValueError: If the input mask is not a Tensor
34+
"""
35+
36+
if mask is None:
37+
mask = torch.tril(
38+
torch.ones(size=(bsz, seq_len, seq_len), dtype=torch.bool).to(device)
39+
)
40+
41+
if not isinstance(mask, torch.Tensor):
42+
raise ValueError(
43+
f"For non-flex attention, mask must be a Tensor. Got: {type(mask)}"
44+
)
45+
46+
all_ones = torch.ones_like(mask, dtype=torch.bool)
47+
sliding_mask = torch.triu(all_ones, -1 * sliding_window_size + 1) & torch.tril(
48+
all_ones, sliding_window_size - 1
49+
)
50+
mask = mask & sliding_mask
51+
52+
return mask

0 commit comments

Comments
 (0)