Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rotary embedding fusion rule (part 1) #1981

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,28 @@
return default


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
shape = _get_input(node, 1)
if input is None or shape is None:
return None

Check warning on line 314 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L314

Added line #L314 was not covered by tests
input_shape = input.shape
if input_shape is None:
return None

Check warning on line 317 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L317

Added line #L317 was not covered by tests
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None

Check warning on line 320 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L320

Added line #L320 was not covered by tests
shape_value = _get_numpy_value(shape)
if shape_value is None:
return None

Check warning on line 323 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L323

Added line #L323 was not covered by tests
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
return op.Identity(input)
return None


@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
Expand Down
27 changes: 27 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Licensed under the MIT License.
from __future__ import annotations

import math
from typing import Callable

import numpy as np

import onnxscript.ir as ir
Expand Down Expand Up @@ -77,3 +80,27 @@
if np_val is not None and np_val.size == 1:
return np_val.item()
return None


def is_singleton_value(
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
) -> bool:
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
scalar = get_singleton_value(val)
if scalar is None:
return False

Check warning on line 91 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L91

Added line #L91 was not covered by tests
if isinstance(expected, Callable):
Fixed Show fixed Hide fixed
return expected(scalar)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
if isinstance(expected, int):
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rtol=rtol)

Check warning on line 98 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L97-L98

Added lines #L97 - L98 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def has_rank(value: ir.Value | None, rank: int) -> bool:
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
if value is None:
return False

Check warning on line 104 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L104

Added line #L104 was not covered by tests
shape = value.shape
return (shape is not None) and (shape.rank() == rank)
12 changes: 12 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization

__all__ = [
"fuse_rms_normalization",
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
]
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnxruntime/xformers/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ort_run(model_name: str, model, inputs):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)
_save(model, model_path)
# Run model
session = onnxruntime.InferenceSession(model_path, providers=providers)
ort_outputs = session.run(None, inputs)
Expand Down
90 changes: 90 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import remove_unused_nodes
from onnxscript.rewriter import _ir_utils, pattern

# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops.

# Original code (from transformers) for computing cos/sin cache for RoPE:
# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135
# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# position_ids_expanded = position_ids[:, None, :].float()
# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# emb = torch.cat((freqs, freqs), dim=-1)
# cos = emb.cos()
# sin = emb.sin()


class CosSinCacheFusion(pattern.RewriteRuleClassBase):
def __init__(self, name: str, max_pos_id: int):
super().__init__(name)
self._max_pos_id = max_pos_id
self.remove_nodes = False
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads):
position_ids_expanded = op.Unsqueeze(position_ids, 1)
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq, position_ids_expanded)
freqs = op.Transpose(freqs, perm=[0, 2, 1])
emb = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, 1) # convert
sin_4d = op.Unsqueeze(sin, 1)
return op.RotaryEmbedding(
x,
cos_4d,
sin_4d,
interleaved=interleaved,
num_heads=num_heads,
_domain="ai.onnxruntime.fusion",
)

def check(self, context, inv_freq, position_ids, **_):
if not _ir_utils.has_rank(position_ids, 2):
return False

Check warning on line 50 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L50

Added line #L50 was not covered by tests
if not _ir_utils.has_rank(inv_freq, 3):
return False

Check warning on line 52 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L52

Added line #L52 was not covered by tests
inv_freq_shape = inv_freq.shape
if inv_freq.const_value is None:
return False

Check warning on line 55 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L55

Added line #L55 was not covered by tests
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1

def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_):
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
cos_value = np.concatenate([cos_value, cos_value], axis=-1)
sin_value = np.sin(angles)
sin_value = np.concatenate([sin_value, sin_value], axis=-1)
cos_2d = op.Constant(value=ir.tensor(cos_value))
# cos = op.Gather(cos_2d, position_ids, axis=0)
sin_2d = op.Constant(value=ir.tensor(sin_value))
# sin = op.Gather(sin_2d, position_ids, axis=0)
return op.RotaryEmbedding(
x,
position_ids,
cos_2d,
sin_2d,
interleaved=interleaved,
num_heads=num_heads,
_domain="com.microsoft",
)


_rule = CosSinCacheFusion.rule("CosSinCache", 2048)

cos_sin_cache_rules = pattern.RewriteRuleSet([_rule])


def fuse_cos_sin_cache(model: ir.Model) -> int:
count = cos_sin_cache_rules.apply_to_model(model)
print(f"CosSinCache count: {count}")
remove_unused_nodes(model)
return count
29 changes: 29 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run


class TestCosSinCacheTransform(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)
count = fuse_rotary_embedding(model)
Fixed Show fixed Hide fixed
self.assertGreater(count, 0)
count = fuse_cos_sin_cache(model)
self.assertGreater(count, 0)
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)


if __name__ == "__main__":
unittest.main()

Check warning on line 29 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py#L29

Added line #L29 was not covered by tests
8 changes: 2 additions & 6 deletions onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,10 @@ def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
cast_input: Whether to cast input to do the normalization in a different precision.
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
"""
self._name = name
super().__init__(name=name)
self._cast_input = cast_input
self._cast_normalized = cast_normalized

@property
def name(self):
return self._name

def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
if self._cast_input:
x = op.Cast(x, to=compute_dtype)
Expand Down Expand Up @@ -95,5 +91,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):


def fuse_rms_normalization(model: ir.Model) -> None:
count = rms_normalization_ruleset.apply_to_model(model, verbose=5)
count = rms_normalization_ruleset.apply_to_model(model)
print(f"RMS Normalization count: {count}")
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@

import unittest

import onnx

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization


def model_repr(self):
return f"Model({self.graph.name})"


onnx.ModelProto.__repr__ = model_repr


class TestRmsNormalization(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
Expand Down
64 changes: 64 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.rewriter import _ir_utils, pattern

# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern
# for full rotation without interleaving.
# TODO(rama): Add pattern variations to handle other cases.

# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet,
# so it can't be tested by running against ORT. Unfortunately, this is the new pattern out
# of current version of transformers (not yet supported by ORT).


def _rotate_half_pattern(op, x, start1, end1, start2, end2):
# Slice(input, starts, ends, axes, steps)
x1 = op.Slice(x, start1, end1, [3], [1])
x2 = op.Slice(x, start2, end2, [3], [1])
minus_x2 = op.Neg(x2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this logic is correct and makes sense, this doesn't match the function logic in the op definition. Is it correct to assume that the pattern logic should mimic the onnx function in the op schema?

Currently in the op schema, this pattern would look like after x1, x2 (which uses split instead of slice for non-interleaved case):

real = cos * x1 - sin * x2
imag = sin * x1 + cos * x2
rotated_x = op.Concat(real, imag)

So the concat happens after the multiplication

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not have to match the function logic in the op definition. But it has to match the function graph produced by the ONNX exporter from the logic defined in the source (eg., the transformers implementation).

But what we have to guarantee or ensure is that replacing this logic by the pattern in rewrite is fine: that they will both produce the same values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, it is more important to match the source logic like this transformer code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting all this together, there are 3 parts to these rewrite-rules:

  • the pattern should typically be aligned with the subgraph pattern we see in the ONNX graphs produced by the exporter (which itself depends on the source pytorch code).
  • the rewrite part is aligned with the (fused) op definition (existing in ORT or being introduced to ONNX).
  • the check condition has to be strong enough to guarantee that the replacement is sound. So, that we can be sure we will produce the same outputs with or without the optimization.

rotated_x = op.Concat(minus_x2, x1, axis=-1)
return rotated_x


class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

def check(self, op, x, start1, end1, start2, end2, **_):
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to schema, x can be a 3D tensor as well. And num_heads are necessary to be known in cases with 3D tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means the optimization is safe and correct (in this regard). To generalize and allow a 3D also here, we would need to guarantee that the entire-fusion is guaranteed to be semantically correct ... it is not enough to know that the RotaryEmbedding op permits 3D inputs.

What do you think about the correctness of this fusion optimization? Do you think it is fine to generalize and allow 3D here?

if x is None or x.shape is None or len(x.shape) != 4:
return False

Check warning on line 33 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L33

Added line #L33 was not covered by tests
if not isinstance(x.shape[1], int):
return False

Check warning on line 35 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L35

Added line #L35 was not covered by tests
head_size = x.shape[3]
if not isinstance(head_size, int):
return False

Check warning on line 38 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L38

Added line #L38 was not covered by tests
half_head_size = head_size // 2

# Check that x is being split into two equal halves of size half_head_size
return (
_ir_utils.is_singleton_value(start1, 0)
and _ir_utils.is_singleton_value(end1, half_head_size)
and _ir_utils.is_singleton_value(start2, half_head_size)
and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size)
)

def rewrite(self, op, x, cos, sin, **_):
num_heads = x.shape[1]
return op.RotaryEmbedding(
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious about the domain here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I am currently using these to split the fusion optimization into multiple stages. We may need to clean this up finally. For now, we also need to target the existing RotaryEmbedding op in ORT (which is what we can test against also). Eventually, we can target the new proposed RotaryEmbedding op ... so we may also need to support some variations in the fusion optimization (depending on target ORT/ONNX versions).

)


_rule = RotaryEmbeddingFusion.rule()

rotary_embedding_rules = pattern.RewriteRuleSet([_rule])


def fuse_rotary_embedding(model: ir.Model) -> int:
count = rotary_embedding_rules.apply_to_model(model)
print(f"Rotary Embedding count: {count}")
return count
23 changes: 23 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding


class TestRotaryEmbedding(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
fuse_rotary_embedding(model)
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)


if __name__ == "__main__":
unittest.main()

Check warning on line 23 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py#L23

Added line #L23 was not covered by tests
Loading
Loading