Skip to content

Commit

Permalink
[TOPI] Fix index dtype in topi strided_slice (apache#14022)
Browse files Browse the repository at this point in the history
try use more proper index dtype than i64 to avoid integer arith issue in lowering.
  • Loading branch information
wrongtest-intellif authored and yongwww committed Feb 27, 2023
1 parent caf7aa1 commit 3a11abc
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -710,16 +710,17 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
const te::Tensor& end, const te::Tensor& strides,
std::string name = "T_strided_slice_dynamic",
std::string tag = topi::kInjective) {
DataType index_dtype = begin->shape[0]->dtype;
const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);

Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
auto i64_ind = IntImm(DataType::Int(64), i);
begin_expr.push_back(begin(i64_ind));
end_expr.push_back(end(i64_ind));
strides_expr.push_back(strides(i64_ind));
auto ind = make_const(index_dtype, i);
begin_expr.push_back(begin(ind));
end_expr.push_back(end(ind));
strides_expr.push_back(strides(ind));
}
return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
}
Expand Down Expand Up @@ -822,9 +823,10 @@ inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const
Array<Integer> end_full(end);
Array<Integer> strides_full(strides);

const IntImm one = IntImm(DataType::Int(64), 1);
const IntImm zero = IntImm(DataType::Int(64), 0);
const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max());
DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64);
const IntImm one = IntImm(index_dtype, 1);
const IntImm zero = IntImm(index_dtype, 0);
const IntImm max_range = Downcast<IntImm>(max_value(index_dtype));

for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
strides_full.push_back(one);
Expand Down

0 comments on commit 3a11abc

Please sign in to comment.