Skip to content

Commit cddf207

Browse files
refactor custom_scale usage
1 parent 06b27c7 commit cddf207

File tree

1 file changed

+12
-12
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+12
-12
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool)
1313
self._use_mask = use_mask
1414
self._pre_scale = pre_scale
1515
self._use_mul = use_mul
16-
self._custom_scale = False
16+
self._scale: float | None = None
1717

1818
def pattern(
1919
self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale
@@ -60,14 +60,18 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
6060
# Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
6161
# If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
6262
sqrt_scaling_factor = math.sqrt(expected_scaling_factor)
63-
63+
# Calculate the scaling factor for query
6464
if _ir_utils.get_singleton_value(query_scale) is None:
6565
return check_result.fail(
6666
"Query scale is not a scalar.",
6767
query_scale,
6868
)
6969
if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3):
70-
self._custom_scale = True
70+
query_scale_value = _ir_utils.get_singleton_value(query_scale)
71+
self._scale = query_scale_value * query_scale_value
72+
else:
73+
self._scale = expected_scaling_factor
74+
# Ensure the scaling factor for key is the same as for query
7175
if _ir_utils.get_singleton_value(key_scale) is None:
7276
return check_result.fail(
7377
"Key scale is not a scalar.",
@@ -81,7 +85,6 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
8185
"Query and key scales are not equal.",
8286
query_scale,
8387
)
84-
self._custom_scale = True
8588
else:
8689
# Check if qk_scale is a scalar == expected_scaling_factor)
8790
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
@@ -91,22 +94,19 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
9194
qk_scale,
9295
)
9396
if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3):
94-
self._custom_scale = True
97+
self._scale = _ir_utils.get_singleton_value(qk_scale)
98+
else:
99+
self._scale = expected_scaling_factor
95100

96101
# check ranks/shapes
97102

98103
return check_result
99104

100-
def rewrite(
101-
self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale, **_
102-
):
105+
def rewrite(self, op, query, key_transposed, value, mask, **_):
103106
sdpa_args = [query, key_transposed, value]
104107
if self._use_mask:
105108
sdpa_args.append(mask)
106-
if self._custom_scale:
107-
scale = _ir_utils.get_singleton_value(query_scale if self._pre_scale else qk_scale)
108-
return op.SDPA(*sdpa_args, scale=scale, _domain="ai.onnxruntime.fusion")
109-
return op.SDPA(*sdpa_args, _domain="ai.onnxruntime.fusion")
109+
return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion")
110110

111111

112112
# Rules for SDPA without mask

0 commit comments

Comments
 (0)