Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] Loop Support (apache#6700)
Browse files Browse the repository at this point in the history
* Onnx loop almost working, checkpointing for safety.

* Very close to working.

* Last piece is fixing scan initialization.

* snapshotting for debug.

* Fix Josh's issue

* Use subgraph proto class.

* Loop with scan.

* Simple loop test now working.

* Scan outputs now working.

* Added second loop test.

* Removed unneeded helper functions.

* Remove bad merge artifact.

* Cleaned up scan output creation.

* Cleaned up some style mistakes.

* Add pylint skip for unused-argument.

* Remove onnx dependency.

* Remove now obsolete checks for 0 shaped tensors.

Co-authored-by: Jared Roesch <jroesch@octoml.ai>
  • Loading branch information
2 people authored and Trevor Morris committed Dec 4, 2020
1 parent bb767ac commit 7edbfcf
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 21 deletions.
224 changes: 214 additions & 10 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
# pylint: disable=import-outside-toplevel
"""ONNX: Open Neural Network Exchange frontend for Relay."""
import warnings
import numpy as np
import tvm
from tvm.ir import IRModule
from tvm.topi.util import get_const_tuple

from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import vision as _vision
from .. import loops as _loops
from .. import ty as _ty

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
Expand Down Expand Up @@ -95,6 +99,29 @@ def get_numpy(tensor_proto):
return to_array(tensor_proto)


def get_type(elem_type):
"""Converts onnx integer datatype to numpy datatype"""
try:
from onnx import TensorProto
except ImportError as e:
raise ImportError("Unable to import onnx which is required {}".format(e))
return TensorProto.DataType.Name(elem_type).lower()


def get_info(info_proto):
"""Extract the shape from a ValueInfoProto."""
shape = []
for dim in info_proto.type.tensor_type.shape.dim:
value = dim.dim_value
if value is None:
value = _ty.Any
shape.append(value)

name = info_proto.name
dtype = get_type(info_proto.type.tensor_type.elem_type)
return name, shape, dtype


def dimension_picker(prefix, suffix=""):
"""Check that dimensions are supported."""

Expand Down Expand Up @@ -1995,6 +2022,164 @@ def _impl_v11(cls, inputs, attr, params):
return result


class Loop(OnnxOpConverter):
"""Operator converter for Loop"""

@classmethod
def _impl_v11(cls, inputs, attr, params):
max_loop_count = inputs[0]
cond = inputs[1]
loop_deps = inputs[2:]
num_deps = len(loop_deps)
body = attr["body"]
iter_dtype = infer_type(max_loop_count).checked_type.dtype

# Determine what condition mode we're in.
assert cond is not None or max_loop_count is not None
is_for_loop = max_loop_count is not None and cond is None
is_condition_for_loop = cond is not None and max_loop_count is not None

# Loop inputs will be packed as
# [iter_count, max_count, condition, loop_deps, scan_outputs]
def cond_fn(*loop_inputs):
i = loop_inputs[0]
max_count = loop_inputs[1]
w = loop_inputs[2]

if cond is not None:
out_while = _op.equal(w, _expr.const(True, "bool"))
if max_loop_count is not None:
out_loop = _op.less(i, max_count)

if is_condition_for_loop:
return _op.logical_and(out_while, out_loop)
if is_for_loop:
return out_loop
return out_while

# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
subgraph_scope = GraphProto(graph_scope._shape, graph_scope._dtype)
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()

# Create a list of variables for each value updated in the loop.
def get_var(name, val, scan=False):
checked_type = infer_type(val)
if hasattr(checked_type, "type_annotation"):
checked_type = checked_type.type_annotation
shape = get_const_tuple(checked_type.shape)
actual_shape = []
for dim in shape:
if isinstance(dim, int) and dim == 0:
actual_shape.append(_ty.Any())
else:
actual_shape.append(dim)
if scan:
return _expr.var(name, shape=[_ty.Any()] + actual_shape, dtype=checked_type.dtype)

return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype)

loop_vars = [
_expr.var(body.input[0].name, shape=(), dtype=iter_dtype), # iteration count
_expr.var("max_count", shape=(), dtype=iter_dtype), # iteration count
get_var(body.input[1].name, cond), # exit condition
]
loop_vars += [get_var(body.input[i + 2].name, v) for i, v in enumerate(loop_deps)]
loop_var_names = [v.name_hint for v in loop_vars]

num_scan_outputs = len(body.output) - (1 + num_deps)
# TODO (jwfromm) Test with strided slice once type unifier for this case is fixed.
if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]:
warnings.warn(
"""
Using scan outputs in a loop with strided slice
currently may cause errors during compilation.
"""
)

# Construct variables and intial empty tensors for any scan outputs.
scan_output_vars = []
scan_output_init = []
for i in range(num_scan_outputs):
name, shape, dtype = get_info(body.output[i + 1 + num_deps])
scan_output_vars.append(_expr.var(name, shape=([_ty.Any()] + shape), dtype=dtype))
scan_output_init.append(_op.reshape(_expr.const([]), [0] + shape))

# Now we can remove loop iter variables from our inner loop's inputs.
# This is kind of a hack since we have graph inputs that we don't
# want to treat as actual inputs.
while len(body.input) != 0:
body.input.pop(0)

# Define the loop body, in this function we need to unpack loop inputs,
# convert the loop subgraph, and pack outputs for the next iteration.
def body_fn(*loop_inputs):
# Unpack inputs
loop_count = loop_inputs[0]
max_count = loop_inputs[1]
cond = loop_inputs[2]
current_vars = list(loop_inputs[3 : (3 + num_deps)])
scan_outputs = loop_inputs[(3 + num_deps) :]

# Prepare body inputs by adding them to node dictionary.
new_inputs = [loop_count, max_count, cond] + current_vars
for i, inp in enumerate(new_inputs):
subgraph_scope._nodes[loop_var_names[i]] = inp

# Get the output of the current loop using the updated inputs.
with subgraph_scope:
loop_outputs = subgraph_scope.from_onnx(body, 11, get_output_expr=True)
# Unpack the body outputs and prepare variables for next iteration.
new_cond = loop_outputs[0]
new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)]
new_scan_outputs = [loop_outputs[i] for i in range(1 + num_deps, len(loop_outputs))]

# Increment counter.
if max_loop_count is not None:
incr = _expr.const(1, dtype=iter_dtype)
loop_count = loop_count + incr

# Add new scan outputs to tracking
combined_scan_outputs = []
for i, scan in enumerate(scan_outputs):
new_scan = _op.expand_dims(new_scan_outputs[i], axis=0)
combined_scan = _op.concatenate([scan, new_scan], axis=0)
combined_scan_outputs.append(combined_scan)

# Pack loop outputs for next iteration
# [iter_count, cond, loop_deps, loop_scans]
return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs

# Create the loop function.
loop = _loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn)

# Now need to run initial values through the graph.
init_count = _expr.const(0, dtype=iter_dtype)
loop_vals = loop(init_count, max_loop_count, cond, *loop_deps, *scan_output_init)

# Extract final iteration outputs.
if num_deps + num_scan_outputs == 1:
outputs = _expr.TupleGetItem(loop_vals, 3)
else:
outputs = _expr.TupleWrapper(
_expr.Tuple(
[
_expr.TupleGetItem(loop_vals, i + 3)
for i in range(num_deps + num_scan_outputs)
]
),
num_deps + num_scan_outputs,
)

# Update outer graph with constants found in the subgraph.
free_vars = analysis.free_vars(loop)
graph_scope._params.update(subgraph_scope._params)
for var in free_vars:
graph_scope._nodes.update({var.name_hint: var})
return outputs


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2150,6 +2335,8 @@ def _get_convert_map(opset):
"Resize": Resize.get_converter(opset),
"NonZero": NonZero.get_converter(opset),
"Range": Range.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
}


Expand All @@ -2166,6 +2353,8 @@ class GraphProto:
The input types to the graph
"""

current = None

def __init__(self, shape, dtype):
self._nodes = {}
self._params = {}
Expand All @@ -2176,15 +2365,24 @@ def __init__(self, shape, dtype):
self._shape = shape if shape else {}
self._dtype = dtype

def __enter__(self):
self._old_manager = GraphProto.current
GraphProto.current = self
return self

def __exit__(self, ptype, value, trace):
GraphProto.current = self._old_manager

def freeze(self, func, params):
bind_map = {}
for name in params.keys():
bind_map[self._nodes[name]] = _expr.const(params[name])
if name in self._nodes.keys():
bind_map[self._nodes[name]] = _expr.const(params[name])
body = _expr.bind(func.body, bind_map)
fn = _function.Function(analysis.free_vars(body), body)
return fn, {}

def from_onnx(self, graph, opset, freeze_params=False):
def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
"""Construct Relay expression from ONNX graph.
Onnx graph is a python protobuf object.
Expand All @@ -2208,6 +2406,11 @@ def from_onnx(self, graph, opset, freeze_params=False):
at compile time and helps in making models static if certain inputs represent
attributes relay would traditionally consider compile-time constants.
get_output_expr: bool
If set to true, this conversion will return each output expression rather
than a packaged module. This can be useful when converting subgraphs to
relay.
Returns
-------
mod : tvm.IRModule
Expand Down Expand Up @@ -2309,6 +2512,9 @@ def from_onnx(self, graph, opset, freeze_params=False):
# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
# If requested, directly return the converted expressions.
if get_output_expr:
return outputs
## Maintain the order of inputs and parameters from the ONNX graph, but only include
## those parameters that are needed to execute the relay graph
free_vars = analysis.free_vars(outputs)
Expand All @@ -2317,6 +2523,7 @@ def from_onnx(self, graph, opset, freeze_params=False):
for i_name in self._params:
if i_name in free_vars and i_name not in self._inputs:
self._inputs[i_name] = self._nodes[i_name]
# Create a function from our output expression and all input variables.
func = _function.Function([v for k, v in self._inputs.items()], outputs)
if freeze_params:
func, params = self.freeze(func, self._params)
Expand Down Expand Up @@ -2348,7 +2555,7 @@ def _parse_attr(self, attr_proto):
"""Convert a list of AttributeProto to a dict, with names as keys."""
attrs = {}
for a in attr_proto:
for f in ["f", "i", "s"]:
for f in ["f", "i", "s", "g"]:
if a.HasField(f):
attrs[a.name] = getattr(a, f)
for f in ["floats", "ints", "strings"]:
Expand All @@ -2362,12 +2569,9 @@ def _parse_attr(self, attr_proto):
if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ["g"]:
if a.HasField(f):
raise NotImplementedError("Filed {} is not supported in relay.".format(f))
for f in ["graphs"]:
if list(getattr(a, f)):
raise NotImplementedError("Filed {} is not supported in relay.".format(f))
raise NotImplementedError("Field {} is not supported in relay.".format(f))
if a.name not in attrs:
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs
Expand Down Expand Up @@ -2469,8 +2673,6 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
try:
onnx.checker.check_model(model)
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
import warnings

# the checker is a bit violent about errors, so simply print warnings here
warnings.warn(str(e))
except ImportError:
Expand All @@ -2482,5 +2684,7 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset = 1
mod, params = g.from_onnx(graph, opset, freeze_params)
# Use the graph proto as a scope so that ops can access other nodes if needed.
with g:
mod, params = g.from_onnx(graph, opset, freeze_params)
return mod, params
17 changes: 16 additions & 1 deletion python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Basic tensor operations."""
# pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin, unused-argument
from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from tvm.te.hybrid import script

from . import _make
from .dyn import _make as _dyn_make
from ..expr import Tuple, Expr
from . import op as reg


# We create a wrapper function for each operator in the
Expand Down Expand Up @@ -1138,6 +1140,19 @@ def copy(data):
return _make.copy(data)


@script
def _copy_shape_func(data_shape):
return data_shape


@reg.register_shape_func("copy", False)
def copy_shape_func(attrs, inputs, _):
"""
Shape function for copy op.
"""
return [_copy_shape_func(inputs[0])]


def device_copy(data, src_dev, dst_dev):
"""Copy data from the source device to the destination device. This
operator helps data transferring between difference contexts for
Expand Down
7 changes: 0 additions & 7 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
void VisitExpr_(const ConstantNode* const_node) {
// Check the shape is valid
NDArray data = const_node->data;
const DLTensor* tensor = data.operator->();
if (tensor->ndim > 0) {
int64_t* shapes = reinterpret_cast<int64_t*>(tensor->shape);
for (auto i = 0; i < tensor->ndim; i++) {
CHECK_GT(shapes[i], 0U);
}
}
size_t konst_idx = context_->constants.size();
if (expr_device_map_.empty()) {
context_->const_device_type.push_back(targets_.begin()->first);
Expand Down
3 changes: 0 additions & 3 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,6 @@ class ConstantFolder : public MixedModeMutator {
Expr ObjectToExpr(const ObjectRef& value) {
if (value->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(value);
for (auto dim : nd_array.Shape()) {
CHECK_GT(dim, 0) << "invalid dimension after constant eval";
}
return Constant(nd_array);
} else if (const auto* val = value.as<runtime::ADTObj>()) {
runtime::ADT adt = GetRef<runtime::ADT>(val);
Expand Down
Loading

0 comments on commit 7edbfcf

Please sign in to comment.