Skip to content

Commit

Permalink
Segmenter bug fix, and deterministic iteration ordering. (#1865)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong authored Jul 25, 2022
1 parent 1b665b9 commit a6b3e70
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 55 deletions.
41 changes: 41 additions & 0 deletions torch/csrc/jit/codegen/cuda/disjoint_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,47 @@ class VectorOfUniqueEntries {
return set_.find(entry) != set_.end();
}

// Erase given entry from the containers if
// there is a match.
void erase(T entry) {
vector_.erase(
std::remove_if(
vector_.begin(),
vector_.end(),
[entry](T val) { return val == entry; }),
vector_.end());

set_.erase(entry);
}

// Insert elements at the end of the container.
template <typename InputIt>
void insert(InputIt begin, InputIt end) {
for (auto it = begin; it != end; it++) {
pushBack(*it);
}
}

// Returns iterator pointing to the beginning of vector container
auto begin() const {
return vector().begin();
}

// Returns iterator pointing to the end of vector container
auto end() const {
return vector().end();
}

// Returns iterator pointing to the beginning of vector container
auto begin() {
return vector().begin();
}

// Returns iterator pointing to the end of vector container
auto end() {
return vector().end();
}

std::string toString() {
std::stringstream ss;
ss << "{ ";
Expand Down
108 changes: 53 additions & 55 deletions torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ namespace jit {
namespace fuser {
namespace cuda {

namespace {

using GroupSet = VectorOfUniqueEntries<SegmentedGroup*>;

} // namespace

std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::getNeighborGroups() {
std::vector<NeighborGroup> neighbors;
for (auto inp : producer_edges) {
Expand Down Expand Up @@ -75,7 +81,7 @@ std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::
return {};
}

std::vector<bool> can_merge(true, neighbors.size());
std::vector<bool> can_merge(neighbors.size(), true);

// Find neighbors with a level that is only 1 differant than this groups level
for (const auto i : c10::irange(neighbors.size())) {
Expand Down Expand Up @@ -155,16 +161,16 @@ void insertUniquePredicated(
std::vector<Val*>& v,
const std::vector<SegmentedEdge*>& e,
PREDICATE pred) {
std::unordered_set<Val*> to_add;
std::transform(
e.cbegin(),
e.cend(),
std::inserter(to_add, to_add.end()),
[](SegmentedEdge* se) { return se->val; });
VectorOfUniqueEntries<Val*> to_add;
for (auto edge : e) {
to_add.pushBack(edge->val);
}

std::copy_if(
to_add.begin(), to_add.end(), std::back_inserter(v), [pred](Val* val) {
return pred(val);
});
to_add.vector().begin(),
to_add.vector().end(),
std::back_inserter(v),
[pred](Val* val) { return pred(val); });
}

void SegmentedGroup::finalize() {
Expand Down Expand Up @@ -811,7 +817,6 @@ void SegmentedFusion::finalize() {
//! currently O(n^2). O(nlogn) would be a reasonable
//! goal to achieve.
class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
using GroupSet = std::unordered_set<SegmentedGroup*>;
using GroupSetOwningPtr = std::unique_ptr<GroupSet>;
using DependencyMap = std::unordered_map<SegmentedGroup*, GroupSetOwningPtr>;

Expand All @@ -829,7 +834,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
const std::vector<SegmentedGroup*>& groups_to_check) {
auto& producers_of_group = getAllKnownProducersSet(group);
for (const auto& potential_producer : groups_to_check) {
if (producers_of_group->count(potential_producer)) {
if (producers_of_group->has(potential_producer)) {
return true;
}
}
Expand All @@ -841,7 +846,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
if (it == known_producers_of_.end()) {
return false;
}
return it->second->count(b);
return it->second->has(b);
}

bool isProducerOf(SegmentedGroup* a, SegmentedGroup* b) {
Expand Down Expand Up @@ -872,18 +877,14 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
GroupSet values_between;
auto& all_producers_of_consumer = known_producers_of_.at(consumer);
TORCH_INTERNAL_ASSERT(
all_producers_of_consumer->count(producer),
all_producers_of_consumer->has(producer),
"Fusion segment: Trying to compute path between two nodes that are not producer-consumer pairs");

std::copy_if(
all_producers_of_consumer->begin(),
all_producers_of_consumer->end(),
std::inserter(values_between, values_between.end()),
[this, producer](SegmentedGroup* producer_of_consumer) {
// Checks if producer is on the producer path of this intermediate
// node
return known_producers_of_.at(producer_of_consumer)->count(producer);
});
for (auto producer_of_consumer : *all_producers_of_consumer) {
if (known_producers_of_.at(producer_of_consumer)->has(producer)) {
values_between.pushBack(producer_of_consumer);
}
}

return values_between;
}
Expand All @@ -892,7 +893,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
//! used for generating assertions after transforms
bool isproducerMapDAG() const {
for (auto& it : known_producers_of_) {
if (it.second->count(it.first)) {
if (it.second->has(it.first)) {
return false;
}
}
Expand All @@ -909,7 +910,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
void addConsumersToWorkList(SegmentedGroup* producer, GroupSet& to_visit) {
for (auto e : producer->consumer_edges) {
// A consumer wouldn't have been worked before any of its producer
to_visit.insert(e->to);
to_visit.pushBack(e->to);
}
}

Expand All @@ -922,7 +923,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
SegmentedGroup* from) {
auto& producer_set_to_merge = *getAllKnownProducersSet(from);
for (auto group : producer_set_to_merge) {
getAllKnownProducersSet(into)->insert(group);
getAllKnownProducersSet(into)->pushBack(group);
}
}

Expand All @@ -943,8 +944,8 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {

GroupSet intersection;
for (auto group : smaller_group_set) {
if (bigger_group_set.count(group)) {
intersection.insert(group);
if (bigger_group_set.has(group)) {
intersection.pushBack(group);
}
}
return intersection;
Expand All @@ -956,7 +957,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
};

//! Finds the common producers of given set of groups
GroupDependencyAnalysis::GroupSet GroupDependencyAnalysis::getCommonProducersOf(
GroupSet GroupDependencyAnalysis::getCommonProducersOf(
std::vector<SegmentedGroup*> groups) {
if (groups.empty()) {
return {};
Expand Down Expand Up @@ -1006,9 +1007,9 @@ void GroupDependencyAnalysis::mergeGroups(
// update producer maps of other groups
for (auto& it : known_producers_of_) {
// for all groups that are produced by either a or b
if (it.second->count(a) || it.second->count(b)) {
if (it.second->has(a) || it.second->has(b)) {
// insert ab as the new producer
it.second->insert(ab);
it.second->pushBack(ab);
// all producers of both a and b are now producers of `it`
mergeAllKnownProducersIntoFrom(it.first, ab);
}
Expand Down Expand Up @@ -1054,7 +1055,7 @@ void GroupDependencyAnalysis::mergeGroups(
it.second->erase(merged_producer);
}
// insert the new group as producer
it.second->insert(merged);
it.second->pushBack(merged);
}
}
}
Expand All @@ -1068,11 +1069,11 @@ void GroupDependencyAnalysis::computeAllProducers() {

// Collect source nodes, with no producers we are guaranteed
// a source node on a DAG
std::copy_if(
segmented_fusion_->cgroups().begin(),
segmented_fusion_->cgroups().end(),
std::inserter(visited, visited.end()),
[](SegmentedGroup* group) { return group->producer_edges.empty(); });
for (auto group : segmented_fusion_->cgroups()) {
if (group->producer_edges.empty()) {
visited.pushBack(group);
}
}

// visited now only contain source nodes
// they can go backward to nowhere
Expand All @@ -1086,20 +1087,18 @@ void GroupDependencyAnalysis::computeAllProducers() {
if (std::all_of(
visiting_group->producer_edges.begin(),
visiting_group->producer_edges.end(),
[&visited](SegmentedEdge* e) {
return visited.count(e->from);
})) {
[&visited](SegmentedEdge* e) { return visited.has(e->from); })) {
// filter multi-edges
GroupSet producers_of_visiting_group;
for (auto edge : visiting_group->producer_edges) {
producers_of_visiting_group.insert(edge->from);
producers_of_visiting_group.pushBack(edge->from);
}

// populate all possible paths
// from producer backward, including
// the producer
for (auto producer : producers_of_visiting_group) {
getAllKnownProducersSet(visiting_group)->insert(producer);
getAllKnownProducersSet(visiting_group)->pushBack(producer);
mergeAllKnownProducersIntoFrom(visiting_group, producer);
}
to_update = visiting_group;
Expand All @@ -1109,7 +1108,7 @@ void GroupDependencyAnalysis::computeAllProducers() {
if (to_update) {
addConsumersToWorkList(to_update, to_visit);
to_visit.erase(to_update);
visited.insert(to_update);
visited.pushBack(to_update);
} else {
TORCH_INTERNAL_ASSERT(false, "unreachable, original graph not a DAG");
}
Expand Down Expand Up @@ -2060,7 +2059,6 @@ bool SegmentCandidateFinder::TranslateWelfordInFusion(
//! This pass tries to merge nodes with the same reduction type based
//! on the graph structure.
class CombineReductions {
using GroupSet = std::unordered_set<SegmentedGroup*>;
using GroupVec = std::vector<SegmentedGroup*>;
class ReductionSignature;

Expand Down Expand Up @@ -2240,7 +2238,7 @@ class CombineReductions {
groups_with_reductions_.begin(),
groups_with_reductions_.end(),
[&all_groups_to_merge](SegmentedGroup* group) {
return all_groups_to_merge.count(group);
return all_groups_to_merge.has(group);
}),
groups_with_reductions_.end());

Expand Down Expand Up @@ -2374,7 +2372,7 @@ class CombineReductions {
groups_with_reductions_.begin(),
groups_with_reductions_.end(),
[&groups_to_merge_set](SegmentedGroup* group) {
return groups_to_merge_set.count(group);
return groups_to_merge_set.has(group);
}),
groups_with_reductions_.end());

Expand Down Expand Up @@ -2414,8 +2412,8 @@ class CombineReductions {
maybe_consumer, maybe_producer)) {
auto groups_to_check =
dependency_analysis->valuesBetween(maybe_producer, maybe_consumer);
groups_to_check.insert(maybe_producer);
groups_to_check.insert(maybe_consumer);
groups_to_check.pushBack(maybe_producer);
groups_to_check.pushBack(maybe_consumer);

// Check that either no group has a reduction or all groups have the same
// reduction signature
Expand All @@ -2428,13 +2426,13 @@ class CombineReductions {
// output edge does not generate much saving of global memory access
// we want to postpone merging these edges till the very final pass
for (auto producer_edge_of_group : group->producer_edges) {
if (groups_to_check.count(producer_edge_of_group->from) &&
if (groups_to_check.has(producer_edge_of_group->from) &&
producer_edge_of_group->val->isFusionOutput()) {
return {};
}
}
for (auto consumer_edge_of_group : group->consumer_edges) {
if (groups_to_check.count(consumer_edge_of_group->to) &&
if (groups_to_check.has(consumer_edge_of_group->to) &&
consumer_edge_of_group->val->isFusionOutput()) {
return {};
}
Expand Down Expand Up @@ -2653,11 +2651,11 @@ void SegmentCandidateFinder::findSegments() {

// Expressions to exclude from segmentation because they're just derived from
// unary ops on inputs to the complete fusion
std::unordered_set<Expr*> excluded_inp_unary_exprs;
VectorOfUniqueEntries<Expr*> excluded_inp_unary_exprs;

// "Terminating" outputs from the excluded input unary exprs, these will be
// treated as complete fusion inputs.
std::unordered_set<Val*> forwarded_inputs;
VectorOfUniqueEntries<Val*> forwarded_inputs;
{
std::deque<Expr*> to_visit;
for (auto inp : completeFusion()->inputs()) {
Expand All @@ -2677,8 +2675,8 @@ void SegmentCandidateFinder::findSegments() {
}

if (expr->output(0)->uses().size() > 1) {
excluded_inp_unary_exprs.emplace(expr);
forwarded_inputs.emplace(expr->output(0));
excluded_inp_unary_exprs.pushBack(expr);
forwarded_inputs.pushBack(expr->output(0));
continue;
}

Expand Down Expand Up @@ -2735,7 +2733,7 @@ void SegmentCandidateFinder::findSegments() {
continue;
}

if (excluded_inp_unary_exprs.count(expr)) {
if (excluded_inp_unary_exprs.has(expr)) {
continue;
}

Expand Down

0 comments on commit a6b3e70

Please sign in to comment.