diff --git a/cinn/poly/ast_gen.cc b/cinn/poly/ast_gen.cc index f4af440255a34..e175fc8db2dad 100644 --- a/cinn/poly/ast_gen.cc +++ b/cinn/poly/ast_gen.cc @@ -3,17 +3,22 @@ namespace cinn { namespace poly { -isl::ast_node AstGen::operator()(const std::vector &elements, const Scheduler &scheduler) { - // Collect domains. - auto sets = utils::Map, isl::set>(elements, [](const Element &e) { return e.domain(); }); - isl::union_set domain = SetsToUnionSet(sets); +isl::union_set AstGen::domain() { + CHECK(!poly_elements_.empty()); + auto sets = utils::Map, isl::set>(poly_elements_, [](const Element &e) { return e.domain(); }); + return SetsToUnionSet(sets); +} - isl::ctx ctx = elements.front().domain().ctx(); +isl::ctx AstGen::ctx() const { + CHECK(!poly_elements_.empty()); + return poly_elements_.front().domain().ctx(); +} +isl::ast_node AstGen::Build() { // Collect schedule from scheduler. - auto schedules = scheduler.BuildSchedule(); + auto schedules = scheduler_.BuildSchedule(); std::vector maps; - for (auto &ele : elements) { + for (auto &ele : poly_elements_) { auto it = schedules.find(ele.id()); CHECK(it != std::end(schedules)); maps.push_back(it->second); @@ -21,18 +26,41 @@ isl::ast_node AstGen::operator()(const std::vector &elements, const Sch auto schedule = MapsToUnionMap(maps); // Build it. - auto build = isl::ast_build::from_context(context_); + auto ast_build = isl::ast_build::from_context(context_); // Set iterators. if (!iterator_names_.empty()) { - auto iterator_names = scheduler.WrapIteratorNames(iterator_names_); - isl::id_list ids = isl::manage(isl_id_list_alloc(ctx.get(), iterator_names.size())); + auto iterator_names = scheduler_.WrapIteratorNames(iterator_names_); + isl::id_list ids = isl::manage(isl_id_list_alloc(ctx().get(), iterator_names.size())); for (int i = 0; i < iterator_names.size(); i++) { - ids = isl::manage(isl_id_list_add(ids.release(), isl_id_alloc(ctx.get(), iterator_names[i].c_str(), nullptr))); + ids = isl::manage(isl_id_list_add(ids.release(), isl_id_alloc(ctx().get(), iterator_names[i].c_str(), nullptr))); } - build = isl::manage(isl_ast_build_set_iterators(build.release(), ids.release())); + ast_build = isl::manage(isl_ast_build_set_iterators(ast_build.release(), ids.release())); } - auto ast = build.node_from_schedule_map(schedule.intersect_domain(domain)); + // collect iterator map + auto get_domain_by_name = [this](const std::string &name) -> isl::set { + auto ele_it = std::find_if( + poly_elements_.begin(), poly_elements_.end(), [&name](const Element &ele) { return ele.id() == name; }); + CHECK(ele_it != std::end(poly_elements_)); + return ele_it->domain(); + }; + + auto collect = [&](isl::ast_node node, isl::ast_build build) -> isl::ast_node { + auto tuple_name = detail::GetTupleName(node.get()); + auto indice_map = ExtractIslTransformedIndiceMap(get_domain_by_name(tuple_name), build.get()); + transformed_indice_map_[tuple_name] = indice_map; + return node; + }; + + ast_build = ast_build.set_at_each_domain(collect); + + isl::union_map transformed_schedule = transform().apply_range(schedule); + auto schedule_domain = transformed_schedule.intersect_domain(domain()); + VLOG(4) << "domain: " << domain(); + VLOG(4) << "transform schedule " << poly_elements()[0].schedule(); + VLOG(4) << "schedule: " << schedule; + VLOG(4) << "schedule_domain: " << schedule_domain; + auto ast = ast_build.node_from_schedule_map(schedule_domain); VLOG(2) << "\n" << isl_ast_node_to_C_str(ast.get()); return ast; } @@ -42,5 +70,77 @@ AstGen &AstGen::SetIteratorNames(const std::vector &names) { return *this; } +isl::ast_expr CreateIslAstIndexExpression(isl_ast_build *build, const isl::map &access); + +std::map AstGen::ExtractIslTransformedIndiceMap(const isl::set &iterator_domain, + isl_ast_build *build) { + std::map iterator_map; + isl::map identity = isl::manage(isl_set_identity(iterator_domain.copy())); + isl::map schedule = identity; + + identity = identity.apply_domain(schedule); + isl::ast_expr idx_expr = CreateIslAstIndexExpression(build, identity); + isl::space domain_space = iterator_domain.space(); + + for (int i = 1; i < isl_ast_expr_get_op_n_arg(idx_expr.get()); i++) { + if (isl_space_has_dim_name(domain_space.get(), isl_dim_set, i - 1)) { + std::string original_idx_name = isl_space_get_dim_name(domain_space.get(), isl_dim_set, i - 1); + isl::ast_expr transformed_index = isl::manage(isl_ast_expr_get_op_arg(idx_expr.get(), i)); + iterator_map.emplace(original_idx_name, transformed_index); + } + } + + return iterator_map; +} + +const std::map &AstGen::axis2ast(const std::string &tuple_name) const { + auto it = transformed_indice_map_.find(tuple_name); + CHECK(it != transformed_indice_map_.end()) << "no id " << tuple_name; + return it->second; +} + +isl::ast_expr CreateIslAstIndexExpression(isl_ast_build *build, const isl::map &access) { + CHECK(build); + isl::map schedule = isl::manage(isl_map_from_union_map(isl_ast_build_get_schedule(build))); + + // get identity access from schedule. + auto statement = isl_map_get_statement_repr(schedule.get(), isl_dim_in); + auto statement_set = isl::manage(isl_set_read_from_str(isl_map_get_ctx(schedule.get()), + utils::StringFormat("{ %s : }", statement.c_str()).c_str())); + auto identity_access = isl::manage(isl_set_identity(statement_set.release())); + isl::map map = isl::manage(isl_map_reverse(schedule.copy())); + + isl::pw_multi_aff iterator_map = isl::manage(isl_pw_multi_aff_from_map(map.copy())); + isl::pw_multi_aff index_aff = isl::manage(isl_pw_multi_aff_from_map(identity_access.copy())); + + isl::space model2 = iterator_map.space(); + index_aff = isl::manage(isl_pw_multi_aff_align_params(index_aff.copy(), model2.copy())); + isl::space model = index_aff.space(); + iterator_map = isl::manage(isl_pw_multi_aff_align_params(iterator_map.copy(), model.copy())); + iterator_map = isl::manage(isl_pw_multi_aff_pullback_pw_multi_aff(index_aff.copy(), iterator_map.copy())); + isl::ast_expr index_expr = isl::manage(isl_ast_build_access_from_pw_multi_aff(build, iterator_map.copy())); + + return index_expr; +} + +isl::union_map AstGen::transform() { + std::vector transforms; + for (auto &ele : poly_elements()) { + transforms.push_back(ele.schedule()); + } + return MapsToUnionMap(transforms); +} + +namespace detail { + +std::string GetTupleName(isl_ast_node *node) { + auto expr = isl::manage(isl_ast_node_user_get_expr(node)); + auto arg = isl::manage(isl_ast_expr_get_op_arg(expr.get(), 0)); + auto name = isl_id_get_name(isl_ast_expr_get_id(arg.get())); + return name; +} + +} // namespace detail + } // namespace poly } // namespace cinn diff --git a/cinn/poly/ast_gen.h b/cinn/poly/ast_gen.h index 2f70e9b5e701d..eae55d678c844 100644 --- a/cinn/poly/ast_gen.h +++ b/cinn/poly/ast_gen.h @@ -1,6 +1,9 @@ #pragma once #include +#include +#include + #include "cinn/poly/element.h" #include "cinn/poly/isl_utils.h" #include "cinn/poly/schedule.h" @@ -11,7 +14,8 @@ namespace poly { class AstGen { public: - AstGen(const isl::set& context) : context_(context) {} + AstGen(const isl::set& context, const std::vector& elements, const Scheduler& scheduler) + : context_(context), poly_elements_(elements), scheduler_(scheduler) {} /** * Set forloop iterator names. @@ -20,12 +24,42 @@ class AstGen { */ AstGen& SetIteratorNames(const std::vector& names); - isl::ast_node operator()(const std::vector& elements, const Scheduler& scheduler); + isl::ctx ctx() const; + + isl::ast_node Build(); + + const std::vector& poly_elements() const { return poly_elements_; } + + const std::map& axis2ast(const std::string& tuple_name) const; + + private: + //! Return a domain composed of all the elements. + isl::union_set domain(); + + //! Return a map composed of all the transforms. + isl::union_map transform(); + + //! Replace the Expr with the transformed indices. + //! e.g. Stage's expr is C[i,j] ... + //! e.g. with ISL transformed statement S0(c0+1, c1*2), the expr will turn to C[c0+1, c1*2] + static std::map ExtractIslTransformedIndiceMap(const isl::set& iterator_domain, + isl_ast_build* build); private: isl::set context_; + std::vector poly_elements_; + const Scheduler& scheduler_; std::vector iterator_names_; + //! tuple name -> { axis -> isl_ast } + std::map> transformed_indice_map_; }; +namespace detail { + +//! Get tuple name of a ast node. +std::string GetTupleName(isl_ast_node* node); + +} // namespace detail + } // namespace poly } // namespace cinn diff --git a/cinn/poly/ast_gen_test.cc b/cinn/poly/ast_gen_test.cc index 20e49b3bb6a23..9d7b987bb48a2 100644 --- a/cinn/poly/ast_gen_test.cc +++ b/cinn/poly/ast_gen_test.cc @@ -10,14 +10,25 @@ TEST(ast_gen, basic) { Element A(isl::set(ctx, "{ A[i,j,k]: 0 #include +#include "cinn/utils/string.h" + namespace cinn { namespace poly { @@ -58,5 +60,16 @@ isl::union_set SetsToUnionSet(const std::vector &sets) { return uset; } +std::string isl_map_get_statement_repr(__isl_keep isl_map *map, isl_dim_type type) { + CHECK(map); + auto tuple_name = isl_map_get_tuple_name(map, type); + std::vector dims; + + for (int i = 0; i < isl_map_dim(map, type); i++) { + dims.push_back(isl_map_get_dim_name(map, type, i)); + } + return utils::StringFormat("%s[%s]", tuple_name, utils::Join(dims, ", ").c_str()); +} + } // namespace poly } // namespace cinn diff --git a/cinn/poly/isl_utils.h b/cinn/poly/isl_utils.h index 4632f99eeddce..f93f6f80ed0d9 100644 --- a/cinn/poly/isl_utils.h +++ b/cinn/poly/isl_utils.h @@ -21,5 +21,8 @@ void SetDimNames(isl::map* map, isl_dim_type dim_type, const std::vector& maps); isl::union_set SetsToUnionSet(const std::vector& sets); +//! Get a representation of the tuple in the map. +std::string isl_map_get_statement_repr(__isl_keep isl_map* map, isl_dim_type type); + } // namespace poly } // namespace cinn