From a6b3e70da5dee51dbc246347228ea21384e46ac3 Mon Sep 17 00:00:00 2001 From: "S. Song" <41357537+shmsong@users.noreply.github.com> Date: Sun, 24 Jul 2022 23:38:00 -0700 Subject: [PATCH] Segmenter bug fix, and deterministic iteration ordering. (#1865) --- torch/csrc/jit/codegen/cuda/disjoint_set.h | 41 +++++++ .../jit/codegen/cuda/fusion_segmenter.cpp | 108 +++++++++--------- 2 files changed, 94 insertions(+), 55 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/disjoint_set.h b/torch/csrc/jit/codegen/cuda/disjoint_set.h index c3c91d08ac79c..64f7c55dd1c76 100644 --- a/torch/csrc/jit/codegen/cuda/disjoint_set.h +++ b/torch/csrc/jit/codegen/cuda/disjoint_set.h @@ -96,6 +96,47 @@ class VectorOfUniqueEntries { return set_.find(entry) != set_.end(); } + // Erase given entry from the containers if + // there is a match. + void erase(T entry) { + vector_.erase( + std::remove_if( + vector_.begin(), + vector_.end(), + [entry](T val) { return val == entry; }), + vector_.end()); + + set_.erase(entry); + } + + // Insert elements at the end of the container. + template + void insert(InputIt begin, InputIt end) { + for (auto it = begin; it != end; it++) { + pushBack(*it); + } + } + + // Returns iterator pointing to the beginning of vector container + auto begin() const { + return vector().begin(); + } + + // Returns iterator pointing to the end of vector container + auto end() const { + return vector().end(); + } + + // Returns iterator pointing to the beginning of vector container + auto begin() { + return vector().begin(); + } + + // Returns iterator pointing to the end of vector container + auto end() { + return vector().end(); + } + std::string toString() { std::stringstream ss; ss << "{ "; diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp index 92b9a79835e3c..4e76bffe665b4 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp @@ -16,6 +16,12 @@ namespace jit { namespace fuser { namespace cuda { +namespace { + +using GroupSet = VectorOfUniqueEntries; + +} // namespace + std::vector SegmentedGroup::getNeighborGroups() { std::vector neighbors; for (auto inp : producer_edges) { @@ -75,7 +81,7 @@ std::vector SegmentedGroup:: return {}; } - std::vector can_merge(true, neighbors.size()); + std::vector can_merge(neighbors.size(), true); // Find neighbors with a level that is only 1 differant than this groups level for (const auto i : c10::irange(neighbors.size())) { @@ -155,16 +161,16 @@ void insertUniquePredicated( std::vector& v, const std::vector& e, PREDICATE pred) { - std::unordered_set to_add; - std::transform( - e.cbegin(), - e.cend(), - std::inserter(to_add, to_add.end()), - [](SegmentedEdge* se) { return se->val; }); + VectorOfUniqueEntries to_add; + for (auto edge : e) { + to_add.pushBack(edge->val); + } + std::copy_if( - to_add.begin(), to_add.end(), std::back_inserter(v), [pred](Val* val) { - return pred(val); - }); + to_add.vector().begin(), + to_add.vector().end(), + std::back_inserter(v), + [pred](Val* val) { return pred(val); }); } void SegmentedGroup::finalize() { @@ -811,7 +817,6 @@ void SegmentedFusion::finalize() { //! currently O(n^2). O(nlogn) would be a reasonable //! goal to achieve. class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { - using GroupSet = std::unordered_set; using GroupSetOwningPtr = std::unique_ptr; using DependencyMap = std::unordered_map; @@ -829,7 +834,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { const std::vector& groups_to_check) { auto& producers_of_group = getAllKnownProducersSet(group); for (const auto& potential_producer : groups_to_check) { - if (producers_of_group->count(potential_producer)) { + if (producers_of_group->has(potential_producer)) { return true; } } @@ -841,7 +846,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { if (it == known_producers_of_.end()) { return false; } - return it->second->count(b); + return it->second->has(b); } bool isProducerOf(SegmentedGroup* a, SegmentedGroup* b) { @@ -872,18 +877,14 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { GroupSet values_between; auto& all_producers_of_consumer = known_producers_of_.at(consumer); TORCH_INTERNAL_ASSERT( - all_producers_of_consumer->count(producer), + all_producers_of_consumer->has(producer), "Fusion segment: Trying to compute path between two nodes that are not producer-consumer pairs"); - std::copy_if( - all_producers_of_consumer->begin(), - all_producers_of_consumer->end(), - std::inserter(values_between, values_between.end()), - [this, producer](SegmentedGroup* producer_of_consumer) { - // Checks if producer is on the producer path of this intermediate - // node - return known_producers_of_.at(producer_of_consumer)->count(producer); - }); + for (auto producer_of_consumer : *all_producers_of_consumer) { + if (known_producers_of_.at(producer_of_consumer)->has(producer)) { + values_between.pushBack(producer_of_consumer); + } + } return values_between; } @@ -892,7 +893,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { //! used for generating assertions after transforms bool isproducerMapDAG() const { for (auto& it : known_producers_of_) { - if (it.second->count(it.first)) { + if (it.second->has(it.first)) { return false; } } @@ -909,7 +910,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { void addConsumersToWorkList(SegmentedGroup* producer, GroupSet& to_visit) { for (auto e : producer->consumer_edges) { // A consumer wouldn't have been worked before any of its producer - to_visit.insert(e->to); + to_visit.pushBack(e->to); } } @@ -922,7 +923,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { SegmentedGroup* from) { auto& producer_set_to_merge = *getAllKnownProducersSet(from); for (auto group : producer_set_to_merge) { - getAllKnownProducersSet(into)->insert(group); + getAllKnownProducersSet(into)->pushBack(group); } } @@ -943,8 +944,8 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { GroupSet intersection; for (auto group : smaller_group_set) { - if (bigger_group_set.count(group)) { - intersection.insert(group); + if (bigger_group_set.has(group)) { + intersection.pushBack(group); } } return intersection; @@ -956,7 +957,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { }; //! Finds the common producers of given set of groups -GroupDependencyAnalysis::GroupSet GroupDependencyAnalysis::getCommonProducersOf( +GroupSet GroupDependencyAnalysis::getCommonProducersOf( std::vector groups) { if (groups.empty()) { return {}; @@ -1006,9 +1007,9 @@ void GroupDependencyAnalysis::mergeGroups( // update producer maps of other groups for (auto& it : known_producers_of_) { // for all groups that are produced by either a or b - if (it.second->count(a) || it.second->count(b)) { + if (it.second->has(a) || it.second->has(b)) { // insert ab as the new producer - it.second->insert(ab); + it.second->pushBack(ab); // all producers of both a and b are now producers of `it` mergeAllKnownProducersIntoFrom(it.first, ab); } @@ -1054,7 +1055,7 @@ void GroupDependencyAnalysis::mergeGroups( it.second->erase(merged_producer); } // insert the new group as producer - it.second->insert(merged); + it.second->pushBack(merged); } } } @@ -1068,11 +1069,11 @@ void GroupDependencyAnalysis::computeAllProducers() { // Collect source nodes, with no producers we are guaranteed // a source node on a DAG - std::copy_if( - segmented_fusion_->cgroups().begin(), - segmented_fusion_->cgroups().end(), - std::inserter(visited, visited.end()), - [](SegmentedGroup* group) { return group->producer_edges.empty(); }); + for (auto group : segmented_fusion_->cgroups()) { + if (group->producer_edges.empty()) { + visited.pushBack(group); + } + } // visited now only contain source nodes // they can go backward to nowhere @@ -1086,20 +1087,18 @@ void GroupDependencyAnalysis::computeAllProducers() { if (std::all_of( visiting_group->producer_edges.begin(), visiting_group->producer_edges.end(), - [&visited](SegmentedEdge* e) { - return visited.count(e->from); - })) { + [&visited](SegmentedEdge* e) { return visited.has(e->from); })) { // filter multi-edges GroupSet producers_of_visiting_group; for (auto edge : visiting_group->producer_edges) { - producers_of_visiting_group.insert(edge->from); + producers_of_visiting_group.pushBack(edge->from); } // populate all possible paths // from producer backward, including // the producer for (auto producer : producers_of_visiting_group) { - getAllKnownProducersSet(visiting_group)->insert(producer); + getAllKnownProducersSet(visiting_group)->pushBack(producer); mergeAllKnownProducersIntoFrom(visiting_group, producer); } to_update = visiting_group; @@ -1109,7 +1108,7 @@ void GroupDependencyAnalysis::computeAllProducers() { if (to_update) { addConsumersToWorkList(to_update, to_visit); to_visit.erase(to_update); - visited.insert(to_update); + visited.pushBack(to_update); } else { TORCH_INTERNAL_ASSERT(false, "unreachable, original graph not a DAG"); } @@ -2060,7 +2059,6 @@ bool SegmentCandidateFinder::TranslateWelfordInFusion( //! This pass tries to merge nodes with the same reduction type based //! on the graph structure. class CombineReductions { - using GroupSet = std::unordered_set; using GroupVec = std::vector; class ReductionSignature; @@ -2240,7 +2238,7 @@ class CombineReductions { groups_with_reductions_.begin(), groups_with_reductions_.end(), [&all_groups_to_merge](SegmentedGroup* group) { - return all_groups_to_merge.count(group); + return all_groups_to_merge.has(group); }), groups_with_reductions_.end()); @@ -2374,7 +2372,7 @@ class CombineReductions { groups_with_reductions_.begin(), groups_with_reductions_.end(), [&groups_to_merge_set](SegmentedGroup* group) { - return groups_to_merge_set.count(group); + return groups_to_merge_set.has(group); }), groups_with_reductions_.end()); @@ -2414,8 +2412,8 @@ class CombineReductions { maybe_consumer, maybe_producer)) { auto groups_to_check = dependency_analysis->valuesBetween(maybe_producer, maybe_consumer); - groups_to_check.insert(maybe_producer); - groups_to_check.insert(maybe_consumer); + groups_to_check.pushBack(maybe_producer); + groups_to_check.pushBack(maybe_consumer); // Check that either no group has a reduction or all groups have the same // reduction signature @@ -2428,13 +2426,13 @@ class CombineReductions { // output edge does not generate much saving of global memory access // we want to postpone merging these edges till the very final pass for (auto producer_edge_of_group : group->producer_edges) { - if (groups_to_check.count(producer_edge_of_group->from) && + if (groups_to_check.has(producer_edge_of_group->from) && producer_edge_of_group->val->isFusionOutput()) { return {}; } } for (auto consumer_edge_of_group : group->consumer_edges) { - if (groups_to_check.count(consumer_edge_of_group->to) && + if (groups_to_check.has(consumer_edge_of_group->to) && consumer_edge_of_group->val->isFusionOutput()) { return {}; } @@ -2653,11 +2651,11 @@ void SegmentCandidateFinder::findSegments() { // Expressions to exclude from segmentation because they're just derived from // unary ops on inputs to the complete fusion - std::unordered_set excluded_inp_unary_exprs; + VectorOfUniqueEntries excluded_inp_unary_exprs; // "Terminating" outputs from the excluded input unary exprs, these will be // treated as complete fusion inputs. - std::unordered_set forwarded_inputs; + VectorOfUniqueEntries forwarded_inputs; { std::deque to_visit; for (auto inp : completeFusion()->inputs()) { @@ -2677,8 +2675,8 @@ void SegmentCandidateFinder::findSegments() { } if (expr->output(0)->uses().size() > 1) { - excluded_inp_unary_exprs.emplace(expr); - forwarded_inputs.emplace(expr->output(0)); + excluded_inp_unary_exprs.pushBack(expr); + forwarded_inputs.pushBack(expr->output(0)); continue; } @@ -2735,7 +2733,7 @@ void SegmentCandidateFinder::findSegments() { continue; } - if (excluded_inp_unary_exprs.count(expr)) { + if (excluded_inp_unary_exprs.has(expr)) { continue; }