Skip to content

Commit

Permalink
Adding sibling path for MaxInfoSpanningTree (#1776)
Browse files Browse the repository at this point in the history
The sibling path is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector. For example, when the producer of a Welford is excluded from the propagation section. See test `FusionTransformPropagateSelectorSibling_CUDA` for a detailed example. Besides, since we know that siblings should be transformed exactly the same, the sibling path is a perfect next hop for preserving information.

If you want a spanning tree without a sibling path, you can override `allowSibling` as `return false` in your selector;
  • Loading branch information
zasdfgbnm authored Jun 28, 2022
1 parent 86f46aa commit 33a824d
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 54 deletions.
41 changes: 28 additions & 13 deletions torch/csrc/jit/codegen/cuda/ir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,22 @@ TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(Val* val) {
return uniqueEntries<Val>(consumer_vals);
}

// Return immediate siblings of val
TORCH_CUDA_CU_API std::vector<Val*> siblingValsOf(Val* val) {
std::vector<Val*> sibling_vals;
auto def = val->definition();
if (def != nullptr) {
auto outs = def->outputs();
for (auto sibling_val : outs) {
if (sibling_val == val) {
continue;
}
sibling_vals.emplace_back(sibling_val);
}
}
return sibling_vals;
}

// Return immediate producers of val
TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(
const std::vector<Val*>& vals) {
Expand Down Expand Up @@ -556,22 +572,21 @@ TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(
}

std::vector<TensorView*> producerTvsOf(TensorView* tv) {
if (tv->definition() == nullptr) {
return {};
}
auto producer_vals =
ir_utils::filterByType<TensorView>(tv->definition()->inputs());
return uniqueEntries<TensorView>(
{producer_vals.begin(), producer_vals.end()});
auto producer_vals = producerValsOf(tv);
auto producer_tvs = ir_utils::filterByType<TensorView>(producer_vals);
return {producer_tvs.begin(), producer_tvs.end()};
}

std::vector<TensorView*> consumerTvsOf(TensorView* tv) {
std::vector<TensorView*> consumer_tvs;
for (auto use_expr : tv->uses()) {
auto outputs = ir_utils::filterByType<TensorView>(use_expr->outputs());
consumer_tvs.insert(consumer_tvs.end(), outputs.begin(), outputs.end());
}
return uniqueEntries<TensorView>(consumer_tvs);
auto consumer_vals = consumerValsOf(tv);
auto consumer_tvs = ir_utils::filterByType<TensorView>(consumer_vals);
return {consumer_tvs.begin(), consumer_tvs.end()};
}

TORCH_CUDA_CU_API std::vector<TensorView*> siblingTvsOf(TensorView* tv) {
auto sibling_vals = siblingValsOf(tv);
auto sibling_tvs = ir_utils::filterByType<TensorView>(sibling_vals);
return {sibling_tvs.begin(), sibling_tvs.end()};
}

std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs) {
Expand Down
20 changes: 20 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@ TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(Val* val);
// code.
TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(Val* val);

// Return immediate siblings of val, this function can be used on any Val and
// will return siblings through Exprs.
//
// Warning: returned val's are not guaranteed to be between fusion inputs and
// outputs. This function simply uses val->definition() or val->uses() which is
// limited to not go through fusion inputs/outputs, but if on a path that isn't
// strictly between fusion inputs/outputs, it could effectively return dead
// code.
TORCH_CUDA_CU_API std::vector<Val*> siblingValsOf(Val* val);

// Return immediate producers of vals, this function can be used on any vals and
// will return producers through Exprs.
//
Expand Down Expand Up @@ -223,6 +233,16 @@ TORCH_CUDA_CU_API std::vector<TensorView*> producerTvsOf(TensorView* tv);
// code.
TORCH_CUDA_CU_API std::vector<TensorView*> consumerTvsOf(TensorView* tv);

// Return immediate siblings of tv, this function will return all immediate
// siblings of tv through Exprs.
//
// Warning: returned tv's are not guaranteed to be between fusion inputs and
// outputs. This function simply uses tv->definition() or tv->uses() which is
// limited to not go through fusion inputs/outputs, but if on a path that isn't
// strictly between fusion inputs/outputs, it could effectively return dead
// code.
TORCH_CUDA_CU_API std::vector<TensorView*> siblingTvsOf(TensorView* tv);

// Return immediate producers of tvs, this function will return all immediate
// producers of tvs through Exprs.
//
Expand Down
36 changes: 36 additions & 0 deletions torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
return selector_->allowCasP(from, to);
};

auto allowSibling = [this](TensorView* from, TensorView* to) {
if (selector_ == nullptr) {
return true;
}
return selector_->allowSibling(from, to);
};

while (!candidates.empty()) {
const auto next_hop_info = candidates.back();
const auto& next_hop = next_hop_info.next_hop;
Expand All @@ -91,6 +98,21 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
}
replayed.emplace(next_hop.to);

for (auto sibling_tv : ir_utils::siblingTvsOf(next_hop.to)) {
if (replayed.count(sibling_tv) ||
!allowSibling(next_hop.to, sibling_tv)) {
continue;
}
insertNextHop(
{.next_hop =
{.type = NextHopType::SIBLING,
.from = next_hop.to,
.to = sibling_tv},
.info_from = next_hop_info.info_to,
.info_to = computeInfoSibling(
next_hop.to, sibling_tv, next_hop_info.info_to)});
}

for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) {
if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) {
continue;
Expand Down Expand Up @@ -127,6 +149,9 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
}
for (const auto& next_hop : path_) {
switch (next_hop.type) {
case NextHopType::SIBLING:
propagator->propagateTvSibling(next_hop.from, next_hop.to);
break;
case NextHopType::C_AS_P:
propagator->propagateTvCasP(next_hop.from, next_hop.to);
break;
Expand Down Expand Up @@ -380,6 +405,17 @@ MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(
return std::make_shared<RootDomainInfo>(std::move(result));
}

// Given the preserved reference root ID info of a tensor, compute
// the corresponding info in its sibling. Since info has nothing to do with
// replay state, so sibling info is always identical by definition.
std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
computeInfoSibling(
TensorView* from,
TensorView* to,
std::shared_ptr<Information> from_info) const {
return from_info;
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
16 changes: 14 additions & 2 deletions torch/csrc/jit/codegen/cuda/maxinfo_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ namespace cuda {
* MaxInfoSpanningTree::Information and implement `operator<` which is used to
* tell which path contains more information, and `operator bool` which is used
* to tell if there is any information stored. You also need to implement
* computeInfoPasC and computeInfoCasP, which are the functions that compute
* information of the `to` tensor from the information of the `from` tensor.
* computeInfoPasC, computeInfoCasP, and computeInfoSibling, which are the
* functions that compute information of the `to` tensor from the information of
* the `from` tensor.
*/
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API MaxInfoSpanningTree {
Expand All @@ -40,12 +41,14 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
struct Selector {
virtual bool allowPasC(TensorView* from, TensorView* to) = 0;
virtual bool allowCasP(TensorView* from, TensorView* to) = 0;
virtual bool allowSibling(TensorView* from, TensorView* to) = 0;
};

// This is the interface to implement the actual propagation
struct Propagator {
virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0;
virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0;
virtual void propagateTvSibling(TensorView* from, TensorView* to) = 0;
};

// This is the interface that specifies the structure of information used to
Expand All @@ -71,6 +74,7 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {

private:
enum class NextHopType {
SIBLING,
C_AS_P,
P_AS_C,
};
Expand Down Expand Up @@ -109,6 +113,10 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
TensorView* from,
TensorView* to,
std::shared_ptr<Information> from_info) const = 0;
virtual std::shared_ptr<Information> computeInfoSibling(
TensorView* from,
TensorView* to,
std::shared_ptr<Information> from_info) const = 0;

public:
MaxInfoSpanningTree(
Expand Down Expand Up @@ -190,6 +198,10 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree
TensorView* from,
TensorView* to,
std::shared_ptr<Information> from_info) const override;
virtual std::shared_ptr<Information> computeInfoSibling(
TensorView* from,
TensorView* to,
std::shared_ptr<Information> from_info) const override;

private:
static std::shared_ptr<RootDomainInfo> getReferenceRootIDInfo(TensorView* tv);
Expand Down
Loading

0 comments on commit 33a824d

Please sign in to comment.