Skip to content

Commit

Permalink
test the groups the same order as they are merged (#1949)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Sep 1, 2022
1 parent 208262b commit 992e17c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
38 changes: 34 additions & 4 deletions torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2598,9 +2598,39 @@ bool CombineReductions::shouldRun(
return false;
}

bool SegmentCandidateFinder::codeGenSupportedMerge(SegmentedEdge* edge) {
namespace {

//! Returns true if group1 and group2 are an immediate producer-consumer pair.
bool areDirectlyConnected(SegmentedGroup* group1, SegmentedGroup* group2) {
// Check if group1 is a immediate consumer of group2
if (std::any_of(
group1->producer_edges.begin(),
group1->producer_edges.end(),
[group2](SegmentedEdge* edge) { return edge->from == group2; })) {
return true;
}

// Check if group1 is a immediate producer of group2
if (std::any_of(
group1->consumer_edges.begin(),
group1->consumer_edges.end(),
[group2](SegmentedEdge* edge) { return edge->to == group2; })) {
return true;
}

return false;
}

} // namespace

bool SegmentCandidateFinder::codeGenSupportedMerge(
SegmentedGroup* group1,
SegmentedGroup* group2) {
TORCH_INTERNAL_ASSERT(
areDirectlyConnected(group1, group2),
"only support testing immediate producer-consumer groups");
Fusion* fusion = segmented_fusion_->completeFusion();
auto h = tryMerge(fusion, runtime_info_, edge->from, edge->to);
auto h = tryMerge(fusion, runtime_info_, group1, group2);
return h.has_value();
}

Expand Down Expand Up @@ -2827,7 +2857,7 @@ void SegmentCandidateFinder::findSegments() {

auto candidate_it = candidates.begin();
while (candidate_it != candidates.end() &&
!codeGenSupportedMerge(candidate_it->edge)) {
!codeGenSupportedMerge(group, candidate_it->group)) {
candidate_it++;
}
if (candidate_it == candidates.end()) {
Expand Down Expand Up @@ -2896,7 +2926,7 @@ void SegmentCandidateFinder::finalMerge() {
for (auto consumer : all_consumers_of_producer_group) {
if (!producer_check->isConsumerOfAny(
consumer, all_consumers_of_producer_group) &&
codeGenSupportedMerge(consumer_edge_map.at(consumer))) {
codeGenSupportedMerge(producer_group, consumer)) {
to_merge_.emplace_back(producer_group);
to_merge_.emplace_back(consumer);
producer_group->merged_ = true;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ class TORCH_CUDA_CU_API SegmentCandidateFinder {

SegmentedGroup* mergeNodes();

bool codeGenSupportedMerge(SegmentedEdge* edge);
bool codeGenSupportedMerge(SegmentedGroup* group1, SegmentedGroup* group2);

void findSegments();

Expand Down

0 comments on commit 992e17c

Please sign in to comment.