Skip to content

Commit

Permalink
fix bug of init_reduce tensor (PaddlePaddle#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored Jun 9, 2021
1 parent 967911c commit 0ad4c83
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 26 deletions.
29 changes: 15 additions & 14 deletions cinn/ir/tensor.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -241,26 +241,27 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target)
int reduce_axis_num = this->reduce_axis.size();
auto dim_out_names = poly::isl_get_dim_names(stages[this]->transform(), isl_dim_out);
auto dim_in_size = isl_map_dim(stages[this]->transform().get(), isl_dim_in);
temp_transform = isl::manage(
isl_map_remove_dims(temp_transform.release(), isl_dim_in, dim_in_size - reduce_axis_num, reduce_axis_num));
std::string deleted_transform = isl_map_to_str(temp_transform.get());
int compute_at_axis = -1;
int deleted_dim = 0;
//! Get the ComputeAt level. It increases until reduce_axis.
for (int i = 0; i < dim_out_names.size(); i++) {
if (utils::Count(&deleted_transform, dim_out_names[i]) == utils::Count(&this_transform, dim_out_names[i])) {
auto dim_in_names = poly::isl_get_dim_names(stages[this]->transform(), isl_dim_in);
std::vector<std::string> reduce_axis_input;
for (int i = dim_in_size - reduce_axis_num; i < dim_in_size; i++) {
reduce_axis_input.push_back(dim_in_names[i]);
}
auto reduce_axis_output = poly::GetRelatedOutputAxies(temp_transform, reduce_axis_input);
std::set<std::string> reduce_axis_output_set;
for (auto &i : reduce_axis_output) {
reduce_axis_output_set.insert(i);
}
int compute_at_axis = -1;
for (auto &i : dim_out_names) {
if (reduce_axis_output_set.count(i) == 0) {
compute_at_axis++;
} else {
break;
}
}

for (int i = 0; i < dim_out_names.size(); i++) {
if (utils::Count(&deleted_transform, dim_out_names[i]) != utils::Count(&this_transform, dim_out_names[i])) {
temp_transform = isl::manage(isl_map_remove_dims(temp_transform.release(), isl_dim_out, i - deleted_dim, 1));
deleted_dim++;
}
}
temp_transform = poly::RemoveAxiesByOutputNames(temp_transform, reduce_axis_output);

//! When the first axis is not reduce axis, do ComputeAt.
if (compute_at_axis >= 0) {
stages[init_tensor]->ComputeAt2(stages[this], compute_at_axis);
Expand Down
18 changes: 6 additions & 12 deletions cinn/poly/isl_utils.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,13 @@ isl::map RemoveAxiesByInputNames(const isl::map &x, const std::vector<std::strin
std::string map_str = isl_map_to_str(x.get());
isl::ctx this_ctx = x.ctx();
isl::map temp_transform(this_ctx, map_str);
auto related_output_names = GetRelatedOutputAxies(x, dim_in_names);
if (dim_in_names.empty()) return temp_transform;
auto dim_out_names = isl_get_dim_names(temp_transform, isl_dim_out);
for (auto &i : dim_in_names) {
temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_in, i.c_str()));
}
std::string deleted_map = isl_map_to_str(temp_transform.get());
for (auto &i : dim_out_names) {
if (utils::Count(&map_str, i) != utils::Count(&deleted_map, i)) {
temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str()));
}
for (auto &i : related_output_names) {
temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str()));
}
return temp_transform;
}
Expand All @@ -372,16 +369,13 @@ isl::map RemoveAxiesByOutputNames(const isl::map &x, const std::vector<std::stri
std::string map_str = isl_map_to_str(x.get());
isl::ctx this_ctx = x.ctx();
isl::map temp_transform(this_ctx, map_str);
auto related_input_names = GetRelatedInputAxies(x, dim_out_names);
if (dim_out_names.empty()) return temp_transform;
auto dim_in_names = isl_get_dim_names(temp_transform, isl_dim_in);
for (auto &i : dim_out_names) {
temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str()));
}
std::string deleted_map = isl_map_to_str(temp_transform.get());
for (auto &i : dim_in_names) {
if (utils::Count(&map_str, i) != utils::Count(&deleted_map, i)) {
temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_in, i.c_str()));
}
for (auto &i : related_input_names) {
temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_in, i.c_str()));
}
return temp_transform;
}
Expand Down
12 changes: 12 additions & 0 deletions cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,18 @@ std::tuple<Iterator, Iterator> Stage::Skew(const Iterator &i, const Iterator &j,
return std::make_tuple(i_new, j_new);
}

Iterator Stage::Fuse(const std::vector<int> &levels) {
CHECK_GE(levels.size(), 2);
if (levels.size() == 2) {
return Fuse(levels[0], levels[1]);
} else {
for (int i = 0; i < levels.size() - 1; i++) {
auto temp = Fuse(levels[0], levels[1]);
if (i == levels.size() - 2) return temp;
}
}
}

Iterator Stage::Fuse(int level0, int level1) {
AssertAxisIsNotLocked(level0);
AssertAxisIsNotLocked(level1);
Expand Down
1 change: 1 addition & 0 deletions cinn/poly/stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class Stage : public Object {
* @param level1 the second level.
* @return the new level.
*/
Iterator Fuse(const std::vector<int>& levels);
Iterator Fuse(const Iterator& level0, const Iterator& level1);
Iterator Fuse(int level0, int level1);
Iterator Fuse(const std::string& level0, const std::string& level1);
Expand Down

0 comments on commit 0ad4c83

Please sign in to comment.