@@ -74,6 +74,52 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
74
74
return attn_output
75
75
76
76
77
+ @script ()
78
+ def _custom_scale_pre_div_sdpa_script (query , key , value ):
79
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
80
+ divisor = op .Constant (value_float = 2.0 )
81
+ scaled_query = op .Div (query , divisor )
82
+ scaled_key = op .Div (key_transposed , divisor )
83
+ attn_score = op .MatMul (scaled_query , scaled_key )
84
+ attn_weight = op .Softmax (attn_score , axis = - 1 )
85
+ attn_output = op .MatMul (attn_weight , value )
86
+ return attn_output
87
+
88
+
89
+ @script ()
90
+ def _custom_scale_pre_mul_sdpa_script (query , key , value ):
91
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
92
+ multiplier = op .Constant (value_float = 0.5 )
93
+ scaled_query = op .Mul (query , multiplier )
94
+ scaled_key = op .Mul (key_transposed , multiplier )
95
+ attn_score = op .MatMul (scaled_query , scaled_key )
96
+ attn_weight = op .Softmax (attn_score , axis = - 1 )
97
+ attn_output = op .MatMul (attn_weight , value )
98
+ return attn_output
99
+
100
+
101
+ @script ()
102
+ def _custom_scale_post_div_sdpa_script (query , key , value ):
103
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
104
+ divisor = op .Constant (value_float = 0.1 )
105
+ attn_score = op .MatMul (query , key_transposed )
106
+ scaled_attn_score = op .Div (attn_score , divisor )
107
+ attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
108
+ attn_output = op .MatMul (attn_weight , value )
109
+ return attn_output
110
+
111
+
112
+ @script ()
113
+ def _custom_scale_post_mul_sdpa_script (query , key , value ):
114
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
115
+ multiplier = op .Constant (value_float = 0.125 )
116
+ attn_score = op .MatMul (query , key_transposed )
117
+ scaled_attn_score = op .Mul (attn_score , multiplier )
118
+ attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
119
+ attn_output = op .MatMul (attn_weight , value )
120
+ return attn_output
121
+
122
+
77
123
@script ()
78
124
def _masked_pre_div_sdpa_script (query , key , value , mask ):
79
125
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
@@ -124,6 +170,56 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
124
170
return attn_output
125
171
126
172
173
+ @script ()
174
+ def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
175
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
176
+ divisor = op .Constant (value_float = 2.0 )
177
+ scaled_query = op .Div (query , divisor )
178
+ scaled_key = op .Div (key_transposed , divisor )
179
+ attn_score = op .MatMul (scaled_query , scaled_key )
180
+ masked_attn_score = op .Add (attn_score , mask )
181
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
182
+ attn_output = op .MatMul (attn_weight , value )
183
+ return attn_output
184
+
185
+
186
+ @script ()
187
+ def _custom_scale_mul_sdpa_script (query , key , value , mask ):
188
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
189
+ multiplier = op .Constant (value_float = 0.5 )
190
+ scaled_query = op .Mul (query , multiplier )
191
+ scaled_key = op .Mul (key_transposed , multiplier )
192
+ attn_score = op .MatMul (scaled_query , scaled_key )
193
+ masked_attn_score = op .Add (attn_score , mask )
194
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
195
+ attn_output = op .MatMul (attn_weight , value )
196
+ return attn_output
197
+
198
+
199
+ @script ()
200
+ def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
201
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
202
+ divisor = op .Constant (value_float = 0.1 )
203
+ attn_score = op .MatMul (query , key_transposed )
204
+ scaled_attn_score = op .Div (attn_score , divisor )
205
+ masked_attn_score = op .Add (scaled_attn_score , mask )
206
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
207
+ attn_output = op .MatMul (attn_weight , value )
208
+ return attn_output
209
+
210
+
211
+ @script ()
212
+ def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
213
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
214
+ multiplier = op .Constant (value_float = 0.125 )
215
+ attn_score = op .MatMul (query , key_transposed )
216
+ scaled_attn_score = op .Mul (attn_score , multiplier )
217
+ masked_attn_score = op .Add (scaled_attn_score , mask )
218
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
219
+ attn_output = op .MatMul (attn_weight , value )
220
+ return attn_output
221
+
222
+
127
223
class SDPATestCase :
128
224
def __init__ (self , script_func ):
129
225
self .script_func = script_func
@@ -161,6 +257,14 @@ class TestSDPAFusion(unittest.TestCase):
161
257
("pre_mul" , _masked_pre_mul_sdpa_script ),
162
258
("post_div" , _masked_post_div_sdpa_script ),
163
259
("post_mul" , _masked_post_mul_sdpa_script ),
260
+ ("custom_scale_post_mul" , _custom_scale_post_mul_sdpa_script ),
261
+ ("custom_scale_post_div" , _custom_scale_post_div_sdpa_script ),
262
+ ("custom_scale_pre_mul" , _custom_scale_pre_mul_sdpa_script ),
263
+ ("custom_scale_pre_div" , _custom_scale_pre_div_sdpa_script ),
264
+ ("custom_scale_post_mul_masked" , _custom_scale_post_mul_sdpa_script ),
265
+ ("custom_scale_post_div_masked" , _custom_scale_post_div_sdpa_script ),
266
+ ("custom_scale_pre_mul_masked" , _custom_scale_pre_mul_sdpa_script ),
267
+ ("custom_scale_pre_div_masked" , _custom_scale_pre_div_sdpa_script ),
164
268
]
165
269
)
166
270
def test_sdpa_fusion (self , name , script_func ):
0 commit comments