diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 5f1ee2f31cc5..6bfdb492fed0 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -85,7 +85,7 @@ struct Conv1DAttrs : public tvm::AttrsNode { .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(data_layout) .set_default("NCW") .describe( @@ -148,7 +148,7 @@ struct Conv2DAttrs : public tvm::AttrsNode { .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(data_layout) .set_default("NCHW") .describe( @@ -242,7 +242,7 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(data_layout) .set_default("NCHW") .describe( @@ -331,7 +331,7 @@ struct Conv3DAttrs : public tvm::AttrsNode { .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(data_layout) .set_default("NCDHW") .describe( @@ -381,7 +381,7 @@ struct Conv3DTransposeAttrs : public tvm::AttrsNode { "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(strides) .set_default(Array({1, 1, 1})) .describe("The strides of the convolution."); @@ -480,7 +480,7 @@ struct Conv3DWinogradAttrs : public tvm::AttrsNode { .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(data_layout) .set_default("NCDHW") .describe( @@ -539,7 +539,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(strides) .set_default(Array({1, 1})) .describe("The strides of the convolution."); @@ -626,7 +626,7 @@ struct Conv1DTransposeAttrs : public tvm::AttrsNode { "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(strides) .set_default(Array({1})) .describe("The strides of the convolution."); @@ -1016,7 +1016,7 @@ struct UpSampling3DAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { double pad_value; - Array > pad_width; + Array> pad_width; std::string pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { @@ -1037,7 +1037,7 @@ struct PadAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the MirrorPadding operator */ struct MirrorPadAttrs : public tvm::AttrsNode { std::string mode; - Array > pad_width; + Array> pad_width; TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") { TVM_ATTR_FIELD(mode) @@ -1242,7 +1242,7 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode { .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") - .set_default(NullValue >()); + .set_default(NullValue>()); TVM_ATTR_FIELD(data_layout) .set_default("NCHW") .describe( diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 17eb0d0fcf3f..d257d3cbb863 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -124,6 +124,8 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl * "constant" pads with constant_value; * "edge" pads using the edge values of the input array; * "reflect" pads by reflecting values with respect to the edges. + * \param dyn_output_shape Output shape of the pad op, default nullptr. + * You only need to pass this in if the shape was evaluated dynamically. * \param name The name of the operation * \param tag The tag to mark the operation * @@ -151,30 +153,40 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array& pad_before, tvm::Array pad_after = tvm::Array(), PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", - std::string tag = kElementWise, std::string pad_mode = "constant") { + std::string tag = kElementWise, std::string pad_mode = "constant", + const Array* dyn_output_shape = nullptr) { if (pad_after.size() < pad_before.size()) { for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { pad_after.push_back(pad_before[i]); } } + arith::Analyzer analyzer; CHECK_GE(pad_before.size(), 1); CHECK_EQ(pad_before.size(), pad_after.size()); - tvm::Array output_shape; tvm::Array pad_before_int32; tvm::Array pad_after_int32; + for (const auto& ele : pad_before) { pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } for (const auto& ele : pad_after) { pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } - for (size_t i = 0; i < t->shape.size(); ++i) { - if (i >= pad_before.size()) { - output_shape.push_back(t->shape[i]); - } else { - output_shape.push_back( - analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); + + tvm::Array output_shape; + if (dyn_output_shape == nullptr) { + for (size_t i = 0; i < t->shape.size(); ++i) { + if (i >= pad_before.size()) { + output_shape.push_back(t->shape[i]); + } else { + output_shape.push_back( + analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); + } + } + } else { + for (size_t i = 0; i < dyn_output_shape->size(); i++) { + output_shape.push_back((*dyn_output_shape)[i]); } } diff --git a/python/tvm/relay/op/dyn/nn/__init__.py b/python/tvm/relay/op/dyn/nn/__init__.py new file mode 100644 index 000000000000..01a3a1bc0679 --- /dev/null +++ b/python/tvm/relay/op/dyn/nn/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay namespace containing dynamic ops.""" + +from . import _nn diff --git a/python/tvm/relay/op/dyn/nn/_make.py b/python/tvm/relay/op/dyn/nn/_make.py new file mode 100644 index 000000000000..280fe72315ad --- /dev/null +++ b/python/tvm/relay/op/dyn/nn/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relay.op.dyn.nn._make", __name__) diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py new file mode 100644 index 000000000000..141fc22a1e80 --- /dev/null +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in +"""Backend compiler related feature registration""" + +from __future__ import absolute_import + +from tvm.te.hybrid import script +from ...op import register_shape_func +from ...op import register_broadcast_schedule + +# pad +register_broadcast_schedule("dyn.nn.pad") + +##################### +# Shape functions # +##################### + +@script +def _dyn_pad_shape_func(data, pad_width): + ndim = len(data.shape) + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + out[i] = int64(pad_width[i, 0] + pad_width[i, 1] + data.shape[i]) + return out + +@register_shape_func("dyn.nn.pad", True) +def pad_shape_func(attrs, inputs, data): + """ + Shape function for dynamic pad op. + """ + return [_dyn_pad_shape_func(inputs[0], inputs[1])] diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b2df8505e691..c04db3060f97 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -19,7 +19,9 @@ from tvm.relay import expr from . import _make +from ..dyn.nn import _make as _dyn_make from .util import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d +from ...expr import const, Expr def conv1d(data, @@ -1410,7 +1412,7 @@ def prelu(data, alpha, axis=1): def pad(data, pad_width, - pad_value=0.0, + pad_value=0, pad_mode='constant'): r"""Padding @@ -1421,10 +1423,10 @@ def pad(data, ---------- data: tvm.relay.Expr The input data to the operator - pad_width: tuple of >, required + pad_width: tuple of >, or tvm.relay.Expr, required Number of values padded to the edges of each axis, in the format of ((before_1, after_1), ..., (before_N, after_N)) - pad_value: float, optional, default=0.0 + pad_value: float, or tvm.relay.Expr, optional, default=0 The value used for padding pad_mode: 'constant', 'edge', 'reflect' 'constant' pads with constant_value pad_value @@ -1435,6 +1437,12 @@ def pad(data, result : tvm.relay.Expr The computed result. """ + if (isinstance(pad_width, Expr) or (isinstance(pad_value, Expr))): + if not isinstance(pad_width, Expr): + pad_width = const(list(pad_width)) + if not isinstance(pad_value, Expr): + pad_value = const(pad_value) + return _dyn_make.pad(data, pad_width, pad_value, pad_mode) return _make.pad(data, pad_width, pad_value, pad_mode) diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 78ed1dce3a44..88ade6e49294 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -73,7 +73,7 @@ def _math_intrin(func_id, args): from tvm.tir import op return getattr(op, func_id)(*args) -sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name +sqrt = log = exp = tanh = sigmoid = power = popcount = round = _math_intrin #pylint: disable=invalid-name def _min_max(func_id, args): diff --git a/python/tvm/te/hybrid/runtime.py b/python/tvm/te/hybrid/runtime.py index 7dcfc7c3966b..7987e46a4768 100644 --- a/python/tvm/te/hybrid/runtime.py +++ b/python/tvm/te/hybrid/runtime.py @@ -126,6 +126,7 @@ def max_num_threads(allow_none=True): 'exp' : numpy.exp, 'sigmoid' : sigmoid, 'popcount' : popcount, + 'round' : round, 'likely' : lambda cond: cond, 'uint8' : numpy.uint8, 'uint16' : numpy.uint16, diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index 96a13efc541a..d8da41f79c0a 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -57,7 +57,6 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', elif layout == "NHWC": out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)), simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype))) - else: raise ValueError("not support this layout {} yet".format(layout)) coord_trans = "align_corners" if align_corners else "asymmetric" diff --git a/src/relay/op/dyn/nn/pad.cc b/src/relay/op/dyn/nn/pad.cc new file mode 100644 index 000000000000..8a17f50df0df --- /dev/null +++ b/src/relay/op/dyn/nn/pad.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file pad.cc + * \brief Implementation of dynamic pad + */ +#include +#include +#include +#include +#include + +#include + +#include "../../make_op.h" +#include "../../op_common.h" + +namespace tvm { +namespace relay { +namespace dyn { + +// relay.dyn.nn.pad + +bool PadRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types = [data_type, pad_width_type, pad_value_type, ret_type] + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const auto* pad_width = types[1].as(); + if (pad_width == nullptr) return false; + + const auto* pad_value = types[2].as(); + if (pad_value == nullptr) return false; + + int data_rank = data->shape.size(); + CHECK(data_rank) << "Data shape must have static rank"; + + int pad_width_rank = pad_width->shape.size(); + CHECK_EQ(pad_width_rank, 2) << "Pad width must be 2D"; + + auto pad_width_dim1 = pad_width->shape[0].as(); + auto pad_width_dim2 = pad_width->shape[1].as(); + + CHECK(pad_width_dim1->value == data_rank && pad_width_dim2->value == 2) + << "Pad width must have shape (N, 2), where N is the rank of input data"; + + const PadAttrs* param = attrs.as(); + CHECK(param != nullptr); + + std::vector oshape; + for (int i = 0; i < data_rank; i++) { + oshape.push_back(Any()); + } + + reporter->Assign(types[3], TensorType(oshape, data->dtype)); + return true; +} + +Array PadCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param); + + auto data = inputs[0]; + auto pad_width = inputs[1]; + + const PrimExpr& pad_value = inputs[2](Array()); + + Array pad_before; + Array pad_after; + + for (int i = 0; i < pad_width->shape[0].as()->value; ++i) { + pad_before.push_back(pad_width[i][0]); + pad_after.push_back(pad_width[i][1]); + } + + const auto* out_ttype = out_type.as(); + CHECK(out_ttype != nullptr); + + return Array{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad", + topi::kElementWise, param->pad_mode, + &out_type.as()->shape)}; +} + +// Handler to create a call to the padding op used by front-end FFI +Expr MakePad(Expr data, Expr pad_width, Expr pad_value, String pad_mode) { + auto attrs = make_object(); + attrs->pad_mode = std::move(pad_mode); + static const Op& op = Op::Get("dyn.nn.pad"); + return Call(op, {data, pad_width, pad_value}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.pad").set_body_typed(MakePad); + +RELAY_REGISTER_OP("dyn.nn.pad") + .describe(R"code(Pad for n-D tensor. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Tensor that will be padded") + .add_argument("pad_width", "Tensor", "Tensor of how much to pad by") + .add_argument("pad_val", "double", "The value to fill the padded area with") + .set_support_level(2) + .add_type_rel("DynamicPad", PadRel) + .set_attr("TOpPattern", kInjective) + .set_attr("FTVMCompute", PadCompute); + +} // namespace dyn +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 1e17bbe90692..c759be338cd9 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -54,7 +54,7 @@ Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); Expr MakeOnes(Array shape, DataType dtype); -Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode); +Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode); Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, String op_name); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index d7103602deca..45447e155135 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -53,7 +53,7 @@ Array> PadInferCorrectLayout(const Attrs& attrs, const Array> axis_pad_width; + std::map> axis_pad_width; int index_counter = 0; CHECK_EQ(new_in_layouts.size(), 1); CHECK_EQ(old_in_layouts.size(), 1); @@ -64,7 +64,7 @@ Array> PadInferCorrectLayout(const Attrs& attrs, const Array> new_pad_width; + tvm::Array> new_pad_width; for (auto iter_var : new_in_layouts[0]->axes) { const auto& new_layout_axis = LayoutAxis::Get(iter_var); auto axis_name = new_layout_axis.name(); @@ -178,7 +178,7 @@ Array PadCompute(const Attrs& attrs, const Array& inputs } // Handler to create a call to the padding op used by front-end FFI -Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode) { +Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode) { auto attrs = make_object(); attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 0ccc4c3d1269..3de773eeed9f 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -124,6 +124,21 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("dyn.nn.pad"), + [](const CallNode* call_node) { + const ConstantNode* pad_width = call_node->args[1].as(); + const ConstantNode* pad_fill = call_node->args[2].as(); + if (pad_width && pad_fill) { + CHECK_EQ(pad_fill->data->ndim, 0); // pad_val is 1d + CHECK_EQ(pad_width->data->ndim, 2); // pad_width is 2d + + const PadAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakePad(call_node->args[0], ToMatrix(pad_width->data), ToScalar(pad_fill->data), + param->pad_mode); + } + return Expr(nullptr); + }}, }; } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 0b6484642ef1..f493720aeda9 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -419,7 +419,7 @@ static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) */ static inline Array ToVector(const runtime::NDArray& array) { size_t ndim = array.Shape().size(); - CHECK_EQ(ndim, 1) << "This function should only used for shape tensor."; + CHECK_EQ(ndim, 1) << "This function should only be used for 1D NDArrays"; size_t len = array.Shape().front(); Array out; for (size_t i = 0; i < len; ++i) { @@ -429,6 +429,30 @@ static inline Array ToVector(const runtime::NDArray& array) { return out; } +/*! + * \brief Convert a NDArray with type int or float to Array>. + * \param array Input NDArray + * \return Converted Array. + */ +static inline Array> ToMatrix(const runtime::NDArray& array) { + size_t ndim = array.Shape().size(); + CHECK_EQ(ndim, 2) << "This function should only used for 2D NDArrays"; + size_t dim1 = array.Shape().at(0); + size_t dim2 = array.Shape().at(1); + + Array> out; + + for (size_t i = 0; i < dim1; ++i) { + Array inner_out; + for (size_t j = 0; j < dim2; ++j) { + double elem_val = ToScalar(array, i * dim2 + j); + inner_out.push_back(Integer(static_cast(elem_val))); + } + out.push_back(inner_out); + } + return out; +} + inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); } inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); } @@ -629,7 +653,11 @@ static inline Expr AvgPool2D(Expr data, Array pool_size, Array> pad_width, double pad_value, std::string pad_mode) { - return MakePad(data, pad_width, pad_value, pad_mode); + Array> pad_width_int; + for (size_t i = 0; i < pad_width.size(); ++i) { + pad_width_int.push_back(CheckConstantShapeArrayInteger(pad_width[i])); + } + return MakePad(data, pad_width_int, pad_value, pad_mode); } static inline Expr Tile(Expr data, Array reps) { return MakeTile(data, reps); } diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py new file mode 100644 index 000000000000..137febd19d1b --- /dev/null +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Support level2 dynamic operator test cases. +""" + +import numpy as np +import tvm +from tvm import relay +from tvm import te +from tvm.relay.testing import ctx_list +import random +from test_dynamic_op_level3 import verify_func +import tvm.topi.testing +from tvm.relay.testing import run_infer_type + +def test_dyn_pad(): + def verify_pad(dshape, pad_width, pad_val, dtype): + x = relay.var("x", relay.TensorType(dshape, dtype)) + ndim = len(dshape) + pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), 'int64')) + pad_val_var = relay.var("pad_val_var", relay.TensorType((), dtype)) + y = relay.nn.pad(x, pad_width_var, pad_val_var) + yy = run_infer_type(y) + + assert yy.checked_type == relay.ty.TensorType((relay.Any(),) * ndim, dtype) + func = relay.Function([x, pad_width_var, pad_val_var], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = np.pad(data, pad_width, 'constant', constant_values=(((pad_val,)*2),) * ndim) + pad_width = np.array(pad_width).astype('int64') + + verify_func(func, [data, pad_width, np.array(pad_val).astype(dtype)], ref_res) + + def verify_pad_default_fill(dshape, pad_width, dtype): + x = relay.var("x", relay.TensorType(dshape, dtype)) + ndim = len(dshape) + pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), 'int64')) + y = relay.nn.pad(x, pad_width_var) + yy = run_infer_type(y) + + assert yy.checked_type == relay.ty.TensorType((relay.Any(),) * ndim, dtype) + func = relay.Function([x, pad_width_var], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = np.pad(data, pad_width) + pad_width = np.array(pad_width).astype('int64') + + verify_func(func, [data, pad_width], ref_res) + + verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32") + verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64") + verify_pad_default_fill((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), "float64") + verify_pad_default_fill((2, 7), ((1, 4), (2, 2)), "int32") + +if __name__ == "__main__": + test_dyn_pad() diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index c61f169d53e0..ed9b94c5a9d2 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -23,7 +23,6 @@ from tvm.relay.testing import run_infer_type, create_workload, ctx_list import tvm.topi.testing - def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) @@ -312,7 +311,7 @@ def verify_full(fill_value, fill_shape, dtype): zz = func2.body assert isinstance(zz, relay.Call) - assert zz.checked_type == relay.TensorType(fill_shape, dtype) + assert zz.op == relay.op.get("full") ref_res = np.full(fill_shape, fill_value).astype(dtype) y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64') @@ -321,6 +320,24 @@ def verify_full(fill_value, fill_shape, dtype): verify_full(4, (1, 2, 3, 4), 'int32') verify_full(4.0, (1, 2, 8, 10), 'float32') +def test_dynamic_to_static_pad(): + def verify_pad(data_shape, pad_width, pad_val, dtype): + x = relay.var("x", relay.TensorType(data_shape, dtype)) + z = relay.nn.pad(x, relay.const(np.array(pad_width)), pad_val) + func = run_infer_type(relay.Function([x], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("nn.pad") + + x_data = np.random.uniform(size=data_shape).astype(dtype) + ref_res = np.pad(x_data, pad_width, 'constant', constant_values=(((pad_val,)*2),) * len(data_shape)) + verify_func(func2, [x_data], ref_res) + + verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32") + verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64") + + if __name__ == "__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() @@ -332,3 +349,4 @@ def verify_full(fill_value, fill_shape, dtype): test_dynamic_to_static_resize() test_dynamic_to_static_one_hot() test_dynamic_to_static_full() + test_dynamic_to_static_pad()