diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9fae94b5a8a1..e2c6b9abc449 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -17,9 +17,11 @@ # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=import-outside-toplevel """ONNX: Open Neural Network Exchange frontend for Relay.""" +import warnings import numpy as np import tvm from tvm.ir import IRModule +from tvm.topi.util import get_const_tuple from ... import nd as _nd from .. import analysis @@ -27,6 +29,8 @@ from .. import function as _function from .. import op as _op from .. import vision as _vision +from .. import loops as _loops +from .. import ty as _ty from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels @@ -95,6 +99,29 @@ def get_numpy(tensor_proto): return to_array(tensor_proto) +def get_type(elem_type): + """Converts onnx integer datatype to numpy datatype""" + try: + from onnx import TensorProto + except ImportError as e: + raise ImportError("Unable to import onnx which is required {}".format(e)) + return TensorProto.DataType.Name(elem_type).lower() + + +def get_info(info_proto): + """Extract the shape from a ValueInfoProto.""" + shape = [] + for dim in info_proto.type.tensor_type.shape.dim: + value = dim.dim_value + if value is None: + value = _ty.Any + shape.append(value) + + name = info_proto.name + dtype = get_type(info_proto.type.tensor_type.elem_type) + return name, shape, dtype + + def dimension_picker(prefix, suffix=""): """Check that dimensions are supported.""" @@ -1995,6 +2022,164 @@ def _impl_v11(cls, inputs, attr, params): return result +class Loop(OnnxOpConverter): + """Operator converter for Loop""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + max_loop_count = inputs[0] + cond = inputs[1] + loop_deps = inputs[2:] + num_deps = len(loop_deps) + body = attr["body"] + iter_dtype = infer_type(max_loop_count).checked_type.dtype + + # Determine what condition mode we're in. + assert cond is not None or max_loop_count is not None + is_for_loop = max_loop_count is not None and cond is None + is_condition_for_loop = cond is not None and max_loop_count is not None + + # Loop inputs will be packed as + # [iter_count, max_count, condition, loop_deps, scan_outputs] + def cond_fn(*loop_inputs): + i = loop_inputs[0] + max_count = loop_inputs[1] + w = loop_inputs[2] + + if cond is not None: + out_while = _op.equal(w, _expr.const(True, "bool")) + if max_loop_count is not None: + out_loop = _op.less(i, max_count) + + if is_condition_for_loop: + return _op.logical_and(out_while, out_loop) + if is_for_loop: + return out_loop + return out_while + + # Get the current graph proto and create a clone for the subgraph + graph_scope = GraphProto.current + subgraph_scope = GraphProto(graph_scope._shape, graph_scope._dtype) + # Load nodes from outer graph into inner graph. + subgraph_scope._nodes = graph_scope._nodes.copy() + + # Create a list of variables for each value updated in the loop. + def get_var(name, val, scan=False): + checked_type = infer_type(val) + if hasattr(checked_type, "type_annotation"): + checked_type = checked_type.type_annotation + shape = get_const_tuple(checked_type.shape) + actual_shape = [] + for dim in shape: + if isinstance(dim, int) and dim == 0: + actual_shape.append(_ty.Any()) + else: + actual_shape.append(dim) + if scan: + return _expr.var(name, shape=[_ty.Any()] + actual_shape, dtype=checked_type.dtype) + + return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) + + loop_vars = [ + _expr.var(body.input[0].name, shape=(), dtype=iter_dtype), # iteration count + _expr.var("max_count", shape=(), dtype=iter_dtype), # iteration count + get_var(body.input[1].name, cond), # exit condition + ] + loop_vars += [get_var(body.input[i + 2].name, v) for i, v in enumerate(loop_deps)] + loop_var_names = [v.name_hint for v in loop_vars] + + num_scan_outputs = len(body.output) - (1 + num_deps) + # TODO (jwfromm) Test with strided slice once type unifier for this case is fixed. + if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]: + warnings.warn( + """ + Using scan outputs in a loop with strided slice + currently may cause errors during compilation. + """ + ) + + # Construct variables and intial empty tensors for any scan outputs. + scan_output_vars = [] + scan_output_init = [] + for i in range(num_scan_outputs): + name, shape, dtype = get_info(body.output[i + 1 + num_deps]) + scan_output_vars.append(_expr.var(name, shape=([_ty.Any()] + shape), dtype=dtype)) + scan_output_init.append(_op.reshape(_expr.const([]), [0] + shape)) + + # Now we can remove loop iter variables from our inner loop's inputs. + # This is kind of a hack since we have graph inputs that we don't + # want to treat as actual inputs. + while len(body.input) != 0: + body.input.pop(0) + + # Define the loop body, in this function we need to unpack loop inputs, + # convert the loop subgraph, and pack outputs for the next iteration. + def body_fn(*loop_inputs): + # Unpack inputs + loop_count = loop_inputs[0] + max_count = loop_inputs[1] + cond = loop_inputs[2] + current_vars = list(loop_inputs[3 : (3 + num_deps)]) + scan_outputs = loop_inputs[(3 + num_deps) :] + + # Prepare body inputs by adding them to node dictionary. + new_inputs = [loop_count, max_count, cond] + current_vars + for i, inp in enumerate(new_inputs): + subgraph_scope._nodes[loop_var_names[i]] = inp + + # Get the output of the current loop using the updated inputs. + with subgraph_scope: + loop_outputs = subgraph_scope.from_onnx(body, 11, get_output_expr=True) + # Unpack the body outputs and prepare variables for next iteration. + new_cond = loop_outputs[0] + new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)] + new_scan_outputs = [loop_outputs[i] for i in range(1 + num_deps, len(loop_outputs))] + + # Increment counter. + if max_loop_count is not None: + incr = _expr.const(1, dtype=iter_dtype) + loop_count = loop_count + incr + + # Add new scan outputs to tracking + combined_scan_outputs = [] + for i, scan in enumerate(scan_outputs): + new_scan = _op.expand_dims(new_scan_outputs[i], axis=0) + combined_scan = _op.concatenate([scan, new_scan], axis=0) + combined_scan_outputs.append(combined_scan) + + # Pack loop outputs for next iteration + # [iter_count, cond, loop_deps, loop_scans] + return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs + + # Create the loop function. + loop = _loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn) + + # Now need to run initial values through the graph. + init_count = _expr.const(0, dtype=iter_dtype) + loop_vals = loop(init_count, max_loop_count, cond, *loop_deps, *scan_output_init) + + # Extract final iteration outputs. + if num_deps + num_scan_outputs == 1: + outputs = _expr.TupleGetItem(loop_vals, 3) + else: + outputs = _expr.TupleWrapper( + _expr.Tuple( + [ + _expr.TupleGetItem(loop_vals, i + 3) + for i in range(num_deps + num_scan_outputs) + ] + ), + num_deps + num_scan_outputs, + ) + + # Update outer graph with constants found in the subgraph. + free_vars = analysis.free_vars(loop) + graph_scope._params.update(subgraph_scope._params) + for var in free_vars: + graph_scope._nodes.update({var.name_hint: var}) + return outputs + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2150,6 +2335,8 @@ def _get_convert_map(opset): "Resize": Resize.get_converter(opset), "NonZero": NonZero.get_converter(opset), "Range": Range.get_converter(opset), + # defs/control_flow + "Loop": Loop.get_converter(opset), } @@ -2166,6 +2353,8 @@ class GraphProto: The input types to the graph """ + current = None + def __init__(self, shape, dtype): self._nodes = {} self._params = {} @@ -2176,15 +2365,24 @@ def __init__(self, shape, dtype): self._shape = shape if shape else {} self._dtype = dtype + def __enter__(self): + self._old_manager = GraphProto.current + GraphProto.current = self + return self + + def __exit__(self, ptype, value, trace): + GraphProto.current = self._old_manager + def freeze(self, func, params): bind_map = {} for name in params.keys(): - bind_map[self._nodes[name]] = _expr.const(params[name]) + if name in self._nodes.keys(): + bind_map[self._nodes[name]] = _expr.const(params[name]) body = _expr.bind(func.body, bind_map) fn = _function.Function(analysis.free_vars(body), body) return fn, {} - def from_onnx(self, graph, opset, freeze_params=False): + def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False): """Construct Relay expression from ONNX graph. Onnx graph is a python protobuf object. @@ -2208,6 +2406,11 @@ def from_onnx(self, graph, opset, freeze_params=False): at compile time and helps in making models static if certain inputs represent attributes relay would traditionally consider compile-time constants. + get_output_expr: bool + If set to true, this conversion will return each output expression rather + than a packaged module. This can be useful when converting subgraphs to + relay. + Returns ------- mod : tvm.IRModule @@ -2309,6 +2512,9 @@ def from_onnx(self, graph, opset, freeze_params=False): # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + # If requested, directly return the converted expressions. + if get_output_expr: + return outputs ## Maintain the order of inputs and parameters from the ONNX graph, but only include ## those parameters that are needed to execute the relay graph free_vars = analysis.free_vars(outputs) @@ -2317,6 +2523,7 @@ def from_onnx(self, graph, opset, freeze_params=False): for i_name in self._params: if i_name in free_vars and i_name not in self._inputs: self._inputs[i_name] = self._nodes[i_name] + # Create a function from our output expression and all input variables. func = _function.Function([v for k, v in self._inputs.items()], outputs) if freeze_params: func, params = self.freeze(func, self._params) @@ -2348,7 +2555,7 @@ def _parse_attr(self, attr_proto): """Convert a list of AttributeProto to a dict, with names as keys.""" attrs = {} for a in attr_proto: - for f in ["f", "i", "s"]: + for f in ["f", "i", "s", "g"]: if a.HasField(f): attrs[a.name] = getattr(a, f) for f in ["floats", "ints", "strings"]: @@ -2362,12 +2569,9 @@ def _parse_attr(self, attr_proto): if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) - for f in ["g"]: - if a.HasField(f): - raise NotImplementedError("Filed {} is not supported in relay.".format(f)) for f in ["graphs"]: if list(getattr(a, f)): - raise NotImplementedError("Filed {} is not supported in relay.".format(f)) + raise NotImplementedError("Field {} is not supported in relay.".format(f)) if a.name not in attrs: raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) return attrs @@ -2469,8 +2673,6 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals try: onnx.checker.check_model(model) except onnx.onnx_cpp2py_export.checker.ValidationError as e: - import warnings - # the checker is a bit violent about errors, so simply print warnings here warnings.warn(str(e)) except ImportError: @@ -2482,5 +2684,7 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals opset = model.opset_import[0].version if model.opset_import else 1 except AttributeError: opset = 1 - mod, params = g.from_onnx(graph, opset, freeze_params) + # Use the graph proto as a scope so that ops can access other nodes if needed. + with g: + mod, params = g.from_onnx(graph, opset, freeze_params) return mod, params diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 832372a6ed0d..453a9b7a7759 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -15,13 +15,15 @@ # specific language governing permissions and limitations # under the License. """Basic tensor operations.""" -# pylint: disable=redefined-builtin +# pylint: disable=redefined-builtin, unused-argument from tvm.runtime import ndarray as _nd from tvm.runtime import TVMContext as _TVMContext +from tvm.te.hybrid import script from . import _make from .dyn import _make as _dyn_make from ..expr import Tuple, Expr +from . import op as reg # We create a wrapper function for each operator in the @@ -1138,6 +1140,19 @@ def copy(data): return _make.copy(data) +@script +def _copy_shape_func(data_shape): + return data_shape + + +@reg.register_shape_func("copy", False) +def copy_shape_func(attrs, inputs, _): + """ + Shape function for copy op. + """ + return [_copy_shape_func(inputs[0])] + + def device_copy(data, src_dev, dst_dev): """Copy data from the source device to the destination device. This operator helps data transferring between difference contexts for diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c7ceca3604c8..c3bf80571638 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -343,13 +343,6 @@ class VMFunctionCompiler : ExprFunctor { void VisitExpr_(const ConstantNode* const_node) { // Check the shape is valid NDArray data = const_node->data; - const DLTensor* tensor = data.operator->(); - if (tensor->ndim > 0) { - int64_t* shapes = reinterpret_cast(tensor->shape); - for (auto i = 0; i < tensor->ndim; i++) { - CHECK_GT(shapes[i], 0U); - } - } size_t konst_idx = context_->constants.size(); if (expr_device_map_.empty()) { context_->const_device_type.push_back(targets_.begin()->first); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 8d2cba05be49..1de690d91036 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -199,9 +199,6 @@ class ConstantFolder : public MixedModeMutator { Expr ObjectToExpr(const ObjectRef& value) { if (value->IsInstance()) { auto nd_array = Downcast(value); - for (auto dim : nd_array.Shape()) { - CHECK_GT(dim, 0) << "invalid dimension after constant eval"; - } return Constant(nd_array); } else if (const auto* val = value.as()) { runtime::ADT adt = GetRef(val); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 07e6dc465268..81b5186d0e26 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3660,6 +3660,157 @@ def verify_roi_align( verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) +def verify_cond_loop(): + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [1]) + cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) + + y = np.array([-2]).astype(np.float32) + + five_const_node = helper.make_node( + "Constant", + inputs=[], + outputs=["five"], + value=helper.make_tensor( + name="const_tensor_five", data_type=TensorProto.FLOAT, dims=(), vals=[5] + ), + ) + + iter_cast_node = helper.make_node( + "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT + ) + + y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) + + less_node = helper.make_node("Less", inputs=["y_out", "five"], outputs=["cond_less"]) + + squeeze_node = helper.make_node("Squeeze", inputs=["cond_less"], outputs=["cond_squeeze"]) + + cond_cast_node = helper.make_node( + "Cast", inputs=["cond_squeeze"], outputs=["cond_out"], to=onnx.TensorProto.BOOL + ) + + scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) + + loop_body = helper.make_graph( + [ + five_const_node, + iter_cast_node, + y_add_node, + less_node, + squeeze_node, + cond_cast_node, + scan_identity_node, + ], + "loop_body", + [iter_count, cond_in, y_in], + [cond_out, y_out, scan_out], + ) + + loop_node = helper.make_node( + "Loop", inputs=["trip_count", "cond", "y"], outputs=["res_y", "res_scan"], body=loop_body + ) + + trip_count = np.array(5).astype(np.int64) + res_y = np.array([13]).astype(np.float32) + cond = np.array(1).astype(np.bool) + loop_graph = onnx.helper.make_graph( + [loop_node], + "loop_outer", + inputs=[ + onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 1]), + ], + ) + loop_model = onnx.helper.make_model(loop_graph) + + # Set a high trip count so that condition trips first. + trip_count = np.array(40).astype(np.int64) + cond = np.array(1).astype(np.bool) + input_vals = [trip_count, cond, y] + onnx_out = get_onnxruntime_output(loop_model, input_vals) + + for target, ctx in [("llvm", tvm.cpu())]: + tvm_out = get_tvm_output_with_vm(loop_model, input_vals, target, ctx, freeze_params=True) + for i in range(len(tvm_out)): + tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) + + +def verify_count_loop(): + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [1]) + cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) + + y = np.array([-2]).astype(np.float32) + + iter_cast_node = helper.make_node( + "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT + ) + + y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) + + identity_node = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"]) + + scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) + + loop_body = helper.make_graph( + [identity_node, iter_cast_node, y_add_node, scan_identity_node], + "loop_body", + [iter_count, cond_in, y_in], + [cond_out, y_out, scan_out], + ) + + loop_node = helper.make_node( + "Loop", inputs=["trip_count", "cond", "y"], outputs=["res_y", "res_scan"], body=loop_body + ) + + trip_count = np.array(5).astype(np.int64) + res_y = np.array([13]).astype(np.float32) + cond = np.array(1).astype(np.bool) + loop_graph = onnx.helper.make_graph( + [loop_node], + "loop_outer", + inputs=[ + onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 1]), + ], + ) + loop_model = onnx.helper.make_model(loop_graph) + + trip_count = np.array(5).astype(np.int64) + cond = np.array(1).astype(np.bool) + input_vals = [trip_count, cond, y] + onnx_out = get_onnxruntime_output(loop_model, input_vals) + + for target, ctx in [("llvm", tvm.cpu())]: + tvm_out = get_tvm_output_with_vm(loop_model, input_vals, target, ctx, freeze_params=True) + for i in range(len(tvm_out)): + tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) + + +def test_loop(): + # Test a loop that exits once a condition is met. + verify_cond_loop() + # Test a loop that exits after a fixed number of iterations. + verify_count_loop() + + if __name__ == "__main__": test_flatten() test_reshape() @@ -3734,3 +3885,5 @@ def verify_roi_align( test_xor() test_max_roi_pool() test_roi_align() + test_range() + test_loop()