@@ -76,6 +76,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
76
)
77
77
_testing .assert_onnx_program (onnx_program )
78
78
79
+ def test_sdpa_with_bool_attn_mask (self ):
80
+ class ScaledDotProductAttention (torch .nn .Module ):
81
+ def forward (self , query , key , value , attn_mask ):
82
+ return torch .nn .functional .scaled_dot_product_attention ( # pylint: disable=not-callable
83
+ query , key , value , attn_mask = attn_mask
84
+ )
85
+
86
+ model = ScaledDotProductAttention ()
87
+ attn_mask = torch .ones (2 , 4 , 8 , 8 ).bool () # boolean mask for attention
88
+ attn_mask [0 , 0 , 0 , :] = False # masking an entire row (padding token)
89
+ query = key = value = torch .randn (2 , 4 , 8 , 16 )
90
+
91
+ onnx_program = torch .onnx .export (
92
+ model ,
93
+ (query , key , value , attn_mask ),
94
+ input_names = ["query" , "key" , "value" , "attn_mask" ],
95
+ output_names = ["output" ],
96
+ opset_version = 18 ,
97
+ dynamo = True ,
98
+ )
99
+ _testing .assert_onnx_program (onnx_program )
100
+
79
101
80
102
if __name__ == "__main__" :
81
103
unittest .main ()
0 commit comments