diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 84dda6fdbc41..c0e227291ada 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { - Array newshape; + Optional> newshape; bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_ATTR_FIELD(newshape).describe( diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 7731efe04dfd..1d97b5577361 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -114,7 +114,10 @@ def convert(self, v): def __call__(self, args, attrs, type_args): if attrs is None: attrs = {} - x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()}) + if self.operator is op.reshape: + x = self.operator(*args) + else: + x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()}) if isinstance(x, expr.TupleWrapper): x = x.astuple() return x diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 913a0165066d..ab9e9e656516 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1155,14 +1155,11 @@ def _impl(inputs, attr, params, mod): shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) except Exception: # Deal with symbolic shape case. - # Currently only shape_of can be the direct ancestor. - if not isinstance(pop_node, tvm.relay.expr.Call) or \ - "shape_of" not in str(pop_node.op): - raise RuntimeError("If shape operator is used in reshape to " - "express reshape_like, shape_of must be " - "the direct ancestor of reshape when input " - "shape is symbolic.") - return _op.reshape_like(inputs[0], pop_node.args[0]) + if isinstance(pop_node, _expr.Call) and \ + "shape_of" in str(pop_node.op): + # shape_of is the direct ancestor. + return _op.reshape_like(inputs[0], pop_node.args[0]) + shape_arg = pop_node return AttrCvt( op_name="reshape", extras={'newshape': shape_arg}, diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index c034bccc5dee..8be335842f0e 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -479,7 +479,7 @@ def dense_grad(orig, grad): @register_gradient("reshape") def reshape_grad(orig, grad): """Gradient of reshape""" - return [reshape_like(grad, orig.args[0])] + return [reshape_like(grad, orig.args[0]), orig.args[1]] @register_gradient("cast") diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index ee23fcefe010..43d8d6266b77 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -123,7 +123,7 @@ def concatenate_shape_func(attrs, inputs, _): return [_concatenate_shape_func(inputs, convert(axis))] @script -def _reshape_shape_func(data_shape, newshape, ndim): +def _reshape_shape_func_input_shape(data_shape, newshape, ndim): out = output_tensor((ndim,), "int64") src_idx = 0 dst_idx = 0 @@ -189,10 +189,83 @@ def _reshape_shape_func(data_shape, newshape, ndim): out[infer_idx] = old_size // new_size return out -@_reg.register_shape_func("reshape", False) +@script +def _reshape_shape_func_input_data(data, newshape, ndim): + out = output_tensor((ndim,), "int64") + data_shape = allocate((len(data.shape),), "int64") + for x in const_range(len(data.shape)): + data_shape[x] = int64(data.shape[x]) + src_idx = 0 + dst_idx = 0 + infer_idx = -1 + copy = False + skip = 0 + for i in const_range(len(newshape)): + if skip > 0: + skip -= 1 + elif newshape[i] > 0: + out[dst_idx] = int64(newshape[i]) + src_idx += 1 + dst_idx += 1 + elif newshape[i] == 0: + out[dst_idx] = data_shape[src_idx] + src_idx += 1 + dst_idx += 1 + elif newshape[i] == -1: + assert infer_idx < 0, "One and only one dim can be inferred" + out[dst_idx] = int64(1) + infer_idx = i + dst_idx += 1 + elif newshape[i] == -2: + copy = True + elif newshape[i] == -3: + assert data_shape.shape[0] - src_idx > 1, \ + "Not enough dims in input shape for -3" + out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] + src_idx += 2 + dst_idx += 1 + elif newshape[i] == -4: + assert len(newshape) - i > 2, "Not enough dims in new shape for -4" + if newshape[i+1] == -1: + assert newshape[i+2] != -1, "Split dims cannot both be -1." + out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2]) + out[dst_idx+1] = int64(newshape[i+2]) + else: + out[dst_idx] = int64(newshape[i+1]) + if newshape[i+2] == -1: + out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1]) + else: + out[dst_idx+1] = int64(newshape[i+2]) + assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\ + "Product of split dims doesn't match to input dim" + src_idx += 1 + dst_idx += 2 + skip = 2 + else: + assert False, "Invalid special values in new shape" + if len(data_shape.shape) > 0: + # if data is not constant, we can then handle -1 and -2 + if copy: + for i in range(src_idx, data_shape.shape[0]): + out[dst_idx] = data_shape[i] + dst_idx += 1 + if infer_idx >= 0: + old_size = int64(1) + for i in const_range(data_shape.shape[0]): + old_size *= data_shape[i] + new_size = int64(1) + for i in const_range(out.shape[0]): + new_size *= out[i] + out[infer_idx] = old_size // new_size + return out + +@_reg.register_shape_func("reshape", True) def reshape_shape_func(attrs, inputs, out_ndims): - newshape = get_const_tuple(attrs.newshape) - return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])] + if attrs.newshape is None: + return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] + return [_reshape_shape_func_input_shape(inputs[0], + convert(attrs.newshape), + out_ndims[0])] @script def _take_no_axis_shape_func(indices_shape, out_ndim): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 4e9bb45abd94..2d9e4ba40197 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -201,7 +201,7 @@ def reshape(data, newshape): data : relay.Expr The input data to the operator. - newshape : Union[int, Tuple[int], List[int]] + newshape : Union[int, Tuple[int], List[int]] or relay.Expr The new shape. Should be compatible with the original shape. Returns @@ -210,8 +210,10 @@ def reshape(data, newshape): The reshaped result. """ if isinstance(newshape, int): - newshape = [newshape] - return _make.reshape(data, list(newshape)) + newshape = const([newshape]) + if isinstance(newshape, (tuple, list)): + newshape = const(list(newshape)) + return _make.reshape(data, newshape) def argwhere(condition): """Find the indices of elements of a tensor that are diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 1d840166d89a..af23836b5479 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include "../transforms/pass_util.h" @@ -414,5 +415,44 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { return ret; } +struct IsDynamicVisitor : public TypeVisitor { + bool is_dyn{false}; + void VisitType_(const TensorTypeNode* tt) { + for (auto dim : tt->shape) { + if (dim.as()) { + is_dyn = true; + break; + } + } + } +}; + +bool IsDynamic(const Type& ty) { + IsDynamicVisitor v; + v.VisitType(ty); + return v.is_dyn; +} + +TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic); + +bool IsDataDependant(const CallNode* call) { + static auto tshape_data_dependant = Op::GetAttr("TShapeDataDependant"); + Op op = Downcast(call->op); + + if (!tshape_data_dependant.count(op)) { + return false; + } + + if (op->name == "reshape") { + if (const auto* attrs = call->attrs.as()) { + if (attrs->newshape) { + // If newshape attribute exists, it isn't data dependant. + return false; + } + } + } + + return tshape_data_dependant[op]; +} } // namespace relay } // namespace tvm diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 3851de144bf0..12a5add248d0 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -45,6 +45,7 @@ #include #include +#include "../transforms/pass_util.h" #include "utils.h" namespace tvm { @@ -70,27 +71,6 @@ CCacheKey::CCacheKey(Function source_func, Target target) { data_ = std::move(n); } -struct IsDynamicVisitor : public TypeVisitor { - bool is_dyn{false}; - void VisitType_(const TensorTypeNode* tt) { - for (auto dim : tt->shape) { - if (dim.as()) { - is_dyn = true; - break; - } - } - } -}; - -bool IsDynamic(const Type& ty) { - IsDynamicVisitor v; - v.VisitType(ty); - return v.is_dyn; -} - -// TODO(@jroesch): MOVE ME -TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic); - Array GetShape(const Array& shape) { // for now, we always use int32 shape when possible // even if the result of shape inference becomes int64. @@ -485,7 +465,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> CHECK_GT(tshape_data_dependant.count(op), 0) << "Internal error, cannot find TShapeDataDependant for " << op->name; - data_dependants_.push_back(tshape_data_dependant[op]); + data_dependants_.push_back(IsDataDependant(call_node)); // Visit all inputs Array inputs; int count_tuple = 0; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 15761f6eb0f4..8b58946dbd3a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -447,10 +447,54 @@ RELAY_REGISTER_OP("transpose") /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); +double ToScalar(const runtime::NDArray& array, int i = 0) { + if (array->dtype.code == kDLInt) { + if (array->dtype.bits == 8) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 16) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } else if (array->dtype.code == kDLUInt) { + if (array->dtype.bits == 8) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 16) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } else if (array->dtype.code == kDLFloat) { +#if (__ARM_FP16_FORMAT_IEEE == 1) + if (array->dtype.bits == 16) { + return reinterpret_cast<__fp16*>(array->data)[i]; + } +#endif + if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } + LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); + // make compiler happy + return -std::numeric_limits::infinity(); +} + bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // types: [data, result] - CHECK_EQ(types.size(), 2); + const auto* param = attrs.as(); + if (param->reverse) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + } else { + // types: [data, newshape, result] + CHECK_EQ(types.size(), 3); + } const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) @@ -458,17 +502,31 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } - const auto* param = attrs.as(); + Array oshape; Array data_shape; Array newshape; - if (param->reverse) { - data_shape.assign(data->shape.rbegin(), data->shape.rend()); - newshape.assign(param->newshape.rbegin(), param->newshape.rend()); + + if (param->newshape) { + auto temp = param->newshape.value(); + if (param->reverse) { + data_shape.assign(data->shape.rbegin(), data->shape.rend()); + newshape.assign(temp.rbegin(), temp.rend()); + } else { + data_shape = data->shape; + newshape = temp; + } } else { - data_shape = data->shape; - newshape = param->newshape; + const auto* newshape = types[1].as(); + + // Doesn't support dynamic output rank + for (int i = 0; i < newshape->shape[0].as()->value; i++) { + oshape.push_back(Any::make()); + } + + reporter->Assign(types[2], TensorType(oshape, data->dtype)); + return true; } - Array oshape; + std::unordered_set used_input_dims; std::unordered_set used_output_dims; size_t src_idx = 0; @@ -581,7 +639,7 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); } else { - reporter->Assign(types[1], TensorType(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); } return true; } @@ -601,12 +659,19 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in return {topi::reshape(inputs[0], newshape)}; } -Expr MakeReshape(Expr data, Array newshape) { +Expr MakeReshape(Expr data, Expr newshape) { auto attrs = make_object(); - attrs->newshape = std::move(newshape); + if (const ConstantNode* c = newshape.as()) { + CHECK_EQ(c->data->ndim, 1); + Array newshape; + for (int i = 0; i < c->data->shape[0]; i++) { + newshape.push_back(Integer(static_cast(ToScalar(c->data, i)))); + } + attrs->newshape = newshape; + } attrs->reverse = false; static const Op& op = Op::Get("reshape"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {data, newshape}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape); @@ -662,9 +727,10 @@ Example:: - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) )code" TVM_ADD_FILELINE) - .set_num_inputs(1) + .set_num_inputs(2) .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") + .add_argument("newshape", "Tensor", "The shape of output tensor.") .set_support_level(3) .add_type_rel("Reshape", ReshapeRel) .set_attr("FTVMCompute", ReshapeCompute) @@ -1005,44 +1071,6 @@ and type as the input array. // arange operator TVM_REGISTER_NODE_TYPE(ArangeAttrs); -double ToScalar(const runtime::NDArray& array) { - if (array->dtype.code == kDLInt) { - if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[0]; - } - } else if (array->dtype.code == kDLUInt) { - if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[0]; - } - } else if (array->dtype.code == kDLFloat) { -#if (__ARM_FP16_FORMAT_IEEE == 1) - if (array->dtype.bits == 16) { - return reinterpret_cast<__fp16*>(array->data)[0]; - } -#endif - if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[0]; - } - } - LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); - // make compiler happy - return -std::numeric_limits::infinity(); -} - bool ArangeRel(const Array& types, int num_inputs, const Attrs& raw_attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 1d1f9c0b64ee..bc35ed629fbd 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -38,7 +38,7 @@ namespace tvm { namespace relay { -extern Expr MakeReshape(Expr data, Array newshape); +extern Expr MakeReshape(Expr data, Expr newshape); template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 4c8025a8d382..4083d0816932 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -329,7 +329,8 @@ static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, arr.push_back(1); } } - return MakeReshape(scale, std::move(arr)); + return MakeReshape( + scale, MakeConstantTensor(DataType::Int(32), {static_cast(arr.size())}, arr)); } // if only one axis, use expand dim. Else, use reshape diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 0ca8d7c3bdf1..054244dc3516 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -31,6 +31,7 @@ #include #include "../../support/arena.h" +#include "pass_util.h" #include "pattern_util.h" namespace tvm { @@ -237,7 +238,13 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // need to call Update, as it may be an arbitrary expression. OpPatternKind op_pattern = kOpaque; if (const OpNode* opnode = call->op.as()) { - op_pattern = static_cast(fpattern[GetRef(opnode)]); + auto op = GetRef(opnode); + if (IsDynamic(call->checked_type()) && IsDataDependant(call)) { + // output of a shape func can't be fed to a data-dependent shape func + op_pattern = kOpaque; + } else { + op_pattern = static_cast(fpattern[op]); + } } else { this->Update(call->op, node, kOpaque); } diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 32ee09fe5a77..cbdd4b4a626b 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -76,6 +76,20 @@ Type TypeSubst(const Type& type, const tvm::Map& subst_map); */ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map); +/*! + * \brief Check if type is dynamic. + * \param ty The type to be checked. + * \return Whether the type is dynamic. + */ +bool IsDynamic(const Type& ty); + +/*! + * \brief Check if call is data dependant. + * \param call The call to be checked. + * \return Whether the call is data dependant. + */ +bool IsDataDependant(const CallNode* call); + /*! * \brief Make arbitrary transformation preserve the out most function. * \param func The transformation. diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index edb6a659f092..0a51404911e0 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -282,6 +282,34 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s return Constant(arr); } +/*! + * \brief Create a Constant with a tensor. + * + * \param dtype The data type. + * \param value The array of the tensor values. + * \return A Constant. + */ +template +static inline Constant MakeConstantTensor(DataType dtype, std::vector shape, + Array value) { + runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0}); + TVM_DTYPE_DISPATCH(dtype, DType, { + for (size_t i = 0; i < value.size(); i++) { + if (dtype == DataType::Float(16)) { + // convert to float16 + // storage is uint16_t + // Similar handling as that in MakeConstantScalar + *(static_cast(arr->data) + i) = + __truncXfYf2__( + static_cast(value[i])); + } else { + *(static_cast(arr->data) + i) = value[i]; + } + } + }) + return Constant(arr); +} + /*! * \brief Check if two expressions are equal scalars. * \param a The expression to be checked. @@ -519,12 +547,12 @@ static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclu return Call(op, {data}, Attrs(attrs), {}); } +Expr MakeReshape(Expr data, Expr newshape); + static inline Expr Reshape(Expr data, Array newshape) { - auto attrs = make_object(); - attrs->newshape = std::move(newshape); - attrs->reverse = false; - static const Op& op = Op::Get("reshape"); - return Call(op, {data}, Attrs(attrs), {}); + auto newshape_tensor = + MakeConstantTensor(DataType::Int(32), {static_cast(newshape.size())}, newshape); + return MakeReshape(data, newshape_tensor); } static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 33f606130d48..d7ce0c0e3d6c 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -105,6 +105,7 @@ TEST(Relay, BuildModule) { } auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs); (*reg)("add", "FTVMStrategy", fgeneric, 10); + (*reg)("add", "TShapeDataDependant", false, 10); // build auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); tvm::runtime::Module build_mod = (*pfb)(); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index cd6c45476a90..c3313b69a0bd 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -747,6 +747,17 @@ def _test_reshape_like(data, shape_like): compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0') +def _test_reshape_symbolic(data, a_data, b_data): + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + a = array_ops.placeholder(shape=a_data.shape, dtype=a_data.dtype) + b = array_ops.placeholder(shape=b_data.shape, dtype=b_data.dtype) + newshape = tf.add(a, b) + out = array_ops.reshape(in_data, newshape) + + for mode in ["debug", "vm"]: + compare_tf_with_tvm([data, a_data, b_data], [in_data.name, a.name, b.name], out.name, mode=mode) + def test_forward_reshape(): _test_reshape(np.arange(6.0), [2, 3]) _test_reshape(np.arange(6), [-1, 2]) @@ -754,6 +765,10 @@ def test_forward_reshape(): _test_reshape(np.arange(6), [-1]) _test_reshape_with_call() _test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2))) + _test_reshape_symbolic(np.arange(6.0), np.array([2, 0]), np.array([0, 3])) + _test_reshape_symbolic(np.arange(6), np.array([-1, 0]), np.array([0, 2])) + _test_reshape_symbolic(np.arange(6), np.array([3, 0]), np.array([3, -1])) + _test_reshape_symbolic(np.arange(6), np.array([0]), np.array([-1])) ####################################################################### # DepthToSpace diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6ce59bbf1c36..c9de6754aa89 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -138,23 +138,36 @@ def test_any_concat(): result = ex.evaluate()(x_np, y_np) tvm.testing.assert_allclose(result.asnumpy(), ref) -def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape): +def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): x = relay.var('x', shape=x_shape, dtype="float32") - y = relay.reshape(x, newshape=newshape) - mod = tvm.IRModule() - mod["main"] = relay.Function([x], y) + relu_x = relay.nn.relu(x) data = np.random.uniform(size=x_np_shape).astype('float32') + params = [x] + args = [data] + + if variable_newshape: + newshape_var = relay.var('newshape', shape=(len(newshape),), dtype='int64') + params.append(newshape_var) + args.append(np.array(newshape, dtype='int64')) + newshape = newshape_var + + y = relay.reshape(relu_x, newshape=newshape) + mod = tvm.IRModule() + mod["main"] = relay.Function(params, y) + for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data).asnumpy() + result = ex.evaluate()(*args).asnumpy() assert result.shape == out_shape tvm.testing.assert_allclose(result.flatten(), data.flatten()) def test_any_reshape(): - verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24)) - verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12)) + for variable_newshape in [False, True]: + # Variable newshape only supports that output rank is the same as newshape + verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24), variable_newshape) + verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12), variable_newshape) + verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4), variable_newshape) verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4)) - verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 2334de7e6905..e1fdfcb7ba0f 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -345,9 +345,9 @@ def visit_call(self, call): method, align_corners) elif call.op == self.reshape and len(input_types[0].shape) == 4: - data, = args + data, _ = args data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) - return op.reshape(data, input_types[0].shape) + return op.reshape(data, [int(x) for x in input_types[0].shape]) return relay.Call( self.visit(call.op),