Skip to content

Commit

Permalink
Refactor TransormPropagator to allow specifying a position and propag…
Browse files Browse the repository at this point in the history
…ating to part of the DAG (#1775)

`MaxInfoPropagator` is renamed to `MaxInfoSpanningTree`, it now only does path-finding, and the propagation is in a separate class `MaxInfoSpanningTree::Propagator`. Same for `MaxRootDomainInfoPropagator`.

`MaxInfoSpanningTree` and `MaxRootDomainInfoSpanningTree`  now allow specifying a selector, which controls which subgraph should be included in path-finding.

`MaxRootDomainInfoSpanningTree` also gets a few new constructors for convenience to use.

`TransormPropagator` is now a subclass of `MaxInfoSpanningTree::Propagator`, so the way to use it has changed.

Now `MaxInfoSpanningTree` and `MaxRootDomainInfoSpanningTree` will store the path after generation so that the same path can be traversed multiple times. This will be useful to support use cases like new `computeAt`. Pseudo-code:
```C++
void TensorView::computeAt(TensorView tv, int pos) {
  auto ComputeAtSubgraphSelector selector(this, tv);
  MaxRootDomainInfoSpanningTree path(tv, pos, &selector);
  TransformPropagator propagator(tv, pos);
  path.traverse(&propagator);
  ComputeAtPosPropagator ca_propagator(tv, pos);
  path.traverse(&ca_propagator);
}
```
  • Loading branch information
zasdfgbnm authored Jun 26, 2022
1 parent d67e1cd commit a054b3e
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 205 deletions.
189 changes: 132 additions & 57 deletions torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,24 @@ namespace jit {
namespace fuser {
namespace cuda {

bool MaxInfoPropagator::Information::operator>(const Information& r) const {
bool MaxInfoSpanningTree::Information::operator>(const Information& r) const {
return r < *this;
}

bool MaxInfoPropagator::Information::operator==(const Information& r) const {
bool MaxInfoSpanningTree::Information::operator==(const Information& r) const {
return !(r < *this) && !(*this < r);
}

// Dijkstra
void MaxInfoPropagator::run() {
// Prim's algorithm
MaxInfoSpanningTree::MaxInfoSpanningTree(
TensorView* reference,
std::shared_ptr<Information> reference_info,
Selector* selector)
: reference_(reference),
reference_info_(reference_info),
selector_(selector) {}

void MaxInfoSpanningTree::compute_spanning_tree() {
// A set that allows us to quickly tell if a tensor has been replayed. If yes,
// then we will not bother computing if a new path to this tensor is worth
// taking (because the answer is always not worth)
Expand All @@ -28,88 +36,114 @@ void MaxInfoPropagator::run() {
// std::list instead of std::priority_queue because C++'s
// std::priority_queue does not support increase-key, and might not be
// deterministic either.
std::list<NextHopInfo> propagation(1);
propagation.back().from = nullptr;
propagation.back().to = reference;
propagation.back().info_to = reference_info;
std::list<NextHopWithInfo> candidates(1);
candidates.back().next_hop.from = nullptr;
candidates.back().next_hop.to = reference_;
candidates.back().info_to = reference_info_;

// Insert the given next hop the correct position in `propagation`. If there
// Insert the given next hop the correct position in `candidates`. If there
// is an existing next hop that preserves more information, then we will just
// discard `info`.
auto insertNextHopInfo = [&](const NextHopInfo& info) {
auto insertNextHop = [&](const NextHopWithInfo& info) {
if (!*(info.info_from)) {
// When there is no more information about the starting tensor,
// we are not interested in continuing the propagation.
// we are not interested in continuing the path-finding.
return;
}
// Find if there is already a path to the dest tensor
auto existing = std::find_if(
propagation.begin(), propagation.end(), [&](const NextHopInfo& i) {
return i.to == info.to;
candidates.begin(), candidates.end(), [&](const NextHopWithInfo& i) {
return i.next_hop.to == info.next_hop.to;
});
// Only insert if there is no existing path to the dest tensor, or the new
// path preserves more information about the starting tensor.
if (existing == propagation.end() || *existing < info) {
if (existing != propagation.end()) {
propagation.erase(existing);
if (existing == candidates.end() || *existing < info) {
if (existing != candidates.end()) {
candidates.erase(existing);
}
auto pos = std::upper_bound(propagation.begin(), propagation.end(), info);
propagation.insert(pos, info);
auto pos = std::upper_bound(candidates.begin(), candidates.end(), info);
candidates.insert(pos, info);
}
};

auto allowPasC = [this](TensorView* from, TensorView* to) {
if (selector_ == nullptr) {
return true;
}
return selector_->allowPasC(from, to);
};

while (!propagation.empty()) {
auto next_hop = propagation.back();
propagation.pop_back();
auto allowCasP = [this](TensorView* from, TensorView* to) {
if (selector_ == nullptr) {
return true;
}
return selector_->allowCasP(from, to);
};

while (!candidates.empty()) {
const auto next_hop_info = candidates.back();
const auto& next_hop = next_hop_info.next_hop;
candidates.pop_back();

if (next_hop.from != nullptr) {
// nullptr used to start from reference
switch (next_hop.type) {
case NextHopType::C_AS_P:
propagateTvCasP(next_hop.from, next_hop.to);
break;
case NextHopType::P_AS_C:
propagateTvPasC(next_hop.from, next_hop.to);
break;
}
path_.push_back(next_hop);
}
replayed.emplace(next_hop.to);

for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) {
if (replayed.count(consumer_tv)) {
if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) {
continue;
}
insertNextHopInfo(
{.type = NextHopType::C_AS_P,
.from = next_hop.to,
.to = consumer_tv,
.info_from = next_hop.info_to,
.info_to =
computeInfoCasP(next_hop.to, consumer_tv, next_hop.info_to)});
insertNextHop(
{.next_hop =
{.type = NextHopType::C_AS_P,
.from = next_hop.to,
.to = consumer_tv},
.info_from = next_hop_info.info_to,
.info_to = computeInfoCasP(
next_hop.to, consumer_tv, next_hop_info.info_to)});
}

for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) {
if (replayed.count(producer_tv)) {
if (replayed.count(producer_tv) || !allowPasC(next_hop.to, producer_tv)) {
continue;
}
insertNextHopInfo(
{.type = NextHopType::P_AS_C,
.from = next_hop.to,
.to = producer_tv,
.info_from = next_hop.info_to,
.info_to =
computeInfoPasC(next_hop.to, producer_tv, next_hop.info_to)});
insertNextHop(
{.next_hop =
{.type = NextHopType::P_AS_C,
.from = next_hop.to,
.to = producer_tv},
.info_from = next_hop_info.info_to,
.info_to = computeInfoPasC(
next_hop.to, producer_tv, next_hop_info.info_to)});
}
}
}

MaxRootDomainInfoPropagator::RootDomainInfo::operator bool() const {
void MaxInfoSpanningTree::traverse(Propagator* propagator) {
if (path_.empty()) {
compute_spanning_tree();
}
for (const auto& next_hop : path_) {
switch (next_hop.type) {
case NextHopType::C_AS_P:
propagator->propagateTvCasP(next_hop.from, next_hop.to);
break;
case NextHopType::P_AS_C:
propagator->propagateTvPasC(next_hop.from, next_hop.to);
break;
}
}
}

MaxRootDomainInfoSpanningTree::RootDomainInfo::operator bool() const {
return !info.empty();
}

bool MaxRootDomainInfoPropagator::RootDomainInfo::operator<(
const MaxInfoPropagator::Information& r) const {
auto rr = dynamic_cast<const MaxRootDomainInfoPropagator::RootDomainInfo&>(r);
bool MaxRootDomainInfoSpanningTree::RootDomainInfo::operator<(
const Information& r) const {
auto rr = dynamic_cast<const RootDomainInfo&>(r);
if (info.size() != rr.info.size()) {
return info.size() < rr.info.size();
}
Expand Down Expand Up @@ -174,17 +208,17 @@ std::unordered_set<IterDomain*> mapRFactorToRoot(
// Given the preserved reference root ID info of a producer, compute
// the corresponding info in consumer. The given info may be represented by
// producer's root domain, or rfactor domain, depending on how we reached the
// producer during propagation. If the given info is already represented with
// producer during path-finding. If the given info is already represented with
// producer's rfactor domain, then we directly map it to the consumer's root
// domain. If the given info is represented with producer's root domain, we need
// to first map it to the rfactor domain of the producer, then we can map it to
// the consumer's root domain. The computed info will be represented by root
// domain as root domain contains the raw information.
std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
computeInfoCasP(
TensorView* from,
TensorView* to,
std::shared_ptr<Information> from_info) {
std::shared_ptr<Information> from_info) const {
RootDomainInfo result;

TensorView* producer = from;
Expand Down Expand Up @@ -231,17 +265,17 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
// Given the preserved reference root ID info of a consumer, compute
// the corresponding info in producer. The given info may be represented by
// consumer's root domain, or rfactor domain, depending on how we reached the
// consumer during propagation. If the given info is already represented with
// consumer during path-finding. If the given info is already represented with
// consumer's root domain, then we directly map it to the producer's rfactor
// domain. If the given info is represented with consumer's rfactor domain, we
// need to first map it to the root domain of the consumer, then we can map it
// to the producer's rfactor domain. The computed info will be represented by
// rfactor domain as rfactor domain contains the raw information.
std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
computeInfoPasC(
TensorView* from,
TensorView* to,
std::shared_ptr<Information> from_info) {
std::shared_ptr<Information> from_info) const {
RootDomainInfo result;

TensorView* producer = to;
Expand Down Expand Up @@ -279,9 +313,9 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
// We will stop at the rfactor ids in producer, and will not further map
// them into root ids in producer. This means, we only keep the unprocessed
// raw information of a tensor. This behavior is important to make sure that
// info is as accurate as possible throughout the propagation.
// info is as accurate as possible throughout the path-finding.
//
// For example, if we do a C->P->C' propagation, we want to do
// For example, in a C->P->C' path, we want to do
// C(root) -> P(rfactor) -> C'(root)
// instead of
// C(root) -> P(rfactor) -> P(root) -> P(rfactor) -> C'(root)
Expand All @@ -305,6 +339,47 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
return std::make_shared<RootDomainInfo>(std::move(result));
}

std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo>
MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(TensorView* tv) {
RootDomainInfo result;
const auto& root_domain = tv->getRootDomain();
result.info.reserve(root_domain.size());
for (auto id : root_domain) {
result.info.emplace_back(RootIDInfo{{id}, true, false});
}
return std::make_shared<RootDomainInfo>(std::move(result));
}

std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo>
MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(
TensorView* tv,
int64_t leaf_pos) {
if (leaf_pos < 0) {
leaf_pos += int64_t(tv->nDims()) + 1;
}
TORCH_CHECK(
leaf_pos >= 0 && leaf_pos <= tv->nDims(),
"MaxRootDomainInfoSpanningTree called on an leaf_pos outside valid range.");
RootDomainInfo result;
const auto& root_domain = tv->getMaybeRFactorDomain();
const auto& leaf_domain = tv->domain()->domain();
std::unordered_set<IterDomain*> selected_leaves(
leaf_domain.begin(), leaf_domain.begin() + leaf_pos);
for (auto id : root_domain) {
if (selected_leaves.count(id) > 0) {
result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()});
continue;
}
for (auto selected_leaf_id : selected_leaves) {
if (DependencyCheck::isDependencyOf(id, selected_leaf_id)) {
result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()});
break;
}
}
}
return std::make_shared<RootDomainInfo>(std::move(result));
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
Loading

0 comments on commit a054b3e

Please sign in to comment.