Skip to content

Commit

Permalink
Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Jul 1, 2022
1 parent 28cbaf9 commit 45f5203
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 49 deletions.
66 changes: 66 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24079,6 +24079,72 @@ TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) {
testValidate(&fusion, cg_outputs, {in1, in2}, {tv_ref}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionSkipReplay_CUDA) {
struct TransformPropagatorWithCheck : public TransformPropagator {
public:
virtual void propagateTvPasC(TensorView* from, TensorView* to) override {
TransformPropagator::propagateTvPasC(from, to);
auto from_pos = replayed_pos_.at(from);
auto to_pos = replayed_pos_.at(to);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(
to, from, from_pos) == to_pos);
}
virtual void propagateTvCasP(TensorView* from, TensorView* to) override {
TransformPropagator::propagateTvCasP(from, to);
auto from_pos = replayed_pos_.at(from);
auto to_pos = replayed_pos_.at(to);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(
to, from, from_pos) == to_pos);
}
virtual void propagateTvSibling(TensorView* from, TensorView* to) override {
TransformPropagator::propagateTvCasP(from, to);
auto from_pos = replayed_pos_.at(from);
auto to_pos = replayed_pos_.at(to);
TORCH_CHECK(from_pos == to_pos);
TORCH_CHECK(TransformReplay::fullSelfMatching(from, to));
}
using TransformPropagator::TransformPropagator;
};

{
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* tv0 = makeContigTensor(1);
TensorView* tv1 = makeContigTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = broadcast(tv0, {false, true});
auto tv3 = add(tv2, tv1);
fusion.addOutput(tv3);

tv3->split(1, 2, false);

TransformPropagatorWithCheck propagator(tv3);
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
}

{
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* tv0 = makeContigTensor(3);
fusion.addInput(tv0);

auto tv1 = sum(tv0, {0, 2});
auto tv2 = sin(tv1);
fusion.addOutput(tv2);

tv0->split(1, 2, false);

TransformPropagatorWithCheck propagator(tv0);
MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator);
}
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
157 changes: 108 additions & 49 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,101 +644,160 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
return replayCasP(consumer, producer, compute_at_axis, root_map);
}

namespace {

int getMatchedLeafPosWithoutReplay(
int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
const TensorView* producer,
const TensorView* consumer,
int consumer_or_producer_pos,
bool consumer_pos = true) {
FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay");
int consumer_pos) {
FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayPasC");

const auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
id_map c2p_root_map = pairwise_map.mapConsumerToProducer(
consumer->domain(), producer->domain());

const auto c2p_root_map =
PairwiseRootDomainMap(producer, consumer)
.mapConsumerToProducer(consumer->domain(), producer->domain());
// IterDomains in `consumer` root also in `producer` root
const auto consumer_domain = consumer->domain()->domain();

// IterDomains in consumer root also in producer root
std::unordered_set<Val*> mapped_consumer_roots;
for (auto entry : c2p_root_map) {
mapped_consumer_roots.emplace(entry.first);
}

const auto consumer_domain = consumer->domain()->domain();

auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween(
auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween(
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});

std::unordered_set<Val*> mapped_consumer_domain_ids(
mapped_consumer_domain_ids_vec.begin(),
mapped_consumer_domain_ids_vec.end());
std::unordered_set<Val*> unskippable_consumer_ids(
unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end());

// IterDomains in `producer` root also in `consumer` root
const auto producer_domain = producer->domain()->domain();

auto it_consumer = consumer_domain.begin();
auto it_producer = producer_domain.begin();

auto best_effort_PasC = BestEffortReplay::replayPasC(
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer));

auto c2p_map = best_effort_PasC.getReplay();
id_map c2p_map =
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
.getReplay();

int mismatched_consumer_pos = 0;
int mismatched_producer_pos = 0;
while (it_consumer != consumer_domain.end()) {
if (consumer_pos == mismatched_consumer_pos) {
return mismatched_producer_pos;
}

auto consumer_id = *it_consumer;
if (!mapped_consumer_domain_ids.count(consumer_id)) {
if (unskippable_consumer_ids.count(consumer_id) == 0) {
++it_consumer;
mismatched_consumer_pos++;
++mismatched_consumer_pos;
continue;
}

auto c2p_it = c2p_map.find(consumer_id);
if (c2p_it == c2p_map.end()) {
break;
if (it_producer == producer_domain.end()) {
return -1;
}

if (it_producer == producer_domain.end()) {
break;
auto c2p_it = c2p_map.find(consumer_id);
if (c2p_it == c2p_map.end()) {
return -1;
}

auto producer_id = *it_producer;

if (c2p_it->second == producer_id) {
++mismatched_consumer_pos;
++mismatched_producer_pos;
++it_consumer;
++it_producer;
if (consumer_pos) {
if (consumer_or_producer_pos == mismatched_consumer_pos) {
return mismatched_producer_pos;
}
} else {
if (consumer_or_producer_pos == mismatched_producer_pos) {
return mismatched_consumer_pos;
}
}
} else {
break;
return -1;
}
}
if (consumer_pos == mismatched_consumer_pos) {
return mismatched_producer_pos;
}
return -1;
}

} // namespace

int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
const TensorView* producer,
const TensorView* consumer,
int consumer_pos) {
return getMatchedLeafPosWithoutReplay(producer, consumer, consumer_pos, true);
}

int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
const TensorView* consumer,
const TensorView* producer,
int producer_pos) {
return getMatchedLeafPosWithoutReplay(
producer, consumer, producer_pos, false);
FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayCasP");

const auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
id_map p2c_root_map = pairwise_map.mapProducerToConsumer(
producer->domain(), consumer->domain());

// IterDomains in `producer` root that are not reduction
const auto producer_domain = producer->domain()->domain();
auto unskippable_producer_ids_vec =
TensorDomain::noReductions(producer_domain);
std::unordered_set<IterDomain*> unskippable_producer_ids(
unskippable_producer_ids_vec.begin(), unskippable_producer_ids_vec.end());

// IterDomains in `consumer` root also in `producer` root
const auto consumer_domain = consumer->domain()->domain();

std::unordered_set<Val*> mapped_consumer_roots;
for (auto entry : p2c_root_map) {
mapped_consumer_roots.emplace(entry.second);
}

auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween(
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});

std::unordered_set<Val*> unskippable_consumer_ids(
unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end());

auto it_producer = producer_domain.begin();
auto it_consumer = consumer_domain.begin();

id_map replay_map =
BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_map)
.getReplay();

int mismatched_producer_pos = 0;
int mismatched_consumer_pos = 0;
while (it_producer != producer_domain.end()) {
if (producer_pos == mismatched_producer_pos) {
return mismatched_consumer_pos;
}

auto producer_id = *it_producer;
if (unskippable_producer_ids.count(producer_id) == 0) {
++it_producer;
++mismatched_producer_pos;
continue;
}

if (it_consumer == consumer_domain.end()) {
return -1;
}

auto consumer_id = *it_consumer;
if (unskippable_consumer_ids.count(consumer_id) == 0) {
++it_consumer;
++mismatched_consumer_pos;
continue;
}

auto replay_it = replay_map.find(producer_id);
if (replay_it == replay_map.end()) {
return -1;
}

if (replay_it->second == consumer_id) {
++mismatched_producer_pos;
++mismatched_consumer_pos;
++it_producer;
++it_consumer;
} else {
return -1;
}
}
if (producer_pos == mismatched_producer_pos) {
return mismatched_consumer_pos;
}
return -1;
}

bool TransformReplay::fullSelfMatching(
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/transform_replay.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class TORCH_CUDA_CU_API TransformReplay {

class TORCH_CUDA_CU_API TransformPropagator
: public MaxRootDomainInfoSpanningTree::Propagator {
protected:
std::unordered_map<TensorView*, size_t> replayed_pos_;

public:
Expand Down

0 comments on commit 45f5203

Please sign in to comment.