Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rotary embedding fusion rule (part 1) #1981

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open

Conversation

gramalingam
Copy link
Collaborator

@gramalingam gramalingam commented Dec 18, 2024

Initial version of fusion for rotary embedding.

Limitations: currently addresses only non-interleaved and full rotation.

Other:

  • Add support for rewriting rules where the matched nodes are not removed. Useful in cases where matched nodes include some shared nodes.
  • Add optimization to eliminate redundant Reshape (helps simplify pattern).

Copy link

codecov bot commented Dec 18, 2024

❌ 15 Tests Failed:

Tests completed Failed Passed Skipped
16840 15 16825 3780
View the top 1 failed tests by shortest run time
::onnxscript.rewriter.onnxruntime.xformers._test_models
Stack Traces | 0s run time
No failure message available
View the full list of 2 ❄️ flaky tests
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_input_and_attribute_by_kwargs_out_of_order

Flake rate in main: 39.27% (Passed 12625 times, Failed 8164 times)

Stack Traces | 0.002s run time
..../test_onnx_weekly/lib/python3.11.../reference/ops/_op.py:91: in run
    res = self._run(x, y)
..../test_onnx_weekly/lib/python3.11.../reference/ops/_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
..../test_onnx_weekly/lib/python3.11.../onnx/reference/custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests/eager_mode_test.py:115: in test_function_input_and_attribute_by_kwargs_out_of_order
    self.assertEqual(add_with_alpha(alpha=3.0, other=2.0, this=1.0), 7.0)
onnxscript/values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript/evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests/eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
.../onnx_opset/_impl/opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript/values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript/evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript/evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
..../test_onnx_weekly/lib/python3.11.../onnx/reference/reference_evaluator.py:593: in run
    outputs = node.run(*inputs, **linked_attributes)
..../test_onnx_weekly/lib/python3.11.../reference/ops/_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
..../test_onnx_weekly/lib/python3.11.../reference/ops/_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_some_input_by_kwargs

Flake rate in main: 39.27% (Passed 12625 times, Failed 8164 times)

Stack Traces | 0.002s run time
..../test_onnx_weekly/lib/python3.11.../reference/ops/_op.py:91: in run
    res = self._run(x, y)
..../test_onnx_weekly/lib/python3.11.../reference/ops/_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
..../test_onnx_weekly/lib/python3.11.../onnx/reference/custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests/eager_mode_test.py:106: in test_function_some_input_by_kwargs
    self.assertEqual(add_with_alpha(1.0, other=2.0), 3.0)
onnxscript/values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript/evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests/eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
.../onnx_opset/_impl/opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript/values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript/evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript/evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
..../test_onnx_weekly/lib/python3.11.../onnx/reference/reference_evaluator.py:593: in run
    outputs = node.run(*inputs, **linked_attributes)
..../test_onnx_weekly/lib/python3.11.../reference/ops/_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
..../test_onnx_weekly/lib/python3.11.../reference/ops/_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').

To view more test analytics, go to the Test Analytics Dashboard
📢 Thoughts on this report? Let us know!

onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
@gramalingam gramalingam changed the title Add rotary embedding fusion rule (part 1) [Draft - WIP] Add rotary embedding fusion rule (part 1) Dec 20, 2024
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
@gramalingam gramalingam changed the title [Draft - WIP] Add rotary embedding fusion rule (part 1) Add rotary embedding fusion rule (part 1) Dec 23, 2024
# Slice(input, starts, ends, axes, steps)
x1 = op.Slice(x, start1, end1, [3], [1])
x2 = op.Slice(x, start2, end2, [3], [1])
minus_x2 = op.Neg(x2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this logic is correct and makes sense, this doesn't match the function logic in the op definition. Is it correct to assume that the pattern logic should mimic the onnx function in the op schema?

Currently in the op schema, this pattern would look like after x1, x2 (which uses split instead of slice for non-interleaved case):

real = cos * x1 - sin * x2
imag = sin * x1 + cos * x2
rotated_x = op.Concat(real, imag)

So the concat happens after the multiplication

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not have to match the function logic in the op definition. But it has to match the function graph produced by the ONNX exporter from the logic defined in the source (eg., the transformers implementation).

But what we have to guarantee or ensure is that replacing this logic by the pattern in rewrite is fine: that they will both produce the same values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, it is more important to match the source logic like this transformer code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting all this together, there are 3 parts to these rewrite-rules:

  • the pattern should typically be aligned with the subgraph pattern we see in the ONNX graphs produced by the exporter (which itself depends on the source pytorch code).
  • the rewrite part is aligned with the (fused) op definition (existing in ORT or being introduced to ONNX).
  • the check condition has to be strong enough to guarantee that the replacement is sound. So, that we can be sure we will produce the same outputs with or without the optimization.

return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

def check(self, op, x, start1, end1, start2, end2, **_):
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to schema, x can be a 3D tensor as well. And num_heads are necessary to be known in cases with 3D tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means the optimization is safe and correct (in this regard). To generalize and allow a 3D also here, we would need to guarantee that the entire-fusion is guaranteed to be semantically correct ... it is not enough to know that the RotaryEmbedding op permits 3D inputs.

What do you think about the correctness of this fusion optimization? Do you think it is fine to generalize and allow 3D here?

def rewrite(self, op, x, cos, sin, **_):
num_heads = x.shape[1]
return op.RotaryEmbedding(
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious about the domain here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I am currently using these to split the fusion optimization into multiple stages. We may need to clean this up finally. For now, we also need to target the existing RotaryEmbedding op in ORT (which is what we can test against also). Eventually, we can target the new proposed RotaryEmbedding op ... so we may also need to support some variations in the fusion optimization (depending on target ORT/ONNX versions).

@gramalingam gramalingam enabled auto-merge (squash) January 2, 2025 20:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

2 participants