diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 4c96ed42f6e9..dff6374a6185 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -710,16 +710,17 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b const te::Tensor& end, const te::Tensor& strides, std::string name = "T_strided_slice_dynamic", std::string tag = topi::kInjective) { + DataType index_dtype = begin->shape[0]->dtype; const int64_t num_dynamic_axes = begin->shape[0].as()->value; ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); Array begin_expr, end_expr, strides_expr; for (int64_t i = 0; i < num_dynamic_axes; ++i) { - auto i64_ind = IntImm(DataType::Int(64), i); - begin_expr.push_back(begin(i64_ind)); - end_expr.push_back(end(i64_ind)); - strides_expr.push_back(strides(i64_ind)); + auto ind = make_const(index_dtype, i); + begin_expr.push_back(begin(ind)); + end_expr.push_back(end(ind)); + strides_expr.push_back(strides(ind)); } return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag); } @@ -822,9 +823,10 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array end_full(end); Array strides_full(strides); - const IntImm one = IntImm(DataType::Int(64), 1); - const IntImm zero = IntImm(DataType::Int(64), 0); - const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); + DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64); + const IntImm one = IntImm(index_dtype, 1); + const IntImm zero = IntImm(index_dtype, 0); + const IntImm max_range = Downcast(max_value(index_dtype)); for (size_t i = strides.size(); i < src_tensor_dim; ++i) { strides_full.push_back(one);