Skip to content

Commit

Permalink
[onnx][importer] Add support for externalized params (iree-org#18880)
Browse files Browse the repository at this point in the history
This patch adds support to externalize params, and store them to the
given path as an IRPA file.

The IR imported with externalization should now prevent possible OOM
errors happening due to large inlined parameters.

Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
  • Loading branch information
vinayakdsci authored and giacs-epic committed Dec 4, 2024
1 parent 86f634d commit cdb4ac6
Show file tree
Hide file tree
Showing 5 changed files with 399 additions and 19 deletions.
1 change: 1 addition & 0 deletions compiler/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ declare_mlir_python_sources(IREECompilerAPIPythonTools
tools/tf.py
tools/tflite.py
tools/import_onnx/__main__.py
tools/import_onnx/importer_externalization_overrides.py
tools/ir_tool/__main__.py
tools/scripts/iree_compile/__main__.py
tools/scripts/iree_opt/__main__.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,43 @@
python -m iree.compiler.tools.import_onnx ...
"""

import argparse
import os
from pathlib import Path
import sys
import tempfile

try:
import onnx
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"iree-import-onnx requires that the `onnx` Python package is installed "
f"(typically `{sys.executable} -m pip install onnx`)"
) from e

try:
from ...extras import onnx_importer
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"iree-import-onnx is only available if IREE was built with Torch support"
) from e

from ...ir import (
Context,
)
from .importer_externalization_overrides import *


def main(args: argparse.Namespace):
model_proto = load_onnx_model(args)
context = Context()
model_info = onnx_importer.ModelInfo(model_proto)
m = model_info.create_module(context=context).operation
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)

imp: Any = None
if args.externalize_params:
imp = IREENodeImporter.define_function(
model_info.main_graph, m, args.num_elements_threshold, args.params_scope
)
else:
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
imp.import_all()

if not args.no_verify:
m.verify()

if args.externalize_params:
default_param_path = Path(args.output_file).parent / Path(args.output_file).stem
param_path = (
(str(default_param_path) + "_params.irpa")
if args.save_params_to is None
else str(args.save_params_to)
)
imp.param_archive.create_archive_file(param_path)

# TODO: This isn't very efficient output. If these files ever
# get large, enable bytecode and direct binary emission to save
# some copies.
Expand All @@ -71,6 +72,12 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, str(args.data_dir))

# Only change the opset version if it is greater than the current one.
if args.opset_version and args.opset_version > raw_model.opset_import[0].version:
raw_model = onnx.version_converter.convert_version(
raw_model, args.opset_version
)

# Do shape inference two ways. First, attempt in-memory to avoid redundant
# loading and the need for writing a temporary file somewhere. If that
# fails, typically because of the 2 GB protobuf size limit, try again via
Expand Down Expand Up @@ -132,6 +139,37 @@ def parse_arguments(argv=None) -> argparse.Namespace:
" Defaults to the directory of the input file.",
type=Path,
)
parser.add_argument(
"--opset-version",
help="Allows specification of a newer opset_version to update the model"
" to before importing to MLIR. This can sometime assist with shape inference.",
type=int,
)
parser.add_argument(
"--num-elements-threshold",
help="Minimum number of elements for an initializer to be externalized. Only has an effect if 'externalize-params' is true.",
type=int,
default=100,
)
parser.add_argument(
"--externalize-params",
help="Externalize large parameters and store them on the disk, to load at runtime.",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--save-params-to",
help="Location to save the externalized parameters. When not set, the parameters will be written to '<output_file_name>_params.irpa'"
" under the namespace 'model', which can be configured by passing the namespace string to 'params-scope'.",
default=None,
type=Path,
)
parser.add_argument(
"--params-scope",
help="The namespace or the scope in which the externalized parameters are placed. Default is 'model'.",
type=str,
default="model",
)
args = parser.parse_args(argv)
return args

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import copy
import random
import string
import iree.runtime as rt

from ...dialects import util
from typing import Optional, Tuple, Any

try:
import onnx
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"iree-import-onnx requires that the `onnx` Python package is installed "
f"(typically `{sys.executable} -m pip install onnx`)"
) from e

try:
from ...extras import onnx_importer
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"iree-import-onnx is only available if IREE was built with Torch support"
) from e

from onnx import numpy_helper

from ...ir import (
Context,
Type as IrType,
TypeAttr,
RankedTensorType,
StringAttr,
Attribute,
Operation,
Location,
InsertionPoint,
Value,
SymbolTable,
IntegerType,
)


class IREENodeImporter(onnx_importer.NodeImporter):
def __init__(
self,
graph_info: onnx_importer.GraphInfo,
*,
parent_op: Operation,
block: onnx_importer.Block,
context_cache: "onnx_importer.ContextCache",
module_op: Operation,
module_cache: "onnx_importer.ModuleCache",
num_elements_threshold: int,
params_scope: str,
):
super().__init__(
graph_info,
parent_op=parent_op,
block=block,
context_cache=context_cache,
module_op=module_op,
module_cache=module_cache,
)
self.last_global_op = None
self.symbol_table = SymbolTable(module_op)
self.symbol_table.insert(parent_op)
self.num_elements_threshold = num_elements_threshold
self.param_archive = rt.ParameterIndex()
self.params_scope = params_scope

def sanitize_name(self, name: str) -> str:
# There are often some initializers in the models that have no name
# labels, or contain substrings like '::', which can cause conflicts,
# and invalid symbol names for symbolic references. This function will
# remove substrings like '::' when the name is not empty, and generate
# a random string when it is, as a placeholder.
new_name: str = ""
for c in range(len(name)):
if name[c] == ":":
new_name += "_"
else:
new_name += name[c]

if len(new_name) == 0:
alpha = string.ascii_lowercase
ch = random.choice(alpha)
new_name = str(random.randrange(1, 1000)) + "__" + ch
return new_name

def create_tensor_global(
self,
t: onnx.TensorProto,
) -> Tuple[str, IrType]:
# Always create globals at the top. Then after created, if there was
# a prior one, move the new one to after it to maintain declaration
# order.
name = self.sanitize_name(t.name)
with InsertionPoint.at_block_begin(
self._m.regions[0].blocks[0]
), Location.unknown():
# After lowering to linalg-on-tensors, the data type needs to be signless.
# So, we construct the globals to have signless types, and use
# torch_c.from_builtin_tensor to convert to the correct frontend type.
vtensor_type = RankedTensorType.get(
tuple(t.dims), ELEM_TYPE_TO_SIGNLESS_IR_TYPE[t.data_type]()
)
ir_attrs = {
"sym_name": StringAttr.get(name),
"sym_visibility": StringAttr.get("private"),
"type": TypeAttr.get(vtensor_type),
}

external_scope_attr = StringAttr.get(self.params_scope)
external_name_attr = StringAttr.get(name)
ir_attrs["initial_value"] = Attribute.parse(
f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {vtensor_type}"
)
global_op = util.GlobalOp(
ir_attrs["sym_name"],
ir_attrs["type"],
sym_visibility=ir_attrs["sym_visibility"],
initial_value=ir_attrs["initial_value"],
)
self.symbol_table.insert(global_op)
if self.last_global_op is not None:
global_op.move_after(self.last_global_op)
self.last_global_op = global_op
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
return actual_symbol_name, vtensor_type

@classmethod
def define_function(
cls,
graph_info: onnx_importer.GraphInfo,
module_op: Operation,
num_elements_threshold: int,
params_scope: str,
context_cache: Optional["onnx_importer.ContextCache"] = None,
module_cache: Optional["onnx_importer.ModuleCache"] = None,
private: bool = False,
) -> "IREENodeImporter":
# Recover per-context caches of various attributes.
# Allows modifications in the same context without
# loss of current state.
cc = (
context_cache
if context_cache is not None
else onnx_importer.ContextCache(module_op.context)
)
# Recover per-module caches of various attributes.
# Allows modification in the same module_op without
# loss of current state.
mc = (
module_cache
if module_cache is not None
else onnx_importer.ModuleCache(module_op, cc)
)
with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"):
body = module_op.regions[0].blocks[0]
func_name = graph_info.graph_proto.name
input_types = [
cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values()
]
output_types = [
cc.type_proto_to_type(out.type)
for out in graph_info.output_map.values()
]
ftype = onnx_importer.FunctionType.get(input_types, output_types)
func_op = onnx_importer.func_dialect.FuncOp(
func_name,
ftype,
ip=InsertionPoint(body),
visibility="private" if private else None,
)
block = func_op.add_entry_block(
[Location.name(k) for k in graph_info.input_map.keys()]
)
imp = IREENodeImporter(
graph_info,
parent_op=func_op,
block=block,
context_cache=cc,
module_op=module_op,
module_cache=mc,
num_elements_threshold=num_elements_threshold,
params_scope=params_scope,
)
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
imp._nv_map[node_name] = input_value
imp._populate_graph_attrs(func_op)
return imp

def import_initializer(
self, initializer: onnx.TensorProto, extern_name: Optional[str] = None
) -> Value:
# If an explicitly specified name is given, use that; otherwise, pick
# up the name from the tensor proto itself
initializer_name = extern_name if extern_name else initializer.name
dims = list(initializer.dims)
num_elements = 1
for d in dims:
num_elements = num_elements * d
if num_elements < self.num_elements_threshold:
imported_tensor = super().import_initializer(initializer)
self._nv_map[initializer_name] = imported_tensor
return imported_tensor

actual_symbol_name, tensor_type = self.create_tensor_global(initializer)
vtensor_type = self._cc.get_vtensor_type(
tuple(initializer.dims), self._cc.tensor_element_type(initializer.data_type)
)

with InsertionPoint(self._b), Location.name(initializer_name):
old_op = util.GlobalLoadOp(tensor_type, actual_symbol_name)
converted_value = Operation.create(
"torch_c.from_builtin_tensor",
results=[vtensor_type],
operands=[old_op.result],
).result

self._nv_map[initializer_name] = converted_value
tensor_as_array = numpy_helper.to_array(initializer)
self.param_archive.add_buffer(actual_symbol_name, tensor_as_array)
return converted_value


ELEM_TYPE_TO_SIGNLESS_IR_TYPE = copy.deepcopy(onnx_importer.ELEM_TYPE_TO_IR_TYPE_CB)

ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT64
] = lambda: IntegerType.get_signless(64)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT32
] = lambda: IntegerType.get_signless(32)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT16
] = lambda: IntegerType.get_signless(16)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT8
] = lambda: IntegerType.get_signless(8)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.INT4
] = lambda: IntegerType.get_signless(4)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT8
] = lambda: IntegerType.get_signless(8)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT4
] = lambda: IntegerType.get_signless(4)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT16
] = lambda: IntegerType.get_signless(16)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT64
] = lambda: IntegerType.get_signless(64)
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
onnx.TensorProto.DataType.UINT32
] = lambda: IntegerType.get_signless(32)
Loading

0 comments on commit cdb4ac6

Please sign in to comment.