Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#28 from Superjomn/fea/gen_isl_indice_map
Browse files Browse the repository at this point in the history
isl indice map and fix bug in isl code gen
  • Loading branch information
Superjomn authored Feb 12, 2020
2 parents 60b917a + f17420a commit b7feaa3
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 20 deletions.
126 changes: 113 additions & 13 deletions cinn/poly/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,64 @@
namespace cinn {
namespace poly {

isl::ast_node AstGen::operator()(const std::vector<Element> &elements, const Scheduler &scheduler) {
// Collect domains.
auto sets = utils::Map<std::vector<Element>, isl::set>(elements, [](const Element &e) { return e.domain(); });
isl::union_set domain = SetsToUnionSet(sets);
isl::union_set AstGen::domain() {
CHECK(!poly_elements_.empty());
auto sets = utils::Map<std::vector<Element>, isl::set>(poly_elements_, [](const Element &e) { return e.domain(); });
return SetsToUnionSet(sets);
}

isl::ctx ctx = elements.front().domain().ctx();
isl::ctx AstGen::ctx() const {
CHECK(!poly_elements_.empty());
return poly_elements_.front().domain().ctx();
}

isl::ast_node AstGen::Build() {
// Collect schedule from scheduler.
auto schedules = scheduler.BuildSchedule();
auto schedules = scheduler_.BuildSchedule();
std::vector<isl::map> maps;
for (auto &ele : elements) {
for (auto &ele : poly_elements_) {
auto it = schedules.find(ele.id());
CHECK(it != std::end(schedules));
maps.push_back(it->second);
}
auto schedule = MapsToUnionMap(maps);

// Build it.
auto build = isl::ast_build::from_context(context_);
auto ast_build = isl::ast_build::from_context(context_);
// Set iterators.
if (!iterator_names_.empty()) {
auto iterator_names = scheduler.WrapIteratorNames(iterator_names_);
isl::id_list ids = isl::manage(isl_id_list_alloc(ctx.get(), iterator_names.size()));
auto iterator_names = scheduler_.WrapIteratorNames(iterator_names_);
isl::id_list ids = isl::manage(isl_id_list_alloc(ctx().get(), iterator_names.size()));
for (int i = 0; i < iterator_names.size(); i++) {
ids = isl::manage(isl_id_list_add(ids.release(), isl_id_alloc(ctx.get(), iterator_names[i].c_str(), nullptr)));
ids = isl::manage(isl_id_list_add(ids.release(), isl_id_alloc(ctx().get(), iterator_names[i].c_str(), nullptr)));
}
build = isl::manage(isl_ast_build_set_iterators(build.release(), ids.release()));
ast_build = isl::manage(isl_ast_build_set_iterators(ast_build.release(), ids.release()));
}

auto ast = build.node_from_schedule_map(schedule.intersect_domain(domain));
// collect iterator map
auto get_domain_by_name = [this](const std::string &name) -> isl::set {
auto ele_it = std::find_if(
poly_elements_.begin(), poly_elements_.end(), [&name](const Element &ele) { return ele.id() == name; });
CHECK(ele_it != std::end(poly_elements_));
return ele_it->domain();
};

auto collect = [&](isl::ast_node node, isl::ast_build build) -> isl::ast_node {
auto tuple_name = detail::GetTupleName(node.get());
auto indice_map = ExtractIslTransformedIndiceMap(get_domain_by_name(tuple_name), build.get());
transformed_indice_map_[tuple_name] = indice_map;
return node;
};

ast_build = ast_build.set_at_each_domain(collect);

isl::union_map transformed_schedule = transform().apply_range(schedule);
auto schedule_domain = transformed_schedule.intersect_domain(domain());
VLOG(4) << "domain: " << domain();
VLOG(4) << "transform schedule " << poly_elements()[0].schedule();
VLOG(4) << "schedule: " << schedule;
VLOG(4) << "schedule_domain: " << schedule_domain;
auto ast = ast_build.node_from_schedule_map(schedule_domain);
VLOG(2) << "\n" << isl_ast_node_to_C_str(ast.get());
return ast;
}
Expand All @@ -42,5 +70,77 @@ AstGen &AstGen::SetIteratorNames(const std::vector<std::string> &names) {
return *this;
}

isl::ast_expr CreateIslAstIndexExpression(isl_ast_build *build, const isl::map &access);

std::map<std::string, isl::ast_expr> AstGen::ExtractIslTransformedIndiceMap(const isl::set &iterator_domain,
isl_ast_build *build) {
std::map<std::string, isl::ast_expr> iterator_map;
isl::map identity = isl::manage(isl_set_identity(iterator_domain.copy()));
isl::map schedule = identity;

identity = identity.apply_domain(schedule);
isl::ast_expr idx_expr = CreateIslAstIndexExpression(build, identity);
isl::space domain_space = iterator_domain.space();

for (int i = 1; i < isl_ast_expr_get_op_n_arg(idx_expr.get()); i++) {
if (isl_space_has_dim_name(domain_space.get(), isl_dim_set, i - 1)) {
std::string original_idx_name = isl_space_get_dim_name(domain_space.get(), isl_dim_set, i - 1);
isl::ast_expr transformed_index = isl::manage(isl_ast_expr_get_op_arg(idx_expr.get(), i));
iterator_map.emplace(original_idx_name, transformed_index);
}
}

return iterator_map;
}

const std::map<std::string, isl::ast_expr> &AstGen::axis2ast(const std::string &tuple_name) const {
auto it = transformed_indice_map_.find(tuple_name);
CHECK(it != transformed_indice_map_.end()) << "no id " << tuple_name;
return it->second;
}

isl::ast_expr CreateIslAstIndexExpression(isl_ast_build *build, const isl::map &access) {
CHECK(build);
isl::map schedule = isl::manage(isl_map_from_union_map(isl_ast_build_get_schedule(build)));

// get identity access from schedule.
auto statement = isl_map_get_statement_repr(schedule.get(), isl_dim_in);
auto statement_set = isl::manage(isl_set_read_from_str(isl_map_get_ctx(schedule.get()),
utils::StringFormat("{ %s : }", statement.c_str()).c_str()));
auto identity_access = isl::manage(isl_set_identity(statement_set.release()));
isl::map map = isl::manage(isl_map_reverse(schedule.copy()));

isl::pw_multi_aff iterator_map = isl::manage(isl_pw_multi_aff_from_map(map.copy()));
isl::pw_multi_aff index_aff = isl::manage(isl_pw_multi_aff_from_map(identity_access.copy()));

isl::space model2 = iterator_map.space();
index_aff = isl::manage(isl_pw_multi_aff_align_params(index_aff.copy(), model2.copy()));
isl::space model = index_aff.space();
iterator_map = isl::manage(isl_pw_multi_aff_align_params(iterator_map.copy(), model.copy()));
iterator_map = isl::manage(isl_pw_multi_aff_pullback_pw_multi_aff(index_aff.copy(), iterator_map.copy()));
isl::ast_expr index_expr = isl::manage(isl_ast_build_access_from_pw_multi_aff(build, iterator_map.copy()));

return index_expr;
}

isl::union_map AstGen::transform() {
std::vector<isl::map> transforms;
for (auto &ele : poly_elements()) {
transforms.push_back(ele.schedule());
}
return MapsToUnionMap(transforms);
}

namespace detail {

std::string GetTupleName(isl_ast_node *node) {
auto expr = isl::manage(isl_ast_node_user_get_expr(node));
auto arg = isl::manage(isl_ast_expr_get_op_arg(expr.get(), 0));
auto name = isl_id_get_name(isl_ast_expr_get_id(arg.get()));
return name;
}

} // namespace detail

} // namespace poly
} // namespace cinn
38 changes: 36 additions & 2 deletions cinn/poly/ast_gen.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#pragma once
#include <isl/cpp.h>

#include <map>
#include <string>

#include "cinn/poly/element.h"
#include "cinn/poly/isl_utils.h"
#include "cinn/poly/schedule.h"
Expand All @@ -11,7 +14,8 @@ namespace poly {

class AstGen {
public:
AstGen(const isl::set& context) : context_(context) {}
AstGen(const isl::set& context, const std::vector<Element>& elements, const Scheduler& scheduler)
: context_(context), poly_elements_(elements), scheduler_(scheduler) {}

/**
* Set forloop iterator names.
Expand All @@ -20,12 +24,42 @@ class AstGen {
*/
AstGen& SetIteratorNames(const std::vector<std::string>& names);

isl::ast_node operator()(const std::vector<Element>& elements, const Scheduler& scheduler);
isl::ctx ctx() const;

isl::ast_node Build();

const std::vector<Element>& poly_elements() const { return poly_elements_; }

const std::map<std::string, isl::ast_expr>& axis2ast(const std::string& tuple_name) const;

private:
//! Return a domain composed of all the elements.
isl::union_set domain();

//! Return a map composed of all the transforms.
isl::union_map transform();

//! Replace the Expr with the transformed indices.
//! e.g. Stage's expr is C[i,j] ...
//! e.g. with ISL transformed statement S0(c0+1, c1*2), the expr will turn to C[c0+1, c1*2]
static std::map<std::string, isl::ast_expr> ExtractIslTransformedIndiceMap(const isl::set& iterator_domain,
isl_ast_build* build);

private:
isl::set context_;
std::vector<Element> poly_elements_;
const Scheduler& scheduler_;
std::vector<std::string> iterator_names_;
//! tuple name -> { axis -> isl_ast }
std::map<std::string, std::map<std::string, isl::ast_expr>> transformed_indice_map_;
};

namespace detail {

//! Get tuple name of a ast node.
std::string GetTupleName(isl_ast_node* node);

} // namespace detail

} // namespace poly
} // namespace cinn
19 changes: 15 additions & 4 deletions cinn/poly/ast_gen_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,25 @@ TEST(ast_gen, basic) {
Element A(isl::set(ctx, "{ A[i,j,k]: 0<i,j,k<100 }"));
Element B(isl::set(ctx, "{ B[i,j,k]: 0<i,j,k<100 }"));

Iterator A_i0, A_i1;
Iterator B_i0, B_i1;

std::tie(A_i0, A_i1) = A.Split(Iterator("i"), 4);
std::tie(B_i0, B_i1) = B.Split(Iterator("i"), 4);

Scheduler scheduler;
scheduler.RegisterElement(A);
scheduler.RegisterElement(B);
scheduler.After(A, B, 2);
scheduler.After(A, B, 3);

AstGen gen(isl::set(ctx, "{:}"), {A, B}, scheduler);
gen.SetIteratorNames({"i.outer", "i.inner", "j", "k"});
gen.Build();

AstGen gen(isl::set(ctx, "{:}"));
gen.SetIteratorNames({"i", "j", "k"});
gen({A, B}, scheduler);
auto iters = gen.axis2ast("A");
for (auto& x : iters) {
LOG(INFO) << x.first << " " << x.second;
}
}

} // namespace poly
Expand Down
1 change: 0 additions & 1 deletion cinn/poly/element.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ void Element::InitSchedule() {
Element::Element(const isl::set &domain) : domain_(domain) {
CHECK(!domain_.is_null());
CHECK(!domain_.is_empty());

InitSchedule();
}

Expand Down
13 changes: 13 additions & 0 deletions cinn/poly/isl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <glog/logging.h>
#include <isl/cpp.h>

#include "cinn/utils/string.h"

namespace cinn {
namespace poly {

Expand Down Expand Up @@ -58,5 +60,16 @@ isl::union_set SetsToUnionSet(const std::vector<isl::set> &sets) {
return uset;
}

std::string isl_map_get_statement_repr(__isl_keep isl_map *map, isl_dim_type type) {
CHECK(map);
auto tuple_name = isl_map_get_tuple_name(map, type);
std::vector<std::string> dims;

for (int i = 0; i < isl_map_dim(map, type); i++) {
dims.push_back(isl_map_get_dim_name(map, type, i));
}
return utils::StringFormat("%s[%s]", tuple_name, utils::Join(dims, ", ").c_str());
}

} // namespace poly
} // namespace cinn
3 changes: 3 additions & 0 deletions cinn/poly/isl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,8 @@ void SetDimNames(isl::map* map, isl_dim_type dim_type, const std::vector<std::st
isl::union_map MapsToUnionMap(const std::vector<isl::map>& maps);
isl::union_set SetsToUnionSet(const std::vector<isl::set>& sets);

//! Get a representation of the tuple in the map.
std::string isl_map_get_statement_repr(__isl_keep isl_map* map, isl_dim_type type);

} // namespace poly
} // namespace cinn

0 comments on commit b7feaa3

Please sign in to comment.