Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx][importer] Add support for externalized params #18880

Merged
merged 9 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ declare_mlir_python_sources(IREECompilerAPIPythonTools
tools/tf.py
tools/tflite.py
tools/import_onnx/__main__.py
tools/import_onnx/importer_externalization_overrides.py
tools/ir_tool/__main__.py
tools/scripts/iree_compile/__main__.py
tools/scripts/iree_opt/__main__.py
Expand Down
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I won't have time to review this again for a few days. One thing that could be worked on in the meantime is documentation at https://iree.dev/guides/ml-frameworks/onnx/ for parameter usage, similar to https://iree.dev/guides/ml-frameworks/pytorch/#using-external-parameters. The source of that page is here: https://github.com/iree-org/iree/blob/main/docs/website/docs/guides/ml-frameworks/onnx.md

Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,43 @@

python -m iree.compiler.tools.import_onnx ...
"""

import argparse
import os
from pathlib import Path
import sys
import tempfile

try:
import onnx
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"iree-import-onnx requires that the `onnx` Python package is installed "
f"(typically `{sys.executable} -m pip install onnx`)"
) from e

try:
from ...extras import onnx_importer
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"iree-import-onnx is only available if IREE was built with Torch support"
) from e

from ...ir import (
Context,
)
from .importer_externalization_overrides import *


def main(args: argparse.Namespace):
model_proto = load_onnx_model(args)
context = Context()
model_info = onnx_importer.ModelInfo(model_proto)
m = model_info.create_module(context=context).operation
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)

imp: Any = None
if args.externalize_params:
imp = IREENodeImporter.define_function(
model_info.main_graph, m, args.num_elements_threshold, args.params_scope
)
else:
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
imp.import_all()

if not args.no_verify:
m.verify()

if args.externalize_params:
default_param_path = Path(args.output_file).parent / Path(args.output_file).stem
param_path = (
(str(default_param_path) + "_params.irpa")
if args.save_params_to is None
else str(args.save_params_to)
)
imp.param_archive.create_archive_file(param_path)

# TODO: This isn't very efficient output. If these files ever
# get large, enable bytecode and direct binary emission to save
# some copies.
Expand All @@ -71,6 +72,12 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, str(args.data_dir))

# Only change the opset version if it is greater than the current one.
if args.opset_version and args.opset_version > raw_model.opset_import[0].version:
raw_model = onnx.version_converter.convert_version(
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
raw_model, args.opset_version
)

# Do shape inference two ways. First, attempt in-memory to avoid redundant
# loading and the need for writing a temporary file somewhere. If that
# fails, typically because of the 2 GB protobuf size limit, try again via
Expand Down Expand Up @@ -132,6 +139,37 @@ def parse_arguments(argv=None) -> argparse.Namespace:
" Defaults to the directory of the input file.",
type=Path,
)
parser.add_argument(
"--opset-version",
help="Allows specification of a newer opset_version to update the model"
" to before importing to MLIR. This can sometime assist with shape inference.",
type=int,
)
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
"--num-elements-threshold",
help="Minimum number of elements for an initializer to be externalized. Only has an effect if 'externalize-params' is true.",
type=int,
default=100,
)
parser.add_argument(
"--externalize-params",
help="Externalize large parameters and store them on the disk, to load at runtime.",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--save-params-to",
help="Location to save the externalized parameters. When not set, the parameters will be written to '<output_file_name>_params.irpa'"
" under the namespace 'model', which can be configured by passing the namespace string to 'params-scope'.",
default=None,
type=Path,
)
parser.add_argument(
"--params-scope",
help="The namespace or the scope in which the externalized parameters are placed. Default is 'model'.",
type=str,
default="model",
)
args = parser.parse_args(argv)
return args

Expand Down
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import copy
import random
import string
import iree.runtime as rt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this during review, but this is the only place in the compiler Python bindings that depends on the runtime Python bindings. It is used for rt.ParameterIndex and create_archive_file. That makes

pip install iree-base-compiler[onnx]

insufficient for running iree-import-onnx now, as iree-base-runtime must also be installed

We could

  • have this check the import and fail gracefully
  • add a compiler API for creating parameter archives so the runtime API is not needed
  • have the compiler Python package depend on the runtime Python package

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on a few improvements in #19217. Will make a conditional import for this in that patch.

Copy link
Contributor Author

@vinayakdsci vinayakdsci Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ScottTodd I think it would be best to fail gracefully in case the import fails, at least for now. I am not very sure about how having the compiler python package depend on the runtime package should be implemented (maybe you could look at that if you find the time?), and adding a compiler API for creating parameter archives could possibly lead to code duplication IMO.

I can create a PR for the first option right away.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the first option, importing conditionally on whether "--externalize-parameters" is passed, with a message saying something like "externalizing parameters requires runtime api" would be the most efficient route.


from ...dialects import util
from typing import Optional, Tuple, Any

try:
import onnx
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"iree-import-onnx requires that the `onnx` Python package is installed "
f"(typically `{sys.executable} -m pip install onnx`)"
) from e

try:
from ...extras import onnx_importer
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"iree-import-onnx is only available if IREE was built with Torch support"
) from e

from onnx import numpy_helper

from ...ir import (
Context,
Type as IrType,
TypeAttr,
RankedTensorType,
StringAttr,
Attribute,
Operation,
Location,
InsertionPoint,
Value,
SymbolTable,
IntegerType,
)


class IREENodeImporter(onnx_importer.NodeImporter):
def __init__(
self,
graph_info: onnx_importer.GraphInfo,
*,
parent_op: Operation,
block: onnx_importer.Block,
context_cache: "onnx_importer.ContextCache",
module_op: Operation,
module_cache: "onnx_importer.ModuleCache",
num_elements_threshold: int,
params_scope: str,
):
super().__init__(
graph_info,
parent_op=parent_op,
block=block,
context_cache=context_cache,
module_op=module_op,
module_cache=module_cache,
)
self.last_global_op = None
self.symbol_table = SymbolTable(module_op)
self.symbol_table.insert(parent_op)
self.num_elements_threshold = num_elements_threshold
self.param_archive = rt.ParameterIndex()
self.params_scope = params_scope

def sanitize_name(self, name: str) -> str:
# There are often some initializers in the models that have no name
# labels, or contain substrings like '::', which can cause conflicts,
# and invalid symbol names for symbolic references. This function will
# remove substrings like '::' when the name is not empty, and generate
# a random string when it is, as a placeholder.
new_name: str = ""
for c in range(len(name)):
if name[c] == ":":
new_name += "_"
else:
new_name += name[c]

if len(new_name) == 0:
alpha = string.ascii_lowercase
ch = random.choice(alpha)
new_name = str(random.randrange(1, 1000)) + "__" + ch
return new_name

def create_tensor_global(
self,
t: onnx.TensorProto,
) -> Tuple[str, IrType]:
# Always create globals at the top. Then after created, if there was
# a prior one, move the new one to after it to maintain declaration
# order.
name = self.sanitize_name(t.name)
with InsertionPoint.at_block_begin(
self._m.regions[0].blocks[0]
), Location.unknown():
# After lowering to linalg-on-tensors, the data type needs to be signless.
# So, we construct the globals to have signless types, and use
# torch_c.from_builtin_tensor to convert to the correct frontend type.
vtensor_type = RankedTensorType.get(
tuple(t.dims), ELEM_TYPE_TO_SIGNLESS_IR_TYPE[t.data_type]()
)
ir_attrs = {
"sym_name": StringAttr.get(name),
"sym_visibility": StringAttr.get("private"),
"type": TypeAttr.get(vtensor_type),
}

external_scope_attr = StringAttr.get(self.params_scope)
external_name_attr = StringAttr.get(name)
ir_attrs["initial_value"] = Attribute.parse(
f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {vtensor_type}"
)
global_op = util.GlobalOp(
ir_attrs["sym_name"],
ir_attrs["type"],
sym_visibility=ir_attrs["sym_visibility"],
initial_value=ir_attrs["initial_value"],
)
self.symbol_table.insert(global_op)
if self.last_global_op is not None:
global_op.move_after(self.last_global_op)
self.last_global_op = global_op
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
return actual_symbol_name, vtensor_type

@classmethod
def define_function(
cls,
graph_info: onnx_importer.GraphInfo,
module_op: Operation,
num_elements_threshold: int,
params_scope: str,
context_cache: Optional["onnx_importer.ContextCache"] = None,
module_cache: Optional["onnx_importer.ModuleCache"] = None,
private: bool = False,
) -> "IREENodeImporter":
# Recover per-context caches of various attributes.
# Allows modifications in the same context without
# loss of current state.
cc = (
context_cache
if context_cache is not None
else onnx_importer.ContextCache(module_op.context)
)
# Recover per-module caches of various attributes.
# Allows modification in the same module_op without
# loss of current state.
mc = (
module_cache
if module_cache is not None
else onnx_importer.ModuleCache(module_op, cc)
)
with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"):
body = module_op.regions[0].blocks[0]
func_name = graph_info.graph_proto.name
input_types = [
cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values()
]
output_types = [
cc.type_proto_to_type(out.type)
for out in graph_info.output_map.values()
]
ftype = onnx_importer.FunctionType.get(input_types, output_types)
func_op = onnx_importer.func_dialect.FuncOp(
func_name,
ftype,
ip=InsertionPoint(body),
visibility="private" if private else None,
)
block = func_op.add_entry_block(
[Location.name(k) for k in graph_info.input_map.keys()]
)
imp = IREENodeImporter(
graph_info,
parent_op=func_op,
block=block,
context_cache=cc,
module_op=module_op,
module_cache=mc,
num_elements_threshold=num_elements_threshold,
params_scope=params_scope,
)
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
imp._nv_map[node_name] = input_value
imp._populate_graph_attrs(func_op)
return imp

def import_initializer(
self, initializer: onnx.TensorProto, extern_name: Optional[str] = None
) -> Value:
# If an explicitly specified name is given, use that; otherwise, pick
# up the name from the tensor proto itself
initializer_name = extern_name if extern_name else initializer.name
dims = list(initializer.dims)
num_elements = 1
for d in dims:
num_elements = num_elements * d
if num_elements < self.num_elements_threshold:
imported_tensor = super().import_initializer(initializer)
self._nv_map[initializer_name] = imported_tensor
return imported_tensor

actual_symbol_name, tensor_type = self.create_tensor_global(initializer)
vtensor_type = self._cc.get_vtensor_type(
tuple(initializer.dims), self._cc.tensor_element_type(initializer.data_type)
)

with InsertionPoint(self._b), Location.name(initializer_name):
old_op = util.GlobalLoadOp(tensor_type, actual_symbol_name)
converted_value = Operation.create(
"torch_c.from_builtin_tensor",
results=[vtensor_type],
operands=[old_op.result],
).result

self._nv_map[initializer_name] = converted_value
tensor_as_array = numpy_helper.to_array(initializer)
self.param_archive.add_buffer(actual_symbol_name, tensor_as_array)
return converted_value


ELEM_TYPE_TO_SIGNLESS_IR_TYPE = copy.deepcopy(onnx_importer.ELEM_TYPE_TO_IR_TYPE_CB)

ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT64
] = lambda: IntegerType.get_signless(64)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT32
] = lambda: IntegerType.get_signless(32)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT16
] = lambda: IntegerType.get_signless(16)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT8
] = lambda: IntegerType.get_signless(8)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT4
] = lambda: IntegerType.get_signless(4)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT8
] = lambda: IntegerType.get_signless(8)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT4
] = lambda: IntegerType.get_signless(4)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT16
] = lambda: IntegerType.get_signless(16)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT64
] = lambda: IntegerType.get_signless(64)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT32
] = lambda: IntegerType.get_signless(32)
Loading
Loading