@@ -13,7 +13,7 @@ def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool)
13
13
self ._use_mask = use_mask
14
14
self ._pre_scale = pre_scale
15
15
self ._use_mul = use_mul
16
- self ._custom_scale = False
16
+ self ._scale : float | None = None
17
17
18
18
def pattern (
19
19
self , op , query , key_transposed , value , mask , query_scale , key_scale , qk_scale
@@ -60,14 +60,18 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
60
60
# Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
61
61
# If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
62
62
sqrt_scaling_factor = math .sqrt (expected_scaling_factor )
63
-
63
+ # Calculate the scaling factor for query
64
64
if _ir_utils .get_singleton_value (query_scale ) is None :
65
65
return check_result .fail (
66
66
"Query scale is not a scalar." ,
67
67
query_scale ,
68
68
)
69
69
if not _ir_utils .is_singleton_value (query_scale , sqrt_scaling_factor , rtol = 1e-3 ):
70
- self ._custom_scale = True
70
+ query_scale_value = _ir_utils .get_singleton_value (query_scale )
71
+ self ._scale = query_scale_value * query_scale_value
72
+ else :
73
+ self ._scale = expected_scaling_factor
74
+ # Ensure the scaling factor for key is the same as for query
71
75
if _ir_utils .get_singleton_value (key_scale ) is None :
72
76
return check_result .fail (
73
77
"Key scale is not a scalar." ,
@@ -81,7 +85,6 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
81
85
"Query and key scales are not equal." ,
82
86
query_scale ,
83
87
)
84
- self ._custom_scale = True
85
88
else :
86
89
# Check if qk_scale is a scalar == expected_scaling_factor)
87
90
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
@@ -91,22 +94,19 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
91
94
qk_scale ,
92
95
)
93
96
if not _ir_utils .is_singleton_value (qk_scale , expected_scaling_factor , rtol = 1e-3 ):
94
- self ._custom_scale = True
97
+ self ._scale = _ir_utils .get_singleton_value (qk_scale )
98
+ else :
99
+ self ._scale = expected_scaling_factor
95
100
96
101
# check ranks/shapes
97
102
98
103
return check_result
99
104
100
- def rewrite (
101
- self , op , query , key_transposed , value , mask , query_scale , key_scale , qk_scale , ** _
102
- ):
105
+ def rewrite (self , op , query , key_transposed , value , mask , ** _ ):
103
106
sdpa_args = [query , key_transposed , value ]
104
107
if self ._use_mask :
105
108
sdpa_args .append (mask )
106
- if self ._custom_scale :
107
- scale = _ir_utils .get_singleton_value (query_scale if self ._pre_scale else qk_scale )
108
- return op .SDPA (* sdpa_args , scale = scale , _domain = "ai.onnxruntime.fusion" )
109
- return op .SDPA (* sdpa_args , _domain = "ai.onnxruntime.fusion" )
109
+ return op .SDPA (* sdpa_args , scale = self ._scale , _domain = "ai.onnxruntime.fusion" )
110
110
111
111
112
112
# Rules for SDPA without mask
0 commit comments