Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#13 from Superjomn/init/scheduler
Browse files Browse the repository at this point in the history
init schedule
  • Loading branch information
Superjomn authored Feb 4, 2020
2 parents fb1888e + 3624658 commit 06105cc
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cinn/poly/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ cc_library(poly SRCS
map.cc
element.cc
isl_utils.cc
schedule.cc
ast_gen.cc
DEPS common)

cc_test(test_poly_element SRCS element_test.cc DEPS poly)
1 change: 1 addition & 0 deletions cinn/poly/ast_gen.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions cinn/poly/ast_gen.h
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#pragma once
2 changes: 2 additions & 0 deletions cinn/poly/element.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,7 @@ std::string OuterName(const std::string &name) { return name + "_outer"; }
std::string InnerName(const Iterator &iterator) { return InnerName(iterator.id); }
std::string OuterName(const Iterator &iterator) { return OuterName(iterator.id); }

const char *Element::id() const { return isl_set_get_tuple_name(domain_.get()); }

} // namespace poly
} // namespace cinn
8 changes: 8 additions & 0 deletions cinn/poly/element.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class Element {
public:
explicit Element(isl::set domain);

/**
* The id of this element, should be unique across the schedule.
*/
const char* id() const;

/**
* Split the loop level of into two new loop levels.
* @param level the level to split.
Expand Down Expand Up @@ -61,6 +66,9 @@ class Element {
*/
Iterator Fuse(const Iterator& level0, const Iterator& level1);

const isl::set& domain() const { return domain_; }
const isl::map& schedule() const { return schedule_; }

private:
/**
* Initialize with an identity schedule.
Expand Down
8 changes: 8 additions & 0 deletions cinn/poly/isl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,13 @@ std::vector<std::string> GetDimNames(const isl::set &x) {
return res;
}

std::vector<std::string> poly::GetDimNames(const isl::map &x, isl_dim_type dim_type) {
std::vector<std::string> res;
for (int i = 0; i < isl_map_dim(x.get(), dim_type); i++) {
res.push_back(isl_map_get_dim_name(x.get(), dim_type, i));
}
return res;
}

} // namespace poly
} // namespace cinn
3 changes: 3 additions & 0 deletions cinn/poly/isl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ namespace cinn {
namespace poly {

//! Get dimension names from isl containers.
// @{
std::vector<std::string> GetDimNames(const isl::set &x);
std::vector<std::string> GetDimNames(const isl::map &x, isl_dim_type dim_type);
// @}

} // namespace poly
} // namespace cinn
63 changes: 63 additions & 0 deletions cinn/poly/schedule.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "cinn/poly/schedule.h"
#include "cinn/utils/string.h"

#include <sstream>

namespace cinn {
namespace poly {

std::string TimeSchedule::__str__() const {
CHECK(!time_dims.empty());

// generate range: [dup, t0, t1...]
std::vector<std::string> range_dims({"dup"});
for (int i = 0; i < time_dims.size(); i++) {
range_dims.push_back("t" + std::to_string(i));
}

// generate conditions
std::vector<std::string> conds;
for (int i = 0; i < time_dims.size(); i++) {
conds.push_back(std::to_string(time_dims[i].time));
conds.push_back(time_dims[i].dim);
}

return utils::StringFormat("{ %s[%s] -> [%s]: %s",
id.c_str(),
utils::Join(domain_dims, ", ").c_str(),
utils::Join(range_dims, ", ").c_str(),
utils::Join(conds, " and ").c_str());
}

void Scheduler::RegisterElement(const Element &x) {
CHECK(!registration_finalized_) << "element registration has been finalized.";
space_size_ = std::max(space_size_, isl_map_dim(x.schedule().get(), isl_dim_out));

// Use the dimensions from element's schedule's range as the new domain dimensions because in Element, the schedule is
// like '{ S0[i,j] -> S0[i_outer, i_inner, j] }', the scheduler should schedule base on the range.
TimeSchedule schedule(GetDimNames(x.schedule(), isl_dim_out));
schedule_.emplace(x.id(), std::move(schedule));
}

void Scheduler::FinalizeRegistration() {
CHECK_GT(space_size_, 0) << "No valid dimension is collected, use RegisterElement to collect some elements";
CHECK(!schedule_.empty()) << "No valid dimension is collected, use RegisterElement to collect some elements";
registration_finalized_ = false;

for (auto &item : schedule_) {
item.second.ResizeTimeSpace(space_size_);
}
}

Scheduler &Scheduler::After(const Element &a, const Element &b, int level) {
CHECK_LT(level, space_size_);
depend_flow_graph_[b.id()].depend_level[a.id()] = level;
return *this;
}

Scheduler &Scheduler::Before(const Element &a, const Element &b, int level) { return After(b, a, level); }

std::unordered_map<std::string, isl::map> Scheduler::BuildSchedule() const {}

} // namespace poly
} // namespace cinn
120 changes: 120 additions & 0 deletions cinn/poly/schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#pragma once

#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "cinn/poly/element.h"
#include "cinn/poly/isl_utils.h"
#include "cinn/poly/map.h"

namespace cinn {
namespace poly {

struct TimeDim {
//! time of this dimension.
int time;
//! name of this dimension.
std::string dim;

TimeDim(std::string dim, int time) : dim(std::move(dim)), time(time) {}
};

struct DependFlow {
//! Map from the depended Element.id to the level.
std::unordered_map<std::string, int> depend_level;
};

/**
* The range of the schedule.
*/
struct TimeSchedule {
//! ISL range format, such as '[dup, t0, t1]: dup=0 and t0=0 and t1=i]'
std::string __str__() const;

TimeSchedule(const std::vector<std::string> &dims) {
domain_dims = dims;
for (auto &dim : domain_dims) {
time_dims.emplace_back(dim, 0);
}
}

void ResizeTimeSpace(int size) { time_dims.resize(size); }

//! Get the isl map.
isl::map to_isl(isl::ctx ctx) const { return isl::map(ctx, __str__()); }

std::string id;
std::vector<std::string> domain_dims;
int duplicate_id{};
std::vector<TimeDim> time_dims;
};

/**
* Scheduler - Perform schedule on polyhedral model.
* It takes a normal schedule as input, and transform it into
*
*/
class Scheduler {
public:
/**
* Constructor.
* @param schedule A normal isl schedule, such as '{ S[i,j] -> [i,j] }'
*
* The schedule input can be transformed, that's ok, such as
* '{ S[i,j] -> [i_outer, i_inner, j]: i_outer=floor(i/4) and i_inner=i%4 }'
* that's OK.
*/
Scheduler() = default;

/**
* Register an Element to the scheduler.
*/
void RegisterElement(const Element &x);

/**
* Finalize the registration.
*/
void FinalizeRegistration();

/**
* Mark this should schedule after another.
*
* @param b
* @param level
*/
Scheduler &After(const Element &a, const Element &b, int level);
/**
* Mark this should schedule before another.
* @param b
* @param level
*/
Scheduler &Before(const Element &a, const Element &b, int level);

/**
* Build and create schedule.
*/
std::unordered_map<std::string, isl::map> BuildSchedule() const;

private:
/**
* The polyhedral schedule, any schedule is performed on it.
* We use the time-space map to record the schedule infomation, the format is borrowed from Tiramisu project:
* [redundant,
*
*/
int space_size_{};
//! Tell if the element registration is finalized.
bool registration_finalized_{false};
//! map from Schedule id to time schedule.
std::unordered_map<std::string, TimeSchedule> schedule_;
//! The graph constructed from the dependency and level. There should be only one element which doesn't has
//! dependency and that is the start point.
std::unordered_map<std::string, DependFlow> depend_flow_graph_;
};

} // namespace poly
} // namespace cinn

0 comments on commit 06105cc

Please sign in to comment.