Skip to content

Commit

Permalink
fix isl for1 problems (PaddlePaddle#395)
Browse files Browse the repository at this point in the history
* fix isl for1 problems
  • Loading branch information
wenming2014 authored Jun 3, 2021
1 parent eeaba40 commit 967911c
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ void main(void* _args, int32_t num_args)
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
for (int32_t i = 0; i < 1; i += 1) {
{
cinn_pod_value_t _pod_val_;
buffer_p_to_cinn_pod_value(_A, &_pod_val_);
cinn_pod_value_t _pod_val__0;
Expand Down
1 change: 1 addition & 0 deletions cinn/lang/lower_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ struct MarkVectorizeMutator : public ir::IRMutator<Expr*> {
CHECK(tensor_n);
auto it = vectorizes.find(tensor_n->name);
if (it != vectorizes.end()) {
CHECK_LT(it->second.level, forloop_stack.size());
forloop_stack[it->second.level]->set_vectorize_info(it->second);
CHECK(it->second.valid());
}
Expand Down
4 changes: 1 addition & 3 deletions cinn/poly/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,9 @@ isl::ast_node AstGen::Build() {

ast_build = ast_build.set_at_each_domain(collect);

isl::union_set new_domain = TransIdentityExtentToContextId(impl_->domain());

isl::union_map transformed_schedule = impl_->transform().apply_range(schedule);
VLOG(4) << "transformed_schedule: " << transformed_schedule;
auto schedule_domain = transformed_schedule.intersect_domain(new_domain);
auto schedule_domain = transformed_schedule.intersect_domain(impl_->domain());
VLOG(4) << "domain: " << impl_->domain();
VLOG(4) << "transform schedule " << impl_->stages()[0]->transform();
VLOG(4) << "schedule: " << schedule;
Expand Down
29 changes: 29 additions & 0 deletions cinn/poly/isl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,35 @@ isl_set *isl_get_precending_aixs(isl_set *set, int level, bool with_tuple_name)
return isl_set_apply(set, transform.release());
}

int isl_get_precending_removed_axes_counts(isl_set __isl_keep *a, int level) {
int removed_axes_counts = 0;
std::vector<std::tuple<int, int>> iden_dim_offsets;
for (int i = 0; i < level; i++) {
if (isl_set_axis_has_noparam_constant_bound(a, i)) {
auto [minv, maxv] = isl_set_get_axis_range(a, i);
int min_iv = minv.get_num_si();
int max_iv = maxv.get_num_si();
if (max_iv == min_iv) {
removed_axes_counts++;
}
}
}
return removed_axes_counts;
}

bool isl_is_removed_axis(isl_set __isl_keep *a, int level) {
std::vector<std::tuple<int, int>> iden_dim_offsets;
if (isl_set_axis_has_noparam_constant_bound(a, level)) {
auto [minv, maxv] = isl_set_get_axis_range(a, level);
int min_iv = minv.get_num_si();
int max_iv = maxv.get_num_si();
if (max_iv == min_iv) {
return true;
}
}
return false;
}

int isl_max_level_compatible(isl_set *a, isl_set *b) {
int an = isl_set_dim(a, isl_dim_set);
int bn = isl_set_dim(b, isl_dim_set);
Expand Down
9 changes: 9 additions & 0 deletions cinn/poly/isl_utils.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <llvm/ADT/ArrayRef.h>

#include <string>
#include <tuple>
#include <vector>

namespace cinn {
Expand Down Expand Up @@ -34,6 +35,14 @@ std::string isl_map_get_statement_repr(__isl_keep isl_map* map, isl_dim_type typ

isl_set* __isl_give isl_get_precending_aixs(isl_set* set, int level, bool with_tuple_name);

//! If the min and max bounds of the axis are same, isl will remove this axis after ast_build. Counts the removed axes
//! before the given axis.
int isl_get_precending_removed_axes_counts(isl_set __isl_keep* a, int level);

//! If the min and max bounds of the axis are same, isl will remove this axis after ast_build. Judge whether or not the
//! axis will be removed by isl.
bool isl_is_removed_axis(isl_set __isl_keep* a, int level);

//! Get the maximum level of axis that is has the same domain.
int isl_max_level_compatible(isl_set* __isl_keep a, isl_set* __isl_keep b);

Expand Down
56 changes: 50 additions & 6 deletions cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,10 +668,21 @@ bool ComputeAtRelation::IsCompatible(Stage *self) {

void Stage::Vectorize(int level, int factor) {
AssertAxisIsNotLocked(level);
CHECK_GE(level, 0);
CHECK_LT(level, n_out_dims());
CHECK_GT(factor, 0);
auto dim_name = ith_dim_name(level);
vectorize_info_.set(level /*inner*/, factor);
if (factor == 1) {
LOG(INFO) << "Vectorize-factor 1 has no sense, skip it";
return;
}
auto transformed_domain = this->transformed_domain();
if (isl_is_removed_axis(transformed_domain.get(), level)) {
LOG(INFO) << "Vectorizing for-1 has no sense, skip it";
return;
}
int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level);
VLOG(3) << "removed_axes_counts are " << removed_axes_counts << " before axis " << ith_dim_name(level);
vectorize_info_.set(level - removed_axes_counts /*inner*/, factor);
}

void Stage::Vectorize(const std::string &axis, int factor) {
Expand All @@ -684,13 +695,30 @@ void Stage::Vectorize(const std::string &axis, int factor) {
void Stage::Vectorize(const Iterator &axis, int factor) { return Vectorize(axis.id, factor); }

void Stage::Parallel(int level) {
CHECK_GE(level, 0);
AssertAxisIsNotLocked(level);
parallel_info_.insert(level);
auto transformed_domain = this->transformed_domain();
LOG(INFO) << "transformed_domain" << transformed_domain;
if (isl_is_removed_axis(transformed_domain.get(), level)) {
LOG(INFO) << "Paralleling for-1 has no sense, skip it";
return;
}
int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level);
VLOG(3) << "removed_axes_counts are " << removed_axes_counts << " before axis " << ith_dim_name(level);
parallel_info_.insert(level - removed_axes_counts);
}

void Stage::Unroll(int level) {
CHECK_GE(level, 0);
AssertAxisIsNotLocked(level);
unroll_info_.insert(level);
auto transformed_domain = this->transformed_domain();
if (isl_is_removed_axis(transformed_domain.get(), level)) {
LOG(INFO) << "Unrolling for-1 has no sense, skip it";
return;
}
int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level);
VLOG(3) << "removed_axes_counts are " << removed_axes_counts << " before axis " << ith_dim_name(level);
unroll_info_.insert(level - removed_axes_counts);
}

std::string Stage::ith_dim_name(int level) {
Expand Down Expand Up @@ -1034,8 +1062,16 @@ isl_map *__isl_give GatherAccesses(Stage *stage, const std::string &tensor_name)

void Stage::AddForloopInfo(int level, const StageForloopInfo &info) {
int num_levels = isl_map_dim(transform_.get(), isl_dim_out);
CHECK_GE(level, 0);
CHECK_LT(level, num_levels);
forloop_infos_[level] = info;
auto transformed_domain = this->transformed_domain();
if (isl_is_removed_axis(transformed_domain.get(), level)) {
LOG(INFO) << "for-1 has no sense, skip it";
return;
}
int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level);
VLOG(3) << "removed_axes_counts are " << removed_axes_counts << " before axis " << ith_dim_name(level);
forloop_infos_[level - removed_axes_counts] = info;
}

void Stage::CopyTransform(Stage *other, int level) {
Expand Down Expand Up @@ -1150,10 +1186,18 @@ void Stage::CopyTransform(Stage *other, int level) {
void Stage::CopyLoopInfo(std::map<int, StageForloopInfo> target_forloop_infos, const isl::map &target_transform) {
std::map<std::string, StageForloopInfo> dim_forloop_infos;
std::vector<std::string> this_dim_names = isl_get_dim_names(transform_, isl_dim_out);
int removed_axes_counts = 0;
for (int i = 0; i < this_dim_names.size(); i++) {
auto transformed_domain = this->transformed_domain();
if (isl_is_removed_axis(transformed_domain.get(), i)) {
LOG(INFO) << "for-1 has no sense, skip it";
removed_axes_counts++;
continue;
}
int index = isl_map_find_dim_by_name(target_transform.get(), isl_dim_out, this_dim_names[i].c_str());
if (target_forloop_infos.count(index) != 0) {
forloop_infos_[i] = target_forloop_infos[index];
// Isl ast build will remove for-1 axes, so we decrease the level correspondingly.
forloop_infos_[i - removed_axes_counts] = target_forloop_infos[index];
}
}
}
Expand Down

0 comments on commit 967911c

Please sign in to comment.