diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 405f071e3283..e99ac3c97f73 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -123,6 +123,7 @@ This level enables additional math and transform operators. tvm.relay.min tvm.relay.mean tvm.relay.prod + tvm.relay.strided_slice **Level 5: Vision/Image Operators** @@ -227,6 +228,7 @@ Level 4 Definitions .. autofunction:: tvm.relay.min .. autofunction:: tvm.relay.mean .. autofunction:: tvm.relay.prod +.. autofunction:: tvm.relay.strided_slice diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cb87d358e966..4d2008628d3a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -123,6 +123,21 @@ struct SplitAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for StridedSlice operator */ +struct StridedSliceAttrs : public tvm::AttrsNode { + Array begin; + Array end; + Array strides; + + TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { + TVM_ATTR_FIELD(begin) + .describe("Indices for begin of slice, begin index is also inclusive"); + TVM_ATTR_FIELD(end) + .describe("Indices for end of slice, end index is also inclusive"); + TVM_ATTR_FIELD(strides).set_default(Array({})) + .describe("Stride values of the slice"); + } +}; } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 4d08bf761326..2f42727d6083 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -980,23 +980,25 @@ Examples:: const Array& inputs, const Array& out_info) { const StridedSliceParam& param = nnvm::get(attrs.parsed); - Array begin; - Array end; - Array stride; + Array begin; + Array end; + Array stride; for (int64_t i : param.begin) { - begin.push_back(tvm::make_const(tvm::Int(32), i)); + begin.push_back(static_cast(i)); } for (int64_t i : param.end) { - end.push_back(tvm::make_const(tvm::Int(32), i)); + end.push_back(static_cast(i)); } for (int64_t i : param.stride) { - stride.push_back(tvm::make_const(tvm::Int(32), i)); + stride.push_back(static_cast(i)); } - return Array{ topi::strided_slice(inputs[0], begin, end, stride) }; + return Array{ + topi::strided_slice(inputs[0], begin, end, stride) + }; }) .set_support_level(1); @@ -1210,6 +1212,15 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, return true; } +// Adapter function to make int array. +Array GetIntArray(Array arr) { + for (size_t i = 0; i < arr.size(); ++i) { + CHECK(!arr[i].defined() || arr[i].as()) + << "Expect an int array"; + } + return Array(arr.node_); +} + NNVM_REGISTER_OP(slice_like) .describe(R"code(Slice the first input respect to the second input. )code" NNVM_ADD_FILELINE) @@ -1261,7 +1272,10 @@ NNVM_REGISTER_OP(slice_like) } } return Array{ - topi::strided_slice(inputs[0], begin_idx, end_idx, strides) + topi::strided_slice(inputs[0], + GetIntArray(begin_idx), + GetIntArray(end_idx), + GetIntArray(strides)) }; }) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/node_generic.py index b7230f29da59..e86453499faa 100644 --- a/python/tvm/_ffi/node_generic.py +++ b/python/tvm/_ffi/node_generic.py @@ -56,6 +56,8 @@ def convert_to_node(value): return _api_internal._Map(*vlist) elif isinstance(value, NodeGeneric): return value.asnode() + elif value is None: + return None else: raise ValueError("don't know how to convert type %s to node" % type(value)) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 9b581486608b..30aef433d7c6 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -13,6 +13,7 @@ # operator registry from . import _tensor +from . import _transform from ..expr import Expr from ..base import register_relay_node diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py new file mode 100644 index 000000000000..7867336d033f --- /dev/null +++ b/python/tvm/relay/op/_transform.py @@ -0,0 +1,8 @@ +#pylint: disable=invalid-name, unused-argument +"""Backend compiler related feature registration""" +from __future__ import absolute_import +from . import op as _reg +from .op import schedule_injective + +# strided_slice +_reg.register_schedule("strided_slice", schedule_injective) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 909b175f08ca..e43a4a573e54 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -334,3 +334,30 @@ def split(data, indices_or_sections, axis=0): else: ret_size = len(indices_or_sections) + 1 return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) + + +def strided_slice(data, begin, end, strides=None): + """Strided slice of an array.. + + Parameters + ---------- + data : relay.Expr + The source array to be sliced. + + begin: list of int + The indices to begin with in the slicing. + + end: list of int + Indicies indicating end of the slice. + + strides: list of int, optional + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + strides = strides or [] + return _make.strided_slice(data, list(begin), list(end), list(strides)) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 75365da5bf50..3525e23b8b20 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -47,7 +47,11 @@ TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { std::vector > data; for (int i = 0; i < args.size(); ++i) { - data.push_back(args[i].node_sptr()); + if (args[i].type_code() != kNull) { + data.push_back(args[i].node_sptr()); + } else { + data.push_back(NodePtr(nullptr)); + } } auto node = make_node(); node->data = std::move(data); diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 93ed76bed3c2..bfc5f0db52b7 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -403,7 +403,11 @@ class TextPrinter : * \param os The output type. */ void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*) - this->VisitAttr(value, os); + if (value.defined()) { + this->VisitAttr(value, os); + } else { + os << "None"; + } } //------------------------------------ // Overload of Attr printing functions diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 20e0e3adbfd3..98ac1c30b66c 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "../op_common.h" @@ -890,6 +891,173 @@ RELAY_REGISTER_OP("broadcast_to_like") .set_support_level(10) .add_type_rel("BroadCastToLike", BroadCastToLikeRel); + +// strided_slice +TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); +bool StridedSliceRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const StridedSliceAttrs *param = attrs.as(); + CHECK(param != nullptr); + + auto dshape = data->shape; + auto num_axis = dshape.size(); + + std::vector stride_vec; + for (Integer i : param->strides) { + CHECK(i.defined()); + stride_vec.push_back(i->value); + } + for (size_t i = stride_vec.size(); i < num_axis; ++i) { + stride_vec.push_back(1); + } + const int64_t max_range = std::numeric_limits::max(); + + std::vector begin_vec; + for (size_t i = 0; i < param->begin.size(); ++i) { + if (!param->begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(param->begin[i]->value); + } + } + for (size_t i = begin_vec.size(); i < num_axis; ++i) { + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } + + std::vector end_vec; + for (size_t i = 0; i < param->end.size(); ++i) { + // allow end to be None + if (!param->end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(param->end[i]->value); + } + } + for (size_t i = end_vec.size(); i < num_axis; ++i) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } + + std::vector oshape(dshape.size()); + for (size_t i = 0; i < num_axis; ++i) { + int64_t stride_v = stride_vec[i]; + int64_t begin_v = begin_vec[i]; + int64_t end_v = end_vec[i]; + + if ((stride_v == 1 && + begin_v == 0 && + end_v == max_range) || + (stride_v == -1 && + begin_v == max_range && + end_v == 0)) { + // Quick path, do not slice this dimension. + oshape[i] = dshape[i]; + continue; + } + // Normal path, require the shape to be concrete integer. + // Require concrete integer as symbolic inference of min/max + // can get complicated and not very helpful. + const int64_t* p_dim_size = as_const_int(dshape[i]); + CHECK(p_dim_size) + << "strided_slice requires sliced dimension to be concrete int"; + int64_t dim_size = p_dim_size[0]; + begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; + end_v = (end_v < 0) ? dim_size + end_v : end_v; + + int64_t slice_range, step; + if (stride_v < 0) { + if (end_v < -1) end_v = -1; + CHECK_LT(end_v, begin_v) + << "strided_slice get empty slice at axis " << i; + begin_v = std::min(dim_size - 1, begin_v); + slice_range = begin_v - end_v; + step = -stride_v; + } else { + if (begin_v < 0) begin_v = 0; + CHECK_GE(stride_v, 0); + CHECK_LT(begin_v, end_v) + << "strided_slice get empty slice at axis " << i; + end_v = std::min(dim_size, end_v); + slice_range = end_v - begin_v; + step = stride_v; + } + oshape[i] = make_const(dshape[i].type(), (slice_range + step - 1) / step); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + + +// Positional relay function to create StridedSlice operator used by frontend FFI. +Expr MakeStridedSlice(Expr data, + Array begin, + Array end, + Array strides) { + auto attrs = make_node(); + attrs->begin = std::move(begin); + attrs->end = std::move(end); + attrs->strides = std::move(strides); + static const Op& op = Op::Get("strided_slice"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +Array StridedSliceCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const StridedSliceAttrs *param = attrs.as(); + CHECK(param != nullptr); + return Array{ + topi::strided_slice(inputs[0], param->begin, param->end, param->strides) + }; +} + + +TVM_REGISTER_API("relay.op._make.strided_slice") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeStridedSlice, args, rv); + }); + + +RELAY_REGISTER_OP("strided_slice") + .describe(R"code(Strided slice of an array. + +Examples:: + + x = [[ 1., 4., 7., 10.], + [ 2., 5., 8., 11.], + [ 3., 6., 9., 12.]] + + strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.], + [ 5., 8., 11.]] + + x = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] + + strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(4) +.set_attrs_type_key("relay.attrs.StridedSliceAttrs") +.add_type_rel("StridedSlice", StridedSliceRel) +.set_attr("FTVMCompute", StridedSliceCompute) +.set_attr("TOpPattern", kInjective); + + // Split TVM_REGISTER_NODE_TYPE(SplitAttrs); diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 6fd70c386567..dd12dc7cff3a 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -2,7 +2,7 @@ import numpy as np from tvm import relay from tvm.relay.testing import ctx_list - +import topi.testing def test_binary_op(): def check_binary_op(opfunc, ref): @@ -142,7 +142,43 @@ def test_reduce_functions(): verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) + +def test_strided_slice(): + def verify(dshape, begin, end, strides, output, test_ref=True): + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.strided_slice(x, begin=begin, end=end, strides=strides) + func = relay.Function([x], z) + func = relay.ir_pass.infer_type(func) + text = func.astext() + assert "begin=" in text + assert "end=" in text + if output: + assert func.body.checked_type == relay.ty.TensorType(output, "float32") + if not test_ref: + return + x_data = np.random.uniform(size=dshape).astype("float32") + ref_res = topi.testing.strided_slice_python( + x_data, begin, end, strides) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) + + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") + verify((d1, d2, 3), [None, None, 1], [None, None, 2], None, (d1, d2, 1), False) + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) + + if __name__ == "__main__": + test_strided_slice() test_binary_op() test_cmp_type() test_binary_int_broadcast() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 7fc408c2c79c..cb09f1cb419e 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "topi/tags.h" #include "topi/detail/ravel_unravel.h" @@ -403,31 +404,51 @@ inline Array split(const Tensor& x, * \return A Tensor whose op member is the split operation */ inline Tensor strided_slice(const Tensor& x, - const Array& begin, - const Array& end, - const Array& strides, + const Array& begin, + const Array& end, + const Array& strides, std::string name = "tensor", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); - std::vector begin_vec = GetConstInt64Values(begin, "begin"); - std::vector end_vec = GetConstInt64Values(end, "end"); - std::vector stride_vec = GetConstInt64Values(strides, "strides"); - // in case user has not provided begin indices for all the axes, - // then inflate it with default value = 0 - for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { - begin_vec.push_back(0); - } - // in case user has not provided end indices for all the axes, - // then inflate it with default value = input_tensor.shape[axis] - for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { - end_vec.push_back(GetConstInt(x->shape[i])); + // Setup the ranges. + // NOTE: this code duplicates the shape inference logic relay.op + // Consider to refactor in the future. + std::vector stride_vec; + for (Integer i : strides) { + CHECK(i.defined()); + stride_vec.push_back(i->value); } - // in case user has not provided stride values, - // then inflate it with default value = 1 for (size_t i = stride_vec.size(); i < src_tensor_dim; ++i) { stride_vec.push_back(1); } + const int64_t max_range = std::numeric_limits::max(); + + std::vector begin_vec; + for (size_t i = 0; i < begin.size(); ++i) { + if (!begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(begin[i]->value); + } + } + for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } + std::vector end_vec; + for (size_t i = 0; i < end.size(); ++i) { + // allow end to be None + if (!end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(end[i]->value); + } + } + for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } + // Compute Array out_shape; Array begin_expr; Array strides_expr; diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 8a3269ba83ae..c496e08c1835 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -19,3 +19,4 @@ from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python from .gather_nd_python import gather_nd_python +from .strided_slice_python import strided_slice_python diff --git a/topi/python/topi/testing/strided_slice_python.py b/topi/python/topi/testing/strided_slice_python.py new file mode 100644 index 000000000000..4407b3bec1c7 --- /dev/null +++ b/topi/python/topi/testing/strided_slice_python.py @@ -0,0 +1,32 @@ +"""gather_nd in python""" + +def strided_slice_python(data, begin, end, strides): + """Python version of strided slice operator. + + Parameters + ---------- + data : numpy.ndarray + Input data + + begin : list + Begining of the slices. + + end : list + End of the slices. + + strides : list + The stride of each slice. + + Returns + ------- + result : numpy.ndarray + The sliced result. + """ + strides = [] if strides is None else strides + slices = [] + for i in range(len(data.shape)): + slices.append(slice( + begin[i] if i < len(begin) else None, + end[i] if i < len(end) else None, + strides[i] if i < len(strides) else None)) + return data[tuple(slices)] diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 75e4d3b675b0..dc3c3fb70b24 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -249,13 +249,11 @@ def check_device(device): for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) -def verify_strided_slice(in_shape, begin, end, stride=None): - stride = stride if stride else [1, 1, 1] +def verify_strided_slice(in_shape, begin, end, strides=None): A = tvm.placeholder(shape=in_shape, name="A") - B = topi.strided_slice(A, begin, end, stride) + 1 - def test_forward(x, begin, end, stride): - return x[begin[0]:end[0]:stride[0], - begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1 + strides = [1,1,1] if strides is None else strides + B = topi.strided_slice(A, begin, end, strides) + 1 + def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: @@ -267,7 +265,8 @@ def check_device(device): foo = tvm.build(s, [A, B], device, name="stride_slice") x_np = np.random.uniform(size=in_shape).astype(A.dtype) - out_npy = test_forward(x_np, begin, end, stride) + out_npy = topi.testing.strided_slice_python( + x_np, begin, end, strides) + 1 data_nd = tvm.nd.array(x_np, ctx) out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype) foo(data_nd, out_nd) @@ -298,7 +297,7 @@ def check_device(device): shape_size = shape_size * src_shape[i] data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) out_npys = topi.testing.gather_nd_python(data_npy, indices_src) - + data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) @@ -412,6 +411,7 @@ def test_gather_nd(): indices_dtype) if __name__ == "__main__": + test_strided_slice() test_concatenate() test_tranpose() test_expand_dims() @@ -421,5 +421,4 @@ def test_gather_nd(): test_flip() test_expand_like() test_take() - test_strided_slice() test_gather_nd()