Skip to content

Commit

Permalink
Cherry pick sorting patch (#85620)
Browse files Browse the repository at this point in the history
Fixes csarofeen/pytorch#1947

Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser:
1. test the groups comes with the same order as they are merged.
2. Fix detection of un-mappable root domains:
    ComputeAtRootDomainMap flags domains that should not be mapped due to
    reductions. Previously, checking if a domain potentially causes an
    invalid mapping is only done with one domain in each group of domains
    that are found to be mappable so far. That's not actually sufficient as
    the unmappable domain set is created just once with no root mapping
    information. The fix is to check all consumer domains of a producer
    tensor. A small other fix is also done to address a different problem
    discovered after the first fix.

Pull Request resolved: pytorch/pytorch#85620
Approved by: https://github.com/csarofeen, https://github.com/davidberard98
  • Loading branch information
shmsong authored and pytorchmergebot committed Sep 27, 2022
1 parent 96601eb commit ad5233c
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 78 deletions.
38 changes: 34 additions & 4 deletions torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2598,9 +2598,39 @@ bool CombineReductions::shouldRun(
return false;
}

bool SegmentCandidateFinder::codeGenSupportedMerge(SegmentedEdge* edge) {
namespace {

//! Returns true if group1 and group2 are an immediate producer-consumer pair.
bool areDirectlyConnected(SegmentedGroup* group1, SegmentedGroup* group2) {
// Check if group1 is a immediate consumer of group2
if (std::any_of(
group1->producer_edges.begin(),
group1->producer_edges.end(),
[group2](SegmentedEdge* edge) { return edge->from == group2; })) {
return true;
}

// Check if group1 is a immediate producer of group2
if (std::any_of(
group1->consumer_edges.begin(),
group1->consumer_edges.end(),
[group2](SegmentedEdge* edge) { return edge->to == group2; })) {
return true;
}

return false;
}

} // namespace

bool SegmentCandidateFinder::codeGenSupportedMerge(
SegmentedGroup* group1,
SegmentedGroup* group2) {
TORCH_INTERNAL_ASSERT(
areDirectlyConnected(group1, group2),
"only support testing immediate producer-consumer groups");
Fusion* fusion = segmented_fusion_->completeFusion();
auto h = tryMerge(fusion, runtime_info_, edge->from, edge->to);
auto h = tryMerge(fusion, runtime_info_, group1, group2);
return h.has_value();
}

Expand Down Expand Up @@ -2827,7 +2857,7 @@ void SegmentCandidateFinder::findSegments() {

auto candidate_it = candidates.begin();
while (candidate_it != candidates.end() &&
!codeGenSupportedMerge(candidate_it->edge)) {
!codeGenSupportedMerge(group, candidate_it->group)) {
candidate_it++;
}
if (candidate_it == candidates.end()) {
Expand Down Expand Up @@ -2896,7 +2926,7 @@ void SegmentCandidateFinder::finalMerge() {
for (auto consumer : all_consumers_of_producer_group) {
if (!producer_check->isConsumerOfAny(
consumer, all_consumers_of_producer_group) &&
codeGenSupportedMerge(consumer_edge_map.at(consumer))) {
codeGenSupportedMerge(producer_group, consumer)) {
to_merge_.emplace_back(producer_group);
to_merge_.emplace_back(consumer);
producer_group->merged_ = true;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ class TORCH_CUDA_CU_API SegmentCandidateFinder {

SegmentedGroup* mergeNodes();

bool codeGenSupportedMerge(SegmentedEdge* edge);
bool codeGenSupportedMerge(SegmentedGroup* group1, SegmentedGroup* group2);

void findSegments();

Expand Down
132 changes: 64 additions & 68 deletions torch/csrc/jit/codegen/cuda/root_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,12 @@ class FindInputDomains : BackwardVisitor {
private:
FindInputDomains(TensorView* tv, const IterDomain* id)
: BackwardVisitor(false), tv_(tv) {
input_keys.insert(DomainKey(tv_->domain(), id));
input_keys_.insert(DomainKey(tv_->domain(), id));
}

DomainKeySet find() {
traverseFrom(tv_->fusion(), {tv_});
return input_keys;
return input_keys_;
}

void handle(Expr* expr) override {
Expand All @@ -249,21 +249,21 @@ class FindInputDomains : BackwardVisitor {
.mapConsumerToProducer(out_tv->domain(), in_tv->domain());
for (auto root_dom : out_tv->getRootDomain()) {
DomainKey out_key({out_tv->domain(), root_dom});
if (input_keys.find(out_key) == input_keys.end()) {
if (input_keys_.find(out_key) == input_keys_.end()) {
continue;
}
auto input_id_it = c2p.find(root_dom);
if (input_id_it == c2p.end()) {
continue;
}
DomainKey input_key(in_tv->domain(), input_id_it->second);
input_keys.insert(input_key);
input_keys_.insert(input_key);
}
}

private:
TensorView* tv_ = nullptr;
DomainKeySet input_keys;
DomainKeySet input_keys_;

public:
static DomainKeySet find(TensorView* tv, const IterDomain* id) {
Expand All @@ -285,6 +285,10 @@ void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) {
auto use_chains = DependencyCheck::getAllUseChains(out_tv);
for (const auto& chain : use_chains) {
for (const auto& tv : ir_utils::filterByType<TensorView>(chain)) {
// Do not include the tensor itself in its consumers
if (tv == out_tv) {
continue;
}
const auto& root_domain = tv->getRootDomain();
for (const auto& id : root_domain) {
DomainKey consumer_key(tv->domain(), id);
Expand Down Expand Up @@ -327,30 +331,41 @@ void UnmappableReductionDomains::handle(WelfordOp* op) {
}

bool UnmappableReductionDomains::isReductionOutputMapped(
const std::vector<DomainKey>& consumer_domains,
const DomainKeySet& consumer_domains,
const ComputeAtRootDomainMap& root_map) const {
// Check each reduction domain if any of the consumer domains
// conflicts with it
for (const auto& kv : reduction_domains_) {
const DomainKey& reduction_domain = kv.first;
// Domains that must not be mapped with the reduction domain
const DomainKeySet& incompatible_domains = kv.second;
DomainKey consumer_domain_with_reduction;
bool reduction_found = false;
// Input domains to the reduction domain
const auto& input_keys = reduction_domain_inputs_.at(reduction_domain);
for (const DomainKey& consumer_domain : consumer_domains) {
for (const auto& input_key : input_keys) {
if (input_key == consumer_domain) {
consumer_domain_with_reduction = consumer_domain;
reduction_found = true;
break;
}
}
}
if (!reduction_found) {
// Check if any of the consumer domains is an input to the
// reduction
auto it = std::find_if(
consumer_domains.begin(),
consumer_domains.end(),
[&](const auto& consumer_domain) {
return std::find(
input_keys.begin(), input_keys.end(), consumer_domain) !=
input_keys.end();
});
// None of the consumer domains is used for the reduction
// domain. They should be safe with respect to this reduction
// domain
if (it == consumer_domains.end()) {
continue;
}
// Make sure no incompatible domains will be merged with the reduction
// domain.

// A consumer domain that is an input to the reduction domain
const DomainKey& input_to_reduction = *it;

// Check if mapping input_to_reduction with the other domains in
// consumer_domains. If there's a domain that is a consumer of the
// reduction, they must not be mapped together
for (const auto& consumer_domain : consumer_domains) {
if (consumer_domain == consumer_domain_with_reduction) {
if (consumer_domain == input_to_reduction) {
continue;
}
if (std::any_of(
Expand All @@ -370,6 +385,27 @@ bool UnmappableReductionDomains::isReductionOutputMapped(
return false;
}

std::string UnmappableReductionDomains::toString() const {
std::stringstream ss;
ss << "Reduction-to-consumer map\n";
for (const auto& kv : reduction_domains_) {
ss << "\tReduction: " << kv.first.toString() << "\n";
for (const auto& mapped_val : kv.second) {
ss << "\t\tConsumer domain: " << mapped_val.toString() << "\n";
}
}

ss << "Reduction-to-producer map\n";
for (const auto& kv : reduction_domain_inputs_) {
ss << "\tReduction: " << kv.first.toString() << "\n";
for (const auto& mapped_val : kv.second) {
ss << "\t\tProducer domain: " << mapped_val.toString() << "\n";
}
}

return ss.str();
}

void ComputeAtRootDomainMap::build(bool map_through_reduction) {
// Make sure we start from scratch. Throw away previous results.
eq_set_.clear();
Expand Down Expand Up @@ -712,7 +748,7 @@ void ComputeAtRootDomainMapBuilder::setInvalid(
}

bool ComputeAtRootDomainMapBuilder::isInvalid(
const std::vector<DomainKey>& domains) const {
const DomainKeySet& domains) const {
// First, collect all invalid mappings for each of the keys in domains
DomainKeyMap<DomainKeySet> invalid_key_map;
for (const auto& key : domains) {
Expand All @@ -729,8 +765,9 @@ bool ComputeAtRootDomainMapBuilder::isInvalid(

// Next, check if any pair is invalid to map.
const auto num_keys = domains.size();
const std::vector<DomainKey> domains_vec({domains.begin(), domains.end()});
for (const auto i : c10::irange(num_keys)) {
const auto& key_i = domains[i];
const auto& key_i = domains_vec[i];
// If no invalid keys found for key_i, it can be skipped.
const auto invalid_key_map_it = invalid_key_map.find(key_i);
if (invalid_key_map_it == invalid_key_map.end()) {
Expand All @@ -743,7 +780,7 @@ bool ComputeAtRootDomainMapBuilder::isInvalid(
// If any other key in domains is identified mappable with any of
// the keys in this set, the mapping with key_i is invalid.
for (const auto j : c10::irange(i + 1, num_keys)) {
const auto& key_j = domains[j];
const auto& key_j = domains_vec[j];
if (std::any_of(
invalid_keys_for_i.begin(),
invalid_keys_for_i.end(),
Expand Down Expand Up @@ -1051,61 +1088,20 @@ void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) {
}
}

// Checks whether all consumers of a producer can be joined without
// introducing unsupported mappings. Specifically, if a domain of a
// consumer has a mapped iteration domain in another consumer that
// does not correspond to the same producer iteration domain, mapping
// the consumer domains would result in the producer iteration domain
// mapped to two different consumer iteration domains, requiring
// recomputations.
bool ComputeAtRootDomainMapBuilder::hasMatchingDomains(
const std::vector<DomainKey>& unique_domains) {
for (const auto& key : unique_domains) {
for (const auto& other_key : unique_domains) {
if (key == other_key) {
continue;
}
const auto& other_root = other_key.td()->getRootDomain();
if (std::any_of(
other_root.begin(), other_root.end(), [&](const IterDomain* id) {
return root_map_.canMap(key, other_key.td(), id);
})) {
return true;
}
}
}
return false;
}

// Checks whether all consumers of a producer can be joined without
// introducing unsupported mappings, i.e., requiring recomputations.
bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) {
if (domains.size() <= 1) {
return true;
}
// Filter out equivalent domains
std::vector<DomainKey> unique_domains;
for (const auto& domain : domains) {
if (std::none_of(
unique_domains.begin(),
unique_domains.end(),
[&](const auto& unique_dom) {
return root_map_.canMap(domain, unique_dom);
})) {
unique_domains.push_back(domain);
}
}
if (hasMatchingDomains(unique_domains)) {
return false;
}

// Can't map if reduction output domains would be mapped
if (incompatible_domains_.isReductionOutputMapped(
unique_domains, root_map_) &&
if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) &&
!map_through_reduction_) {
return false;
}
// Make sure mapping these domains won't cause any invalid mapping
if (isInvalid(unique_domains)) {
if (isInvalid(domains)) {
return false;
}
return true;
Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/jit/codegen/cuda/root_domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class DomainKey {
return td() == other.td() && id() == other.id() &&
concreteId() == other.concreteId();
}
bool operator!=(const DomainKey& other) const {
return !(*this == other);
}

std::string toString() const;

Expand Down Expand Up @@ -183,9 +186,11 @@ class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor {
//! reduction outputs within the corresponding reduction loop is not
//! possible. This routine is used to build root domain mappings.
bool isReductionOutputMapped(
const std::vector<DomainKey>& consumer_domains,
const DomainKeySet& consumer_domains,
const ComputeAtRootDomainMap& root_map) const;

std::string toString() const;

private:
using IterVisitor::handle;
void handle(ReductionOp* op) override;
Expand Down Expand Up @@ -365,7 +370,7 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder
void setInvalid(const DomainKey& key1, const DomainKey& key2);

//! Check if no pair of domains is invalid to map
bool isInvalid(const std::vector<DomainKey>& domains) const;
bool isInvalid(const DomainKeySet& domains) const;

//! Track a pair of producer-consumer domains as potentially mappable. Inserts
//! entries into pending_map_, but does not add anything into the root_map_
Expand Down Expand Up @@ -453,8 +458,6 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder
//! mapping is done separately for each concrete domain.
void mapAllPendingMappings(const TensorDomain* td, IterDomain* id);

bool hasMatchingDomains(const std::vector<DomainKey>& unique_domains);

bool safeToMap(const DomainKeySet& domains);

private:
Expand Down
34 changes: 33 additions & 1 deletion torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3779,7 +3779,39 @@ TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) {
testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) {
// Repro of issue #1950
TEST_F(NVFuserTest, FusionRootMappingRepro1950_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeSymbolicTensor(3);
auto tv1 = makeSymbolicTensor(3);
auto tv2 = makeSymbolicTensor(3);

fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addInput(tv2);

auto tv3 = set(tv0);
auto tv4 = mul(tv1, tv3);
auto tv5 = mul(tv1, tv2);
auto tv6 = mul(tv5, tv3);
auto tv7 = sum(tv6, {2});
auto tv8 = broadcast(tv7, {false, false, true});
auto tv9 = mul(tv3, tv8);

// Issue #1950 was caused by a particular traversal ordering based
// on the output tensor ordering as below
fusion.addOutput(tv9);
fusion.addOutput(tv5);
fusion.addOutput(tv4);

ComputeAtRootDomainMap root_map;
root_map.build();

checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false);
}

TEST_F(NVFuserTest, FusionDetectSelfMappedDomains_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

Expand Down

0 comments on commit ad5233c

Please sign in to comment.