Skip to content

Commit 854f585

Browse files
add custom scale
1 parent d7955f4 commit 854f585

File tree

2 files changed

+129
-10
lines changed

2 files changed

+129
-10
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool)
1313
self._use_mask = use_mask
1414
self._pre_scale = pre_scale
1515
self._use_mul = use_mul
16+
self._custom_scale = False
1617

1718
def pattern(
1819
self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale
@@ -57,34 +58,48 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
5758

5859
if self._pre_scale:
5960
# Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
61+
# If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
6062
sqrt_scaling_factor = math.sqrt(expected_scaling_factor)
61-
if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3):
63+
64+
if _ir_utils.get_singleton_value(query_scale) is None:
6265
return check_result.fail(
63-
"Query scale is not a scalar or does not match the expected scaling factor.",
66+
"Query scale is not a scalar.",
6467
query_scale,
6568
)
66-
if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3):
69+
if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3):
70+
self._custom_scale = True
71+
if _ir_utils.get_singleton_value(key_scale) is None:
6772
return check_result.fail(
68-
"Key scale is not a scalar or does not match the expected scaling factor.",
73+
"Key scale is not a scalar.",
6974
key_scale,
7075
)
76+
if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3):
77+
self._custom_scale = True
7178
else:
7279
# Check if qk_scale is a scalar == expected_scaling_factor)
73-
if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3):
80+
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
81+
if _ir_utils.get_singleton_value(qk_scale) is None:
7482
return check_result.fail(
75-
"QK scale is not a scalar or does not match the expected scaling factor.",
83+
"QK scale is not a scalar.",
7684
qk_scale,
7785
)
86+
if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3):
87+
self._custom_scale = True
7888

7989
# check ranks/shapes
8090

8191
return check_result
8292

83-
def rewrite(self, op, query, key_transposed, value, mask, **_):
93+
def rewrite(
94+
self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale, **_
95+
):
96+
sdpa_args = [query, key_transposed, value]
8497
if self._use_mask:
85-
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
86-
else:
87-
return op.SDPA(query, key_transposed, value, _domain="ai.onnxruntime.fusion")
98+
sdpa_args.append(mask)
99+
if self._custom_scale:
100+
scale = _ir_utils.get_singleton_value(query_scale if self._pre_scale else qk_scale)
101+
return op.SDPA(*sdpa_args, scale=scale, _domain="ai.onnxruntime.fusion")
102+
return op.SDPA(*sdpa_args, _domain="ai.onnxruntime.fusion")
88103

89104

90105
# Rules for SDPA without mask

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,52 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7474
return attn_output
7575

7676

77+
@script()
78+
def _custom_scale_pre_div_sdpa_script(query, key, value):
79+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
80+
divisor = op.Constant(value_float=2.0)
81+
scaled_query = op.Div(query, divisor)
82+
scaled_key = op.Div(key_transposed, divisor)
83+
attn_score = op.MatMul(scaled_query, scaled_key)
84+
attn_weight = op.Softmax(attn_score, axis=-1)
85+
attn_output = op.MatMul(attn_weight, value)
86+
return attn_output
87+
88+
89+
@script()
90+
def _custom_scale_pre_mul_sdpa_script(query, key, value):
91+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
92+
multiplier = op.Constant(value_float=0.5)
93+
scaled_query = op.Mul(query, multiplier)
94+
scaled_key = op.Mul(key_transposed, multiplier)
95+
attn_score = op.MatMul(scaled_query, scaled_key)
96+
attn_weight = op.Softmax(attn_score, axis=-1)
97+
attn_output = op.MatMul(attn_weight, value)
98+
return attn_output
99+
100+
101+
@script()
102+
def _custom_scale_post_div_sdpa_script(query, key, value):
103+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
104+
divisor = op.Constant(value_float=0.1)
105+
attn_score = op.MatMul(query, key_transposed)
106+
scaled_attn_score = op.Div(attn_score, divisor)
107+
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
108+
attn_output = op.MatMul(attn_weight, value)
109+
return attn_output
110+
111+
112+
@script()
113+
def _custom_scale_post_mul_sdpa_script(query, key, value):
114+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
115+
multiplier = op.Constant(value_float=0.125)
116+
attn_score = op.MatMul(query, key_transposed)
117+
scaled_attn_score = op.Mul(attn_score, multiplier)
118+
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
119+
attn_output = op.MatMul(attn_weight, value)
120+
return attn_output
121+
122+
77123
@script()
78124
def _masked_pre_div_sdpa_script(query, key, value, mask):
79125
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
@@ -124,6 +170,56 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
124170
return attn_output
125171

126172

173+
@script()
174+
def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
175+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
176+
divisor = op.Constant(value_float=2.0)
177+
scaled_query = op.Div(query, divisor)
178+
scaled_key = op.Div(key_transposed, divisor)
179+
attn_score = op.MatMul(scaled_query, scaled_key)
180+
masked_attn_score = op.Add(attn_score, mask)
181+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
182+
attn_output = op.MatMul(attn_weight, value)
183+
return attn_output
184+
185+
186+
@script()
187+
def _custom_scale_mul_sdpa_script(query, key, value, mask):
188+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
189+
multiplier = op.Constant(value_float=0.5)
190+
scaled_query = op.Mul(query, multiplier)
191+
scaled_key = op.Mul(key_transposed, multiplier)
192+
attn_score = op.MatMul(scaled_query, scaled_key)
193+
masked_attn_score = op.Add(attn_score, mask)
194+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
195+
attn_output = op.MatMul(attn_weight, value)
196+
return attn_output
197+
198+
199+
@script()
200+
def _custom_scale_post_div_sdpa_script(query, key, value, mask):
201+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
202+
divisor = op.Constant(value_float=0.1)
203+
attn_score = op.MatMul(query, key_transposed)
204+
scaled_attn_score = op.Div(attn_score, divisor)
205+
masked_attn_score = op.Add(scaled_attn_score, mask)
206+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
207+
attn_output = op.MatMul(attn_weight, value)
208+
return attn_output
209+
210+
211+
@script()
212+
def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
213+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
214+
multiplier = op.Constant(value_float=0.125)
215+
attn_score = op.MatMul(query, key_transposed)
216+
scaled_attn_score = op.Mul(attn_score, multiplier)
217+
masked_attn_score = op.Add(scaled_attn_score, mask)
218+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
219+
attn_output = op.MatMul(attn_weight, value)
220+
return attn_output
221+
222+
127223
class SDPATestCase:
128224
def __init__(self, script_func):
129225
self.script_func = script_func
@@ -161,6 +257,14 @@ class TestSDPAFusion(unittest.TestCase):
161257
("pre_mul", _masked_pre_mul_sdpa_script),
162258
("post_div", _masked_post_div_sdpa_script),
163259
("post_mul", _masked_post_mul_sdpa_script),
260+
("custom_scale_post_mul", _custom_scale_post_mul_sdpa_script),
261+
("custom_scale_post_div", _custom_scale_post_div_sdpa_script),
262+
("custom_scale_pre_mul", _custom_scale_pre_mul_sdpa_script),
263+
("custom_scale_pre_div", _custom_scale_pre_div_sdpa_script),
264+
("custom_scale_post_mul_masked", _custom_scale_post_mul_sdpa_script),
265+
("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script),
266+
("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script),
267+
("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script),
164268
]
165269
)
166270
def test_sdpa_fusion(self, name, script_func):

0 commit comments

Comments
 (0)