Skip to content

Commit

Permalink
Transform propagator skip replay when possible (#1782)
Browse files Browse the repository at this point in the history
This comment in the code describes what this PR is doing:

```C++
  // Note: [Using multiple TransformPropagators]
  // There are cases that we use multiple TransformPropagators along different
  // spanning trees with different references in the same fusion. Some of these
  // spanning trees could overlap. In cases when there are overlapping nodes,
  // TransformPropagator needs to respect the replay of others, because the
  // current TransformPropagator might not contain the most amount of
  // information on how to do the correct transformation. The logic below tells
  // TransformPropagator to skip the replay when not necessary.
```
  • Loading branch information
zasdfgbnm committed Jun 30, 2022
1 parent ebf23a5 commit fe93bf5
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 110 deletions.
102 changes: 7 additions & 95 deletions torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,96 +342,6 @@ void ComputeAt::runWith(
ca.runPass();
}

namespace {

// Checks if producer and consumer are transformed consistently so that to
// satisfy the provided compute at position. This means no replay is actually
// necessary for the compute at requested. If consumer_pos then
// consumer_or_producer_pos is relative to the consumer and skipReplay returns
// the associated position in producer.
//
// If producer and consumer are not transformed consistently with provided
// postition, returns -1.
int skipReplay(
const TensorView* producer,
const TensorView* consumer,
int consumer_or_producer_pos,
bool consumer_pos = true) {
FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay");

const auto c2p_root_map =
PairwiseRootDomainMap(producer, consumer)
.mapConsumerToProducer(consumer->domain(), producer->domain());

// IterDomains in consumer root also in producer root
std::unordered_set<Val*> mapped_consumer_roots;
for (auto entry : c2p_root_map) {
mapped_consumer_roots.emplace(entry.first);
}

const auto consumer_domain = consumer->domain()->domain();

auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween(
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});

std::unordered_set<Val*> mapped_consumer_domain_ids(
mapped_consumer_domain_ids_vec.begin(),
mapped_consumer_domain_ids_vec.end());

const auto producer_domain = producer->domain()->domain();

auto it_consumer = consumer_domain.begin();
auto it_producer = producer_domain.begin();

auto best_effort_PasC = BestEffortReplay::replayPasC(
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer));

auto c2p_map = best_effort_PasC.getReplay();

int mismatched_consumer_pos = 0;
int mismatched_producer_pos = 0;
while (it_consumer != consumer_domain.end()) {
auto consumer_id = *it_consumer;
if (!mapped_consumer_domain_ids.count(consumer_id)) {
++it_consumer;
mismatched_consumer_pos++;
continue;
}

auto c2p_it = c2p_map.find(consumer_id);
if (c2p_it == c2p_map.end()) {
break;
}

if (it_producer == producer_domain.end()) {
break;
}

auto producer_id = *it_producer;

if (c2p_it->second == producer_id) {
++mismatched_consumer_pos;
++mismatched_producer_pos;
++it_consumer;
++it_producer;
if (consumer_pos) {
if (consumer_or_producer_pos == mismatched_consumer_pos) {
return mismatched_producer_pos;
}
} else {
if (consumer_or_producer_pos == mismatched_producer_pos) {
return mismatched_consumer_pos;
}
}
} else {
break;
}
}
return -1;
}

} // namespace

// Actually applies transformation
unsigned int ComputeAt::backwardComputeAt_impl(
TensorView* producer,
Expand Down Expand Up @@ -460,9 +370,11 @@ unsigned int ComputeAt::backwardComputeAt_impl(
max_consumer_compute_at_pos);
}

// Short cut if no replay is necessary
auto maybe_producer_pos =
skipReplay(producer, consumer, (int)consumer_compute_at_pos, true);
// Checks if producer and consumer are transformed consistently so that to
// satisfy the provided compute at position. This means no replay is actually
// necessary for the compute at requested.
auto maybe_producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
producer, consumer, consumer_compute_at_pos);
if (maybe_producer_pos >= 0) {
if (!producer->isFusionInput()) {
producer->setComputeAt((unsigned int)maybe_producer_pos);
Expand Down Expand Up @@ -536,8 +448,8 @@ unsigned int ComputeAt::forwardComputeAt_impl(
}

// Short cut if no replay is necessary
auto maybe_consumer_pos =
skipReplay(producer, consumer, (int)producer_compute_at_pos, false);
auto maybe_consumer_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(
consumer, producer, producer_compute_at_pos);
if (maybe_consumer_pos > -1) {
if (!producer->isFusionInput()) {
producer->setComputeAt(producer_compute_at_pos);
Expand Down
47 changes: 40 additions & 7 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23710,7 +23710,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) {
for (auto tensors : siblings) {
for (auto t1 : tensors) {
for (auto t2 : tensors) {
checkSiblingConsistency(t1, t2);
TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2));
}
}
}
Expand Down Expand Up @@ -23769,7 +23769,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) {
for (auto tensors : siblings) {
for (auto t1 : tensors) {
for (auto t2 : tensors) {
checkSiblingConsistency(t1, t2);
TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2));
}
}
}
Expand Down Expand Up @@ -23922,7 +23922,7 @@ TEST_F(NVFuserTest, FusionTransformPropagatorSelector) {
TORCH_CHECK(tv4->nDims() == 1);
}

TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) {
TEST_F(NVFuserTest, FusionTransformPropagatorPos_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand All @@ -23939,10 +23939,9 @@ TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) {
TransformPropagator propagator(tv1, 2);
MaxRootDomainInfoSpanningTree(tv1, 2).traverse(&propagator);

TORCH_CHECK(tv0->nDims() == 3);
TORCH_CHECK(tv0->axis(0)->extent()->evaluateInt() == 11);
TORCH_CHECK(tv0->axis(1)->extent()->evaluateInt() == 2);
TORCH_CHECK(tv0->axis(2)->extent()->evaluateInt() == 105);
auto expect = makeConcreteTensor({22, 105});
expect->split(0, 2);
TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv0));
}

TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) {
Expand Down Expand Up @@ -23996,6 +23995,40 @@ to: 2
TORCH_CHECK(printer2.ss.str() == expect);
}

TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

auto tv0 = makeSymbolicTensor(1);
fusion->addInput(tv0);
auto tv1 = broadcast(tv0, {true, false, true});
auto tv2 = sin(tv1);
fusion->addOutput(tv2);

tv0->split(0, 2);
tv2->split(1, 2);
tv2->split(0, 4);

MaxRootDomainInfoSpanningTree path1(tv2);
TransformPropagator propagator1(tv2);
path1.traverse(&propagator1);

MaxRootDomainInfoSpanningTree path2(tv0);
TransformPropagator propagator2(tv0);
path2.traverse(&propagator2);

TORCH_CHECK(tv1->axis(0)->isBroadcast());
TORCH_CHECK(tv1->axis(1)->isBroadcast());
TORCH_CHECK(!tv1->axis(2)->isBroadcast());
TORCH_CHECK(!tv1->axis(3)->isBroadcast());
TORCH_CHECK(tv1->axis(4)->isBroadcast());

auto expect = makeSymbolicTensor(3);
expect->split(1, 2);
expect->split(0, 4);
TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv1));
}

TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down
165 changes: 157 additions & 8 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,24 +644,173 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
return replayCasP(consumer, producer, compute_at_axis, root_map);
}

namespace {

int getMatchedLeafPosWithoutReplay(
const TensorView* producer,
const TensorView* consumer,
int consumer_or_producer_pos,
bool consumer_pos = true) {
FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay");

const auto c2p_root_map =
PairwiseRootDomainMap(producer, consumer)
.mapConsumerToProducer(consumer->domain(), producer->domain());

// IterDomains in consumer root also in producer root
std::unordered_set<Val*> mapped_consumer_roots;
for (auto entry : c2p_root_map) {
mapped_consumer_roots.emplace(entry.first);
}

const auto consumer_domain = consumer->domain()->domain();

auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween(
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});

std::unordered_set<Val*> mapped_consumer_domain_ids(
mapped_consumer_domain_ids_vec.begin(),
mapped_consumer_domain_ids_vec.end());

const auto producer_domain = producer->domain()->domain();

auto it_consumer = consumer_domain.begin();
auto it_producer = producer_domain.begin();

auto best_effort_PasC = BestEffortReplay::replayPasC(
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer));

auto c2p_map = best_effort_PasC.getReplay();

int mismatched_consumer_pos = 0;
int mismatched_producer_pos = 0;
while (it_consumer != consumer_domain.end()) {
auto consumer_id = *it_consumer;
if (!mapped_consumer_domain_ids.count(consumer_id)) {
++it_consumer;
mismatched_consumer_pos++;
continue;
}

auto c2p_it = c2p_map.find(consumer_id);
if (c2p_it == c2p_map.end()) {
break;
}

if (it_producer == producer_domain.end()) {
break;
}

auto producer_id = *it_producer;

if (c2p_it->second == producer_id) {
++mismatched_consumer_pos;
++mismatched_producer_pos;
++it_consumer;
++it_producer;
if (consumer_pos) {
if (consumer_or_producer_pos == mismatched_consumer_pos) {
return mismatched_producer_pos;
}
} else {
if (consumer_or_producer_pos == mismatched_producer_pos) {
return mismatched_consumer_pos;
}
}
} else {
break;
}
}
return -1;
}

} // namespace

int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
const TensorView* producer,
const TensorView* consumer,
int consumer_pos) {
return getMatchedLeafPosWithoutReplay(producer, consumer, consumer_pos, true);
}

int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
const TensorView* consumer,
const TensorView* producer,
int producer_pos) {
return getMatchedLeafPosWithoutReplay(
producer, consumer, producer_pos, false);
}

bool TransformReplay::fullSelfMatching(
const TensorView* replay,
const TensorView* target) {
auto replay_root = replay->getRootDomain();
auto replay_dom = replay->domain()->domain();
auto target_root = target->getRootDomain();
auto target_dom = target->domain()->domain();
std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
if (replay_root.size() != target_root.size()) {
return false;
}
target2replay_map.reserve(replay_root.size());
std::transform(
target_root.begin(),
target_root.end(),
replay_root.begin(),
std::inserter(target2replay_map, target2replay_map.begin()),
[](auto a, auto b) { return std::make_pair(a, b); });
BestEffortReplay replay_(replay_dom, target_dom, target2replay_map);
auto r = replay_.getReplay();
for (int64_t i = 0; i < replay_dom.size(); i++) {
auto target_id = target_dom[i];
auto replay_it = r.find(target_id);
if (replay_it == r.end() || replay_it->second != replay_dom[i]) {
return false;
}
}
return true;
}

void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) {
int pos = replayed_pos_.at(from);
auto replay = TransformReplay::replayPasC(to, from, pos);
to->setDomain(replay.first);
replayed_pos_[to] = replay.second;
// Note: [Using multiple TransformPropagators]
// There are cases that we use multiple TransformPropagators along different
// spanning trees with different references in the same fusion. Some of these
// spanning trees could overlap. In cases when there are overlapping nodes,
// TransformPropagator needs to respect the replay of others, because the
// current TransformPropagator might not contain the most amount of
// information on how to do the correct transformation. The logic below tells
// TransformPropagator to skip the replay when not necessary.
int new_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
if (new_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
to->setDomain(replay.first);
new_pos = replay.second;
}
replayed_pos_[to] = new_pos;
}

void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) {
int pos = replayed_pos_.at(from);
auto replay = TransformReplay::replayCasP(to, from, pos);
to->setDomain(replay.first);
replayed_pos_[to] = replay.second;
// See note [Using multiple TransformPropagators]
int new_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
if (new_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
to->setDomain(replay.first);
new_pos = replay.second;
}
replayed_pos_[to] = new_pos;
}

void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) {
int pos = replayed_pos_.at(from);
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
to->setDomain(replay);
// See note [Using multiple TransformPropagators]
if (!TransformReplay::fullSelfMatching(to, from)) {
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
to->setDomain(replay);
}
replayed_pos_[to] = pos;
}

Expand Down
Loading

0 comments on commit fe93bf5

Please sign in to comment.