From 2658ebe737d38b441dee6121c01ba3f9f83ce518 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 2 Oct 2020 20:32:39 -0600 Subject: [PATCH] Dynamic ONNX Importer (#6351) * Change onnx importer to use dynamic upsampling3d (#3) fix pylint * Refactor ONNX frontend to be dynamic Make OneHot dynamic Support BatchMatMul with dynamically shaped inputs fix dynamic broadcast Add null checks to broadcast_to rel functions fail more isolated broadcast_to test use StructuralEqual instead of pointer comparisions in dynamic_to_static pass add an optional weight freeze argument to onnx importer convert onnx resize to dynamic op add dynamic expand to onnx importer add a shape_func for power fix BERTSquad, lint handle onnx graph initializer parameters more intelligently * Dynamic ONNX importer: Upsampling and Pad (#2) fix lint fix Call reference fix a type issue with expand fix a bad test refactor respond to review comments, fix batch matmul tests * black format * fix batch matmul test * add dynamic strided slice to the onnx importer * fix clip importer * fix qnn tutorial * fix bad merge, respond to review comments * add a simple dynamic model test * Add dynamic-shaped autopadding to convolution and pooling ops * fix dynamic issues in a few ops * fix pylint * disable tests onnxrt doesn't support * fix pytorch test * respond to review comments * add documentation about partially supporting dynamic shapes Co-authored-by: Lily Orth-Smith --- include/tvm/relay/transform.h | 11 + include/tvm/topi/broadcast.h | 11 +- python/tvm/relay/frontend/onnx.py | 607 +++++++++--------- python/tvm/relay/op/_tensor.py | 1 + python/tvm/relay/op/nn/_nn.py | 51 +- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/relay/op/strategy/x86.py | 21 +- python/tvm/topi/cuda/batch_matmul.py | 2 +- python/tvm/topi/nn/batch_matmul.py | 25 +- python/tvm/topi/x86/batch_matmul.py | 6 +- src/relay/backend/build_module.cc | 3 + src/relay/op/dyn/tensor/transform.cc | 18 +- src/relay/op/nn/convolution.h | 77 ++- src/relay/op/nn/nn.cc | 53 +- src/relay/op/nn/nn.h | 8 +- src/relay/op/tensor/transform.cc | 10 +- src/relay/transforms/dynamic_to_static.cc | 9 +- tests/python/frontend/onnx/test_forward.py | 413 +++++++++--- .../relay/dyn/test_dynamic_op_level10.py | 82 ++- tests/python/relay/test_op_level10.py | 27 + tutorials/frontend/from_onnx.py | 9 + 21 files changed, 957 insertions(+), 489 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index de2bcc4f4318..faa2698fdcbc 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -208,6 +208,17 @@ TVM_DLL Pass SimplifyInference(); */ TVM_DLL Pass FastMath(); +/*! + * \brief Find Dynamic ops and make them static + * + * Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces + * them with static ops and re-performs type inference and constant folding. The pass repeats + * itself until the graph stops changing or we run too many iterations. + * + * \return The pass. + */ +TVM_DLL Pass DynamicToStatic(); + /*! * \brief Infer the type of an expression. * diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index 8fabaaee14f9..d03ddc93b4c0 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -54,14 +54,19 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, << "\nvs\ninput: " << t; auto bh = detail::BroadcastShape(output_shape, t->shape); CHECK_EQ(output_shape.size(), bh.common_shape.size()); + Array oshape; for (size_t i = 0; i < output_shape.size(); ++i) { - CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); + if (output_shape[i].as() == nullptr) { + oshape.push_back(output_shape[i]); + } else { + CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); + oshape.push_back(bh.common_shape[i]); + } } auto l = [&](tvm::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; - return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, name, tag); + return tvm::te::compute(oshape, l, name, tag); } #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 841ff77b142d..59fdb32d1a16 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -28,31 +28,12 @@ from .. import op as _op from .. import vision as _vision -from ..function import Function -from ..expr import Call, Let -from ..expr import If, Tuple, TupleGetItem -from ..expr import RefCreate, RefRead, RefWrite -from ..expr_functor import ExprFunctor -from ..adt import Match, Clause -from ..op.tensor import minimum as _minimum, maximum as _maximum - from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels from .common import infer_type, get_name -from .common import infer_value as _infer_value -from .common import infer_value_simulated as _infer_value_simulated - -__all__ = ["from_onnx"] - -g = None -def infer_value(input_val, params, mod=None): - return g.infer_value(input_val, params, mod) - - -def infer_value_simulated(input_val, params): - return g.infer_value_simulated(input_val, params) +__all__ = ["from_onnx"] class onnx_input: @@ -256,21 +237,28 @@ class Pool(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - input_shape = infer_shape(inputs[0]) + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - pad_tuple = [] - for axis in range(len(input_shape) - 2): - axis_shape = input_shape[2 + axis] - stride = attr["strides"][axis] - kernel = attr["kernel_shape"][axis] - pad = get_pad_pair(axis_shape, kernel, stride) - pad_tuple.append(pad) - pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr["pads"] = pad_tuple + if cls.name == "avg_pool": + pad_tuple = [] + for axis in range(len(input_shape) - 2): + axis_shape = input_shape[2 + axis] + stride = attr["strides"][axis] + kernel = attr["kernel_shape"][axis] + pad = get_pad_pair(axis_shape, kernel, stride) + pad_tuple.append(pad) + pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) + attr["pads"] = pad_tuple + else: + # Warning: Pool does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim) elif attr["auto_pad"] == "VALID": - attr["pads"] = 0 + attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": pass else: @@ -290,7 +278,7 @@ def _impl_v1(cls, inputs, attr, params): transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)}, ignores=["dilations", "storage_order"], custom_check=dimension_constraint(), - )(inputs, attr, params) + )([data], attr, params) class Absolute(Unary): @@ -331,29 +319,68 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt(op_name="instance_norm")(inputs, attr, params) +def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", deconv=False): + """ + Perform autopadding with dynamic input shapes + """ + # get attributes as constants + strides = _op.const(np.array(strides), dtype="int64") + dilated_kernel_shape = _op.const( + np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ), + dtype="int64", + ) + shape = _op.strided_slice(_op.shape_of(data, dtype="int64"), [2], [ndim]) + # get input shape + + # set up integer constants + zero = _op.const(0, dtype="int64") + one = _op.const(1, dtype="int64") + two = _op.const(2, dtype="int64") + + # Calculate total padding + mod = _op.mod(shape, strides) + + left = _op.maximum(dilated_kernel_shape - strides, zero) + right = _op.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _op.where(_op.equal(mod, zero), left, right) + if deconv: + total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad + + # split total padding into before and after + pad_before = _op.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + pad = _op.concatenate( + [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) + + return _op.nn.pad(data, pad, _op.const(0.0), pad_type) + + class Conv(OnnxOpConverter): """Operator converter for Conv.""" @classmethod def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. - input_shape = infer_shape(inputs[0]) + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - pad_tuple = [] - for axis in range(len(input_shape) - 2): - axis_shape = input_shape[2 + axis] - stride = attr["strides"][axis] - kernel = attr["kernel_shape"][axis] - dilation = attr["dilations"][axis] - dilated_kernel = (kernel - 1) * dilation + 1 - pad = get_pad_pair(axis_shape, dilated_kernel, stride) - pad_tuple.append(pad) - pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr["pads"] = pad_tuple + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad(data, attr["strides"], attr["kernel_shape"], attr["dilations"], ndim) elif attr["auto_pad"] == "VALID": - attr["pads"] = tuple([0 for i in range(len(input_shape) - 2)]) + attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": pass else: @@ -381,7 +408,7 @@ def _impl_v1(cls, inputs, attr, params): "group": ("groups", 1), }, custom_check=dimension_constraint(), - )(inputs[:2], attr, params) + )([data, inputs[1]], attr, params) use_bias = len(inputs) == 3 if use_bias: @@ -400,21 +427,24 @@ def _impl_v1(cls, inputs, attr, params): groups = attr.pop("group") attr["groups"] = groups # infer pads for auto_pad + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - input_shape = infer_shape(inputs[0]) - in_h, in_w = input_shape[2], input_shape[3] - stride_h, stride_w = attr["strides"] - kernel_h, kernel_w = attr["kernel_shape"] - dilation_h, dilation_w = attr["dilations"] - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h) - pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w) - attr["pads"] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1]) + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad( + data, + attr["strides"], + attr["kernel_shape"], + attr["dilations"], + ndim, + deconv=True, + ) elif attr["auto_pad"] == "VALID": - attr["pads"] = (0, 0) + attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": pass else: @@ -426,12 +456,13 @@ def _impl_v1(cls, inputs, attr, params): op_name=dimension_picker("conv", "_transpose"), transforms={ "kernel_shape": "kernel_size", - "dilations": ("dilation", (0, 0)), - "pads": ("padding", (0, 0), revert_caffe2_pad), + "dilations": ("dilation", 1), + "pads": ("padding", 0), + "group": ("groups", 1), }, disables=["output_shape"], custom_check=dimension_constraint(), - )(inputs[:2], attr, params) + )([data, inputs[1]], attr, params) use_bias = len(inputs) == 3 if use_bias: out = _op.nn.bias_add(out, inputs[2]) @@ -492,25 +523,46 @@ class MatMul(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) # Need to check input shape as batch matmul must be supported. - a_shape = infer_shape(inputs[0]) + a_shape = _op.shape_of(inputs[0]) # When performing a batch matmul, we need to properly handle N-dim shapes. - if len(a_shape) > 2: - b_shape = infer_shape(inputs[1]) + if infer_shape(a_shape)[0] > 2: + b_shape = _op.shape_of(inputs[1]) + + def flatten_to_3d(x, x_shape): + ndims = infer_shape(x_shape)[0] + newshape = _op.concatenate( + [_expr.const([-1]), _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0 + ) + out = _op.reshape(x, newshape) + return out + # Convert a and b into 3 dimensional tensors. - a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]]) - b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]]) + a = flatten_to_3d(inputs[0], a_shape) + b = flatten_to_3d(inputs[1], b_shape) # Broadcast b to match batch size of a - new_b_shape = list(infer_shape(b)) - new_a_shape = infer_shape(a) - if new_a_shape[0] > new_b_shape[0]: - new_b_shape[0] = new_a_shape[0] - b = _op.broadcast_to(b, new_b_shape) + new_b_shape = _op.concatenate( + [ + _op.strided_slice(_op.shape_of(a), [0], [1]), + _op.strided_slice(_op.shape_of(b), [1], [3]), + ], + 0, + ) + b = _op.broadcast_to(b, new_b_shape) # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) # Reshape output to original dimensions. - return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + final_shape = _op.concatenate( + [ + _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]), + _op.strided_slice( + b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] + ), + ], + 0, + ) + return _op.reshape(output, final_shape) # Otherwise a simple dense op will get the job done. input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t) @@ -545,23 +597,18 @@ class LpPool(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - input_shape = infer_shape(inputs[0]) dtype = infer_type(inputs[0]).checked_type.dtype - + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - pad_tuple = [] - for axis in range(len(input_shape) - 2): - axis_shape = input_shape[2 + axis] - stride = attr["strides"][axis] - kernel = attr["kernel_shape"][axis] - pad = get_pad_pair(axis_shape, kernel, stride) - pad_tuple.append(pad) - pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) - attr["pads"] = pad_tuple + # Warning: LpPool does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim) elif attr["auto_pad"] == "VALID": - attr["pads"] = 0 + attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": pass else: @@ -578,7 +625,7 @@ def _impl_v1(cls, inputs, attr, params): p = _expr.const(attr["p"], dtype) reci_p = _expr.const(1.0 / attr["p"], dtype) - inputs[0] = _op.power(inputs[0], p) + data = _op.power(data, p) out = AttrCvt( op_name=dimension_picker("avg_pool"), @@ -586,7 +633,7 @@ def _impl_v1(cls, inputs, attr, params): extras={"count_include_pad": True}, ignores=["p"], custom_check=dimension_constraint(), - )(inputs, attr, params) + )([data], attr, params) kernels = attr["kernel_shape"] out = _op.abs(out) * _expr.const(np.prod(kernels).astype(dtype)) return _op.power(out, reci_p) @@ -651,27 +698,23 @@ def _impl_v2(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): - pad_width = [] - pads = infer_value_simulated(inputs[1], params).asnumpy() + pads = inputs[1] if len(inputs) == 3: - value = infer_value_simulated(inputs[2], params).asnumpy().item() + value = _op.take(inputs[2], _op.const(0)) else: value = 0 - attr["pad_value"] = value - dims = int(len(pads) / 2) - for i in range(dims): - pad_width.append((pads[i], pads[i + dims])) - attr["pad_width"] = pad_width + + pads_shape = infer_shape(pads) + dims = int(pads_shape[0] / 2) + pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims))) pad_mode = attr.get("mode", b"constant").decode("utf-8") - if pad_mode in ["constant", "edge", "reflect"]: - attr["pad_mode"] = pad_mode - attr.pop("mode", None) - else: + + if not pad_mode in ["constant", "edge", "reflect"]: raise tvm.error.OpAttributeInvalid( "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.' ) - return AttrCvt("pad")(inputs[:1], attr, params) + return _op.nn.pad(inputs[0], pad_width_expr, value, pad_mode=pad_mode) class ParametricSoftPlus(OnnxOpConverter): @@ -736,9 +779,7 @@ def _impl_v5(cls, inputs, attr, params): shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32")) out = _op.reshape(inputs[0], shape) else: - data, shape = inputs - static_shape = infer_value_simulated(shape, params) - out = _op.reshape(data, newshape=tuple(static_shape.asnumpy().astype("int32"))) + out = _op.reshape(*inputs) return out @@ -883,17 +924,22 @@ class Upsample(OnnxOpConverter): @classmethod def _impl_v9(cls, inputs, attr, params): scales = attr.get("scales") + + input_shape = infer_shape(inputs[0]) + dims = len(input_shape) + if not scales: # Here we are going to higher OPSET version. - assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) + assert len(inputs) == 2, "Upsample op takes 2 inputs, {} given".format(len(inputs)) + if get_name(inputs[1]) in params: scales = params[inputs[1].name_hint].asnumpy() else: - scales = infer_value_simulated(inputs[1], params).asnumpy() - inputs = inputs[:1] - assert scales[0] == 1.0 and scales[1] == 1.0 - input_shape = infer_shape(inputs[0]) - dims = len(input_shape) + scales = inputs[1] + + if not isinstance(scales, _expr.Call): + assert scales[0] == 1.0 and scales[1] == 1.0 + mode = attr.get("mode") if mode == b"nearest": method = "nearest_neighbor" @@ -903,21 +949,47 @@ def _impl_v9(cls, inputs, attr, params): raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode) ) - attr = {"scale_h": scales[-2], "scale_w": scales[-1], "method": method} + + if method == "nearest_neighbor": + align_corners = False + else: + align_corners = True + # in 3d case, we use the purely static op if dims == 5: - assert len(scales) == 5 - attr["scale_d"] = scales[-3] - attr["layout"] = "NCDHW" - op_name = "upsampling3d" + if isinstance(scales, _expr.Call): + scale_h = _op.take(scales, _op.const(3)) + scale_w = _op.take(scales, _op.const(4)) + scale_d = _op.take(scales, _op.const(1)) + else: + assert len(scales) == 5 + scale_h = scales[-2] + scale_w = scales[-1] + scale_d = scales[-3] + + layout = "NCDHW" + out = _op.nn.upsampling3d( + inputs[0], scale_d, scale_h, scale_w, layout=layout, method=method + ) + # in 2d case, use dynamic op else: - assert len(scales) == 4 - attr["layout"] = "NCHW" - if method == "nearest_neighbor": - attr["align_corners"] = False + if isinstance(scales, _expr.Call): + scale_h = _op.take(scales, _op.const(3)) + scale_w = _op.take(scales, _op.const(4)) else: - attr["align_corners"] = True - op_name = "upsampling" - return AttrCvt(op_name)(inputs, attr) + assert len(scales) == 4 + scale_h = scales[-2] + scale_w = scales[-1] + layout = "NCHW" + + out = _op.nn.upsampling( + inputs[0], + scale_h, + scale_w, + layout=layout, + method=method, + align_corners=align_corners, + ) + return out class Shape(OnnxOpConverter): @@ -970,8 +1042,7 @@ def _impl_v1(cls, inputs, attr, params): attr["indices_or_sections"].append(index) # When splits isnt specified divide evenly over axis. else: - in_shape = infer_shape(inputs[0]) - attr["indices_or_sections"] = in_shape[attr["axis"]] + attr["indices_or_sections"] = attr["tvm_custom"]["num_outputs"] return AttrCvt("split", ignores=["split"])(inputs, attr, params) @@ -1022,38 +1093,35 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v10(cls, inputs, attr, params): - attrs = {"starts": inputs[1], "ends": inputs[2]} - if len(inputs) >= 4: - attrs["axes"] = inputs[3] - if len(inputs) >= 5: - attrs["steps"] = inputs[4] - - attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()} - attrs = { - k: params[v[1]].asnumpy() - if v[1] in params - else infer_value_simulated(v[0], params).asnumpy() - for (k, v) in attrs.items() - } + starts = inputs[1] + ends = inputs[2] + axes = inputs[3] + steps = inputs[4] - # Update the starts and ends according to axes if required. - if "axes" in attrs and max(attrs["axes"] + 1) != len(attrs["axes"]): - new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"]) - attrs["starts"] = new_starts - attrs["ends"] = new_ends + data_rank = len(infer_shape(inputs[0])) - begins = list(attrs["starts"]) - ends = list(attrs["ends"]) - strides = [1] * len(begins) + # Update the starts and ends according to axes if required. + if axes is not None: + data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) + starts = _op.scatter( + _op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype), + axes, + starts, + axis=0, + ) + ends = _op.scatter(data_shape, axes, ends, axis=0) + if steps is not None: + steps = _op.scatter( + _op.const([1] * data_rank, dtype=infer_type(steps).checked_type.dtype), + axes, + steps, + axis=0, + ) - if "steps" in attrs: - steps = list(attrs["steps"]) - axes = attrs["axes"] - assert len(steps) == len(axes) - for axis, step in zip(axes, steps): - strides[axis] = step + if steps is None: + steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype) - return _op.strided_slice(inputs[0], begin=begins, end=ends, strides=strides) + return _op.strided_slice(inputs[0], starts, ends, steps) class Gather(OnnxOpConverter): @@ -1337,8 +1405,6 @@ def _impl_v9(cls, inputs, attr, params): off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1)) # Extract the datatype of the output from on_value. dtype = infer_type(on_value).checked_type.dtype - # Convert depth into an integer. - depth = int(infer_value(depth, params).asnumpy()[0]) # set default value when axis is not set in the model if "axis" not in attr: attr["axis"] = -1 @@ -1357,8 +1423,7 @@ def _impl_v9(cls, inputs, attr, params): else: value = _expr.const(0) dtype = "float32" - static_shape = infer_value_simulated(inputs[0], params) - output = _op.full(value, shape=tuple(static_shape.asnumpy().astype("int32")), dtype=dtype) + output = _op.full(value, inputs[0], dtype=dtype) return output @@ -1406,8 +1471,7 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v6(cls, inputs, attr, params): - reps = tuple(infer_value_simulated(inputs[1], params).asnumpy().astype("int32")) - return _op.tile(inputs[0], reps) + return _op.tile(inputs[0], inputs[1]) class Erf(OnnxOpConverter): @@ -1466,11 +1530,9 @@ class Expand(OnnxOpConverter): @classmethod def _impl_v8(cls, inputs, attr, params): - in_shape = np.array(infer_shape(inputs[0])).astype("int32") - if get_name(inputs[1]) in params: - shape = params[inputs[1].name_hint].asnumpy().astype("int32") - else: - shape = infer_value_simulated(inputs[1], params).asnumpy().astype("int32") + dtype = infer_type(inputs[1]).checked_type.dtype + in_shape = _op.shape_of(inputs[0], dtype=dtype) + shape = inputs[1] # Currently 'op.broadcast_to' expect the rank of the given 'shape' # (the 2nd input) is always higher than that of the given 'input' (the 1st input) @@ -1485,28 +1547,41 @@ def expand_shape(in_shape, shape): intput. Also it replaces the extent of the shape with the corresponding extent of the intput when it is 1. """ - - # here we flip the shapes because this can be more simply written - # when the innermost dimension is located at the index 0. - in_shape = np.flip(in_shape, axis=0) - shape = np.flip(shape, axis=0) - - if in_shape.size < shape.size: - for i in range(shape.size): - if i < in_shape.size and in_shape[i] > shape[i]: - shape[i] = in_shape[i] - else: - for i in range(in_shape.size): - if i >= shape.size: - np.append(shape, in_shape[i]) - elif shape[i] == 1: - shape[i] = in_shape[i] - - new_shape = np.flip(shape, axis=0) + in_dims = infer_shape(in_shape)[0] + new_dims = infer_shape(shape)[0] + if in_dims < new_dims: + in_shape = _op.concatenate( + [ + _expr.const( + [ + 1, + ] + * (new_dims - in_dims), + dtype=dtype, + ), + in_shape, + ], + axis=0, + ) + elif new_dims > in_dims: + shape = _op.concatenate( + [ + _expr.const( + [ + 1, + ] + * (in_dims - new_dims), + dtype=dtype, + ), + shape, + ], + axis=0, + ) + new_shape = _op.maximum(in_shape, shape) return new_shape shape = expand_shape(in_shape, shape) - return _op.broadcast_to(inputs[0], shape=tuple(shape)) + return _op.broadcast_to(inputs[0], shape=shape) class RNN(OnnxOpConverter): @@ -1779,14 +1854,18 @@ def _impl_v11(cls, inputs, attr, params): 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) ) - in_size = np.array(infer_shape(inputs[0])) - scale = infer_value_simulated(inputs[2], params).asnumpy() + scale = inputs[2] + scale_shape = infer_shape(scale) if len(inputs) == 4: - assert len(scale) == 0, "One of scale or size should be passed, not both." - size = infer_value_simulated(inputs[3], params).asnumpy().astype(np.int32) + assert ( + len(scale_shape) == 0 or scale_shape[0] == 0 + ), "One of scale or size should be passed, not both." + size = inputs[3] else: - assert len(scale) != 0, "One of scale or size should be passed." - size = (in_size * scale).astype(np.int32) + assert len(scale_shape) != 0, "One of scale or size should be passed." + size = ( + _op.cast(_op.shape_of(inputs[0]), infer_type(scale).type_annotation.dtype) * scale + ) coord_trans = attr.get("coordinate_transformation_mode") if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]: @@ -1800,7 +1879,7 @@ def _impl_v11(cls, inputs, attr, params): "Unsupported coordinate_transformation_mode: {}".format(coord_trans) ) layout = "NCHW" # ONNX assumes NCHW layout - out_size = (size[2], size[3]) + out_size = _op.strided_slice(size, [2], [4]) return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) @@ -1831,9 +1910,7 @@ def _impl_v1(cls, inputs, attr, params): if largest == 0: raise ValueError("TVM only supports finding TopK largest elements") - K = int(infer_value(inputs[1], params).asnumpy()[0]) - - return _op.topk(inputs[0], k=K, axis=axis) + return _op.topk(inputs[0], inputs[1], axis=axis) class MaxRoiPool(OnnxOpConverter): @@ -1898,7 +1975,7 @@ def _impl_v11(cls, inputs, attr, params): assert len(inputs) <= 3, "Clip-11 takes up to 3 inputs, input, min, max" result = inputs[0] - for i, op in enumerate([_maximum, _minimum]): + for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]): if i < len(inputs) - 1: result = op(result, inputs[i + 1]) return result @@ -2061,7 +2138,7 @@ def _get_convert_map(opset): } -class GraphProto(ExprFunctor): +class GraphProto: """A helper class for handling Relay expression copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto @@ -2077,108 +2154,22 @@ class GraphProto(ExprFunctor): def __init__(self, shape, dtype): self._nodes = {} self._params = {} + self._inputs = {} self._renames = {} self._num_input = 0 self._num_param = 0 self._shape = shape if shape else {} self._dtype = dtype - # For infering Values - self._tmp_params = {} - self._infer_simulated = True - self._mod = None - super(GraphProto, self).__init__() - - def infer_value(self, input_val, params, mod=None): - self._tmp_params = params - self._infer_simulated = False - self._mod = mod - return self.visit(input_val).data - - def infer_value_simulated(self, input_val, params): - self._tmp_params = params - self._infer_simulated = True - return self.visit(input_val).data - - def infer(self, expr): - if self._infer_simulated: - out = _infer_value_simulated(expr, self._tmp_params) - else: - out = _infer_value(expr, self._tmp_params) - return _expr.const(out.asnumpy()) - - def visit_function(self, fn): - new_params = [self.visit(x) for x in fn.params] - new_body = self.visit(fn.body) - return self.infer( - Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs) - ) - - def visit_let(self, let): - newvar = self.visit(let.var) - newval = self.visit(let.value) - newbody = self.visit(let.body) - return self.infer(Let(newvar, newval, newbody)) - - def visit_call(self, call): - new_fn = self.visit(call.op) - new_args = [self.visit(arg) for arg in call.args] - call = Call(new_fn, new_args, call.attrs) - if new_fn == _op.get("nn.batch_norm"): - return call - return self.infer(call) - - def visit_var(self, var): - return self.infer(var) - - def visit_global_id(self, global_var): - return self.infer(global_var) - - def visit_if(self, ite): - return self.infer( - If(self.visit(ite.cond), self.visit(ite.true_branch), self.visit(ite.false_branch)) - ) - - def visit_tuple(self, tup): - return Tuple([self.visit(field) for field in tup.fields]) - - def visit_tuple_getitem(self, op): - tuple_value = self.visit(op.tuple_value) - if not tuple_value.same_as(op.tuple_value): - return self.infer(TupleGetItem(tuple_value, op.index)) - return self.infer(op) - - def visit_global_var(self, gvar): - return self.infer(gvar) - - def visit_op(self, op): - return op - - def visit_constant(self, const): - return const + def freeze(self, func, params): + bind_map = {} + for name in params.keys(): + bind_map[self._nodes[name]] = _expr.const(params[name]) + body = _expr.bind(func.body, bind_map) + fn = _function.Function(analysis.free_vars(body), body) + return fn, {} - def visit_constructor(self, con): - return con - - def visit_match(self, m): - return self.infer( - Match( - self.visit(m.data), - [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses], - complete=m.complete, - ) - ) - - def visit_ref_create(self, r): - return RefCreate(self.visit(r.value)) - - def visit_ref_write(self, r): - return RefWrite(self.visit(r.ref), self.visit(r.value)) - - def visit_ref_read(self, r): - return RefRead(self.visit(r.ref)) - - def from_onnx(self, graph, opset): + def from_onnx(self, graph, opset, freeze_params=False): """Construct Relay expression from ONNX graph. Onnx graph is a python protobuf object. @@ -2195,6 +2186,13 @@ def from_onnx(self, graph, opset): opset : opset version + freeze_params: bool + If this parameter is true, the importer will take any provided + onnx input values (weights, shapes, etc) and embed them into the relay model + as Constants instead of variables. This allows more aggressive optimizations + at compile time and helps in making models static if certain inputs represent + attributes relay would traditionally consider compile-time constants. + Returns ------- mod : tvm.IRModule @@ -2236,6 +2234,7 @@ def from_onnx(self, graph, opset): else: dtype = d_type self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype) + self._inputs[i_name] = self._nodes[i_name] # get list of unsupported ops convert_map = _get_convert_map(opset) unsupported_ops = set() @@ -2271,11 +2270,12 @@ def from_onnx(self, graph, opset): ) else: i_name = self._parse_value_proto(node) + node_output = self._fix_outputs(op_name, node.output) attr["tvm_custom"] = {} attr["tvm_custom"]["name"] = i_name + attr["tvm_custom"]["num_outputs"] = len(node_output) op = self._convert_operator(op_name, inputs, attr, opset) - node_output = self._fix_outputs(op_name, node.output) if not isinstance(op, _expr.TupleWrapper): outputs_num = 1 else: @@ -2294,7 +2294,18 @@ def from_onnx(self, graph, opset): # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _function.Function(analysis.free_vars(outputs), outputs) + ## Maintain the order of inputs and parameters from the ONNX graph, but only include + ## those parameters that are needed to execute the relay graph + free_vars = analysis.free_vars(outputs) + nodes = {v: k for k, v in self._nodes.items()} + free_vars = [nodes[var] for var in free_vars] + for i_name in self._params: + if i_name in free_vars and i_name not in self._inputs: + self._inputs[i_name] = self._nodes[i_name] + func = _function.Function([v for k, v in self._inputs.items()], outputs) + if freeze_params: + func, params = self.freeze(func, self._params) + return IRModule.from_expr(func), params return IRModule.from_expr(func), self._params def _parse_value_proto(self, value_proto): @@ -2388,7 +2399,7 @@ def _fix_outputs(self, op_name, outputs): return outputs -def from_onnx(model, shape=None, dtype="float32", opset=None): +def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=False): """Convert a ONNX model into an equivalent Relay Function. ONNX graphs are represented as Python Protobuf objects. @@ -2398,6 +2409,13 @@ def from_onnx(model, shape=None, dtype="float32", opset=None): For convenience, we rename the `real` input names to "input_0", "input_1"... And renaming parameters to "param_0", "param_1"... + By default, ONNX defines models in terms of dynamic shapes. The ONNX importer + retains that dynamism upon import, and the compiler attempts to convert the + model into a static shapes at compile time. If this fails, there may still + be dynamic operations in the model. Not all TVM kernels currently support + dynamic shapes, please file an issue on discuss.tvm.ai + if you hit an error with dynamic kernels. + Parameters ---------- model : protobuf object @@ -2413,6 +2431,13 @@ def from_onnx(model, shape=None, dtype="float32", opset=None): Override to autodetected opset. This can be helpful for some testing. + freeze_params: bool + If this parameter is true, the importer will take any provided + onnx input values (weights, shapes, etc) and embed them into the relay model + as Constants instead of variables. This allows more aggressive optimizations + at compile time and helps in making models static if certain inputs represent + attributes relay would traditionally consider compile-time constants. + Returns ------- mod : tvm.IRModule @@ -2435,7 +2460,6 @@ def from_onnx(model, shape=None, dtype="float32", opset=None): warnings.warn(str(e)) except ImportError: pass - global g g = GraphProto(shape, dtype) graph = model.graph if opset is None: @@ -2443,6 +2467,5 @@ def from_onnx(model, shape=None, dtype="float32", opset=None): opset = model.opset_import[0].version if model.opset_import else 1 except AttributeError: opset = 1 - mod, params = g.from_onnx(graph, opset) - g = None + mod, params = g.from_onnx(graph, opset, freeze_params) return mod, params diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 907c512c4a7c..e9e608b578e1 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -241,6 +241,7 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("multiply", False, broadcast_shape_func) register_shape_func("divide", False, broadcast_shape_func) register_shape_func("floor_divide", False, broadcast_shape_func) +register_shape_func("power", False, broadcast_shape_func) register_shape_func("mod", False, broadcast_shape_func) register_shape_func("floor_mod", False, broadcast_shape_func) register_shape_func("logical_and", False, broadcast_shape_func) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 6694b5a5fd75..c83f6a943a31 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -722,29 +722,18 @@ def compute_space_to_depth(attrs, inputs, out_dtype): @script -def _conv2d_shape_func(dshape, kshape, strides, padding, dilation): +def _conv_shape_func(dshape, kshape, strides, padding, dilation): out = output_tensor((dshape.shape[0],), "int64") - height = dshape[2] - width = dshape[3] - kheight = kshape[2] - kwidth = kshape[3] - dilated_kh = (kheight - 1) * dilation[0] + 1 - dilated_kw = (kwidth - 1) * dilation[1] + 1 - - oc = kshape[0] - - out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1 - out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1 - out[0] = dshape[0] - out[1] = oc - out[2] = out_height - out[3] = out_width + out[1] = kshape[0] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i + 2] - 1) * dilation[i] + 1 + out[i + 2] = (dshape[i + 2] + 2 * padding[i] - dilated_k) // strides[i] + 1 return out -@reg.register_shape_func("nn.conv2d", False) -def conv2d_shape_func(attrs, inputs, _): +def conv_shape_func(attrs, inputs, _): """ Shape function for contrib_conv2d_NCHWc op. """ @@ -753,7 +742,7 @@ def conv2d_shape_func(attrs, inputs, _): dilation = get_const_tuple(attrs.dilation) return [ - _conv2d_shape_func( + _conv_shape_func( inputs[0], inputs[1], convert(strides), @@ -763,6 +752,11 @@ def conv2d_shape_func(attrs, inputs, _): ] +reg.register_shape_func("nn.conv1d", False, conv_shape_func) +reg.register_shape_func("nn.conv2d", False, conv_shape_func) +reg.register_shape_func("nn.conv3d", False, conv_shape_func) + + @script def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn): out = output_tensor((dshape.shape[0],), "int64") @@ -968,6 +962,25 @@ def dense_shape_func(attrs, inputs, _): return ret +@script +def _batch_matmul_shape_func(data_shape, weight_shape): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0] - 1): + out[i] = data_shape[i] + out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2] + + return out + + +@reg.register_shape_func("nn.batch_matmul", False) +def batch_matmul_shape_func(attrs, inputs, _): + """ + Shape function for dense op. + """ + ret = [_batch_matmul_shape_func(inputs[0], inputs[1])] + return ret + + @script def _pad_shape_func(data_shape, pad_width): out = output_tensor((data_shape.shape[0],), "int64") diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 68889f3638a9..56ae97652b79 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -683,7 +683,7 @@ def wrap_compute_batch_matmul(topi_compute): """wrap batch_matmul topi compute""" def _compute_batch_matmul(attrs, inputs, out_type): - return [topi_compute(inputs[0], inputs[1])] + return [topi_compute(inputs[0], inputs[1], out_type.shape)] return _compute_batch_matmul diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 8925723d9916..e2a82d396b22 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -21,6 +21,7 @@ import re from tvm import topi from tvm.te import SpecializedCondition +from tvm.relay.ty import is_dynamic from .generic import * from .. import op as _op @@ -355,12 +356,20 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_batch_matmul(topi.x86.batch_matmul), - wrap_topi_schedule(topi.x86.schedule_batch_matmul), - name="batch_matmul.x86", - plevel=10, - ) + if is_dynamic(out_type): + strategy.add_implementation( + wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), + name="batch_matmul.generic", + plevel=10, + ) + else: + strategy.add_implementation( + wrap_compute_batch_matmul(topi.x86.batch_matmul), + wrap_topi_schedule(topi.x86.schedule_batch_matmul), + name="batch_matmul.x86", + plevel=10, + ) if "cblas" in target.libs: strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas), diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 26647dd14e52..bb060b3ad8a7 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -26,7 +26,7 @@ @autotvm.register_topi_compute("batch_matmul.cuda") -def batch_matmul(cfg, x, y): +def batch_matmul(cfg, x, y, out_shape=None): """Compute conv2d with NCHW layout""" return nn.batch_matmul(x, y) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 7c8fead569ae..34a8c6dafc87 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -20,7 +20,7 @@ from ..util import get_const_tuple -def batch_matmul(x, y): +def batch_matmul(x, y, oshape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -37,14 +37,19 @@ def batch_matmul(x, y): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - x_shape = get_const_tuple(x.shape) - y_shape = get_const_tuple(y.shape) - assert x_shape[0] == y_shape[0], "batch dimension doesn't match" - assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" - batch, M, K = x.shape - N = y.shape[1] - k = te.reduce_axis((0, K), name="k") + if oshape is None: + assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" + x_shape = get_const_tuple(x.shape) + y_shape = get_const_tuple(y.shape) + assert x_shape[0] == y_shape[0], "batch dimension doesn't match" + assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" + batch, M, K = x.shape + N = y.shape[1] + k = te.reduce_axis((0, K), name="k") + oshape = (batch, M, N) + else: + _, _, K = x.shape + k = te.reduce_axis((0, K), name="k") return te.compute( - (batch, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" + oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" ) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 333d3bed278b..c095dcb0b6bb 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -25,7 +25,7 @@ @autotvm.register_topi_compute("batch_matmul.x86") -def batch_matmul(cfg, x, y): +def batch_matmul(cfg, x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -49,6 +49,10 @@ def batch_matmul(cfg, x, y): assert XK == YK, "shapes of x and y is inconsistant" B = XB K = XK + if out_shape is not None: + assert out_shape[0] == B, "got invalid output shape" + assert out_shape[1] == M, "got invalid output shape" + assert out_shape[2] == N, "got invalid output shape" if cfg.is_fallback: _default_batch_matmul_config(cfg, M, N, K) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 21fd5915a806..b95e0962bd27 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -263,6 +263,9 @@ class RelayBuildModule : public runtime::ModuleNode { pass_seqs.push_back(transform::Legalize()); } + // Convert Dynamic ops to static versions + pass_seqs.push_back(transform::DynamicToStatic()); + pass_seqs.push_back(transform::SimplifyInference()); PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { Expr expr = args[0]; diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index de1cc5a4ed95..4b594ffccfa5 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -58,6 +58,11 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, Array oshape; const auto* newshape = types[1].as(); + if (newshape == nullptr) { + CHECK(types[1].as()) + << "reshape: expect input type to be TensorType but get " << types[1]; + return false; + } // Doesn't support dynamic output rank for (int i = 0; i < newshape->shape[0].as()->value; i++) { @@ -209,10 +214,17 @@ bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs // types = [data_type, broadcast_shape_type, ret_type] CHECK_EQ(types.size(), 3); - const auto* target_shape = types[1].as(); - DataType out_dtype = types[0].as()->dtype; + const auto* input_type = types[0].as(); + const auto* target_type = types[1].as(); + if (target_type == nullptr) { + return false; + } + if (input_type == nullptr) { + return false; + } + auto out_dtype = input_type->dtype; // rank must be static - const IntImmNode* rank = target_shape->shape[0].as(); + const IntImmNode* rank = target_type->shape[0].as(); CHECK(rank) << "Target shape must have static rank"; // rank must be static even in dyn pass // could add support for dyn rank in futures diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index f53f4e0454a4..2311585deb60 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -100,7 +100,9 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, << "Conv1D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1])); + if (!dshape_ncw[1].as() && !wshape[1].as()) { + CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1])); + } channels = wshape[0]; dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0]; } @@ -211,7 +213,9 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, << "Conv2D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); + if (!dshape_nchw[1].as() && !wshape[1].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); + } channels = wshape[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; @@ -322,7 +326,9 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, << "Conv3D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); + if (!dshape_ncdhw[1].as() && !wshape[1].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); + } channels = wshape[0]; dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1]; @@ -800,7 +806,9 @@ bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& a << "Conv1D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << Array(wshape); } - CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); + if (!dshape_ncw[1].as() && !wshape[0].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); + } channels = wshape[1]; dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0]; } @@ -808,8 +816,12 @@ bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& a IndexExpr pad_w; GetPaddingWidth(param->padding, &pad_w); Array oshape({dshape_ncw[0], channels, 0}); - oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + - param->output_padding[0])); + if (!dshape_ncw[2].as()) { + oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_ncw[2]); + } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -890,7 +902,9 @@ bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& a << "Conv3D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << Array(wshape); } - CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); + if (!dshape_ncdhw[1].as() && !wshape[0].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); + } channels = wshape[1]; dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; @@ -901,12 +915,25 @@ bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& a Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); - oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + - param->output_padding[0])); - oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + - param->output_padding[1])); - oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + - param->output_padding[2])); + + if (!dshape_ncdhw[2].as()) { + oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_ncdhw[2]); + } + if (!dshape_ncdhw[3].as()) { + oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + + param->output_padding[1])); + } else { + oshape.Set(3, dshape_ncdhw[3]); + } + if (!dshape_ncdhw[4].as()) { + oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + + param->output_padding[2])); + } else { + oshape.Set(4, dshape_ncdhw[4]); + } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -985,7 +1012,9 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a << "Conv2D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << Array(wshape); } - CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); + if (!dshape_nchw[1].as() && !wshape[0].as()) { + CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); + } channels = wshape[1]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; @@ -994,10 +1023,18 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a Array oshape({dshape_nchw[0], channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + - param->output_padding[0])); - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + - param->output_padding[1])); + if (!dshape_nchw[2].as()) { + oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + + param->output_padding[0])); + } else { + oshape.Set(2, dshape_nchw[2]); + } + if (!dshape_nchw[3].as()) { + oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + + param->output_padding[1])); + } else { + oshape.Set(3, dshape_nchw[3]); + } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -1053,7 +1090,9 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& << "DeformableConv2D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); + if (!data->shape[1].as() && !wshape[1].as()) { + CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); + } channels = wshape[0]; ksize_y = wshape[2]; ksize_x = wshape[3]; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 619b86d358d1..38ebe421d38d 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -851,15 +851,26 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 3 && y->shape.size() == 3); - CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - << "BatchDot: batch dimension doesn't match, " - << " x shape=" << x->shape << ", y shape=" << y->shape; - CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) - << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape << ", y shape=" << y->shape; - - Array oshape = x->shape; - oshape.Set(2, y->shape[1]); + bool is_dyn = false; + Array oshape; + for (size_t i = 0; i < 3; ++i) { + if (x->shape[i].as() != nullptr || y->shape[i].as() != nullptr) { + is_dyn = true; + oshape.push_back(Any()); + } else { + oshape.push_back(x->shape[i]); + } + } + if (!is_dyn) { + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) + << "BatchDot: batch dimension doesn't match, " + << " x shape=" << x->shape << ", y shape=" << y->shape; + CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) + << "BatchDot: shapes of x and y is inconsistent, " + << " x shape=" << x->shape << ", y shape=" << y->shape; + + oshape.Set(2, y->shape[1]); + } // assign output type reporter->Assign(types[2], TensorType(oshape, x->dtype)); @@ -1021,9 +1032,15 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(1, indexdiv(oshape[1], (block_size * block_size))); - oshape.Set(2, oshape[2] * block_size); - oshape.Set(3, oshape[3] * block_size); + if (!oshape[1].as()) { + oshape.Set(1, indexdiv(oshape[1], (block_size * block_size))); + } + if (!oshape[2].as()) { + oshape.Set(2, oshape[2] * block_size); + } + if (!oshape[3].as()) { + oshape.Set(3, oshape[3] * block_size); + } // Assign output type reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); @@ -1078,9 +1095,15 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(1, oshape[1] * (block_size * block_size)); - oshape.Set(2, indexdiv(oshape[2], block_size)); - oshape.Set(3, indexdiv(oshape[3], block_size)); + if (!oshape[1].as()) { + oshape.Set(1, oshape[1] * (block_size * block_size)); + } + if (!oshape[2].as()) { + oshape.Set(2, indexdiv(oshape[2], block_size)); + } + if (!oshape[3].as()) { + oshape.Set(3, indexdiv(oshape[3], block_size)); + } // Assign output type reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 0fb02638db07..e7f5a4b9d618 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -63,9 +63,11 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight == nullptr) return false; Array wshape = weight->shape; CHECK(static_cast(weight->shape.size()) == 2); - CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) - << "DenseRel: input dimension doesn't match," - << " data shape=" << data->shape << ", weight shape=" << weight->shape; + if (!data->shape.back().as()) { + CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) + << "DenseRel: input dimension doesn't match," + << " data shape=" << data->shape << ", weight shape=" << weight->shape; + } oshape.Set((oshape.size() - 1), wshape[0]); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 8d2d39184ce9..16495860aa96 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1822,9 +1822,9 @@ bool SqueezeRel(const Array& types, int num_inputs, const Attrs& attrs, if (p.second) { result_shape.push_back(p.first); } else { - const int64_t* axis_ptr = tir::as_const_int(p.first); - CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor"; - CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1"; + if (const int64_t* axis_ptr = tir::as_const_int(p.first)) { + CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1"; + } } } } @@ -2028,7 +2028,9 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const StridedSliceAttrs* param = attrs.as(); - CHECK(param != nullptr); + if (param == nullptr) { + return false; + } const auto* data = types[0].as(); if (data == nullptr) { diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 113b599579ab..edcb83972cc7 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -227,6 +227,9 @@ Expr DynamicToStatic(Function f, IRModule m) { vars.Set(kv.second, kv.first); } const auto gv = vars[f]; + // Put a limit on the while loop + // Primarily used to prevent accidental infinite lops in development + const int loop_limit = 1000; int i = 0; do { pre = expr; @@ -236,13 +239,13 @@ Expr DynamicToStatic(Function f, IRModule m) { expr = mutator.Mutate(m->functions[gv]); m->Update(gv, Downcast(expr)); i += 1; - } while (pre != expr && i < 1000); + } while (!StructuralEqual()(pre, expr) && i < loop_limit); return expr; } namespace transform { -Pass ConvertDynamicToStatic() { +Pass DynamicToStatic() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(DynamicToStatic(f, m)); @@ -251,7 +254,7 @@ Pass ConvertDynamicToStatic() { } TVM_REGISTER_GLOBAL("relay._transform.DynamicToStatic").set_body_typed([]() { - return ConvertDynamicToStatic(); + return DynamicToStatic(); }); } // namespace transform diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1c0fced6c3ef..1aeb430de52f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -42,12 +42,21 @@ def get_input_data_shape_dict(graph_def, input_data): return input_names, shape_dict -def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None): +def get_tvm_output_with_vm( + graph_def, input_data, target, ctx, opset=None, freeze_params=False, convert_to_static=False +): """ Generic function to execute and get tvm output with vm executor""" - + if not isinstance(input_data, list): + input_data = [input_data] _, shape_dict = get_input_data_shape_dict(graph_def, input_data) - mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) + mod, params = relay.frontend.from_onnx( + graph_def, shape_dict, opset=opset, freeze_params=freeze_params + ) + if convert_to_static: + from tvm.relay import transform + + mod = transform.DynamicToStatic()(mod) ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) result = ex.evaluate()(*input_data) @@ -118,6 +127,8 @@ def verify_with_ort_with_inputs( targets=None, use_vm=False, opset=None, + freeze_params=False, + convert_to_static=False, dtype="float32", rtol=1e-5, atol=1e-5, @@ -136,9 +147,16 @@ def flatten(out): for target in targets: ctx = tvm.context(target, 0) - if use_vm: - tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=opset) + tvm_out = get_tvm_output_with_vm( + model, + inputs, + target, + ctx, + opset=opset, + freeze_params=freeze_params, + convert_to_static=convert_to_static, + ) else: tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, dtype, opset=opset) @@ -152,6 +170,8 @@ def verify_with_ort( targets=None, use_vm=False, opset=None, + freeze_params=False, + convert_to_static=False, dtype="float32", rtol=1e-5, atol=1e-5, @@ -164,6 +184,8 @@ def verify_with_ort( targets=targets, use_vm=use_vm, opset=opset, + freeze_params=freeze_params, + convert_to_static=convert_to_static, dtype=dtype, rtol=rtol, atol=atol, @@ -213,21 +235,37 @@ def test_reshape(): tvm.testing.assert_allclose(ref_shape, tvm_out.shape) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_expand(): - def _test_expand(name, data, shape, ref_data): + def _test_expand(name, data, shape, ref_data, dtype="int32"): shape_array = np.array(shape) - shape_node = onnx.helper.make_node( - "Constant", - inputs=[], - outputs=["shape"], - value=onnx.helper.make_tensor( - name="const_tensor", - data_type=onnx.TensorProto.INT32, - dims=shape_array.shape, - vals=shape_array.flatten().astype("int32"), - ), - ) + if dtype == "int32": + shape_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT32, + dims=shape_array.shape, + vals=shape_array.flatten().astype("int32"), + ), + ) + elif dtype == "int64": + shape_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT64, + dims=shape_array.shape, + vals=shape_array.flatten().astype("int64"), + ), + ) + else: + raise "Invalid dtype" expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) graph = helper.make_graph( @@ -240,20 +278,22 @@ def _test_expand(name, data, shape, ref_data): model = helper.make_model(graph, producer_name=name) for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, "float32") + tvm_out = get_tvm_output_with_vm(model, data, target, ctx, freeze_params=True) tvm.testing.assert_allclose(ref_data, tvm_out) in_shape = (3, 1) shape = (3, 4) data = np.random.uniform(size=in_shape).astype(np.float32) ref_data = np.tile(data, 4) - _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data) + _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data, "int32") + _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data, "int64") in_shape = (3, 1) shape = (2, 1, 6) data = np.random.uniform(size=in_shape).astype(np.float32) ref_data = data * np.ones(shape, dtype=np.float32) - _test_expand("expand_with_dim_changed_test", data, shape, ref_data) + _test_expand("expand_with_dim_changed_test", data, shape, ref_data, "int32") + _test_expand("expand_with_dim_changed_test", data, shape, ref_data, "int64") def verify_depth_to_space(inshape, outshape, mode, blockSize): @@ -650,11 +690,12 @@ def add_noop_to_input_attr(attr_name, attr): model = helper.make_model(graph, producer_name="slice_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=10) + tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=10, freeze_params=True) tvm.testing.assert_allclose(outdata, tvm_out) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_slice(): x = np.random.randn(20, 10, 5).astype(np.float32) _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) @@ -856,12 +897,13 @@ def test_gather_nd(): verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], "float32") -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_onehot(): indices_shape = [10] indices_array = np.random.randint(low=0, high=9, size=indices_shape, dtype="int32") depth = 10 - values = np.asarray([0, 1]) + values = np.asarray([0, 1]).astype("int32") out_np = np.eye(depth)[indices_array.reshape(-1)] onehot_node = helper.make_node("OneHot", ["indices", "depth", "values"], ["out"]) @@ -874,17 +916,15 @@ def test_onehot(): helper.make_tensor_value_info("depth", TensorProto.INT32, [1]), helper.make_tensor_value_info("values", TensorProto.INT32, values.shape), ], - initializer=[ - helper.make_tensor("depth", TensorProto.INT32, [1], [depth]), - helper.make_tensor("values", TensorProto.INT32, values.shape, values), - ], outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)], ) model = helper.make_model(graph, producer_name="onehot_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [indices_array], target, ctx, out_np.shape) + tvm_out = get_tvm_output_with_vm( + model, [indices_array, np.array([depth]).astype("int32"), values], target, ctx + ) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) @@ -916,7 +956,7 @@ def test_matmul(): tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) -def verify_batch_matmul(a_shape, b_shape): +def verify_batch_matmul(a_shape, b_shape, target, ctx): a_array = np.random.uniform(size=a_shape).astype("float32") b_array = np.random.uniform(size=b_shape).astype("float32") out_np = np.matmul(a_array, b_array) @@ -935,16 +975,67 @@ def verify_batch_matmul(a_shape, b_shape): model = helper.make_model(graph, producer_name="matmul_test") - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape) + tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + + +# TODO(mbrookhart): enable cuda once VM supports heterogenous execution +@tvm.testing.parametrize_targets("llvm") +def test_batch_matmul(target, ctx): + verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx) + verify_batch_matmul((2, 4, 3), (3, 4), target, ctx) + verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx) + + +def verify_simple_dynamic_model(a_shape, b_shape, target, ctx): + def verify_model(ex, a_shape, b_shape): + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + # matmul + out_np = np.matmul(a_array, b_array) + # relu + out_np[out_np < 0] = 0 + + tvm_out = ex.evaluate()(a_array, b_array).asnumpy() tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) + relu_node = helper.make_node("Relu", ["out"], ["relu"]) -@tvm.testing.uses_gpu -def test_batch_matmul(): - verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4)) - verify_batch_matmul((2, 4, 3), (3, 4)) - verify_batch_matmul((2, 3, 4, 3), (3, 4)) + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + # matmul + out_np = np.matmul(a_array, b_array) + + graph = helper.make_graph( + [mul_node, relu_node], + "matmul_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("relu", TensorProto.FLOAT, list(out_np.shape))], + ) + + model = helper.make_model(graph, producer_name="matmul_test") + + a_anys = [relay.Any()] * len(a_shape) + b_anys = [relay.Any()] * len(b_shape) + + mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys}) + + ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) + verify_model(ex, a_shape, b_shape) + verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape]) + verify_model(ex, [a * 3 for a in a_shape], [b * 3 for b in b_shape]) + + +# TODO(mbrookhart): enable cuda once VM supports heterogenous execution +@tvm.testing.parametrize_targets("llvm") +def test_batch_matmul_dynamic_model(target, ctx): + verify_simple_dynamic_model((2, 3, 4, 3), (2, 3, 3, 4), target, ctx) + verify_simple_dynamic_model((2, 4, 3), (3, 4), target, ctx) + verify_simple_dynamic_model((2, 3, 4, 3), (3, 4), target, ctx) def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): @@ -1149,8 +1240,9 @@ def _test_upsample_bilinear_opset9(): model = helper.make_model(graph, producer_name="upsample_bilinear_opset9_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, "float32") - tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) + tvm_out = get_tvm_output_with_vm( + model, [in_array], target, ctx, opset=9, freeze_params=True + ) def _test_upsample3d_trilinear(): @@ -1194,7 +1286,8 @@ def _test_upsample3d_trilinear(): tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_upsample(): _test_upsample_nearest() _test_upsample_bilinear() @@ -1475,18 +1568,19 @@ def verify_constantofshape(input_dim, value, dtype): "fill_test", inputs, outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(out.shape))], - initializer=[helper.make_tensor("input", TensorProto.INT32, (len(input_dim),), input_dim)], ) model = helper.make_model(graph, producer_name="fill_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [], target, ctx, out.shape) + input_np = np.array(input_dim).astype("float32") + tvm_out = get_tvm_output_with_vm(model, [input_np], target, ctx) tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_constantofshape(): verify_constantofshape((2, 3, 4, 5), 10, "float32") verify_constantofshape((3, 3), 0, "int32") @@ -1550,7 +1644,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0): ], ) else: - inputs = [indata, pads, np.array([value])] + inputs = [indata, pads, np.array([value]).astype("float32")] outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) node = helper.make_node( "Pad", inputs=["input", "pads", "constant_value"], outputs=["output"], mode="constant" @@ -1561,7 +1655,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0): inputs=[ helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), - helper.make_tensor_value_info("constant_value", TensorProto.INT64, (1,)), + helper.make_tensor_value_info("constant_value", TensorProto.FLOAT, (1,)), ], initializer=[ helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), @@ -1574,11 +1668,12 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0): model = helper.make_model(graph, producer_name="pad_test") # tvm result for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, inputs, target, ctx, outdata.shape, "float32", opset=11) + tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=11, freeze_params=False) tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_pad(): verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0) verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], "constant", 0.0) @@ -1660,20 +1755,28 @@ def test_all_reduce_funcs(): ) -def verify_split(indata, outdatas, split, axis=0): +def verify_split(indata, outdatas, split, axis=0, pass_split=True): indata = np.array(indata).astype(np.float32) outdatas = [np.array(o).astype(np.float32) for o in outdatas] if split: split_index = range(len(split)) else: split_index = range(len(outdatas)) - node = helper.make_node( - "Split", - inputs=["input"], - outputs=["output_{}".format(i) for i in range(len(split_index))], - axis=axis, - split=split, - ) + if pass_split: + node = helper.make_node( + "Split", + inputs=["input"], + outputs=["output_{}".format(i) for i in range(len(split_index))], + axis=axis, + split=split, + ) + else: + node = helper.make_node( + "Split", + inputs=["input"], + outputs=["output_{}".format(i) for i in range(len(split_index))], + axis=axis, + ) graph = helper.make_graph( [node], "split_test", @@ -1687,18 +1790,26 @@ def verify_split(indata, outdatas, split, axis=0): ) model = helper.make_model(graph, producer_name="split_test") + import onnxruntime.backend + + rep = onnxruntime.backend.prepare(model, "CPU") + onnx_out = rep.run(indata) + for target, ctx in tvm.testing.enabled_targets(): output_shape = [o.shape for o in outdatas] output_type = ["float32", "float32", "float32"] tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type) - for o, t in zip(outdatas, tvm_out): - tvm.testing.assert_allclose(o, t) + for o, t in zip(onnx_out, tvm_out): + tvm.testing.assert_allclose(o, t) @tvm.testing.uses_gpu def test_split(): # 1D verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [2, 2, 2], 0) + verify_split( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [2, 2, 2], 0, False + ) verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0], [4.0, 5.0, 6.0]], [2, 1, 3], 0) # 2D verify_split( @@ -1708,7 +1819,7 @@ def test_split(): 1, ) # Split evenly (unstack) - verify_split([1, 2, 3], [[1], [2], [3]], False) + verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False) @tvm.testing.uses_gpu @@ -2098,19 +2209,17 @@ def verify_tile_v6(indata, repeats, outdata): helper.make_tensor_value_info("repeats", TensorProto.INT64, list(repeats.shape)), ], outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], - initializer=[ - helper.make_tensor("repeats", TensorProto.INT64, list(repeats.shape), repeats) - ], ) model = helper.make_model(graph, producer_name="tile_test") for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape, opset=6) + tvm_out = get_tvm_output_with_vm(model, [indata, repeats], target, ctx, opset=6) tvm.testing.assert_allclose(outdata, tvm_out) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_tile(): x = np.random.rand(2, 3, 4, 5).astype(np.float32) repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64) @@ -2283,9 +2392,11 @@ def verify_batch_norm(in_shape): verify_batch_norm([16, 16, 10, 10]) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_batch_norm_dynamic_subgraph(): def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): + batchnorm = onnx.helper.make_node( "BatchNormalization", inputs=["x", "scale", "B", "mean", "var"], outputs=["Y"] ) @@ -2307,9 +2418,10 @@ def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): ) model = helper.make_model(graph, producer_name="batchnorm_test") + # X, inp, scale, b, mean, var inshapes = [in_shape, o_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] - verify_with_ort(model, inshapes, in_shape, use_vm=False) + verify_with_ort(model, inshapes, in_shape, use_vm=True) verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160]) @@ -2373,7 +2485,7 @@ def verify_conv( model = helper.make_model(graph, producer_name="conv_test") - verify_with_ort(model, [x_shape, w_shape], y_shape) + verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu @@ -2458,6 +2570,68 @@ def repeat(N, D): ) +def verify_convtranspose_with_padding( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + unset_pad=False, +): + if unset_pad: + node = helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + group=1, + ) + elif padding is None: + node = helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + group=1, + auto_pad=auto_pad, + ) + else: + node = helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + group=1, + pads=padding, + ) + + graph = helper.make_graph( + [node], + "convtranspose_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], + ) + + model = helper.make_model(graph, producer_name="conv_test") + + verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, convert_to_static=True) + + def verify_convtranspose(x_shape, w_shape, y_shape, p): node = onnx.helper.make_node( "ConvTranspose", @@ -2492,6 +2666,88 @@ def test_convtranspose(): # [1, 2, 1, 2] list for pads verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2]) + def repeat(N, D): + return tuple([N for _ in range(D)]) + + # TODO(mbrookhart): onnxruntime in CI only supports 2D, + # find something else to test 1D and 3D against + for D in [2]: + # Convolution with padding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution without padding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(7, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution with autopadding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with valid autopadding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(7, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="VALID", + ) + # Convolution with unset padding + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(7, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + True, + ) + # Convolution with non uniform stride + verify_convtranspose_with_padding( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(9, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with dilation + # TODO(mbrookhart): Relay doesn't currently support convtranspose with dilation + # verify_convtranspose_with_padding( + # (1, 1) + repeat(5, D), + # (1, 1) + repeat(3, D), + # (1, 1) + repeat(5, D), + # 2 * repeat(2, D), + # repeat(3, D), + # repeat(1, D), + # repeat(2, D), + # ) + @tvm.testing.uses_gpu def test_unsqueeze_constant(): @@ -2515,6 +2771,7 @@ def forward(self, input): def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"): + print(x_shape, kernel_shape, strides, mode, pads, auto_pad) x_np = np.random.uniform(size=x_shape).astype("float32") if mode == "max": @@ -2546,7 +2803,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p ) model = helper.make_model(graph, producer_name="pooling_test") - verify_with_ort(model, [x_shape], out_shape) + verify_with_ort(model, [x_shape], out_shape, use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu @@ -2796,7 +3053,7 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad=" ) model = helper.make_model(graph, producer_name="lppool_test") - verify_with_ort(model, [x_shape], out_shape) + verify_with_ort(model, [x_shape], out_shape, use_vm=True, convert_to_static=True) @tvm.testing.uses_gpu @@ -3169,7 +3426,8 @@ def test_gru(): ) -@tvm.testing.uses_gpu +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu def test_resize(): def verify(ishape, oshape, scales, mode, coord_trans): nodes = [ @@ -3194,7 +3452,6 @@ def verify(ishape, oshape, scales, mode, coord_trans): if oshape == []: oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)] - graph = helper.make_graph( nodes, "resize_test", @@ -3204,7 +3461,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): model = helper.make_model(graph, producer_name="resize_test") - verify_with_ort(model, [ishape], oshape, use_vm=False, opset=11) + verify_with_ort(model, [ishape], oshape, use_vm=True, opset=11, freeze_params=True) # upsampling verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") @@ -3273,7 +3530,6 @@ def verify_topk(input_dims, K, axis=-1): ], ), ], - initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])], outputs=[ helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims), @@ -3283,17 +3539,10 @@ def verify_topk(input_dims, K, axis=-1): model = helper.make_model(graph, producer_name="topk_test") indata = np.random.uniform(-10, 10, input_dims).astype(np.float32) - onnx_out = get_onnxruntime_output(model, [indata, k]) + onnx_out = get_onnxruntime_output(model, [indata, np.array([K])]) for target, ctx in [("llvm", tvm.cpu())]: - tvm_out = get_tvm_output( - model, - indata, - target, - ctx, - [output_dims, output_dims], - output_dtype=["float32", "int64"], - ) + tvm_out = get_tvm_output_with_vm(model, [indata, np.array(K)], target, ctx) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) for n in [12, 32]: diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index 622e29118f46..18e1dd5bb72e 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -27,34 +27,62 @@ import random import tvm.testing -# TODO(mbrookhart): Enable when VM supports heterogenus execution +# TODO(mbrookhart): Enable when the VM supports heterogenus execution # @tvm.testing.uses_gpu -def test_dyn_broadcast_to(): - dtype = "uint8" - rank = 3 - shape_type = "int64" - dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type)) - x_shape = (1,) - x = relay.Var("x", relay.ty.TensorType(x_shape, dtype)) - z = relay.broadcast_to(x, dyn_shape) - zz = run_infer_type(z) - - assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype) - - func = relay.Function([x, dyn_shape], z) - - x = np.random.uniform(size=x_shape).astype(dtype) - dyn_shape = (1,) * rank - ref_res = np.broadcast_to(x, dyn_shape) - for target, ctx in tvm.testing.enabled_targets(): - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - - -# TODO(mbrookhart): Enable when VM supports heterogenus execution +def test_broadcast_to(): + def verify_more_dynamic_broadcast_to(x_shape, out_shape): + rank = len(out_shape) + dtype = "float32" + shape_type = "int64" + reshape_shape = relay.Var("shape", relay.ty.TensorType((len(x_shape),), shape_type)) + broadcast_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type)) + x = relay.Var("x", relay.ty.TensorType((np.prod(x_shape),), dtype)) + r = relay.reshape(x, reshape_shape) + z = relay.broadcast_to(r, broadcast_shape) + + func = relay.Function([x, reshape_shape, broadcast_shape], z) + + x = np.random.uniform(size=np.prod(x_shape)).astype(dtype) + ref_res = np.broadcast_to(np.reshape(x, x_shape), out_shape) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate(func)( + x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type) + ) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_more_dynamic_broadcast_to((4, 3), (3, 4, 3)) + + def verify_broadcast_to(x_shape, out_shape): + rank = len(out_shape) + dtype = "float32" + shape_type = "int64" + dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type)) + x = relay.Var("x", relay.ty.TensorType(x_shape, dtype)) + z = relay.broadcast_to(x, dyn_shape) + zz = run_infer_type(z) + + assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype) + + func = relay.Function([x, dyn_shape], z) + + x = np.random.uniform(size=x_shape).astype(dtype) + ref_res = np.broadcast_to(x, out_shape) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x, np.array(out_shape).astype(shape_type)) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_broadcast_to((1,), (1, 1, 1)) + verify_broadcast_to((1, 1), (4, 1, 1)) + verify_broadcast_to((4, 1), (1, 4, 3)) + + +# TODO(mbrookhart): Enable when the VM supports heterogenus execution # @tvm.testing.uses_gpu def test_dyn_one_hot(): def _get_oshape(indices_shape, depth, axis): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index edb7c460d5ba..bc565682d932 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -362,6 +362,33 @@ def test_batch_matmul(): verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) +def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): + x = relay.var("x", relay.TensorType(x_shape, dtype)) + y = relay.var("y", relay.TensorType((relay.Any(),) * len(y_shape), dtype)) + z = relay.nn.batch_matmul(x, y) + + func = relay.Function([x, y], z) + x_np = np.random.uniform(size=x_shape).astype(dtype) + y_np = np.random.uniform(size=y_shape).astype(dtype) + z_np = tvm.topi.testing.batch_matmul(x_np, y_np) + + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + z = intrp.evaluate()(x_np, y_np) + tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5) + + +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu +def test_dynamic_batch_matmul(): + verify_dynamic_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) + verify_dynamic_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) + verify_dynamic_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) + verify_dynamic_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) + + @tvm.testing.uses_gpu def test_shape_of(): shape = (10, 5, 12) diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py index e68a398e44b0..22c839cede12 100644 --- a/tutorials/frontend/from_onnx.py +++ b/tutorials/frontend/from_onnx.py @@ -103,3 +103,12 @@ canvas[:, 672:, :] = np.asarray(result) plt.imshow(canvas.astype(np.uint8)) plt.show() + +###################################################################### +# Notes +# --------------------------------------------- +# By default, ONNX defines models in terms of dynamic shapes. The ONNX importer +# retains that dynamism upon import, and the compiler attemps to convert the model +# into a static shapes at compile time. If this fails, there may still be dynamic +# operations in the model. Not all TVM kernels currently support dynamic shapes, +# please file an issue on discuss.tvm.ai if you hit an error with dynamic kernels.