Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onnxscript to peephole pass #1530

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
"supported_providers": [ "*" ],
"supported_accelerators": [ "*" ],
"supported_precisions": [ "*" ],
"extra_dependencies": [ "onnxoptimizer" ]
"extra_dependencies": [ "onnxoptimizer", "onnxscript" ]
},
"OnnxOpVersionConversion": {
"module_path": "olive.passes.onnx.conversion.OnnxOpVersionConversion",
Expand Down
119 changes: 8 additions & 111 deletions olive/passes/onnx/peephole_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import copy
import logging
from pathlib import Path
from typing import Any, Dict, List

import numpy as np
import onnx
import onnxruntime as ort
from google.protobuf.message import Message
from onnx import TensorProto, helper, numpy_helper

from olive.hardware.accelerator import AcceleratorSpec
from olive.model import ONNXModelHandler
Expand Down Expand Up @@ -235,114 +232,14 @@ def fuse_reshape_operations(self):
logger.debug("Fused %d redundant Reshape operators", num_changed)
o_model.prune_graph()

def _find_initializer_by_name(self, model, name):
for initializer in model.graph.initializer:
if initializer.name == name:
return initializer
raise ValueError(f"No initializer named {name}")

def _find_value_info_proto_by_name(self, model, name):
"""Find the ValueInfoProto with the name name in the model's value_info."""
for vi in model.graph.value_info:
if vi.name == name:
return vi

for initializer in model.graph.initializer:
if initializer.name == name:
return helper.make_tensor_value_info(name, initializer.data_type, initializer.dims)

for graph_input in model.graph.input:
if graph_input.name == name:
return graph_input

raise ValueError(f"No value info proto named {name}")

def _run_op(self, model, op):
input_names = set()

op_model = onnx.ModelProto()
op_model.ir_version = model.ir_version
op_model.producer_name = "constant_folding"
op_model.opset_import.extend(model.opset_import)
op_model.graph.name = "ConstantFoldingGraph"
op_model.graph.node.extend([copy.deepcopy(op)])

for input_name in op.input:
if input_name and input_name not in input_names:
try:
initializer = self._find_initializer_by_name(model, input_name)
op_model.graph.initializer.append(copy.deepcopy(initializer))
vi = helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)
op_model.graph.input.append(vi)
except ValueError:
vi = self._find_value_info_proto_by_name(model, input_name)
op_model.graph.input.append(copy.deepcopy(vi))
input_names.add(input_name)

for output_name in op.output:
vi = helper.make_tensor_value_info(output_name, TensorProto.UNDEFINED, [])
op_model.graph.output.append(vi)

return op_model

def _run_onnx_model(self, model):
session = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
input_dict = {}
for model_input in session.get_inputs():
name = model_input.name
tensor = self._find_initializer_by_name(model, name)
input_dict[name] = numpy_helper.to_array(tensor)
return session.run(None, input_dict)

def _get_constant_nodes(self, model):
dynamic_inputs = {graph_input.name for graph_input in model.graph.input}
const_inputs = {
initializer.name for initializer in model.graph.initializer if initializer.name not in dynamic_inputs
}
const_nodes = []
for node in model.graph.node:
if all(node_input == "" or node_input in const_inputs for node_input in node.input):
const_nodes.append(node)
const_inputs.update(node.output)
return const_nodes

def fold_constant(self):
while True:
const_nodes = self._get_constant_nodes(self.model)
if not const_nodes:
break

nodes_to_remove = []
failed_nodes = set()

for node in const_nodes:
if node.name in failed_nodes:
continue

try:
op_model = self._run_op(self.model, node)
outputs = self._run_onnx_model(op_model)
for output_array, name in zip(outputs, node.output):
if any(init.name == name for init in self.model.graph.initializer):
continue
tensor = numpy_helper.from_array(output_array, name)
self.model.graph.initializer.append(tensor)
vi = helper.make_tensor_value_info(name, tensor.data_type, tensor.dims)
self.model.graph.value_info.append(vi)
nodes_to_remove.append(node)
except Exception as e:
logger.warning("Failed to run %s op (name is %s): %s, skip...", node.op_type, node.name, e)
failed_nodes.add(node.name)

if not nodes_to_remove and const_nodes:
logger.warning(
"Failed to fold constants for the following nodes: %s",
", ".join(node.name for node in const_nodes),
)
break
def onnxscript_optimize(self):
try:
import onnxscript
except ImportError:
logger.warning("Please install onnxscript to use the ONNX optimizer feature. Skip onnxscript optimization.")
return

for node in nodes_to_remove:
self.model.graph.node.remove(node)
onnxscript.optimizer.optimize(self.model)


class OnnxPeepholeOptimizer(Pass):
Expand Down Expand Up @@ -379,7 +276,7 @@ def _run_for_config(
peephole_optimizer.fuse_transpose_qat()
peephole_optimizer.patch_unsupported_argmax_operator()
peephole_optimizer.fuse_reshape_operations()
peephole_optimizer.fold_constant()
peephole_optimizer.onnxscript_optimize()

if config["onnxoptimizer"]:
try:
Expand Down
89 changes: 19 additions & 70 deletions test/unit_test/passes/onnx/test_peephole_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def _make_model_for_patch_unsupported_argmax_operator(
return model_proto_to_olive_model(model, filepath, config)


def test_onnx_peephole_optimizer_pass_patch_unsupported_argmax_operator_modified(tmp_path, external_data_config):
@patch("onnxscript.optimizer.optimize")
def test_onnx_peephole_optimizer_pass_patch_unsupported_argmax_operator_modified(
mock_optimize, tmp_path, external_data_config
):
m = _make_model_for_patch_unsupported_argmax_operator(
TensorProto.INT64, str(tmp_path / "input.onnx"), external_data_config
)
Expand Down Expand Up @@ -189,6 +192,21 @@ def test_onnx_peephole_optimizer_pass_fuse_reshape_operations(tmp_path, external
assert others_op_count == 0


@patch("olive.passes.onnx.peephole_optimizer.model_proto_to_olive_model")
@patch("onnxscript.optimizer.optimize")
def test_onnxscript(mock_optimize, mock_model_proto_to_olive_model, tmp_path):
# setup
input_model = get_onnx_model()
p = create_pass_from_dict(OnnxPeepholeOptimizer, {}, disable_search=True)
output_folder = str(tmp_path / "onnx")

# execute
p.run(input_model, output_folder)

# assert
mock_optimize.assert_called_once_with(input_model.load_model())


@patch("olive.passes.onnx.peephole_optimizer.model_proto_to_olive_model")
@patch("onnxoptimizer.optimize")
def test_onnxoptimizer(mock_optimize, mock_model_proto_to_olive_model, tmp_path):
Expand All @@ -212,72 +230,3 @@ def test_onnxoptimizer(mock_optimize, mock_model_proto_to_olive_model, tmp_path)

# assert
mock_optimize.assert_called_once_with(input_model.load_model(), passes, fixed_point)


def test_onnx_peephole_optimizer_pass_constant_folding(tmp_path, external_data_config):
import numpy as np

# setup
model = _get_onnx_model_with_constant(str(tmp_path / "input.onnx"), external_data_config)
input_model = model.load_model()
input_constant_nodes = [node for node in input_model.graph.node if node.op_type == "Constant"]
input_initializer_names = {initializer.name for initializer in input_model.graph.initializer}
p = create_pass_from_dict(
OnnxPeepholeOptimizer, None, disable_search=True, accelerator_spec=DEFAULT_CPU_ACCELERATOR
)

# execute
output_model = p.run(model, tmp_path / "onnx")

# assert
assert Path(output_model.model_path).exists()
output_model = output_model.load_model()
output_constant_nodes = [node for node in output_model.graph.node if node.op_type == "Constant"]
output_initializer_names = {initializer.name for initializer in output_model.graph.initializer}
assert len(output_constant_nodes) < len(input_constant_nodes), "Constant nodes were not folded."
assert len(output_initializer_names) > len(input_initializer_names), "Initializers were not updated."
inputs = {"input": np.random.rand(1, 3).astype(np.float32)}
original_outputs = _run_onnx_model(input_model, inputs)
optimized_outputs = _run_onnx_model(output_model, inputs)
assert np.allclose(
original_outputs, optimized_outputs, atol=1e-6
), "Outputs are not consistent after constant folding."


def _get_onnx_model_with_constant(model_path, external_data_config):
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3])
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3])
const_tensor = helper.make_tensor(
name="const_add",
data_type=TensorProto.FLOAT,
dims=[1, 3],
vals=[1.0, 2.0, 3.0],
)
const_node = helper.make_node(
"Constant",
inputs=[],
outputs=["const"],
value=const_tensor,
)
add_node = helper.make_node(
"Add",
inputs=["input", "const"],
outputs=["output"],
)
graph_def = helper.make_graph(
nodes=[const_node, add_node],
name="ConstantFoldingGraph",
inputs=[input_tensor],
outputs=[output_tensor],
initializer=[],
)
opset_imports = [helper.make_operatorsetid("", 21)]
model = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=opset_imports)
return model_proto_to_olive_model(model, model_path, external_data_config)


def _run_onnx_model(model, inputs):
import onnxruntime as ort

session = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
return session.run(None, inputs)
Loading