Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/rewriter/onnx_fusions/_onnx_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import onnx_ir as ir

from onnxscript.rewriter.rules.fusion import _rms_normalization, _rotary_embedding
from onnxscript.rewriter.rules.fusion import _gqa, _rms_normalization, _rotary_embedding


def _get_onnx_opset_version(model: ir.Model) -> int | None:
Expand All @@ -24,6 +24,7 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
counts: dict[str, int] = {}
counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug)
counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug)
counts["GQA"] = _gqa.fuse_gqa(model, debug=debug)
return counts


Expand Down
113 changes: 113 additions & 0 deletions onnxscript/rewriter/rules/fusion/_gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from typing import Union

import onnx_ir as ir

import onnxscript.rewriter._fusion_utils as _fusion_utils
from onnxscript.rewriter import _basics, pattern

Dim = Union[int, ir.SymbolicDim]


class OnnxGroupQueryAttention(pattern.RewriteRuleClassBase):
def __init__(self):
super().__init__("ONNXGQA", remove_nodes=False)

def pattern(
self,
op,
query_BHSD,
key_BHkvSD,
value_BHkvSD,
past_key_BHkvSpD,
past_value_BHkvSpD,
):
# Concatenate past_key cache and current key, expand across heads
# that share key/value.

present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2)
present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2)
present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE)
present_key_BHStD = op.Reshape(
present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"]
)

# Concatenate past_value cache and current value, expand across heads
# that share key/value.
present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2)
present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2)
present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE)
present_value_BHStD = op.Reshape(
present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"]
)

attention_BHSDh = op.Attention(
query_BHSD,
present_key_BHStD,
present_value_BHStD,
pattern.Var("mask", can_match_none=True),
_outputs=["attention_BHSDh"],
)

return attention_BHSDh

def check(
self,
context: _basics.MatchContext,
query_BHSD,
key_BHkvSD,
value_BHkvSD,
past_key_BHkvSpD,
past_value_BHkvSpD,
present_key_BHStD,
present_value_BHStD,
**_,
):
bindings: dict[str, Dim] = {}
# Check that inputs to new Attention node have expected shapes
_fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"])
_fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"])
_fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"])
_fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"])
_fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"])
# We need to check that the Expand/Reshape arguments are as expected.
# As a substitute, we check that the outputs of Expand=>Reshape have expected shapes.
# TODO (rama): May be better to check the actual Expand/Reshape arguments.
_fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"])
_fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"])

return True

def rewrite(
self,
op,
query_BHSD,
key_BHkvSD,
value_BHkvSD,
past_key_BHkvSpD,
past_value_BHkvSpD,
mask,
attention_BHSDh,
**_,
):
original_attention_node = attention_BHSDh.producer()
original_attrs = original_attention_node.attributes
return op.Attention(
query_BHSD,
key_BHkvSD,
value_BHkvSD,
mask,
past_key_BHkvSpD,
past_value_BHkvSpD,
**original_attrs,
)


_basic_gqa_rule = OnnxGroupQueryAttention.rule()

gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])

fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)
97 changes: 97 additions & 0 deletions onnxscript/rewriter/rules/fusion/_gqa_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import unittest

import onnx
import onnx_ir as ir
from packaging import version

import onnxscript
import onnxscript.optimizer
import onnxscript.rewriter.testing
from onnxscript import FLOAT, script
from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa

op = onnxscript.values.Opset("", 23)

H = [8] # Number of attention heads
Hkv = [4] # Number of key/value heads (H should be divisible by Hkv)
D = [64] # Head size
G = [2] # Number of groups


@script(ir_version=10)
def _gqa_script(
query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64
key_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64
value_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64
past_key_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64
past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64
) -> FLOAT[2, 8, 4, 64]:
"""Basic GQA pattern that should be fused into an Attention op."""

# Concatenate past_key cache and current key
present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D]

# Unsqueeze to add group dimension
present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D]

# Calculate shapes dynamically
B = op.Shape(query_BHSD, start=0, end=1) # [B]
T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P]

# Create expand shape [B, Hkv, G, S+P, D]
expand_shape = op.Concat(B, Hkv, G, T, D, axis=0)
present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D]

# Create reshape shape [B, H, S+P, D]
reshape_shape = op.Concat(B, H, T, D, axis=0)
present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D]

# Same for value
present_value_BHkvStD = op.Concat(
past_value_BHkvPD, value_BHkvSD, axis=-2
) # [B, Hkv, S+P, D]
present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D]
present_value_BHkvGStD = op.Expand(
present_value_BHkv1StD, expand_shape
) # [B, Hkv, G, S+P, D]
present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D]

# Attention computation
attention_BHSDh = op.Attention(
query_BHSD,
present_key_BHStD,
present_value_BHStD,
)

return attention_BHSDh


class GQAFusionTest(unittest.TestCase):
def test_basic_gqa_fusion(self):
"""Test basic GQA fusion pattern."""
model_proto = _gqa_script.to_model_proto()

# Apply GQA fusion
model = ir.serde.deserialize_model(model_proto)
onnxscript.optimizer.optimize(model)
count = fuse_gqa(model)
self.assertGreater(count, 0, "GQA fusion should have occurred")

# We can't yet test numerical equivalence because of a bug in the op spec/implementation.
onnx_ver = version.parse(onnx.__version__)
if onnx_ver >= version.parse("1.19.1") and not (
onnx_ver.is_prerelease or onnx_ver.is_devrelease
):
# Only official releases >= 1.19.1
onnxscript.optimizer.remove_unused_nodes(model)
rewritten_model_proto = ir.serde.serialize_model(model)
onnxscript.rewriter.testing.assert_numerically_equal(
model_proto, rewritten_model_proto, use_reference=True
)


if __name__ == "__main__":
unittest.main()
68 changes: 51 additions & 17 deletions onnxscript/rewriter/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import onnx
import onnx.reference
import onnxruntime as ort

from onnxscript import ir
Expand All @@ -32,10 +33,11 @@ def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]:
def assert_numerically_equal(
original_model_proto: onnx.ModelProto | ir.Model,
rewritten_model_proto: onnx.ModelProto | ir.Model,
args: tuple[Any, ...] | dict[str, Any],
args: tuple[Any, ...] | dict[str, Any] | None = None,
ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
rtol: float = 1,
atol: float = 1e-3,
use_reference: bool = False,
):
"""Assert that the two models are numerically equal.

Expand All @@ -46,14 +48,18 @@ def assert_numerically_equal(
ort_optimization_level: Onnxruntime optimization level.
rtol: Relative tolerance.
atol: Absolute tolerance.
use_reference: If True, use ONNX reference implementation instead of ONNXRuntime.
"""

if isinstance(original_model_proto, ir.Model):
original_model_proto = ir.serde.serialize_model(original_model_proto)
if isinstance(rewritten_model_proto, ir.Model):
rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto)

if isinstance(args, dict):
if args is None:
original_proto_ort_inputs = generate_random_inputs(original_model_proto)
the_rewritten_proto_ort_inputs = original_proto_ort_inputs
elif isinstance(args, dict):
original_proto_ort_inputs = args
the_rewritten_proto_ort_inputs = args
else:
Expand All @@ -64,21 +70,34 @@ def assert_numerically_equal(
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
}

original_proto_ort_inference_session = _ort_session_initializer(
original_model_proto.SerializeToString(), ort_optimization_level
)
run_options = ort.RunOptions()
run_options.log_severity_level = 3 # 3: Error
original_outputs = original_proto_ort_inference_session.run(
None, original_proto_ort_inputs, run_options=run_options
)

the_rewritten_proto_ort_inference_session = _ort_session_initializer(
rewritten_model_proto.SerializeToString(), ort_optimization_level
)
the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run(
None, the_rewritten_proto_ort_inputs, run_options=run_options
)
if use_reference:
# Use ONNX reference implementation
original_evaluator = _reference_session(
original_model_proto.SerializeToString(), ort_optimization_level
)
original_outputs = original_evaluator.run(None, original_proto_ort_inputs)

rewritten_evaluator = _reference_session(
rewritten_model_proto.SerializeToString(), ort_optimization_level
)
the_rewritten_outputs = rewritten_evaluator.run(None, the_rewritten_proto_ort_inputs)
else:
# Use ONNXRuntime
original_proto_ort_inference_session = _ort_session_initializer(
original_model_proto.SerializeToString(), ort_optimization_level
)
run_options = ort.RunOptions()
run_options.log_severity_level = 3 # 3: Error
original_outputs = original_proto_ort_inference_session.run(
None, original_proto_ort_inputs, run_options=run_options
)

the_rewritten_proto_ort_inference_session = _ort_session_initializer(
rewritten_model_proto.SerializeToString(), ort_optimization_level
)
the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run(
None, the_rewritten_proto_ort_inputs, run_options=run_options
)

np.testing.assert_allclose(
original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True
Expand All @@ -103,3 +122,18 @@ def _ort_session_initializer(
provider for provider in possible_providers if provider in available_providers
]
return ort.InferenceSession(model, providers=providers, sess_options=session_options)


def _reference_session(
model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel
) -> onnx.reference.ReferenceEvaluator:
"""Initialize an ONNX reference evaluator with the specified model."""
# Parse the model from bytes if needed
if isinstance(model, (str, bytes)):
model_proto = onnx.load_from_string(model)
else:
model_proto = model

# Note: ort_optimization_level is ignored for reference implementation
# as it doesn't have equivalent optimization levels
return onnx.reference.ReferenceEvaluator(model_proto)
Loading