Skip to content

Commit

Permalink
[RELAY][OP] strided_slice (#2094)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Nov 13, 2018
1 parent 4369b7f commit 1f2c815
Show file tree
Hide file tree
Showing 15 changed files with 371 additions and 37 deletions.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ This level enables additional math and transform operators.
tvm.relay.min
tvm.relay.mean
tvm.relay.prod
tvm.relay.strided_slice


**Level 5: Vision/Image Operators**
Expand Down Expand Up @@ -227,6 +228,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.prod
.. autofunction:: tvm.relay.strided_slice



Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
}
};

/*! \brief Attributes for StridedSlice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Array<Integer> begin;
Array<Integer> end;
Array<Integer> strides;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin)
.describe("Indices for begin of slice, begin index is also inclusive");
TVM_ATTR_FIELD(end)
.describe("Indices for end of slice, end index is also inclusive");
TVM_ATTR_FIELD(strides).set_default(Array<Integer>({}))
.describe("Stride values of the slice");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
30 changes: 22 additions & 8 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,23 +980,25 @@ Examples::
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed);
Array<Expr> begin;
Array<Expr> end;
Array<Expr> stride;
Array<Integer> begin;
Array<Integer> end;
Array<Integer> stride;

for (int64_t i : param.begin) {
begin.push_back(tvm::make_const(tvm::Int(32), i));
begin.push_back(static_cast<int>(i));
}

for (int64_t i : param.end) {
end.push_back(tvm::make_const(tvm::Int(32), i));
end.push_back(static_cast<int>(i));
}

for (int64_t i : param.stride) {
stride.push_back(tvm::make_const(tvm::Int(32), i));
stride.push_back(static_cast<int>(i));
}

return Array<Tensor>{ topi::strided_slice(inputs[0], begin, end, stride) };
return Array<Tensor>{
topi::strided_slice(inputs[0], begin, end, stride)
};
})
.set_support_level(1);

Expand Down Expand Up @@ -1210,6 +1212,15 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
return true;
}

// Adapter function to make int array.
Array<Integer> GetIntArray(Array<Expr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}

NNVM_REGISTER_OP(slice_like)
.describe(R"code(Slice the first input respect to the second input.
)code" NNVM_ADD_FILELINE)
Expand Down Expand Up @@ -1261,7 +1272,10 @@ NNVM_REGISTER_OP(slice_like)
}
}
return Array<Tensor>{
topi::strided_slice(inputs[0], begin_idx, end_idx, strides)
topi::strided_slice(inputs[0],
GetIntArray(begin_idx),
GetIntArray(end_idx),
GetIntArray(strides))
};
})
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/_ffi/node_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def convert_to_node(value):
return _api_internal._Map(*vlist)
elif isinstance(value, NodeGeneric):
return value.asnode()
elif value is None:
return None
else:
raise ValueError("don't know how to convert type %s to node" % type(value))

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# operator registry
from . import _tensor
from . import _transform
from ..expr import Expr
from ..base import register_relay_node

Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from . import op as _reg
from .op import schedule_injective

# strided_slice
_reg.register_schedule("strided_slice", schedule_injective)
27 changes: 27 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,30 @@ def split(data, indices_or_sections, axis=0):
else:
ret_size = len(indices_or_sections) + 1
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)


def strided_slice(data, begin, end, strides=None):
"""Strided slice of an array..
Parameters
----------
data : relay.Expr
The source array to be sliced.
begin: list of int
The indices to begin with in the slicing.
end: list of int
Indicies indicating end of the slice.
strides: list of int, optional
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.
Returns
-------
ret : relay.Expr
The computed result.
"""
strides = strides or []
return _make.strided_slice(data, list(begin), list(end), list(strides))
6 changes: 5 additions & 1 deletion src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ TVM_REGISTER_API("_Array")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<NodePtr<Node> > data;
for (int i = 0; i < args.size(); ++i) {
data.push_back(args[i].node_sptr());
if (args[i].type_code() != kNull) {
data.push_back(args[i].node_sptr());
} else {
data.push_back(NodePtr<Node>(nullptr));
}
}
auto node = make_node<ArrayNode>();
node->data = std::move(data);
Expand Down
6 changes: 5 additions & 1 deletion src/relay/ir/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,11 @@ class TextPrinter :
* \param os The output type.
*/
void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*)
this->VisitAttr(value, os);
if (value.defined()) {
this->VisitAttr(value, os);
} else {
os << "None";
}
}
//------------------------------------
// Overload of Attr printing functions
Expand Down
168 changes: 168 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h>
#include <tvm/ir.h>
#include <topi/transform.h>
#include <vector>
#include "../op_common.h"

Expand Down Expand Up @@ -890,6 +891,173 @@ RELAY_REGISTER_OP("broadcast_to_like")
.set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel);


// strided_slice
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
bool StridedSliceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);

auto dshape = data->shape;
auto num_axis = dshape.size();

std::vector<int64_t> stride_vec;
for (Integer i : param->strides) {
CHECK(i.defined());
stride_vec.push_back(i->value);
}
for (size_t i = stride_vec.size(); i < num_axis; ++i) {
stride_vec.push_back(1);
}
const int64_t max_range = std::numeric_limits<int64_t>::max();

std::vector<int64_t> begin_vec;
for (size_t i = 0; i < param->begin.size(); ++i) {
if (!param->begin[i].defined()) {
// value=None
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
} else {
begin_vec.push_back(param->begin[i]->value);
}
}
for (size_t i = begin_vec.size(); i < num_axis; ++i) {
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
}

std::vector<int64_t> end_vec;
for (size_t i = 0; i < param->end.size(); ++i) {
// allow end to be None
if (!param->end[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else {
end_vec.push_back(param->end[i]->value);
}
}
for (size_t i = end_vec.size(); i < num_axis; ++i) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
}

std::vector<IndexExpr> oshape(dshape.size());
for (size_t i = 0; i < num_axis; ++i) {
int64_t stride_v = stride_vec[i];
int64_t begin_v = begin_vec[i];
int64_t end_v = end_vec[i];

if ((stride_v == 1 &&
begin_v == 0 &&
end_v == max_range) ||
(stride_v == -1 &&
begin_v == max_range &&
end_v == 0)) {
// Quick path, do not slice this dimension.
oshape[i] = dshape[i];
continue;
}
// Normal path, require the shape to be concrete integer.
// Require concrete integer as symbolic inference of min/max
// can get complicated and not very helpful.
const int64_t* p_dim_size = as_const_int(dshape[i]);
CHECK(p_dim_size)
<< "strided_slice requires sliced dimension to be concrete int";
int64_t dim_size = p_dim_size[0];
begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
end_v = (end_v < 0) ? dim_size + end_v : end_v;

int64_t slice_range, step;
if (stride_v < 0) {
if (end_v < -1) end_v = -1;
CHECK_LT(end_v, begin_v)
<< "strided_slice get empty slice at axis " << i;
begin_v = std::min(dim_size - 1, begin_v);
slice_range = begin_v - end_v;
step = -stride_v;
} else {
if (begin_v < 0) begin_v = 0;
CHECK_GE(stride_v, 0);
CHECK_LT(begin_v, end_v)
<< "strided_slice get empty slice at axis " << i;
end_v = std::min(dim_size, end_v);
slice_range = end_v - begin_v;
step = stride_v;
}
oshape[i] = make_const(dshape[i].type(), (slice_range + step - 1) / step);
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}


// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr MakeStridedSlice(Expr data,
Array<Integer> begin,
Array<Integer> end,
Array<Integer> strides) {
auto attrs = make_node<StridedSliceAttrs>();
attrs->begin = std::move(begin);
attrs->end = std::move(end);
attrs->strides = std::move(strides);
static const Op& op = Op::Get("strided_slice");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

Array<Tensor> StridedSliceCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{
topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
};
}


TVM_REGISTER_API("relay.op._make.strided_slice")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeStridedSlice, args, rv);
});


RELAY_REGISTER_OP("strided_slice")
.describe(R"code(Strided slice of an array.
Examples::
x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]
strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.],
[ 5., 8., 11.]]
x = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
.set_attrs_type_key("relay.attrs.StridedSliceAttrs")
.add_type_rel("StridedSlice", StridedSliceRel)
.set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);


// Split
TVM_REGISTER_NODE_TYPE(SplitAttrs);

Expand Down
Loading

0 comments on commit 1f2c815

Please sign in to comment.