From 5374b3d64bdbf7d04bf49e756fc83db59027aa42 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 28 May 2025 12:26:48 -0700 Subject: [PATCH 01/18] Partial fixes to SDPA Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/sdpa.py | 149 ++++++++++++------------ 1 file changed, 77 insertions(+), 72 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index fa827e79aa..0784e4ce81 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -45,102 +45,107 @@ def pattern( # from 3D to 4D and scaling is applied before the reshaping. query_reshape, ): - if self._pre_scale: - # Some implementations scale the query and key before computing the dot product - if self._use_mul: - if self._pre_scale_q: - query = op.Mul(query, qk_scale) - else: - query = op.Mul(query, query_scale) - key_transposed = op.Mul(key_transposed, key_scale) - else: - if self._pre_scale_q: - query = op.Div(query, qk_scale) - else: - query = op.Div(query, query_scale) - key_transposed = op.Div(key_transposed, key_scale) - - # There might be patterns where the reshape and transpose are done - # after the pre-scaling. If the inputs are 3D, we need to reshape them to 4D - # and apply the approriate transposes to query. - if self._has_3d_query and self._pre_scale_q: - # Reshape and transpose 3D input of shape (B, S, D) - # to 4D input of shape (B, N, S, H) - queryBNSH = op.Reshape(query, query_reshape) - query = op.Transpose(queryBNSH, perm=[0, 2, 1, 3]) + # Some implementations scale the query and key before computing the dot product + query = pattern.OrValue([ + op.Mul(query, query_scale), + op.Div(query, query_scale), + query, + ], tag_var="query_scaling", tag_values=["Mul", "Div", "None"]) + key_transposed = pattern.OrValue([ + op.Mul(key_transposed, key_scale), + op.Div(key_transposed, key_scale), + key_transposed, + ], tag_var="key_scaling", tag_values=["Mul", "Div", "None"]) + attn_score = op.MatMul(query, key_transposed) - if not self._pre_scale: - # Some implementations scale the dot product. - if self._use_mul: - attn_score = op.Mul(attn_score, qk_scale) - else: - attn_score = op.Div(attn_score, qk_scale) - if self._use_mask: - # Some implementations add a mask to the dot product. - attn_score = op.Add(attn_score, mask) + + # Some implementations scale the dot product. + attn_score = pattern.OrValues ([ + op.Mul(attn_score, qk_scale), + op.Div(attn_score, qk_scale), + attn_score, + ], tag_var="qk_scaling", tag_values=["Mul", "Div", "None"]) + + # Some implementations add a mask to the dot product. + masked_attn_score = op.Add(attn_score, mask) + attn_score = pattern.OrValue([masked_attn_score, attn_score], tag_var="has_mask", tag_values=[True, False]) + attn_weight = op.Softmax(attn_score, axis=-1) attn_output = op.MatMul(attn_weight, value) return attn_output def check( - self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale, **_ + self, op, query, key_transposed, value, mask, query_scaling, query_scale, key_scaling, key_scale, qk_scale, **_ ): check_result = pattern.MatchResult() - # Check that the scaling factors match what SDPA implements: - - # We need to know the hidden size to check the scaling factors. - if query is None or query.shape is None or len(query.shape) < 2: - return check_result.fail( - "Query shape is not known or has less than 2 dimensions.", query - ) - hidden_size = query.shape[-1] - if not isinstance(hidden_size, int): - return check_result.fail("Hidden size is not an integer.") - - expected_scaling_factor = math.sqrt(hidden_size) - if self._use_mul: - expected_scaling_factor = 1.0 / expected_scaling_factor - - if self._pre_scale and not self._pre_scale_q: - # Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor) - # If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used. - sqrt_scaling_factor = math.sqrt(expected_scaling_factor) - # Calculate the scaling factor for query - if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None: + + if query_scaling == "None": + query_scale_value = 1.0 + elif query_scaling == "Mul": + if (query_scale_value := _ir_utils.get_singleton_value(query_scale, rank=0)) is None: + return check_result.fail( + "Query scale is not a scalar.", + query_scale, + ) + else: + assert query_scaling == "Div", "Unexpected query scaling operation" + if (query_scale_value := _ir_utils.get_singleton_value(query_scale, rank=0)) is None: return check_result.fail( "Query scale is not a scalar.", query_scale, ) - # Ensure the scaling factor for key is the same as for query - if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None: + query_scale_value = 1.0 / query_scale_value + + if key_scaling == "None": + key_scale_value = 1.0 + elif key_scaling == "Mul": + if (key_scale_value := _ir_utils.get_singleton_value(key_scale, rank=0)) is None: return check_result.fail( "Key scale is not a scalar.", key_scale, ) - if not math.isclose(query_scale_value, key_scale_value, rel_tol=1e-3): + else: + assert key_scaling == "Div", "Unexpected key scaling operation" + if (key_scale_value := _ir_utils.get_singleton_value(key_scale, rank=0)) is None: return check_result.fail( - "Query and key scales are not equal.", - query_scale, + "Key scale is not a scalar.", + key_scale, + ) + key_scale_value = 1.0 / key_scale_value + + if qk_scale == "None": + qk_scale_value = 1.0 + elif qk_scale == "Mul": + if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale, rank=0)) is None: + return check_result.fail( + "QK scale is not a scalar.", + qk_scale, ) - if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3): - self._scale = query_scale_value * query_scale_value - else: - # Pass no scaling factor to SDPA, SDPA will use the default scaling factor - self._scale = None else: - # Check if qk_scale is a scalar == expected_scaling_factor) - # If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used - if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None: + assert qk_scale == "Div", "Unexpected QK scaling operation" + if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale, rank=0)) is None: return check_result.fail( "QK scale is not a scalar.", qk_scale, ) - if not math.isclose(qk_scale_value, expected_scaling_factor, rel_tol=1e-3): - self._scale = qk_scale_value - else: - # Pass no scaling factor to SDPA, SDPA will use the default scaling factor - self._scale = None + qk_scale_value = 1.0 / qk_scale_value + + self._scale = query_scale_value * key_scale_value * qk_scale_value + + # If the scaling factor is the default one, we can skip passing it to SDPA. + + if query is None or query.shape is None or len(query.shape) < 2: + return + hidden_size = query.shape[-1] + if not isinstance(hidden_size, int): + return + + default_scaling_factor = math.sqrt(hidden_size) + + if self._scale == default_scaling_factor: + # Pass no scaling factor to SDPA, SDPA will use the default scaling factor + self._scale = None # check ranks/shapes From ac2d91d033a326d3334b16f22aa57d7a7fcb270e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 28 May 2025 14:46:38 -0700 Subject: [PATCH 02/18] Use disjunction in SDPA fusion Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/sdpa.py | 160 +++++++------------ onnxscript/rewriter/ort_fusions/sdpa_test.py | 33 ++-- 2 files changed, 77 insertions(+), 116 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 0784e4ce81..5a68a903ef 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -8,28 +8,7 @@ class SDPA(pattern.RewriteRuleClassBase): - def __init__( - self, - name: str, - *, - use_mask: bool, - pre_scale: bool, - pre_scale_q: bool, - use_mul: bool, - has_3d_query: bool, - ): - super().__init__(name=name) - self._use_mask = use_mask - self._pre_scale = pre_scale - # There are some patterns where only the query is scaled before the dot product - # and essentially (query * qk_scale) * key is equivalent to (query * key) * qk_scale - # TODO: Capture patterns where only the key is scaled before the dot product - self._pre_scale_q = pre_scale_q - self._use_mul = use_mul - # Capture patterns where the query is reshaped from 3D to 4D - # after scaling has been applied to query. - self._has_3d_query = has_3d_query - self._scale: float | None = None + _scale: float | None def pattern( self, @@ -41,56 +20,82 @@ def pattern( query_scale, key_scale, qk_scale, - # Shape used for reshaping the query in patterns where query is reshaped - # from 3D to 4D and scaling is applied before the reshaping. - query_reshape, ): # Some implementations scale the query and key before computing the dot product - query = pattern.OrValue([ - op.Mul(query, query_scale), - op.Div(query, query_scale), - query, - ], tag_var="query_scaling", tag_values=["Mul", "Div", "None"]) - key_transposed = pattern.OrValue([ - op.Mul(key_transposed, key_scale), - op.Div(key_transposed, key_scale), - key_transposed, - ], tag_var="key_scaling", tag_values=["Mul", "Div", "None"]) - + query = pattern.OrValue( + [ + op.Mul(query, query_scale), + op.Div(query, query_scale), + query, + ], + tag_var="query_scaling", + tag_values=["Mul", "Div", "None"], + ) + key_transposed = pattern.OrValue( + [ + op.Mul(key_transposed, key_scale), + op.Div(key_transposed, key_scale), + key_transposed, + ], + tag_var="key_scaling", + tag_values=["Mul", "Div", "None"], + ) attn_score = op.MatMul(query, key_transposed) # Some implementations scale the dot product. - attn_score = pattern.OrValues ([ - op.Mul(attn_score, qk_scale), - op.Div(attn_score, qk_scale), - attn_score, - ], tag_var="qk_scaling", tag_values=["Mul", "Div", "None"]) + attn_score = pattern.OrValue( + [ + op.Mul(attn_score, qk_scale), + op.Div(attn_score, qk_scale), + attn_score, + ], + tag_var="qk_scaling", + tag_values=["Mul", "Div", "None"], + ) # Some implementations add a mask to the dot product. masked_attn_score = op.Add(attn_score, mask) - attn_score = pattern.OrValue([masked_attn_score, attn_score], tag_var="has_mask", tag_values=[True, False]) + attn_score = pattern.OrValue( + [masked_attn_score, attn_score], tag_var="has_mask", tag_values=[True, False] + ) attn_weight = op.Softmax(attn_score, axis=-1) attn_output = op.MatMul(attn_weight, value) return attn_output def check( - self, op, query, key_transposed, value, mask, query_scaling, query_scale, key_scaling, key_scale, qk_scale, **_ + self, + op, + query, + key_transposed, + value, + mask, + query_scaling, + query_scale, + key_scaling, + key_scale, + qk_scaling, + qk_scale, + **_, ): check_result = pattern.MatchResult() - + if query_scaling == "None": query_scale_value = 1.0 elif query_scaling == "Mul": - if (query_scale_value := _ir_utils.get_singleton_value(query_scale, rank=0)) is None: + if ( + query_scale_value := _ir_utils.get_singleton_value(query_scale, rank=0) + ) is None: return check_result.fail( "Query scale is not a scalar.", query_scale, ) else: assert query_scaling == "Div", "Unexpected query scaling operation" - if (query_scale_value := _ir_utils.get_singleton_value(query_scale, rank=0)) is None: + if ( + query_scale_value := _ir_utils.get_singleton_value(query_scale, rank=0) + ) is None: return check_result.fail( "Query scale is not a scalar.", query_scale, @@ -114,16 +119,16 @@ def check( ) key_scale_value = 1.0 / key_scale_value - if qk_scale == "None": + if qk_scaling == "None": qk_scale_value = 1.0 - elif qk_scale == "Mul": + elif qk_scaling == "Mul": if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale, rank=0)) is None: return check_result.fail( "QK scale is not a scalar.", qk_scale, ) else: - assert qk_scale == "Div", "Unexpected QK scaling operation" + assert qk_scaling == "Div", "Unexpected QK scaling operation" if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale, rank=0)) is None: return check_result.fail( "QK scale is not a scalar.", @@ -131,7 +136,7 @@ def check( ) qk_scale_value = 1.0 / qk_scale_value - self._scale = query_scale_value * key_scale_value * qk_scale_value + self._scale = query_scale_value * key_scale_value * qk_scale_value # If the scaling factor is the default one, we can skip passing it to SDPA. @@ -141,13 +146,13 @@ def check( if not isinstance(hidden_size, int): return - default_scaling_factor = math.sqrt(hidden_size) + default_scaling_factor = 1.0 / math.sqrt(hidden_size) - if self._scale == default_scaling_factor: + if math.isclose(self._scale, default_scaling_factor, rel_tol=1e-5, abs_tol=1e-8): # Pass no scaling factor to SDPA, SDPA will use the default scaling factor self._scale = None - # check ranks/shapes + # TODO: check ranks/shapes return check_result @@ -158,61 +163,16 @@ def rewrite( key_transposed, value, mask, - query_scale, - key_scale, - qk_scale, - query_reshape=None, **_, ): - if self._pre_scale and self._pre_scale_q: - if self._use_mul: - query_mul = op.Mul(query, qk_scale) - else: - query_mul = op.Div(query, qk_scale) - # Reshape and transpose 3D input of shape (B, S, D) - # to 4D input of shape (B, N, S, H) - if self._has_3d_query: - queryBNSH = op.Reshape(query_mul, query_reshape) - query = op.Transpose(queryBNSH, perm=[0, 2, 1, 3]) - else: - query = query_mul - sdpa_args = [query, key_transposed, value] - if self._use_mask: + if mask is not None: sdpa_args.append(mask) return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion") -parameter_combinations = [ - { - "name": f"sdpa_{'masked_' if use_mask else 'unmasked_'}{'pre_' if pre_scale else 'post_'}{'only_q_' if pre_scale_q else ''}{'mul' if use_mul else 'div'}{'_3d_query' if has_3d_query else ''}", - "use_mask": use_mask, - "pre_scale": pre_scale, - "pre_scale_q": pre_scale_q, - "use_mul": use_mul, - "has_3d_query": has_3d_query, - } - for use_mask in [False, True] - for pre_scale in [False, True] - for pre_scale_q in [False, True] - for use_mul in [False, True] - for has_3d_query in [False, True] -] - # Dynamically create the rules -sdpa_rules = pattern.RewriteRuleSet( - [ - SDPA.rule( - params["name"], - use_mask=params["use_mask"], - pre_scale=params["pre_scale"], - pre_scale_q=params["pre_scale_q"], - use_mul=params["use_mul"], - has_3d_query=params["has_3d_query"], - ) - for params in parameter_combinations - ] -) +sdpa_rules = pattern.RewriteRuleSet([SDPA.rule()]) fuse_sdpa = _fusion_utils.apply_fusion_rules(sdpa_rules) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 74c718147f..76247766ff 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -26,7 +26,12 @@ MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR) SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR) -CUSTOM_SCALE_FACTOR = 2.0 +# Custom scale factors for testing +CUSTOM_SCALE_FACTOR = 1.0 / math.sqrt(80) +CUSTOM_MUL_SCALE_FACTOR = CUSTOM_SCALE_FACTOR +CUSTOM_DIV_SCALE_FACTOR = 1.0 / CUSTOM_SCALE_FACTOR +SQRT_CUSTOM_MUL_SCALE_FACTOR = math.sqrt(CUSTOM_MUL_SCALE_FACTOR) +SQRT_CUSTOM_DIV_SCALE_FACTOR = math.sqrt(CUSTOM_DIV_SCALE_FACTOR) @script() @@ -78,7 +83,7 @@ def _unmasked_post_mul_sdpa_script(query, key, value): @script() def _custom_scale_pre_div_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + divisor = op.Constant(value_float=SQRT_CUSTOM_DIV_SCALE_FACTOR) scaled_query = op.Div(query, divisor) scaled_key = op.Div(key_transposed, divisor) attn_score = op.MatMul(scaled_query, scaled_key) @@ -90,7 +95,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value): @script() def _custom_scale_pre_mul_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) scaled_query = op.Mul(query, multiplier) scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key) @@ -102,8 +107,8 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value): @script() def _custom_multi_scale_pre_mul_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier_q = op.Constant(value_float=CUSTOM_SCALE_FACTOR) - multiplier_k = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier_q = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) + multiplier_k = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) scaled_query = op.Mul(query, multiplier_q) scaled_key = op.Mul(key_transposed, multiplier_k) attn_score = op.MatMul(scaled_query, scaled_key) @@ -115,7 +120,7 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value): @script() def _custom_scale_post_div_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + divisor = op.Constant(value_float=CUSTOM_DIV_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Div(attn_score, divisor) attn_weight = op.Softmax(scaled_attn_score, axis=-1) @@ -126,7 +131,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value): @script() def _custom_scale_post_mul_sdpa_script(query, key, value): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier = op.Constant(value_float=CUSTOM_MUL_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Mul(attn_score, multiplier) attn_weight = op.Softmax(scaled_attn_score, axis=-1) @@ -187,7 +192,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask): @script() def _custom_scale_pre_div_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + divisor = op.Constant(value_float=SQRT_CUSTOM_DIV_SCALE_FACTOR) scaled_query = op.Div(query, divisor) scaled_key = op.Div(key_transposed, divisor) attn_score = op.MatMul(scaled_query, scaled_key) @@ -200,7 +205,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask): @script() def _custom_scale_pre_mul_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) scaled_query = op.Mul(query, multiplier) scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key) @@ -213,7 +218,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask): @script() def _custom_scale_post_div_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + divisor = op.Constant(value_float=CUSTOM_DIV_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Div(attn_score, divisor) masked_attn_score = op.Add(scaled_attn_score, mask) @@ -225,7 +230,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask): @script() def _custom_scale_post_mul_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) - multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier = op.Constant(value_float=CUSTOM_MUL_SCALE_FACTOR) attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Mul(attn_score, multiplier) masked_attn_score = op.Add(scaled_attn_score, mask) @@ -307,11 +312,7 @@ def test_sdpa_fusion(self, name, script_func): if "custom" in name: self.assertIsNotNone(sdpa_node.attributes.get("scale")) scale_factor = sdpa_node.attributes["scale"].value - self.assertIsNotNone(scale_factor) - if "pre" in name: - self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR) - elif "post" in name: - self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR) + self.assertAlmostEqual(scale_factor, CUSTOM_SCALE_FACTOR, delta=1e-8) else: # These tests are for the default scaling factors, no scale factor is passed to SDPA # pattern rewriting check functions should be sufficient to check if expected value From e9fefba8939b3d34dcd497aa00ba3f8560a49201 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 28 May 2025 15:13:15 -0700 Subject: [PATCH 03/18] Remove rank check Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/mha_test.py | 2 +- onnxscript/rewriter/ort_fusions/sdpa.py | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index e7efb9c978..bfb7e3bbf7 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -33,7 +33,7 @@ def test_smollm(self): original_outputs = ort_run("original", model, inputs) # Fuse SDPA and MHA - sdpa_count = xformers.fuse_sdpa(model) + sdpa_count = xformers.fuse_sdpa(model, debug=True) self.assertGreater(sdpa_count, 0) mha_count = xformers.fuse_mha1(model) mha_count += xformers.fuse_mha2(model) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 5a68a903ef..6457ac977e 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -84,18 +84,14 @@ def check( if query_scaling == "None": query_scale_value = 1.0 elif query_scaling == "Mul": - if ( - query_scale_value := _ir_utils.get_singleton_value(query_scale, rank=0) - ) is None: + if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None: return check_result.fail( "Query scale is not a scalar.", query_scale, ) else: assert query_scaling == "Div", "Unexpected query scaling operation" - if ( - query_scale_value := _ir_utils.get_singleton_value(query_scale, rank=0) - ) is None: + if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None: return check_result.fail( "Query scale is not a scalar.", query_scale, @@ -105,14 +101,14 @@ def check( if key_scaling == "None": key_scale_value = 1.0 elif key_scaling == "Mul": - if (key_scale_value := _ir_utils.get_singleton_value(key_scale, rank=0)) is None: + if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None: return check_result.fail( "Key scale is not a scalar.", key_scale, ) else: assert key_scaling == "Div", "Unexpected key scaling operation" - if (key_scale_value := _ir_utils.get_singleton_value(key_scale, rank=0)) is None: + if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None: return check_result.fail( "Key scale is not a scalar.", key_scale, @@ -122,14 +118,14 @@ def check( if qk_scaling == "None": qk_scale_value = 1.0 elif qk_scaling == "Mul": - if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale, rank=0)) is None: + if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None: return check_result.fail( "QK scale is not a scalar.", qk_scale, ) else: assert qk_scaling == "Div", "Unexpected QK scaling operation" - if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale, rank=0)) is None: + if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None: return check_result.fail( "QK scale is not a scalar.", qk_scale, From 01f7b21acfa5e23058a28c7a639e2ddf0a6bdbae Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 28 May 2025 15:55:13 -0700 Subject: [PATCH 04/18] Remove debug Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/mha_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index bfb7e3bbf7..e7efb9c978 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -33,7 +33,7 @@ def test_smollm(self): original_outputs = ort_run("original", model, inputs) # Fuse SDPA and MHA - sdpa_count = xformers.fuse_sdpa(model, debug=True) + sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) mha_count = xformers.fuse_mha1(model) mha_count += xformers.fuse_mha2(model) From 151244aedad82062ec6bdca9dd41ec97e8b6381e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 28 May 2025 16:12:34 -0700 Subject: [PATCH 05/18] Add type annotations Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/sdpa.py | 28 ++++++++++++------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 6457ac977e..165059ba41 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -4,6 +4,7 @@ import math +from onnxscript import ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern @@ -66,17 +67,14 @@ def pattern( def check( self, - op, - query, - key_transposed, - value, - mask, - query_scaling, - query_scale, - key_scaling, - key_scale, - qk_scaling, - qk_scale, + context, + query: ir.Value | None, + query_scaling: str, + query_scale: ir.Value | None, + key_scaling: str, + key_scale: ir.Value | None, + qk_scaling: str, + qk_scale: ir.Value | None, **_, ): check_result = pattern.MatchResult() @@ -155,10 +153,10 @@ def check( def rewrite( self, op, - query, - key_transposed, - value, - mask, + query: ir.Value | None, + key_transposed: ir.Value | None, + value: ir.Value | None, + mask: ir.Value | None, **_, ): sdpa_args = [query, key_transposed, value] From 366e167d42a40ae66302d056440838db5a8ccf20 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 29 May 2025 11:28:49 -0700 Subject: [PATCH 06/18] Add missing shapes to gqa_test Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gqa_test.py | 13 ++++++++++++ onnxscript/rewriter/ort_fusions/sdpa.py | 23 +++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 4f8f9ab8ba..bab25e6d70 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -44,6 +44,7 @@ def __init__(self, *args, **kwargs): "num_heads must be divisible by kv_num_heads" ) self.num_groups = self.num_heads // self.kv_num_heads + self.total_seqlen = self.seqlen + self.past_seqlen # Abbreviations B = self.batchsize @@ -311,12 +312,24 @@ def test_fusion(self): onnx.TensorProto.FLOAT, ["B", self.seqlen, self.kv_num_heads, self.head_size], ) + key_transposed_value_info = onnx.helper.make_tensor_value_info( + "key_transposed", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.head_size, self.total_seqlen], + ) + value_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "value_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) source_model.graph.value_info.extend( [ query_BHSDh_rope_value_info, key_BHkvSDh_rope_value_info, query_BSHDh_value_info, key_BSHkvDh_value_info, + key_transposed_value_info, + value_BHSDh_value_info, ] ) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 165059ba41..0791814daa 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -3,10 +3,13 @@ from __future__ import annotations import math +from typing import Sequence, Union from onnxscript import ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern +Dim = Union[int, ir.SymbolicDim] + class SDPA(pattern.RewriteRuleClassBase): _scale: float | None @@ -69,6 +72,9 @@ def check( self, context, query: ir.Value | None, + key_transposed: ir.Value | None, + value: ir.Value | None, + mask: ir.Value | None, query_scaling: str, query_scale: ir.Value | None, key_scaling: str, @@ -147,6 +153,23 @@ def check( self._scale = None # TODO: check ranks/shapes + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils._check_shape(bindings, val, dims) + + # Check that query/key/value have the expected shapes: + # They all should have same batch-size (B) and number of heads (H). Conceptually, it is + # different for Q and K/V, but the certain op implementations require them to be the same, + # which is usually achieved via tiling/expanding K/V num-heads to match Q num-heads. + # Query and Key should have same head-size (Dh) while value can have different head-size (Dv). + # Key and Value should have same sequence length (Skv), while Query can have different sequence length (S). + if no_match(query, ["B", "H", "S", "Dh"]): + return False + if no_match(key_transposed, ["B", "H", "Dh", "Skv"]): + return False + if no_match(value, ["B", "H", "Skv", "Dv"]): + return False return check_result From f32b823bf47ac14ebff76e8edafa248d41a0e653 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 29 May 2025 12:29:00 -0700 Subject: [PATCH 07/18] Cleanup match failure handling Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_basics.py | 31 +++++++++++++++++++++++++ onnxscript/rewriter/_fusion_utils.py | 17 ++++++++++++++ onnxscript/rewriter/_rewrite_rule.py | 9 ++++++- onnxscript/rewriter/ort_fusions/sdpa.py | 9 +++---- 4 files changed, 59 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/_basics.py b/onnxscript/rewriter/_basics.py index a875626d3f..17ab0c97a4 100644 --- a/onnxscript/rewriter/_basics.py +++ b/onnxscript/rewriter/_basics.py @@ -16,6 +16,37 @@ import onnxscript.rewriter._rewrite_rule as _rewrite_rule +class MatchFailureInfo: + """Encapsulates information about a pattern match failure.""" + + def __init__( + self, + reason: str = "", + *failure_source: ir.Node | ir.Value, + ): + self.reason = reason + # failure_source is a tuple; convert to list for consistency + self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source + assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), ( + f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}" + ) + + def __str__(self): + return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})" + + +class MatchFailureError(MatchFailureInfo, Exception): + """Exception raised when a pattern match fails.""" + + def __init__( + self, + reason: str = "", + *failure_source: ir.Node | ir.Value, + ): + MatchFailureInfo.__init__(self, reason, *failure_source) + Exception.__init__(self, reason) + + class MatchResult: """The state object used by the pattern-matching algorithm. diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index 59bdf87bd0..5e886932bf 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -7,6 +7,7 @@ import onnxscript.ir as ir from onnxscript.ir.passes.common import shape_inference from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchFailureError Dim = Union[int, ir.SymbolicDim] @@ -24,6 +25,22 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) return True +def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]): + if val.shape is None: + raise MatchFailureError(f"The shape of {val} is unknown.") + if val.shape.rank() != len(shape): + raise MatchFailureError( + f"The rank of {val} ({val.shape.rank()} does not match the expected rank {len(shape)}." + ) + for i, (actual, expected) in enumerate(zip(val.shape, shape)): + if expected not in bindings: + bindings[expected] = actual # type: ignore[assignment] + elif actual != bindings[expected]: + raise MatchFailureError( + f"Dimenion {i} of {val} ({actual}) does not have expected size ({bindings[expected]})." + ) + + def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable: """ Apply the given fusion rules to the model and return the number of fusions applied. diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index bc90a92a21..603353aed3 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -174,7 +174,14 @@ def try_rewrite( if var.name is not None: if var.name not in match.bindings: match.bind(var.name, None) - check_match_result = self._condition_function(context, **match.bindings) + try: + check_match_result = self._condition_function(context, **match.bindings) + except _basics.MatchFailureException as e: + check_match_result = _basics.MatchResult() + check_match_result.fail( + e.reason, + list(e.failure_nodes_and_values), + ) if not check_match_result: # If check function was provided, but it failed, return the reason for failure to the tracer. if isinstance(check_match_result, _basics.MatchResult): diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 0791814daa..70a7089dca 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -164,12 +164,9 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: # which is usually achieved via tiling/expanding K/V num-heads to match Q num-heads. # Query and Key should have same head-size (Dh) while value can have different head-size (Dv). # Key and Value should have same sequence length (Skv), while Query can have different sequence length (S). - if no_match(query, ["B", "H", "S", "Dh"]): - return False - if no_match(key_transposed, ["B", "H", "Dh", "Skv"]): - return False - if no_match(value, ["B", "H", "Skv", "Dv"]): - return False + _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) + _fusion_utils.check_shape(bindings, key_transposed, ["B", "H", "Dh", "Skv"]) + _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) return check_result From 938e5d0d4e1071ef1c8dd4b65ce9003379640778 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 29 May 2025 12:33:03 -0700 Subject: [PATCH 08/18] Add value to match error Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_fusion_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index 5e886932bf..cb232d10ed 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -27,17 +27,19 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]): if val.shape is None: - raise MatchFailureError(f"The shape of {val} is unknown.") + raise MatchFailureError(f"The shape of {val} is unknown.", val) if val.shape.rank() != len(shape): raise MatchFailureError( - f"The rank of {val} ({val.shape.rank()} does not match the expected rank {len(shape)}." + f"The rank of {val} ({val.shape.rank()} does not match the expected rank {len(shape)}.", + val, ) for i, (actual, expected) in enumerate(zip(val.shape, shape)): if expected not in bindings: bindings[expected] = actual # type: ignore[assignment] elif actual != bindings[expected]: raise MatchFailureError( - f"Dimenion {i} of {val} ({actual}) does not have expected size ({bindings[expected]})." + f"Dimenion {i} of {val} ({actual}) does not have expected size ({bindings[expected]}).", + val, ) From d0eb6fa3e644c7eac714770c03b22bf86326ead7 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 29 May 2025 14:51:24 -0700 Subject: [PATCH 09/18] Cleanup duplicated code Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/sdpa.py | 77 +++++++------------------ 1 file changed, 20 insertions(+), 57 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 70a7089dca..06a84f6e51 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -7,6 +7,7 @@ from onnxscript import ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern +from onnxscript.rewriter._basics import MatchFailureError Dim = Union[int, ir.SymbolicDim] @@ -75,66 +76,28 @@ def check( key_transposed: ir.Value | None, value: ir.Value | None, mask: ir.Value | None, - query_scaling: str, - query_scale: ir.Value | None, - key_scaling: str, - key_scale: ir.Value | None, - qk_scaling: str, - qk_scale: ir.Value | None, - **_, + **match_bindings, ): check_result = pattern.MatchResult() - if query_scaling == "None": - query_scale_value = 1.0 - elif query_scaling == "Mul": - if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None: - return check_result.fail( - "Query scale is not a scalar.", - query_scale, - ) - else: - assert query_scaling == "Div", "Unexpected query scaling operation" - if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None: - return check_result.fail( - "Query scale is not a scalar.", - query_scale, - ) - query_scale_value = 1.0 / query_scale_value - - if key_scaling == "None": - key_scale_value = 1.0 - elif key_scaling == "Mul": - if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None: - return check_result.fail( - "Key scale is not a scalar.", - key_scale, - ) - else: - assert key_scaling == "Div", "Unexpected key scaling operation" - if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None: - return check_result.fail( - "Key scale is not a scalar.", - key_scale, - ) - key_scale_value = 1.0 / key_scale_value - - if qk_scaling == "None": - qk_scale_value = 1.0 - elif qk_scaling == "Mul": - if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None: - return check_result.fail( - "QK scale is not a scalar.", - qk_scale, - ) - else: - assert qk_scaling == "Div", "Unexpected QK scaling operation" - if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None: - return check_result.fail( - "QK scale is not a scalar.", - qk_scale, - ) - qk_scale_value = 1.0 / qk_scale_value + def get_scale_value(tag_name: str, scale_name: str) -> float: + scaling_type = match_bindings.get(tag_name, "None") + if scaling_type == "None": + return 1.0 + else: + scale = match_bindings.get(scale_name) + value = _ir_utils.get_singleton_value(scale) + if value is None: + raise MatchFailureError(f"{scale_name} is not a scalar.", scale) + if scaling_type == "Mul": + return value + else: + assert scaling_type == "Div", f"Unexpected {scale_name} scaling operation" + return 1.0 / value + + query_scale_value = get_scale_value("query_scaling", "query_scale") + key_scale_value = get_scale_value("key_scaling", "key_scale") + qk_scale_value = get_scale_value("qk_scaling", "qk_scale") self._scale = query_scale_value * key_scale_value * qk_scale_value From b1450173917497e4a1c497cd6812a102e1e7f132 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 29 May 2025 14:53:39 -0700 Subject: [PATCH 10/18] Remove outdated comment Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_basics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/rewriter/_basics.py b/onnxscript/rewriter/_basics.py index 17ab0c97a4..c03737b16a 100644 --- a/onnxscript/rewriter/_basics.py +++ b/onnxscript/rewriter/_basics.py @@ -25,7 +25,6 @@ def __init__( *failure_source: ir.Node | ir.Value, ): self.reason = reason - # failure_source is a tuple; convert to list for consistency self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), ( f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}" From c5ae7f4091804d397944879a591f9d441be31e78 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 29 May 2025 14:59:05 -0700 Subject: [PATCH 11/18] Fix renaming Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_rewrite_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 603353aed3..1bebc58943 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -176,7 +176,7 @@ def try_rewrite( match.bind(var.name, None) try: check_match_result = self._condition_function(context, **match.bindings) - except _basics.MatchFailureException as e: + except _basics.MatchFailureError as e: check_match_result = _basics.MatchResult() check_match_result.fail( e.reason, From 2cd57fb6f198589ea43eb75672dfd45420fa8a00 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 30 May 2025 20:58:36 -0700 Subject: [PATCH 12/18] Update onnxscript/rewriter/_fusion_utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/rewriter/_fusion_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index cb232d10ed..7b80bdfefa 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -38,7 +38,7 @@ def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]): bindings[expected] = actual # type: ignore[assignment] elif actual != bindings[expected]: raise MatchFailureError( - f"Dimenion {i} of {val} ({actual}) does not have expected size ({bindings[expected]}).", + f"Dimension {i} of {val} ({actual}) does not have expected size ({bindings[expected]}).", val, ) From ae4ff4ac6e1e9cdfdb788e51297dfd9589dd74d6 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 30 May 2025 20:59:45 -0700 Subject: [PATCH 13/18] Potential fix for code scanning alert no. 17154: Unused local variable Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- onnxscript/rewriter/ort_fusions/sdpa.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 06a84f6e51..642aab8dac 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -118,8 +118,7 @@ def get_scale_value(tag_name: str, scale_name: str) -> float: # TODO: check ranks/shapes bindings: dict[str, Dim] = {} - def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) +# Removed unused local variable `no_match`. # Check that query/key/value have the expected shapes: # They all should have same batch-size (B) and number of heads (H). Conceptually, it is From ba623234a461f6557581bbf17026ae0669bddf2f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 30 May 2025 21:28:40 -0700 Subject: [PATCH 14/18] Fix handling of unknown headsize Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/sdpa.py | 39 +++++++++++-------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 642aab8dac..b8b5f93cf2 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -3,7 +3,7 @@ from __future__ import annotations import math -from typing import Sequence, Union +from typing import Union from onnxscript import ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern @@ -80,6 +80,18 @@ def check( ): check_result = pattern.MatchResult() + bindings: dict[str, Dim] = {} + + # Check that query/key/value have the expected shapes: + # They all should have same batch-size (B) and number of heads (H). Conceptually, it is + # different for Q and K/V, but the certain op implementations require them to be the same, + # which is usually achieved via tiling/expanding K/V num-heads to match Q num-heads. + # Query and Key should have same head-size (Dh) while value can have different head-size (Dv). + # Key and Value should have same sequence length (Skv), while Query can have different sequence length (S). + _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) + _fusion_utils.check_shape(bindings, key_transposed, ["B", "H", "Dh", "Skv"]) + _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) + def get_scale_value(tag_name: str, scale_name: str) -> float: scaling_type = match_bindings.get(tag_name, "None") if scaling_type == "None": @@ -103,33 +115,16 @@ def get_scale_value(tag_name: str, scale_name: str) -> float: # If the scaling factor is the default one, we can skip passing it to SDPA. - if query is None or query.shape is None or len(query.shape) < 2: - return - hidden_size = query.shape[-1] - if not isinstance(hidden_size, int): - return + head_size = bindings["Dh"] + if not isinstance(head_size, int): + return check_result - default_scaling_factor = 1.0 / math.sqrt(hidden_size) + default_scaling_factor = 1.0 / math.sqrt(head_size) if math.isclose(self._scale, default_scaling_factor, rel_tol=1e-5, abs_tol=1e-8): # Pass no scaling factor to SDPA, SDPA will use the default scaling factor self._scale = None - # TODO: check ranks/shapes - bindings: dict[str, Dim] = {} - -# Removed unused local variable `no_match`. - - # Check that query/key/value have the expected shapes: - # They all should have same batch-size (B) and number of heads (H). Conceptually, it is - # different for Q and K/V, but the certain op implementations require them to be the same, - # which is usually achieved via tiling/expanding K/V num-heads to match Q num-heads. - # Query and Key should have same head-size (Dh) while value can have different head-size (Dv). - # Key and Value should have same sequence length (Skv), while Query can have different sequence length (S). - _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) - _fusion_utils.check_shape(bindings, key_transposed, ["B", "H", "Dh", "Skv"]) - _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) - return check_result def rewrite( From 2a46711d5b47a8ecf8de962e10251b6c32bfbd46 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Mon, 2 Jun 2025 14:32:46 -0700 Subject: [PATCH 15/18] Update onnxscript/rewriter/ort_fusions/sdpa.py Co-authored-by: Ti-Tai Wang --- onnxscript/rewriter/ort_fusions/sdpa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index b8b5f93cf2..5d9ab87866 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -5,7 +5,7 @@ import math from typing import Union -from onnxscript import ir +import onnx_ir as ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern from onnxscript.rewriter._basics import MatchFailureError From 5792af3b7a1b6dcc02a132718c24fa1aa6a76dbe Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 3 Jun 2025 11:12:33 -0700 Subject: [PATCH 16/18] Add negative SDPA test case Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_rewrite_rule.py | 5 +-- onnxscript/rewriter/ort_fusions/sdpa.py | 1 + onnxscript/rewriter/ort_fusions/sdpa_test.py | 35 ++++++++++++++++++++ 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index b85bff5c1a..33f2aee8a5 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -178,10 +178,7 @@ def try_rewrite( check_match_result = self._condition_function(context, **match.bindings) except _basics.MatchFailureError as e: check_match_result = _basics.MatchResult() - check_match_result.fail( - e.reason, - list(e.failure_nodes_and_values), - ) + check_match_result.fail(e.reason, list(e.failure_sources)) if not check_match_result: # If check function was provided, but it failed, return the reason for failure to the tracer. if isinstance(check_match_result, _basics.MatchResult): diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 5d9ab87866..0ce3378010 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -6,6 +6,7 @@ from typing import Union import onnx_ir as ir + from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern from onnxscript.rewriter._basics import MatchFailureError diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 76247766ff..88eec4fe5d 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -265,6 +265,34 @@ def get_ort_inputs(self): return self._ort_inputs +class InvalidSDPATestCase: + def __init__(self, script_func): + self.script_func = script_func + + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + qk_type = FLOAT[B, N, S, H] + # We broadcast value in the batch dimension, which is not supported by SDPA fusion + v_type = FLOAT[1, N, S, H] + mask_type = FLOAT[B, N, S, S] + model_proto = self.script_func.to_model_proto( + input_types=[qk_type, qk_type, v_type, mask_type], output_types=[qk_type] + ) + self._onnx_model = ir.serde.deserialize_model(model_proto) + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "value": numpy.random.rand(1, N, S, H).astype(numpy.float32), + "mask": numpy.random.rand(B, N, S, S).astype(numpy.float32), + } + self._ort_inputs = inputs + return self._ort_inputs + + class TestSDPAFusion(unittest.TestCase): @parameterized.parameterized.expand( [ @@ -322,6 +350,13 @@ def test_sdpa_fusion(self, name, script_func): # new_outputs = ort_run("optimized", model, inputs) # assert_allclose(new_outputs, original_outputs) + def test_invalid_sdpa_fusion_value_batch_dim(self): + test_case = InvalidSDPATestCase(_masked_pre_mul_sdpa_script) + model = test_case.get_onnx_model() + onnxscript.optimizer.optimize(model) + count = fuse_sdpa(model) + self.assertEqual(count, 0) + if __name__ == "__main__": unittest.main() From 2b17200335bde2fcf935211cfc5c804270df4fce Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 3 Jun 2025 11:35:37 -0700 Subject: [PATCH 17/18] Address PR feedaback Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/sdpa.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 0ce3378010..1ca4c3b1ff 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -140,6 +140,7 @@ def rewrite( sdpa_args = [query, key_transposed, value] if mask is not None: sdpa_args.append(mask) + # If the scale is None, SDPA will use the default scaling factor, which is 1/sqrt(head_size). return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion") From 556e7ff48db225b0dcee60f05e614cb8789f2371 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 3 Jun 2025 11:40:51 -0700 Subject: [PATCH 18/18] Address PR feedback Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_basics.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_basics.py b/onnxscript/rewriter/_basics.py index 65466d7aef..8ea8a24bb3 100644 --- a/onnxscript/rewriter/_basics.py +++ b/onnxscript/rewriter/_basics.py @@ -35,7 +35,13 @@ def __str__(self): class MatchFailureError(MatchFailureInfo, Exception): - """Exception raised when a pattern match fails.""" + """Exception raised when a pattern match fails. + + This makes it easier to handle match failures in a compositional way, + for example, during the condition-checking phase of a pattern match. + It allows us to define utility functions without having to check for + and propagate match failures explicitly. + """ def __init__( self,