diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 750a8a43163c..8af9f6349d41 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { - Optional> newshape; + Array newshape; bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_ATTR_FIELD(newshape).describe( diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 9c565409a49b..e3909d9d6378 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -45,6 +45,7 @@ from .op import annotation from .op import vision from .op import contrib +from .op import dyn from .op.reduce import * from .op.tensor import * from .op.transform import * diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index eb567658f2a1..ac60a1f7bb51 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -114,7 +114,7 @@ def convert(self, v): def __call__(self, args, attrs, type_args): if attrs is None: attrs = {} - if self.operator in (op.reshape, op.strided_slice): + if self.operator in (op.strided_slice,): x = self.operator(*args) elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to): x = self.operator(*args, dtype=attrs["dtype"]) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 2907d72abc03..849d0a3f26d4 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -511,7 +511,7 @@ def batch_matmul_grad(orig, grad): @register_gradient("reshape") def reshape_grad(orig, grad): """Gradient of reshape""" - return [reshape_like(grad, orig.args[0]), orig.args[1]] + return [reshape_like(grad, orig.args[0])] @register_gradient("cast") diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 10238d101b0b..a3f2e08e28d9 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -273,82 +273,11 @@ def _reshape_shape_func_input_shape(data_shape, newshape, ndim): out[infer_idx] = old_size // new_size return out -@script -def _reshape_shape_func_input_data(data, newshape, ndim): - out = output_tensor((ndim,), "int64") - data_shape = allocate((len(data.shape),), "int64") - for x in const_range(len(data.shape)): - data_shape[x] = int64(data.shape[x]) - src_idx = 0 - dst_idx = 0 - infer_idx = -1 - copy = False - skip = 0 - for i in const_range(len(newshape)): - if skip > 0: - skip -= 1 - elif newshape[i] > 0: - out[dst_idx] = int64(newshape[i]) - src_idx += 1 - dst_idx += 1 - elif newshape[i] == 0: - out[dst_idx] = data_shape[src_idx] - src_idx += 1 - dst_idx += 1 - elif newshape[i] == -1: - assert infer_idx < 0, "One and only one dim can be inferred" - out[dst_idx] = int64(1) - infer_idx = i - dst_idx += 1 - elif newshape[i] == -2: - copy = True - elif newshape[i] == -3: - assert data_shape.shape[0] - src_idx > 1, \ - "Not enough dims in input shape for -3" - out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] - src_idx += 2 - dst_idx += 1 - elif newshape[i] == -4: - assert len(newshape) - i > 2, "Not enough dims in new shape for -4" - if newshape[i+1] == -1: - assert newshape[i+2] != -1, "Split dims cannot both be -1." - out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2]) - out[dst_idx+1] = int64(newshape[i+2]) - else: - out[dst_idx] = int64(newshape[i+1]) - if newshape[i+2] == -1: - out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1]) - else: - out[dst_idx+1] = int64(newshape[i+2]) - assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\ - "Product of split dims doesn't match to input dim" - src_idx += 1 - dst_idx += 2 - skip = 2 - else: - assert False, "Invalid special values in new shape" - if len(data_shape.shape) > 0: - # if data is not constant, we can then handle -1 and -2 - if copy: - for i in range(src_idx, data_shape.shape[0]): - out[dst_idx] = data_shape[i] - dst_idx += 1 - if infer_idx >= 0: - old_size = int64(1) - for i in const_range(data_shape.shape[0]): - old_size *= data_shape[i] - new_size = int64(1) - for i in const_range(out.shape[0]): - new_size *= out[i] - out[infer_idx] = old_size // new_size - return out - -@_reg.register_shape_func("reshape", True) +@_reg.register_shape_func("reshape", False) def reshape_shape_func(attrs, inputs, out_ndims): - if attrs.newshape is None: - return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] + newshape = get_const_tuple(attrs.newshape) return [_reshape_shape_func_input_shape(inputs[0], - convert(attrs.newshape), + convert(newshape), out_ndims[0])] @script diff --git a/python/tvm/relay/op/dyn/__init__.py b/python/tvm/relay/op/dyn/__init__.py new file mode 100644 index 000000000000..d659203e27e1 --- /dev/null +++ b/python/tvm/relay/op/dyn/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay namespace containing dynamic ops.""" + +from . import _transform diff --git a/python/tvm/relay/op/dyn/_make.py b/python/tvm/relay/op/dyn/_make.py new file mode 100644 index 000000000000..ab88fe872458 --- /dev/null +++ b/python/tvm/relay/op/dyn/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relay.op.dyn._make", __name__) diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py new file mode 100644 index 000000000000..81c6e5e9229d --- /dev/null +++ b/python/tvm/relay/op/dyn/_transform.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Backend compiler related feature registration""" +# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments +from __future__ import absolute_import +from tvm.te.hybrid import script +from .. import op as _reg + +_reg.register_injective_schedule("dyn.reshape") + +@script +def _reshape_shape_func_input_data(data, newshape, ndim): + out = output_tensor((ndim,), "int64") + data_shape = allocate((len(data.shape),), "int64") + for x in const_range(len(data.shape)): + data_shape[x] = int64(data.shape[x]) + src_idx = 0 + dst_idx = 0 + infer_idx = -1 + copy = False + skip = 0 + for i in const_range(len(newshape)): + if skip > 0: + skip -= 1 + elif newshape[i] > 0: + out[dst_idx] = int64(newshape[i]) + src_idx += 1 + dst_idx += 1 + elif newshape[i] == 0: + out[dst_idx] = data_shape[src_idx] + src_idx += 1 + dst_idx += 1 + elif newshape[i] == -1: + assert infer_idx < 0, "One and only one dim can be inferred" + out[dst_idx] = int64(1) + infer_idx = i + src_idx += 1 + dst_idx += 1 + elif newshape[i] == -2: + assert False, "Value -2 is not valid in newshape argument of dynamic reshape" + elif newshape[i] == -3: + assert data_shape.shape[0] - src_idx > 1, \ + "Not enough dims in input shape for -3" + out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] + src_idx += 2 + dst_idx += 1 + elif newshape[i] == -4: + assert False, "Value -4 is not valid in newshape argument of dynamic reshape" + else: + assert False, "Invalid special values in new shape" + if len(data_shape.shape) > 0: + # if data is not constant, we can then handle -1 and -2 + if copy: + for i in range(src_idx, data_shape.shape[0]): + out[dst_idx] = data_shape[i] + dst_idx += 1 + if infer_idx >= 0: + old_size = int64(1) + for i in const_range(data_shape.shape[0]): + old_size *= data_shape[i] + new_size = int64(1) + for i in const_range(out.shape[0]): + new_size *= out[i] + out[infer_idx] = old_size // new_size + return out + +@_reg.register_shape_func("dyn.reshape", True) +def dynamic_reshape_shape_func(attrs, inputs, out_ndims): + return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index cc9f73042266..9dc96f543e98 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -19,7 +19,8 @@ """Transform operators.""" from . import _make -from ..expr import TupleWrapper, const +from .dyn import _make as _dyn_make +from ..expr import TupleWrapper, const, Expr from ...tir import expr as _expr @@ -210,8 +211,10 @@ def reshape(data, newshape): result : relay.Expr The reshaped result. """ + if isinstance(newshape, Expr): + return _dyn_make.reshape(data, newshape) if isinstance(newshape, int): - newshape = const([newshape]) + newshape = [newshape] if isinstance(newshape, (tuple, list)): tempshape = [] for shape in newshape: @@ -222,8 +225,8 @@ def reshape(data, newshape): tempshape.append(int(shape)) except ValueError as err: raise RuntimeError('Unrecognized shape type: %s' % err) - newshape = const(tempshape) - return _make.reshape(data, newshape) + newshape = tempshape + return _make.reshape(data, list(newshape)) def argwhere(condition): """Find the indices of elements of a tensor that are diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index a490d6f00a71..ede63808d4fd 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -625,6 +625,17 @@ def AnnotateTarget(targets): return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets]) +def DynamicToStatic(): + """If possible, convert tvm.relay.dynamic* ops to static versions + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for dynamic->static conversion. + """ + return _ffi_api.DynamicToStatic() + + def Inline(): """Perform inlining on the given Relay IR module. The global functions that are marked as `inline` should be always inlined. A cost model will be diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index f07c14a286d8..10c226e17f7a 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -448,14 +448,7 @@ bool IsDataDependant(const CallNode* call) { return false; } - if (op->name == "reshape") { - if (const auto* attrs = call->attrs.as()) { - if (attrs->newshape) { - // If newshape attribute exists, it isn't data dependant. - return false; - } - } - } else if (op->name == "topk") { + if (op->name == "topk") { if (const auto* attrs = call->attrs.as()) { if (attrs->k) { // If k attribute exists, it isn't data dependant. diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc new file mode 100644 index 000000000000..18eaa67242f5 --- /dev/null +++ b/src/relay/op/dyn/tensor/transform.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file transform.cc + * \brief Dynamic Transform operators. + */ +#include "transform.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace dyn { + +/* relay.dyn.reshape */ +bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, newshape, result] + CHECK_EQ(types.size(), 3); + + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "reshape: expect input type to be TensorType but get " << types[0]; + return false; + } + + Array oshape; + const auto* newshape = types[1].as(); + + // Doesn't support dynamic output rank + for (int i = 0; i < newshape->shape[0].as()->value; i++) { + oshape.push_back(Any()); + } + + reporter->Assign(types[2], TensorType(oshape, data->dtype)); + return true; +} + +Array ReshapeCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* out_ttype = out_type.as(); + CHECK(out_ttype != nullptr); + Array newshape; + for (auto val : out_ttype->shape) { + if (val->IsInstance()) { + newshape.push_back(val.as()->ToVar()); + } else { + newshape.push_back(val); + } + } + return {topi::reshape(inputs[0], newshape)}; +} + +Expr MakeReshape(Expr data, Expr newshape) { + auto attrs = make_object(); + attrs->reverse = false; + static const Op& op = Op::Get("dyn.reshape"); + return Call(op, {data, newshape}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.reshape").set_body_typed(MakeReshape); + +RELAY_REGISTER_OP("dyn.reshape") + .describe(R"code(Reshapes the input array based on the values in the newshape array. + + To give user more convenience in without doing manual shape inference, + some dimensions of the shape can take special values from the set {0, -1, -3}. + The significance of each is explained below: + + ``0`` copy this dimension from the input to the output shape. + + .. code-block:: python + + data.shape = (2,3,4), newshape = (4,0,2), result.shape = (4,3,2) + data.shape = (2,3,4), newshape = (2,0,0), result.shape = (2,3,4) + + ``-1`` infers the dimension of the output shape by using the remainder of + the input dimensions keeping the size of the new array same as that of the input array. + At most one dimension of shape can be -1. + + .. code-block:: python + + data.shape = (2,3,4), newshape = (6,1,-1), result.shape = (6,1,4) + data.shape = (2,3,4), newshape = (3,-1,8), result.shape = (3,1,8) + data.shape = (2,3,4), newshape = (-1,), result.shape = (24,) + + ``-3`` use the product of two consecutive dimensions of the input shape + as the output dimension. + + .. code-block:: python + + data.shape = (2,3,4), newshape = (-3,4), result.shape = (6,4) + data.shape = (2,3,4,5), newshape = (-3,-3), result.shape = (6,20) + data.shape = (2,3,4), newshape = (0,-3), result.shape = (2,12) + + Special values -2 and -4 from the standard reshape op would introduce dynamic rank + in this op. Thus, they are not permitted. + + )code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("newshape", "Tensor", "The shape of output tensor.") + .set_support_level(3) + .add_type_rel("DynamicReshape", ReshapeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); + +} // namespace dyn +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/dyn/tensor/transform.h b/src/relay/op/dyn/tensor/transform.h new file mode 100644 index 000000000000..98b0474a7e2b --- /dev/null +++ b/src/relay/op/dyn/tensor/transform.h @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/tensor/transform.h + * \brief Transform op attributes that can be shared among Relay and its dialects. + */ +#ifndef TVM_RELAY_OP_DYN_TENSOR_TRANSFORM_H_ +#define TVM_RELAY_OP_DYN_TENSOR_TRANSFORM_H_ + +namespace tvm { +namespace relay { +namespace dyn {} // namespace dyn +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_DYN_TENSOR_TRANSFORM_H_ diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a07fa9ac82f8..b44ddf4ddf2a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -449,13 +449,8 @@ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const auto* param = attrs.as(); - if (param->reverse) { - // types: [data, result] - CHECK_EQ(types.size(), 2); - } else { - // types: [data, newshape, result] - CHECK_EQ(types.size(), 3); - } + // types: [data, result] + CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) @@ -467,25 +462,12 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, Array data_shape; Array newshape; - if (param->newshape) { - auto temp = param->newshape.value(); - if (param->reverse) { - data_shape.Assign(data->shape.rbegin(), data->shape.rend()); - newshape.Assign(temp.rbegin(), temp.rend()); - } else { - data_shape = data->shape; - newshape = temp; - } + if (param->reverse) { + data_shape.Assign(data->shape.rbegin(), data->shape.rend()); + newshape.Assign(param->newshape.rbegin(), param->newshape.rend()); } else { - const auto* newshape = types[1].as(); - - // Doesn't support dynamic output rank - for (int i = 0; i < newshape->shape[0].as()->value; i++) { - oshape.push_back(Any()); - } - - reporter->Assign(types[2], TensorType(oshape, data->dtype)); - return true; + data_shape = data->shape; + newshape = param->newshape; } std::unordered_set used_input_dims; @@ -600,7 +582,7 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); } else { - reporter->Assign(types[2], TensorType(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); } return true; } @@ -620,15 +602,12 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in return {topi::reshape(inputs[0], newshape)}; } -Expr MakeReshape(Expr data, Expr newshape) { +Expr MakeReshape(Expr data, Array newshape) { auto attrs = make_object(); - if (const ConstantNode* c = newshape.as()) { - CHECK_EQ(c->data->ndim, 1); - attrs->newshape = ToVector(c->data); - } + attrs->newshape = std::move(newshape); attrs->reverse = false; static const Op& op = Op::Get("reshape"); - return Call(op, {data, newshape}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape); @@ -684,10 +663,9 @@ Example:: - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) )code" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(1) .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") - .add_argument("newshape", "Tensor", "The shape of output tensor.") .set_support_level(3) .add_type_rel("Reshape", ReshapeRel) .set_attr("FTVMCompute", ReshapeCompute) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 7149417aa9b5..c68dfba784c7 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -38,7 +38,7 @@ namespace tvm { namespace relay { -extern Expr MakeReshape(Expr data, Expr newshape); +Expr MakeReshape(Expr data, Array newshape); template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc new file mode 100644 index 000000000000..7b3f1957811b --- /dev/null +++ b/src/relay/transforms/dynamic_to_static.cc @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file dynamic_to_static.cc + * \brief Rewrite Dynamic Operations to Static operations where possible + */ +#include +#include + +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +class DynamicToStaticMutator : public MixedModeMutator { + public: + DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")) {} + + private: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + const CallNode* call_node = post.as(); + if (call_node->op == dyn_reshape_op_) { + if (const ConstantNode* shape = call_node->args[1].as()) { + auto attrs = make_object(); + CHECK_EQ(shape->data->ndim, 1); + attrs->newshape = ToVector(shape->data); + attrs->reverse = false; + static const Op& reshape = Op::Get("reshape"); + return Call(reshape, {call_node->args[0]}, Attrs(attrs), {}); + } + } + return post; + } + Expr DispatchVisitExpr(const Expr& expr) override { + auto post = MixedModeMutator::DispatchVisitExpr(expr); + if (auto op = post.as()) { + return Function(op->params, op->body, NullValue(), op->type_params, op->attrs); + } + return post; + } + + const Op& dyn_reshape_op_; +}; + +Expr DynamicToStatic(Function f, IRModule m) { + Expr pre = f; + Expr expr = f; + auto fold_const = transform::FoldConstant(); + auto infer_type = transform::InferType(); + Map vars; + for (auto kv : m->functions) { + vars.Set(kv.second, kv.first); + } + const auto gv = vars[f]; + int i = 0; + do { + pre = expr; + // TODO(mbrookhart): Is it possible to run these passes JUST on the current function? + m = infer_type(m); + m = fold_const(m); + expr = DynamicToStaticMutator().Mutate(m->functions[gv]); + m->Update(gv, Downcast(expr)); + i += 1; + } while (pre != expr && i < 1000); + return expr; +} + +namespace transform { + +Pass ConvertDynamicToStatic() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(DynamicToStatic(f, m)); + }; + return CreateFunctionPass(pass_func, 3, "DynamicToStatic", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.DynamicToStatic").set_body_typed([]() { + return ConvertDynamicToStatic(); +}); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index a3765f3c3bef..0c2abbfdd238 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -329,8 +329,7 @@ static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, arr.push_back(1); } } - return MakeReshape( - scale, MakeConstantTensor(DataType::Int(32), {static_cast(arr.size())}, arr)); + return MakeReshape(scale, std::move(arr)); } // if only one axis, use expand dim. Else, use reshape diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index d55041163054..78068e88a510 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -630,12 +630,10 @@ static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclu return Call(op, {data}, Attrs(attrs), {}); } -Expr MakeReshape(Expr data, Expr newshape); +Expr MakeReshape(Expr data, Array newshape); static inline Expr Reshape(Expr data, Array newshape) { - auto newshape_tensor = - MakeConstantTensor(DataType::Int(32), {static_cast(newshape.size())}, newshape); - return MakeReshape(data, newshape_tensor); + return MakeReshape(data, newshape); } static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py new file mode 100644 index 000000000000..29168b6d0d1a --- /dev/null +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Support level3 operator test cases. +""" +import numpy as np +import pytest +import tvm +from tvm import te +from tvm import relay +from tvm.relay import create_executor, transform +from tvm.relay.testing import ctx_list, check_grad, run_infer_type + +def verify_func(func, data, ref_res): + assert isinstance(data, list) + for target, ctx in ctx_list(): + #TODO(mbrookhart): enable Cuda tests onces the VM supports dynamic shapes + if "llvm" not in target: continue + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(*data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + relay.backend.compile_engine.get().clear() + +def test_dyn_reshape(): + def verify_reshape(shape, newshape, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType((len(newshape), ), "int64")) + z = relay.reshape(x, y) + + func = relay.Function([x, y], z) + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = np.reshape(x_data, oshape) + verify_func(func, [x_data, np.array(newshape).astype("int64")], ref_res) + verify_reshape((2, 3, 4), (8, 3), (8, 3)) + verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) + verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) + verify_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4)) + verify_reshape((2, 3, 4), (0, -1), (2, 12)) + verify_reshape((2, 3, 4), (-1, 0), (8, 3)) + verify_reshape((2, 3, 4), (-3, 4), (6, 4)) + verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20)) + verify_reshape((2, 3, 4), (0, -3), (2, 12)) + +def test_dyn_shape_reshape(): + def verify_reshape(shape, newshape, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(newshape, "float32")) + z = relay.reshape(x, relay.shape_of(y)) + + func = relay.Function([x, y], z) + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32") + ref_res = np.reshape(x_data, oshape) + verify_func(func, [x_data, y_data], ref_res) + verify_reshape((2, 3, 4), (8, 3), (8, 3)) + verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) + +if __name__ == "__main__": + test_dyn_reshape() + test_dyn_shape_reshape() diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8e535a692b88..6d940a563566 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -184,9 +184,9 @@ def test_any_reshape(): # Variable newshape only supports that output rank is the same as newshape verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24), variable_newshape) verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12), variable_newshape) - verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4), variable_newshape) verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) + verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): x = relay.var('x', shape=x_shape, dtype=dtype) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py new file mode 100644 index 000000000000..052d95cef0a7 --- /dev/null +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import te +from tvm import relay +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing import run_infer_type, create_workload, ctx_list + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, tvm.transform.Pass) + + mod = tvm.IRModule.from_expr(expr) + mod = opt_pass(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + +def verify_func(func, data, ref_res): + assert isinstance(data, list) + for target, ctx in ctx_list(): + for kind in ["graph", "vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(*data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + +def test_dynamic_to_static_reshape(): + def verify_reshape(shape, newshape, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(newshape, "float32")) + z = relay.reshape(x, relay.shape_of(y)) + func = run_infer_type(relay.Function([x, y], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("reshape") + assert "newshape=" in zz.astext() + assert zz.checked_type == relay.ty.TensorType(oshape, "float32") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32") + ref_res = np.reshape(x_data, oshape) + verify_func(func2, [x_data, y_data], ref_res) + + verify_reshape((2, 3, 4), (8, 3), (8, 3)) + verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) + +def test_dynamic_to_static_double_reshape(): + def verify_reshape(shape, newshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(newshape, "float32")) + z = relay.reshape(x, relay.shape_of(y)) + z = relay.reshape(z, relay.shape_of(x)) + func = run_infer_type(relay.Function([x, y], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("reshape") + assert "newshape=" in zz.astext() + assert zz.checked_type == relay.ty.TensorType(shape, "float32") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32") + verify_func(func2, [x_data, y_data], x_data) + + verify_reshape((2, 3, 4), (8, 3)) + verify_reshape((4, 7), (2, 7, 2)) + +def test_dynamic_to_static_quad_reshape(): + def verify_reshape(shape, newshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(newshape, "float32")) + z1 = relay.reshape(x, relay.shape_of(y)) + z2 = relay.reshape(z1, relay.shape_of(x)) + z3 = relay.reshape(z2, relay.shape_of(z1)) + z4 = relay.reshape(z3, relay.shape_of(z2)) + func = run_infer_type(relay.Function([x, y], z4)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("reshape") + assert "newshape=" in zz.astext() + assert zz.checked_type == relay.ty.TensorType(shape, "float32") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32") + verify_func(func2, [x_data, y_data], x_data) + + verify_reshape((2, 3, 4), (8, 3)) + verify_reshape((4, 7), (2, 7, 2)) + +if __name__=="__main__": + test_dynamic_to_static_reshape() + test_dynamic_to_static_double_reshape() + test_dynamic_to_static_quad_reshape() + diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 231d40033350..633ef3f60c9b 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -345,7 +345,7 @@ def visit_call(self, call): method, align_corners) elif call.op == self.reshape and len(input_types[0].shape) == 4: - data, _ = args + data, = args data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) return op.reshape(data, [int(x) for x in input_types[0].shape])