Skip to content

Commit

Permalink
change slice_mode to string
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Jun 6, 2020
1 parent 4b7a924 commit 274ec31
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 81 deletions.
15 changes: 9 additions & 6 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,18 +213,21 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Optional<Array<Integer>> begin;
Optional<Array<Integer>> end;
Optional<Array<Integer>> strides;
bool slice_mode;
std::string slice_mode;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive");
TVM_ATTR_FIELD(strides).describe("Stride values of the slice");
TVM_ATTR_FIELD(strides)
.describe(
"Stride values of the slice, a stride can be negative, which causes a reverse slice.");
TVM_ATTR_FIELD(slice_mode)
.set_default(false)
.set_default("end")
.describe(
"Specifies whether to enable slice mode. In slice mode,"
"strides will be ignored, end indicates the size of a slice"
"starting at the location specified by begin. If end[i] is -1,"
"The slice mode [end, size]."
"end - The default slice mode, ending indices for the slice."
"size - The input strides will be ignored, input end in this mode indicates the size"
"of a slice starting at the location specified by begin. If end[i] is -1,"
"all remaining elements in that dimension are included in the slice");
}
};
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _impl(inputs, input_types):
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(strides),
slice_mode=True)
slice_mode="size")
return _impl

def _split():
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def _impl(inputs, attr, params, mod):

# slice to get the dynamic result
ret = get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
end=size, slice_mode=True)
end=size, slice_mode="size")
return ret
return _impl

Expand Down Expand Up @@ -1188,7 +1188,7 @@ def _impl(inputs, attr, params, mod):
size = _infer_value(inputs[2], params).asnumpy().tolist()
except Exception:
size = inputs[2]
return _op.strided_slice(inputs[0], begin=begin, end=size, slice_mode=True)
return _op.strided_slice(inputs[0], begin=begin, end=size, slice_mode="size")
return _impl


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
"""
slice_mode = convert(get_const_int(attrs.slice_mode))
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
# data independent if begin, end and strides exist
if attrs.begin and attrs.end and attrs.strides:
return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end,
Expand Down
13 changes: 7 additions & 6 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def split(data, indices_or_sections, axis=0):
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)


def strided_slice(data, begin, end, strides=None, slice_mode=False):
def strided_slice(data, begin, end, strides=None, slice_mode="end"):
"""Strided slice of an array.
Parameters
Expand All @@ -629,11 +629,12 @@ def strided_slice(data, begin, end, strides=None, slice_mode=False):
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.
slice_mode: boolean, optional
Specifies whether to enable slice mode. In slice mode,
strides will be ignored, end indicates the size of a slice
starting at the location specified by begin. If end[i] is -1,
all remaining elements in that dimension are included in the slice
slice_mode: str, optional
The slice mode [end, size].
end: The ending indices for the slice [default].
size: The input strides will be ignored, input end in this mode indicates
the size of a slice starting at the location specified by begin. If end[i]
is -1, all remaining elements in that dimension are included in the slice.
Returns
-------
Expand Down
18 changes: 10 additions & 8 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,7 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
if (param->begin && param->end && param->strides) {
// stride will be set as 1 if slice mode is enabled
std::vector<int64_t> stride_vec(num_axis, 1);
if (!param->slice_mode) {
if (param->slice_mode == "end") {
for (size_t i = 0; i < param->strides.value().size(); ++i) {
CHECK(param->strides.value()[i].defined());
stride_vec[i] = param->strides.value()[i]->value;
Expand All @@ -1713,14 +1713,16 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
// allow end to be None
if (!param->end.value()[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else if (param->slice_mode) {
} else if (param->slice_mode == "size") {
if (param->end.value()[i]->value < 0) {
end_vec.push_back(max_range);
} else {
end_vec.push_back(begin_vec[i] + param->end.value()[i]->value);
}
} else {
} else if (param->slice_mode == "end") {
end_vec.push_back(param->end.value()[i]->value);
} else {
LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode;
}
}
for (int64_t i = end_vec.size(); i < num_axis; ++i) {
Expand Down Expand Up @@ -1805,7 +1807,7 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
if (params->begin && params->end && params->strides) {
for (Integer i : params->strides.value()) {
CHECK(i.defined());
strides.push_back(params->slice_mode ? 1 : i->value);
strides.push_back(params->slice_mode == "size" ? 1 : i->value);
}

for (Integer i : params->begin.value()) {
Expand Down Expand Up @@ -1842,7 +1844,7 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
int64_t ed;
if (!end[i].defined()) {
ed = shape[i].as<IntImmNode>()->value;
} else if (params->slice_mode) {
} else if (params->slice_mode == "size") {
if (end[i]->value < 0) {
ed = shape[i].as<IntImmNode>()->value;
} else {
Expand Down Expand Up @@ -1918,7 +1920,7 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
}

// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, bool slice_mode) {
Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode) {
auto attrs = make_object<StridedSliceAttrs>();
const ConstantNode *cbegin, *cend, *cstrides;
if ((cbegin = begin.as<ConstantNode>()) && (cend = end.as<ConstantNode>()) &&
Expand Down Expand Up @@ -1970,7 +1972,7 @@ Examples::
.add_argument("begin", "Tensor", "The indices to begin with in the slicing.")
.add_argument("end", "Tensor", "Indices indicating end of the slice.")
.add_argument("strides", "Tensor", "The stride values.")
.add_argument("slice_mode", "Tensor", "Whether to enable slice mode.")
.add_argument("slice_mode", "Tensor", "The slice mode.")
.set_support_level(4)
.set_attrs_type<StridedSliceAttrs>()
.add_type_rel("StridedSlice", StridedSliceRel)
Expand Down Expand Up @@ -2230,7 +2232,7 @@ Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>&
}
}
return Array<te::Tensor>{topi::strided_slice(inputs[0], GetIntArray(begin_idx),
GetIntArray(end_idx), GetIntArray(strides), false)};
GetIntArray(end_idx), GetIntArray(strides), "end")};
}

TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike);
Expand Down
30 changes: 9 additions & 21 deletions src/relay/transforms/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,36 +168,24 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
for (const auto& branch : branches) {
const CallNode* conv2d = branch[0];
int64_t channels = GetConv2DSuperChannelsDim(conv2d);
Array<Integer> begin;
Array<Integer> end;
std::vector<int64_t> begin;
std::vector<int64_t> end;
for (size_t i = 0; i < channel_pos_; i++) {
begin.push_back(0);
end.push_back(-1);
}
begin.push_back(index);
index += channels;
end.push_back(index);
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
auto begin_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, DataType::Int(64), ctx);
auto end_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, DataType::Int(64), ctx);
auto strides_ndarray =
runtime::NDArray::Empty({int64_t(begin.size())}, DataType::Int(64), ctx);

auto* begin_data = static_cast<int64_t*>(begin_ndarray->data);
auto* end_data = static_cast<int64_t*>(end_ndarray->data);
auto* strides_data = static_cast<int64_t*>(strides_ndarray->data);

std::vector<int64_t> strides(begin.size(), 1);
for (size_t i = 0; i < begin.size(); ++i) {
begin_data[i] = begin[i];
end_data[i] = end[i];
end_data[i] -= begin_data[i];
strides_data[i] = 1;
end[i] -= begin[i];
}

auto slice = MakeStridedSlice(data, Constant(begin_ndarray), Constant(end_ndarray),
Constant(strides_ndarray), true);
std::vector<int64_t> ndarray_shape = {static_cast<int64_t>(begin.size())};
Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin);
Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end);
Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides);
auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size");
subst_map->insert({GetRef<Expr>(branch[depth]), slice});
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ Expr MakeConcatenate(Expr data, int axis);

Expr MakeRepeat(Expr data, int repeats, int axis);

Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, bool slice_mode);
Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode);

Expr MakeStack(Expr data, int axis);

Expand Down
6 changes: 3 additions & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,12 +644,12 @@ def test_arange_with_dynamic_shape():
tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1)

def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape,
data_np_shape, slice_mode=False, const_attrs=False):
data_np_shape, slice_mode="end", const_attrs=False):
# Generate random numpy input data
np_data = np.random.uniform(size=data_np_shape).astype('float32')
np_begin = np.random.randint(2, size=begin_shape, dtype="int32")
np_end = np.random.randint(5, 10, size=end_shape, dtype="int32")
np_strides = np.random.randint(1, 2 if slice_mode else 3, size=strides_shape, dtype="int32")
np_strides = np.random.randint(1, 2 if slice_mode == "size" else 3, size=strides_shape, dtype="int32")
# target numpy result
ref_res = topi.testing.strided_slice_python(np_data, np_begin, np_end, np_strides, slice_mode)

Expand Down Expand Up @@ -685,7 +685,7 @@ def test_any_strided_slice():
verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21))
verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (23, 29, 41))
verify_any_strided_slice(any_dims(4), (4,), (4,), (4,), (40, 50, 60, 70))
verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21), slice_mode=True)
verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21), slice_mode="size")
verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21), const_attrs=True)


Expand Down
6 changes: 3 additions & 3 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_mean_var_std():


def test_strided_slice():
def verify(dshape, begin, end, strides, output, slice_mode=False,
def verify(dshape, begin, end, strides, output, slice_mode="end",
attr_const=True, test_ref=True, dtype="int32"):
x = relay.var("x", relay.TensorType(dshape, "float32"))
ndim = len(dshape)
Expand Down Expand Up @@ -355,9 +355,9 @@ def verify(dshape, begin, end, strides, output, slice_mode=False,
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1],
(2, 4, 3), slice_mode=True, test_ref=False)
(2, 4, 3), slice_mode="size", test_ref=False)
verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1],
(2, 2, 3), slice_mode=True, test_ref=True)
(2, 2, 3), slice_mode="size", test_ref=True)

def test_strided_set():
def verify(dshape, begin, end, strides, vshape, test_ref=True):
Expand Down
18 changes: 9 additions & 9 deletions tests/python/relay/test_pass_combine_parallel_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,18 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
begin=relay.const([0, 0], "int64"),
end=relay.const([-1, channels1], "int64"),
strides=relay.const([1, 1], 'int64'),
slice_mode=True)
slice_mode="size")
y2 = relay.strided_slice(y,
begin=relay.const([0, channels1], "int64"),
end=relay.const([-1, channels2], "int64"),
strides=relay.const([1, 1], 'int64'),
slice_mode=True)
slice_mode="size")
y3 = relay.nn.conv2d(x, w3)
y4 = relay.strided_slice(y,
begin=relay.const([0, channels1 + channels2], "int64"),
end=relay.const([-1, channels4], "int64"),
strides=relay.const([1, 1], 'int64'),
slice_mode=True)
slice_mode="size")
y5 = relay.nn.max_pool2d(x)
y = relay.Tuple((y1, y2, y3, y4, y5))
return relay.Function(args, y)
Expand Down Expand Up @@ -113,12 +113,12 @@ def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2):
begin=relay.const([0, 0], "int64"),
end=relay.const([-1, channels1], "int64"),
strides=relay.const([1, 1], "int64"),
slice_mode=True)
slice_mode="size")
y2 = relay.strided_slice(y,
begin=relay.const([0, channels1], "int64"),
end=relay.const([-1, channels2], "int64"),
strides=relay.const([1, 1], "int64"),
slice_mode=True)
slice_mode="size")
y2 = relay.add(y2, bias)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
Expand Down Expand Up @@ -160,12 +160,12 @@ def expected(x, w1, w2, scale1, scale2, channels1, channels2):
begin=relay.const([0, 0], "int64"),
end=relay.const([-1, channels1], "int64"),
strides=relay.const([1, 1], "int64"),
slice_mode=True)
slice_mode="size")
y2 = relay.strided_slice(y,
begin=relay.const([0, channels1], "int64"),
end=relay.const([-1, channels2], "int64"),
strides=relay.const([1, 1], "int64"),
slice_mode=True)
slice_mode="size")
y1 = relay.multiply(y1, scale1)
y2 = relay.multiply(y2, scale2)
y = relay.Tuple((y1, y2))
Expand Down Expand Up @@ -208,12 +208,12 @@ def expected(x, w, channels, repeat):
begin=relay.const([0, 0], "int64"),
end=relay.const([-1, channels], "int64"),
strides=relay.const([1, 1], "int64"),
slice_mode=True)
slice_mode="size")
y2 = relay.strided_slice(y,
begin=relay.const([0, channels], "int64"),
end=relay.const([-1, channels], "int64"),
strides=relay.const([1, 1], "int64"),
slice_mode=True)
slice_mode="size")
y = relay.concatenate((y1, y2), axis=1)
return relay.Function(args, y)

Expand Down
6 changes: 3 additions & 3 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,15 +520,15 @@ inline Array<Tensor> split(const Tensor& x, Array<Integer> split_indices, int ax
* \param begin The indices to begin with in the slicing
* \param end Indicies indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* \param slice_mode Specifies whether to enable slice mode
* in that case, the input tensor will be reversed in that particular axis
* \param slice_mode Specifies the slice mode
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the split operation
*/
inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
const Array<Integer>& strides, const bool& slice_mode,
const Array<Integer>& strides, std::string slice_mode = "end",
std::string name = "T_strided_slice", std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
// Setup the ranges.
Expand Down Expand Up @@ -561,7 +561,7 @@ inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const

if (!end[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else if (slice_mode) {
} else if (slice_mode == "size") {
if (end[i]->value < 0) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else {
Expand Down
Loading

0 comments on commit 274ec31

Please sign in to comment.