From 0ad4c834184b5054b337e5105ecff91d41c5e455 Mon Sep 17 00:00:00 2001 From: haozech Date: Wed, 9 Jun 2021 15:12:55 +0800 Subject: [PATCH] fix bug of init_reduce tensor (#398) --- cinn/ir/tensor.cc | 29 +++++++++++++++-------------- cinn/poly/isl_utils.cc | 18 ++++++------------ cinn/poly/stage.cc | 12 ++++++++++++ cinn/poly/stage.h | 1 + 4 files changed, 34 insertions(+), 26 deletions(-) mode change 100755 => 100644 cinn/ir/tensor.cc mode change 100644 => 100755 cinn/poly/isl_utils.cc diff --git a/cinn/ir/tensor.cc b/cinn/ir/tensor.cc old mode 100755 new mode 100644 index 4fc2957cf8ca5..7c436c49bf9fe --- a/cinn/ir/tensor.cc +++ b/cinn/ir/tensor.cc @@ -241,26 +241,27 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) int reduce_axis_num = this->reduce_axis.size(); auto dim_out_names = poly::isl_get_dim_names(stages[this]->transform(), isl_dim_out); auto dim_in_size = isl_map_dim(stages[this]->transform().get(), isl_dim_in); - temp_transform = isl::manage( - isl_map_remove_dims(temp_transform.release(), isl_dim_in, dim_in_size - reduce_axis_num, reduce_axis_num)); - std::string deleted_transform = isl_map_to_str(temp_transform.get()); - int compute_at_axis = -1; - int deleted_dim = 0; - //! Get the ComputeAt level. It increases until reduce_axis. - for (int i = 0; i < dim_out_names.size(); i++) { - if (utils::Count(&deleted_transform, dim_out_names[i]) == utils::Count(&this_transform, dim_out_names[i])) { + auto dim_in_names = poly::isl_get_dim_names(stages[this]->transform(), isl_dim_in); + std::vector reduce_axis_input; + for (int i = dim_in_size - reduce_axis_num; i < dim_in_size; i++) { + reduce_axis_input.push_back(dim_in_names[i]); + } + auto reduce_axis_output = poly::GetRelatedOutputAxies(temp_transform, reduce_axis_input); + std::set reduce_axis_output_set; + for (auto &i : reduce_axis_output) { + reduce_axis_output_set.insert(i); + } + int compute_at_axis = -1; + for (auto &i : dim_out_names) { + if (reduce_axis_output_set.count(i) == 0) { compute_at_axis++; } else { break; } } - for (int i = 0; i < dim_out_names.size(); i++) { - if (utils::Count(&deleted_transform, dim_out_names[i]) != utils::Count(&this_transform, dim_out_names[i])) { - temp_transform = isl::manage(isl_map_remove_dims(temp_transform.release(), isl_dim_out, i - deleted_dim, 1)); - deleted_dim++; - } - } + temp_transform = poly::RemoveAxiesByOutputNames(temp_transform, reduce_axis_output); + //! When the first axis is not reduce axis, do ComputeAt. if (compute_at_axis >= 0) { stages[init_tensor]->ComputeAt2(stages[this], compute_at_axis); diff --git a/cinn/poly/isl_utils.cc b/cinn/poly/isl_utils.cc old mode 100644 new mode 100755 index c8753f63364ba..c460278638fcd --- a/cinn/poly/isl_utils.cc +++ b/cinn/poly/isl_utils.cc @@ -354,16 +354,13 @@ isl::map RemoveAxiesByInputNames(const isl::map &x, const std::vector Stage::Skew(const Iterator &i, const Iterator &j, return std::make_tuple(i_new, j_new); } +Iterator Stage::Fuse(const std::vector &levels) { + CHECK_GE(levels.size(), 2); + if (levels.size() == 2) { + return Fuse(levels[0], levels[1]); + } else { + for (int i = 0; i < levels.size() - 1; i++) { + auto temp = Fuse(levels[0], levels[1]); + if (i == levels.size() - 2) return temp; + } + } +} + Iterator Stage::Fuse(int level0, int level1) { AssertAxisIsNotLocked(level0); AssertAxisIsNotLocked(level1); diff --git a/cinn/poly/stage.h b/cinn/poly/stage.h index ba39386d45525..2a346d4ed2145 100755 --- a/cinn/poly/stage.h +++ b/cinn/poly/stage.h @@ -286,6 +286,7 @@ class Stage : public Object { * @param level1 the second level. * @return the new level. */ + Iterator Fuse(const std::vector& levels); Iterator Fuse(const Iterator& level0, const Iterator& level1); Iterator Fuse(int level0, int level1); Iterator Fuse(const std::string& level0, const std::string& level1);