Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][Onnx] Loop Support #6700

Merged
merged 17 commits into from
Oct 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 214 additions & 10 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
# pylint: disable=import-outside-toplevel
"""ONNX: Open Neural Network Exchange frontend for Relay."""
import warnings
import numpy as np
import tvm
from tvm.ir import IRModule
from tvm.topi.util import get_const_tuple

from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import vision as _vision
from .. import loops as _loops
from .. import ty as _ty

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
Expand Down Expand Up @@ -95,6 +99,29 @@ def get_numpy(tensor_proto):
return to_array(tensor_proto)


def get_type(elem_type):
"""Converts onnx integer datatype to numpy datatype"""
try:
from onnx import TensorProto
except ImportError as e:
raise ImportError("Unable to import onnx which is required {}".format(e))
return TensorProto.DataType.Name(elem_type).lower()


def get_info(info_proto):
"""Extract the shape from a ValueInfoProto."""
shape = []
for dim in info_proto.type.tensor_type.shape.dim:
value = dim.dim_value
if value is None:
value = _ty.Any
shape.append(value)

name = info_proto.name
dtype = get_type(info_proto.type.tensor_type.elem_type)
return name, shape, dtype


def dimension_picker(prefix, suffix=""):
"""Check that dimensions are supported."""

Expand Down Expand Up @@ -1995,6 +2022,164 @@ def _impl_v11(cls, inputs, attr, params):
return result


class Loop(OnnxOpConverter):
"""Operator converter for Loop"""

@classmethod
def _impl_v11(cls, inputs, attr, params):
max_loop_count = inputs[0]
cond = inputs[1]
loop_deps = inputs[2:]
num_deps = len(loop_deps)
body = attr["body"]
iter_dtype = infer_type(max_loop_count).checked_type.dtype

# Determine what condition mode we're in.
assert cond is not None or max_loop_count is not None
is_for_loop = max_loop_count is not None and cond is None
is_condition_for_loop = cond is not None and max_loop_count is not None

# Loop inputs will be packed as
# [iter_count, max_count, condition, loop_deps, scan_outputs]
def cond_fn(*loop_inputs):
i = loop_inputs[0]
max_count = loop_inputs[1]
w = loop_inputs[2]

if cond is not None:
out_while = _op.equal(w, _expr.const(True, "bool"))
if max_loop_count is not None:
out_loop = _op.less(i, max_count)

if is_condition_for_loop:
return _op.logical_and(out_while, out_loop)
if is_for_loop:
return out_loop
return out_while

# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
subgraph_scope = GraphProto(graph_scope._shape, graph_scope._dtype)
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()

# Create a list of variables for each value updated in the loop.
def get_var(name, val, scan=False):
checked_type = infer_type(val)
if hasattr(checked_type, "type_annotation"):
checked_type = checked_type.type_annotation
shape = get_const_tuple(checked_type.shape)
actual_shape = []
for dim in shape:
if isinstance(dim, int) and dim == 0:
actual_shape.append(_ty.Any())
else:
actual_shape.append(dim)
if scan:
return _expr.var(name, shape=[_ty.Any()] + actual_shape, dtype=checked_type.dtype)

return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype)

loop_vars = [
_expr.var(body.input[0].name, shape=(), dtype=iter_dtype), # iteration count
_expr.var("max_count", shape=(), dtype=iter_dtype), # iteration count
get_var(body.input[1].name, cond), # exit condition
]
loop_vars += [get_var(body.input[i + 2].name, v) for i, v in enumerate(loop_deps)]
loop_var_names = [v.name_hint for v in loop_vars]

num_scan_outputs = len(body.output) - (1 + num_deps)
# TODO (jwfromm) Test with strided slice once type unifier for this case is fixed.
Copy link
Contributor

@kevinLu1114 kevinLu1114 Jun 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jwfromm
Did you make this?

I got some error.
I'm not sure if it is because of this.
Can you help me please

The model is mobilenetv1-SSD like this:
https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.onnx

if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]:
warnings.warn(
"""
Using scan outputs in a loop with strided slice
currently may cause errors during compilation.
"""
)

# Construct variables and intial empty tensors for any scan outputs.
scan_output_vars = []
scan_output_init = []
for i in range(num_scan_outputs):
name, shape, dtype = get_info(body.output[i + 1 + num_deps])
scan_output_vars.append(_expr.var(name, shape=([_ty.Any()] + shape), dtype=dtype))
scan_output_init.append(_op.reshape(_expr.const([]), [0] + shape))

# Now we can remove loop iter variables from our inner loop's inputs.
# This is kind of a hack since we have graph inputs that we don't
# want to treat as actual inputs.
while len(body.input) != 0:
body.input.pop(0)

# Define the loop body, in this function we need to unpack loop inputs,
# convert the loop subgraph, and pack outputs for the next iteration.
def body_fn(*loop_inputs):
# Unpack inputs
loop_count = loop_inputs[0]
max_count = loop_inputs[1]
cond = loop_inputs[2]
current_vars = list(loop_inputs[3 : (3 + num_deps)])
scan_outputs = loop_inputs[(3 + num_deps) :]

# Prepare body inputs by adding them to node dictionary.
new_inputs = [loop_count, max_count, cond] + current_vars
for i, inp in enumerate(new_inputs):
subgraph_scope._nodes[loop_var_names[i]] = inp

# Get the output of the current loop using the updated inputs.
with subgraph_scope:
loop_outputs = subgraph_scope.from_onnx(body, 11, get_output_expr=True)
# Unpack the body outputs and prepare variables for next iteration.
new_cond = loop_outputs[0]
new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)]
new_scan_outputs = [loop_outputs[i] for i in range(1 + num_deps, len(loop_outputs))]

# Increment counter.
if max_loop_count is not None:
incr = _expr.const(1, dtype=iter_dtype)
loop_count = loop_count + incr

# Add new scan outputs to tracking
combined_scan_outputs = []
for i, scan in enumerate(scan_outputs):
new_scan = _op.expand_dims(new_scan_outputs[i], axis=0)
combined_scan = _op.concatenate([scan, new_scan], axis=0)
combined_scan_outputs.append(combined_scan)

# Pack loop outputs for next iteration
# [iter_count, cond, loop_deps, loop_scans]
return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs

# Create the loop function.
loop = _loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn)

# Now need to run initial values through the graph.
init_count = _expr.const(0, dtype=iter_dtype)
loop_vals = loop(init_count, max_loop_count, cond, *loop_deps, *scan_output_init)

# Extract final iteration outputs.
if num_deps + num_scan_outputs == 1:
outputs = _expr.TupleGetItem(loop_vals, 3)
else:
outputs = _expr.TupleWrapper(
_expr.Tuple(
[
_expr.TupleGetItem(loop_vals, i + 3)
for i in range(num_deps + num_scan_outputs)
]
),
num_deps + num_scan_outputs,
)

# Update outer graph with constants found in the subgraph.
free_vars = analysis.free_vars(loop)
graph_scope._params.update(subgraph_scope._params)
for var in free_vars:
graph_scope._nodes.update({var.name_hint: var})
return outputs


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2150,6 +2335,8 @@ def _get_convert_map(opset):
"Resize": Resize.get_converter(opset),
"NonZero": NonZero.get_converter(opset),
"Range": Range.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
}


Expand All @@ -2166,6 +2353,8 @@ class GraphProto:
The input types to the graph
"""

current = None

def __init__(self, shape, dtype):
self._nodes = {}
self._params = {}
Expand All @@ -2176,15 +2365,24 @@ def __init__(self, shape, dtype):
self._shape = shape if shape else {}
self._dtype = dtype

def __enter__(self):
self._old_manager = GraphProto.current
GraphProto.current = self
return self

def __exit__(self, ptype, value, trace):
GraphProto.current = self._old_manager

def freeze(self, func, params):
bind_map = {}
for name in params.keys():
bind_map[self._nodes[name]] = _expr.const(params[name])
if name in self._nodes.keys():
bind_map[self._nodes[name]] = _expr.const(params[name])
body = _expr.bind(func.body, bind_map)
fn = _function.Function(analysis.free_vars(body), body)
return fn, {}

def from_onnx(self, graph, opset, freeze_params=False):
def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
"""Construct Relay expression from ONNX graph.

Onnx graph is a python protobuf object.
Expand All @@ -2208,6 +2406,11 @@ def from_onnx(self, graph, opset, freeze_params=False):
at compile time and helps in making models static if certain inputs represent
attributes relay would traditionally consider compile-time constants.

get_output_expr: bool
If set to true, this conversion will return each output expression rather
than a packaged module. This can be useful when converting subgraphs to
relay.

Returns
-------
mod : tvm.IRModule
Expand Down Expand Up @@ -2309,6 +2512,9 @@ def from_onnx(self, graph, opset, freeze_params=False):
# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
# If requested, directly return the converted expressions.
if get_output_expr:
return outputs
## Maintain the order of inputs and parameters from the ONNX graph, but only include
## those parameters that are needed to execute the relay graph
free_vars = analysis.free_vars(outputs)
Expand All @@ -2317,6 +2523,7 @@ def from_onnx(self, graph, opset, freeze_params=False):
for i_name in self._params:
if i_name in free_vars and i_name not in self._inputs:
self._inputs[i_name] = self._nodes[i_name]
# Create a function from our output expression and all input variables.
func = _function.Function([v for k, v in self._inputs.items()], outputs)
if freeze_params:
func, params = self.freeze(func, self._params)
Expand Down Expand Up @@ -2348,7 +2555,7 @@ def _parse_attr(self, attr_proto):
"""Convert a list of AttributeProto to a dict, with names as keys."""
attrs = {}
for a in attr_proto:
for f in ["f", "i", "s"]:
for f in ["f", "i", "s", "g"]:
if a.HasField(f):
attrs[a.name] = getattr(a, f)
for f in ["floats", "ints", "strings"]:
Expand All @@ -2362,12 +2569,9 @@ def _parse_attr(self, attr_proto):
if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ["g"]:
if a.HasField(f):
raise NotImplementedError("Filed {} is not supported in relay.".format(f))
for f in ["graphs"]:
if list(getattr(a, f)):
raise NotImplementedError("Filed {} is not supported in relay.".format(f))
raise NotImplementedError("Field {} is not supported in relay.".format(f))
if a.name not in attrs:
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs
Expand Down Expand Up @@ -2469,8 +2673,6 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
try:
onnx.checker.check_model(model)
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
import warnings

# the checker is a bit violent about errors, so simply print warnings here
warnings.warn(str(e))
except ImportError:
Expand All @@ -2482,5 +2684,7 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset = 1
mod, params = g.from_onnx(graph, opset, freeze_params)
# Use the graph proto as a scope so that ops can access other nodes if needed.
with g:
mod, params = g.from_onnx(graph, opset, freeze_params)
return mod, params
17 changes: 16 additions & 1 deletion python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Basic tensor operations."""
# pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin, unused-argument
from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from tvm.te.hybrid import script

from . import _make
from .dyn import _make as _dyn_make
from ..expr import Tuple, Expr
from . import op as reg


# We create a wrapper function for each operator in the
Expand Down Expand Up @@ -1138,6 +1140,19 @@ def copy(data):
return _make.copy(data)


@script
def _copy_shape_func(data_shape):
return data_shape


@reg.register_shape_func("copy", False)
def copy_shape_func(attrs, inputs, _):
"""
Shape function for copy op.
"""
return [_copy_shape_func(inputs[0])]


def device_copy(data, src_dev, dst_dev):
"""Copy data from the source device to the destination device. This
operator helps data transferring between difference contexts for
Expand Down
7 changes: 0 additions & 7 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
void VisitExpr_(const ConstantNode* const_node) {
// Check the shape is valid
NDArray data = const_node->data;
const DLTensor* tensor = data.operator->();
if (tensor->ndim > 0) {
int64_t* shapes = reinterpret_cast<int64_t*>(tensor->shape);
for (auto i = 0; i < tensor->ndim; i++) {
CHECK_GT(shapes[i], 0U);
}
}
size_t konst_idx = context_->constants.size();
if (expr_device_map_.empty()) {
context_->const_device_type.push_back(targets_.begin()->first);
Expand Down
3 changes: 0 additions & 3 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,6 @@ class ConstantFolder : public MixedModeMutator {
Expr ObjectToExpr(const ObjectRef& value) {
if (value->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(value);
for (auto dim : nd_array.Shape()) {
CHECK_GT(dim, 0) << "invalid dimension after constant eval";
}
return Constant(nd_array);
} else if (const auto* val = value.as<runtime::ADTObj>()) {
runtime::ADT adt = GetRef<runtime::ADT>(val);
Expand Down
Loading