From 026ea8cbcb9f1e1be2191b22c2681c663b718ddc Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 23 Jan 2019 20:59:33 -0800 Subject: [PATCH 1/4] Enable reverse in reshape --- include/tvm/relay/attrs/transform.h | 4 ++ python/tvm/relay/frontend/nnvm_common.py | 5 +- python/tvm/relay/op/transform.py | 14 ++++- src/relay/op/tensor/transform.cc | 69 ++++++++++++++++-------- tests/python/relay/test_op_level3.py | 32 ++++++++--- 5 files changed, 90 insertions(+), 34 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 36ccd3fb6c89..fa27a4d437d2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -63,9 +63,13 @@ struct TransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { Array newshape; + bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_ATTR_FIELD(newshape) .describe("The new shape. Should be compatible with the original shape."); + TVM_ATTR_FIELD(reverse) + .describe("Infer the special values from right to left if true") + .set_default(false); } }; // struct ReshapeAttrs diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index d75b0cac65bc..85e73de3a7dd 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -22,10 +22,9 @@ def impl(inputs, _, _dtype='float32'): def _reshape(inputs, attrs): - if attrs.get_bool("reverse", False): - raise RuntimeError("reshape do not support option reverse") shape = attrs.get_int_tuple("shape") - return _op.reshape(inputs[0], newshape=shape) + reverse = attrs.get_bool("reverse", False) + return _op.reshape(inputs[0], newshape=shape, reverse=reverse) def _init_op(new_op): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 454d335692c6..cf251bdf55f4 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -92,7 +92,7 @@ def squeeze(data, axis=None): return _make.squeeze(data, axis) -def reshape(data, newshape): +def reshape(data, newshape, reverse=False): """Reshapes the input array. Example:: @@ -144,6 +144,13 @@ def reshape(data, newshape): - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) + - If the argument reverse is set to True, then the special values are inferred from right to left. + + Example:: + + - with reverse=False, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5). + - with reverse=True, output shape will be (50,4). + Parameters ---------- data : relay.Expr @@ -152,6 +159,9 @@ def reshape(data, newshape): newshape : Union[int, Tuple[int], List[int]] The new shape. Should be compatible with the original shape. + reverse : bool + Infer the special values from right to left if set to be True. + Returns ------- result : relay.Expr @@ -159,7 +169,7 @@ def reshape(data, newshape): """ if isinstance(newshape, int): newshape = [newshape] - return _make.reshape(data, list(newshape)) + return _make.reshape(data, list(newshape), reverse) def reshape_like(data, shape_like): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f4d625195a2f..abba35a11cac 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -382,20 +382,29 @@ bool ReshapeRel(const Array& types, } const auto* param = attrs.as(); + Array data_shape; + Array newshape; + if (param->reverse) { + data_shape.assign(data->shape.rbegin(), data->shape.rend()); + newshape.assign(param->newshape.rbegin(), param->newshape.rend()); + } else { + data_shape = data->shape; + newshape = param->newshape; + } Array oshape; size_t src_idx = 0; int infer_idx = -1; - for (size_t i = 0; i < param->newshape.size(); ++i) { - int svalue = param->newshape[i]->value; + for (size_t i = 0; i < newshape.size(); ++i) { + int svalue = newshape[i]->value; // special flag handling for shape inference. if (svalue > 0) { - oshape.push_back(param->newshape[i]); + oshape.push_back(newshape[i]); ++src_idx; } else if (svalue == 0) { // keep same - CHECK_LT(src_idx, data->shape.size()); - oshape.push_back(data->shape[src_idx++]); + CHECK_LT(src_idx, data_shape.size()); + oshape.push_back(data_shape[src_idx++]); } else if (svalue == -1) { // inference based on rest CHECK_LT(infer_idx, 0) @@ -405,42 +414,51 @@ bool ReshapeRel(const Array& types, ++src_idx; } else if (svalue == -2) { // copy all remaining dims from source - while (src_idx < data->shape.size()) { - oshape.push_back(data->shape[src_idx++]); + while (src_idx < data_shape.size()) { + oshape.push_back(data_shape[src_idx++]); } } else if (svalue == -3) { // merge two dims from source - CHECK_LT(src_idx + 1, data->shape.size()); - IndexExpr d1 = data->shape[src_idx++]; - IndexExpr d2 = data->shape[src_idx++]; + CHECK_LT(src_idx + 1, data_shape.size()); + IndexExpr d1 = data_shape[src_idx++]; + IndexExpr d2 = data_shape[src_idx++]; oshape.push_back(d1 * d2); } else if (svalue == -4) { // split the source dim s into two dims // read the left dim and then the right dim (either can be -1) - CHECK_LT(i + 2, param->newshape.size()); - CHECK_LT(src_idx, data->shape.size()); - IndexExpr d0 = data->shape[src_idx++]; - Integer d1 = param->newshape[++i]; - Integer d2 = param->newshape[++i]; + CHECK_LT(i + 2, newshape.size()); + CHECK_LT(src_idx, data_shape.size()); + IndexExpr d0 = data_shape[src_idx++]; + Integer d1 = newshape[++i]; + Integer d2 = newshape[++i]; if (d1->value == -1) { CHECK(d2->value != -1) << "Split dims cannot both be -1."; oshape.push_back(d0 / d2); oshape.push_back(d2); } else { - CHECK_EQ(d2->value, -1); oshape.push_back(d1); - oshape.push_back(d0 / d1); + if (d2->value == -1) { + oshape.push_back(d0 / d1); + } else { + oshape.push_back(d2); + } } } } if (infer_idx >= 0) { IndexExpr new_size = arith::ComputeReduce(oshape, 1); - IndexExpr old_size = arith::ComputeReduce(data->shape, 1); + IndexExpr old_size = arith::ComputeReduce(data_shape, 1); oshape.Set(infer_idx, old_size / new_size); } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + + if (param->reverse) { + reporter->Assign(types[1], TensorTypeNode::make( + Array(oshape.rbegin(), oshape.rend()), data->dtype)); + } else { + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + } return true; } @@ -454,16 +472,18 @@ Array ReshapeCompute(const Attrs& attrs, } Expr MakeReshape(Expr data, - Array newshape) { + Array newshape, + bool reverse) { auto attrs = make_node(); attrs->newshape = std::move(newshape); + attrs->reverse = reverse; static const Op& op = Op::Get("reshape"); return CallNode::make(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op._make.reshape") .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeReshape, args, rv); + runtime::detail::unpack_call(MakeReshape, args, rv); }); RELAY_REGISTER_OP("reshape") @@ -516,6 +536,13 @@ Example:: - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) +- If the argument reverse is set to true, then the special values are inferred from right to left. + +Example:: + +- with reverse=false, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5). +- with reverse=true, output shape will be (50,4). + )code" TVM_ADD_FILELINE) .set_num_inputs(1) .set_attrs_type_key("relay.attrs.ReshapeAttrs") diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index e66a73cf84cb..838c154ef80e 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -152,25 +152,41 @@ def test_reshape_infer_type(): (n, t, 2000), "float32") def test_reshape(): - def verify_reshape(shape, oshape): - x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") - ref_res = np.reshape(x_data, oshape) - + def verify_reshape(shape, newshape, oshape, reverse=False): x = relay.var("x", relay.TensorType(shape, "float32")) - z = relay.reshape(x, newshape=ref_res.shape) + z = relay.reshape(x, newshape=newshape, reverse=reverse) zz = relay.ir_pass.infer_type(z) assert "newshape=" in z.astext() assert zz.checked_type == relay.ty.TensorType(oshape, "float32") func = relay.Function([x], z) - + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = np.reshape(x_data, oshape) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_reshape((2, 3, 4), (8, 3)) - verify_reshape((4, 7), (2, 7, 2)) + verify_reshape((2, 3, 4), (8, 3), (8, 3)) + verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) + verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) + verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2), True) + verify_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4)) + verify_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4), True) + verify_reshape((2, 3, 4), (0, -1), (2, 12)) + verify_reshape((2, 3, 4), (0, -1), (3, 8), True) + verify_reshape((2, 3, 4), (-1, 0), (8, 3)) + verify_reshape((2, 3, 4), (-1, 0), (6, 4), True) + verify_reshape((2, 3, 4), (2, -2), (2, 3, 4)) + verify_reshape((2, 3, 4), (-2, 1, 1), (2, 3, 4, 1, 1)) + verify_reshape((2, 3, 4), (-3, 4), (6, 4)) + verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20)) + verify_reshape((2, 3, 4), (0, -3), (2, 12)) + verify_reshape((2, 3, 4), (0, -3), (2, 12), True) + verify_reshape((2, 3, 4), (-3, -2), (6, 4)) + verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4)) + verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) + def test_reshape_like_infer_type(): # concrete shape From 10a4117ebb3e0b6cee4929c49271b2a8fa18b3d5 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 24 Jan 2019 11:49:35 -0800 Subject: [PATCH 2/4] Fix lint and typo --- python/tvm/relay/op/transform.py | 9 +++++---- src/relay/op/tensor/transform.cc | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index cf251bdf55f4..8c5f58bd161b 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -141,15 +141,16 @@ def reshape(data, newshape, reverse=False): Example:: - - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) + - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape = (1,2,3,4) - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) - - If the argument reverse is set to True, then the special values are inferred from right to left. + - If the argument reverse is set to True, then the special values are inferred + from right to left. Example:: - - with reverse=False, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5). - - with reverse=True, output shape will be (50,4). + - with reverse = False, data.shape = (10,5,4), newshape = (-1,0), result.shape = (40,5) + - with reverse = True, result.shape = (50,4) Parameters ---------- diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index abba35a11cac..52b940f7d32a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -540,8 +540,8 @@ Example:: Example:: -- with reverse=false, for input shape = (10,5,4), shape = (-1,0), output shape would be (40,5). -- with reverse=true, output shape will be (50,4). +- with reverse = False, data.shape = (10,5,4), newshape = (-1,0), result.shape = (40,5) +- with reverse = True, result.shape = (50,4) )code" TVM_ADD_FILELINE) .set_num_inputs(1) From 3272f0c5b016d10295df6d57b9976bdde4ab9a15 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 31 Jan 2019 16:57:45 -0800 Subject: [PATCH 3/4] Put reverse_reshape into a separate op --- docs/langref/relay_op.rst | 4 +- python/tvm/relay/frontend/nnvm_common.py | 4 +- python/tvm/relay/op/_transform.py | 1 + python/tvm/relay/op/transform.py | 46 +++++++++++++++------ src/relay/op/tensor/transform.cc | 52 +++++++++++++++++++----- tests/python/relay/test_op_level10.py | 23 +++++++++++ tests/python/relay/test_op_level3.py | 9 +--- 7 files changed, 106 insertions(+), 33 deletions(-) diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 82272b0249bd..e1f38c61eb1f 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -149,6 +149,7 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.layout_transform tvm.relay.device_copy tvm.relay.annotation.on_device + tvm.relay.reverse_reshape Level 1 Definitions @@ -257,4 +258,5 @@ Level 10 Definitions .. autofunction:: tvm.relay.slice_like .. autofunction:: tvm.relay.layout_transform .. autofunction:: tvm.relay.device_copy -.. autofunction:: tvm.relay.annotation.on_device \ No newline at end of file +.. autofunction:: tvm.relay.annotation.on_device +.. autofunction:: tvm.relay.reverse_reshape diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index 85e73de3a7dd..3838c3d4aa3b 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -24,7 +24,9 @@ def impl(inputs, _, _dtype='float32'): def _reshape(inputs, attrs): shape = attrs.get_int_tuple("shape") reverse = attrs.get_bool("reverse", False) - return _op.reshape(inputs[0], newshape=shape, reverse=reverse) + if reverse: + return _op.reverse_reshape(inputs[0], newshape=shape) + return _op.reshape(inputs[0], newshape=shape) def _init_op(new_op): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 085a8ceed5d1..abf0b5317b48 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -26,6 +26,7 @@ _reg.register_schedule("take", schedule_injective) _reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("where", schedule_broadcast) +_reg.register_schedule("_contrib_reverse_reshape", schedule_injective) # layout_transform _reg.register_schedule("layout_transform", schedule_injective) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 8c5f58bd161b..5ecf25a0559c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -92,7 +92,7 @@ def squeeze(data, axis=None): return _make.squeeze(data, axis) -def reshape(data, newshape, reverse=False): +def reshape(data, newshape): """Reshapes the input array. Example:: @@ -144,14 +144,6 @@ def reshape(data, newshape, reverse=False): - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape = (1,2,3,4) - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) - - If the argument reverse is set to True, then the special values are inferred - from right to left. - - Example:: - - - with reverse = False, data.shape = (10,5,4), newshape = (-1,0), result.shape = (40,5) - - with reverse = True, result.shape = (50,4) - Parameters ---------- data : relay.Expr @@ -160,9 +152,6 @@ def reshape(data, newshape, reverse=False): newshape : Union[int, Tuple[int], List[int]] The new shape. Should be compatible with the original shape. - reverse : bool - Infer the special values from right to left if set to be True. - Returns ------- result : relay.Expr @@ -170,7 +159,7 @@ def reshape(data, newshape, reverse=False): """ if isinstance(newshape, int): newshape = [newshape] - return _make.reshape(data, list(newshape), reverse) + return _make.reshape(data, list(newshape)) def reshape_like(data, shape_like): @@ -460,3 +449,34 @@ def layout_transform(data, src_layout, dst_layout): The transformed tensor. """ return _make.layout_transform(data, src_layout, dst_layout) + + +def reverse_reshape(data, newshape): + """Reshapes the input array where the special values are inferred from + right to left. + + Example:: + + The special values have the same semantics as :py:class:`tvm.relay.reshape`. + The difference is that special values are inferred from right to left. It + can be explained in the example below:: + + - data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5) + - data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5) + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + newshape : Union[int, Tuple[int], List[int]] + The new shape. Should be compatible with the original shape. + + Returns + ------- + result : relay.Expr + The reshaped result. + """ + if isinstance(newshape, int): + newshape = [newshape] + return _make._contrib_reverse_reshape(data, list(newshape)) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 52b940f7d32a..b38999d1b1b7 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -472,18 +472,17 @@ Array ReshapeCompute(const Attrs& attrs, } Expr MakeReshape(Expr data, - Array newshape, - bool reverse) { + Array newshape) { auto attrs = make_node(); attrs->newshape = std::move(newshape); - attrs->reverse = reverse; + attrs->reverse = false; static const Op& op = Op::Get("reshape"); return CallNode::make(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op._make.reshape") .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeReshape, args, rv); + runtime::detail::unpack_call(MakeReshape, args, rv); }); RELAY_REGISTER_OP("reshape") @@ -536,13 +535,6 @@ Example:: - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) -- If the argument reverse is set to true, then the special values are inferred from right to left. - -Example:: - -- with reverse = False, data.shape = (10,5,4), newshape = (-1,0), result.shape = (40,5) -- with reverse = True, result.shape = (50,4) - )code" TVM_ADD_FILELINE) .set_num_inputs(1) .set_attrs_type_key("relay.attrs.ReshapeAttrs") @@ -1726,5 +1718,43 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); + +/* relay._contrib_reverse_reshape */ +Expr MakeReverseReshape(Expr data, + Array newshape) { + auto attrs = make_node(); + attrs->newshape = std::move(newshape); + attrs->reverse = true; + static const Op& op = Op::Get("_contrib_reverse_reshape"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeReverseReshape, args, rv); +}); + +RELAY_REGISTER_OP("_contrib_reverse_reshape") +.describe(R"code(Reshapes the input array where the special values are inferred from +right to left. + +Example:: + +The special values have the same semantics as reshape. The difference is that +special values are inferred from right to left. It can be explained in the +example below:: + +- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5) +- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5) + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ReshapeAttrs") +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(10) +.add_type_rel("Reshape", ReshapeRel) +.set_attr("FTVMCompute", ReshapeCompute) +.set_attr("TOpPattern", kInjective); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 2c0ed73a7535..a6e169e23a6c 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -121,8 +121,31 @@ def test_slice_like(): axes=(2, 3), output=(1, 3, 112, 112)) +def test_reverse_reshape(): + def verify_reverse_reshape(shape, newshape, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.reverse_reshape(x, newshape=newshape) + zz = relay.ir_pass.infer_type(z) + print(zz.checked_type) + assert "newshape=" in z.astext() + assert zz.checked_type == relay.ty.TensorType(oshape, "float32") + + func = relay.Function([x], z) + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = np.reshape(x_data, oshape) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_reverse_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) + verify_reverse_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4)) + verify_reverse_reshape((2, 3, 4), (0, -1), (3, 8)) + verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4)) + verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12)) if __name__ == "__main__": test_collapse_sum_like() test_broadcast_to_like() test_slice_like() + test_reverse_reshape() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 838c154ef80e..550637023d43 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -152,9 +152,9 @@ def test_reshape_infer_type(): (n, t, 2000), "float32") def test_reshape(): - def verify_reshape(shape, newshape, oshape, reverse=False): + def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) - z = relay.reshape(x, newshape=newshape, reverse=reverse) + z = relay.reshape(x, newshape=newshape) zz = relay.ir_pass.infer_type(z) assert "newshape=" in z.astext() assert zz.checked_type == relay.ty.TensorType(oshape, "float32") @@ -170,19 +170,14 @@ def verify_reshape(shape, newshape, oshape, reverse=False): verify_reshape((2, 3, 4), (8, 3), (8, 3)) verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) - verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2), True) verify_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4)) - verify_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4), True) verify_reshape((2, 3, 4), (0, -1), (2, 12)) - verify_reshape((2, 3, 4), (0, -1), (3, 8), True) verify_reshape((2, 3, 4), (-1, 0), (8, 3)) - verify_reshape((2, 3, 4), (-1, 0), (6, 4), True) verify_reshape((2, 3, 4), (2, -2), (2, 3, 4)) verify_reshape((2, 3, 4), (-2, 1, 1), (2, 3, 4, 1, 1)) verify_reshape((2, 3, 4), (-3, 4), (6, 4)) verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20)) verify_reshape((2, 3, 4), (0, -3), (2, 12)) - verify_reshape((2, 3, 4), (0, -3), (2, 12), True) verify_reshape((2, 3, 4), (-3, -2), (6, 4)) verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4)) verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) From 61e3097c706dfec29e041cff39b1570d625f441b Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 31 Jan 2019 17:26:55 -0800 Subject: [PATCH 4/4] Fix pylint --- python/tvm/relay/op/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 5ecf25a0559c..78efd3cfd4d9 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -458,7 +458,7 @@ def reverse_reshape(data, newshape): Example:: The special values have the same semantics as :py:class:`tvm.relay.reshape`. - The difference is that special values are inferred from right to left. It + The difference is that special values are inferred from right to left. It can be explained in the example below:: - data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5)