Skip to content

Commit e2fe5e7

Browse files
authored
Add a test for boolean attention mask within SDPA (#2480)
Follow up #2479
1 parent d8ad301 commit e2fe5e7

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7676
)
7777
_testing.assert_onnx_program(onnx_program)
7878

79+
def test_sdpa_with_bool_attn_mask(self):
80+
class ScaledDotProductAttention(torch.nn.Module):
81+
def forward(self, query, key, value, attn_mask):
82+
return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable
83+
query, key, value, attn_mask=attn_mask
84+
)
85+
86+
model = ScaledDotProductAttention()
87+
attn_mask = torch.ones(2, 4, 8, 8).bool() # boolean mask for attention
88+
attn_mask[0, 0, 0, :] = False # masking an entire row (padding token)
89+
query = key = value = torch.randn(2, 4, 8, 16)
90+
91+
onnx_program = torch.onnx.export(
92+
model,
93+
(query, key, value, attn_mask),
94+
input_names=["query", "key", "value", "attn_mask"],
95+
output_names=["output"],
96+
opset_version=18,
97+
dynamo=True,
98+
)
99+
_testing.assert_onnx_program(onnx_program)
100+
79101

80102
if __name__ == "__main__":
81103
unittest.main()

0 commit comments

Comments
 (0)