Skip to content

Commit

Permalink
[Relay, TOPI] Refactor strided_slice and add axes argument (#8165)
Browse files Browse the repository at this point in the history
* Initial import

commit 667011f
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu May 27 16:28:57 2021 +0900

    Squashed commit of the following:

    commit 95242d8
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Thu May 27 15:45:19 2021 +0900

        Add function attribute for shape func for profiling

commit b8ede24
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu May 27 10:21:06 2021 +0900

    layout transform support complete

commit 5782b70
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu May 27 08:31:11 2021 +0900

    support layout transform part1

commit e94aa6b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 19:47:46 2021 +0900

    moved utilities to its own file

commit 8bf8891
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 17:39:50 2021 +0900

    fix format

commit e89d599
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 17:33:02 2021 +0900

    ToVec -> ConvertToVec

commit 001982c
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 17:26:56 2021 +0900

    format

commit fae57f9
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 17:24:35 2021 +0900

    use Any for relay type rel path

commit 053eee2
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 17:14:44 2021 +0900

    fix

commit fbb099c
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 16:39:37 2021 +0900

    refactor type rel

commit ecfe3cd
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 16:23:47 2021 +0900

    working

commit b357c2f
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 16:07:07 2021 +0900

    refactoring output shape calc

commit f69ef40
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 14:23:36 2021 +0900

    bug fix end param init

commit a5611c9
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 13:42:31 2021 +0900

    fix test shape

commit e79931a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 13:42:03 2021 +0900

    dyn slice tests left as todo now work

commit 7db4cea
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 13:36:30 2021 +0900

    remove dynamic input specific op

commit 510bce6
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 12:52:30 2021 +0900

    refactoring dynamic slice

commit 1b3969a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 09:06:46 2021 +0900

    fix slice axes dispatch

commit 9a79560
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 08:32:54 2021 +0900

    refactor compute

commit 80442f8
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 08:11:18 2021 +0900

    fixed output shape, refactored version working

commit d2538ae
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 07:56:05 2021 +0900

    restore another slice variant

commit 36aa777
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon May 24 06:41:50 2021 +0900

    refactoring slice with axes

commit 32698b7
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 22 13:11:01 2021 +0900

    fix axes null check

commit 54fb723
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 22 12:52:18 2021 +0900

    Revert "[Bugfix][Vulkan] Call VulkanDeviceAPI destructor on program exit (#7997)"

    This reverts commit 58c3413.

commit 37eaf57
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 22 04:30:37 2021 +0900

    remove wip layout transform support for slice with axes

commit 9bcb2ad
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 18:01:59 2021 +0900

    fix pylint

commit 7063a09
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:57:03 2021 +0900

    minor fix

commit 96c9231
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:54:16 2021 +0900

    support dynamic scatter nd

commit d4a4db8
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:33:19 2021 +0900

    gather_dim -> num_indices_per_tuple

commit a489375
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:23:46 2021 +0900

    add dynamic gather_nd test

commit 533854a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:18:26 2021 +0900

    refactor gather_nd ref funcs

commit 36a4501
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 14:36:34 2021 +0900

    add gather_nd shape func

commit 1853c35
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 22 04:20:39 2021 +0900

    add eyelike support

commit 150e945
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 04:08:37 2021 +0900

    migrating inlined topi compute to topi/transform.h

commit 763ac37
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 03:45:37 2021 +0900

    strided slice with axes support

* fix bad merge

* fix cpplint

* fix pylint

* more cpplint fix

* fix compiler warning

* add doc

* add tests

* typo fixed

* support axes argument in topi cpp strided slice

* Properly test axes argument in relay tests

* fix bad merge (revert vm change)

* fix tests
  • Loading branch information
masahi authored Jun 2, 2021
1 parent b7c98b8 commit cbe3dca
Show file tree
Hide file tree
Showing 17 changed files with 790 additions and 357 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Optional<Array<Integer>> end;
Optional<Array<Integer>> strides;
std::string slice_mode;
Optional<Array<Integer>> axes;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
Expand All @@ -324,6 +325,9 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
"size - The input strides will be ignored, input end in this mode indicates the size"
"of a slice starting at the location specified by begin. If end[i] is -1,"
"all remaining elements in that dimension are included in the slice");
TVM_ATTR_FIELD(axes).describe(
"Axes along which slicing is applied. When it is specified, the length of begin, end, "
"strides, and axes must be equal.");
}
};

Expand Down
156 changes: 156 additions & 0 deletions include/tvm/topi/detail/strided_slice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file strided_slice.h
* \brief Utility functions for strided_slice op
*/
#ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_
#define TVM_TOPI_DETAIL_STRIDED_SLICE_H_

#include <tvm/tir/expr.h>

#include <algorithm>
#include <limits>
#include <string>
#include <tuple>
#include <vector>

#include "constant_utils.h"

namespace tvm {
namespace topi {
namespace detail {

using namespace tvm::te;

inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) {
int64_t begin_range = stride < 0 ? -1 : 0;
int64_t end_range = stride < 0 ? extent - 1 : extent;
if (index < 0) {
index += extent;
}
return std::min(std::max(index, begin_range), end_range);
}

inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> ConvertToVec(
const Array<Integer>& begin, const Array<Integer>& end, const Array<Integer>& strides,
std::string slice_mode) {
std::vector<int64_t> stride_vec(strides.size(), 1);
if (slice_mode == "end") {
for (size_t i = 0; i < strides.size(); ++i) {
ICHECK(strides[i].defined());
stride_vec[i] = GetConstInt(strides[i]);
}
}
const int64_t max_range = std::numeric_limits<int64_t>::max();
std::vector<int64_t> begin_vec;
for (size_t i = 0; i < begin.size(); ++i) {
if (!begin[i].defined()) {
// value=None
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
} else {
begin_vec.push_back(GetConstInt(begin[i]));
}
}
std::vector<int64_t> end_vec;
for (size_t i = 0; i < end.size(); ++i) {
// allow end to be None
if (!end[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else if (slice_mode == "size") {
int64_t end_val = GetConstInt(end[i]);
if (end_val < 0) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else {
end_vec.push_back(begin_vec[i] + end_val);
}
} else {
end_vec.push_back(GetConstInt(end[i]));
}
}
return std::make_tuple(begin_vec, end_vec, stride_vec);
}

inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Array<PrimExpr>& ishape,
const std::vector<int64_t>& begin,
const std::vector<int64_t>& strides,
const Array<Integer>& axes, DataType dtype,
std::string slice_mode = "end") {
Array<PrimExpr> begin_expr;
for (size_t i = 0; i < axes.size(); ++i) {
if (ishape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(ishape[axes[i]]);
int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
begin_expr.push_back(make_const(dtype, begin_i));
} else {
auto idim = ishape[axes[i]];
auto b_expr = make_const(dtype, begin[i]);
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
auto s = strides[i];
if (s < 0) {
b = tvm::min(b, idim - 1);
} else {
b = tvm::if_then_else(b < 0, 0, b);
}
begin_expr.push_back(b);
}
}
return begin_expr;
}

inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape,
const std::vector<int64_t>& begin,
const std::vector<int64_t>& end,
const std::vector<int64_t>& strides,
const Array<Integer>& axes, std::string slice_mode,
const Array<PrimExpr>& begin_canonicalized,
bool use_any = false) {
const size_t src_tensor_dim = ishape.size();
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(ishape[i]);
}

for (size_t i = 0; i < axes.size(); ++i) {
if (ishape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(ishape[axes[i]]);
ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
int64_t begin_i = GetConstInt(begin_canonicalized[i]);
int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]);
int interval = std::abs(end_i - begin_i);
int slice_size =
static_cast<int>((interval + std::abs(strides[i]) - 1) / std::abs(strides[i]));
ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i;
out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size)));
} else if (use_any) {
out_shape.Set(axes[i], tvm::tir::Any());
} else {
out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype));
}
}

return out_shape;
}

} // namespace detail
} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_DETAIL_STRIDED_SLICE_H_
2 changes: 1 addition & 1 deletion include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
out = reshape(out, r_p_shape);

// Crop the start and end of dimensions of out
Array<PrimExpr> begin_idx, end_idx, strides;
Array<Integer> begin_idx, end_idx, strides;
for (size_t i = 0; i < r_p_shape.size(); ++i) {
strides.push_back(Integer(1));
if (i > 0 && i <= num_block_dims) {
Expand Down
Loading

0 comments on commit cbe3dca

Please sign in to comment.