Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform propagator skip replay when possible #1782

Merged
merged 17 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 7 additions & 95 deletions torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,96 +342,6 @@ void ComputeAt::runWith(
ca.runPass();
}

namespace {

// Checks if producer and consumer are transformed consistently so that to
// satisfy the provided compute at position. This means no replay is actually
// necessary for the compute at requested. If consumer_pos then
// consumer_or_producer_pos is relative to the consumer and skipReplay returns
// the associated position in producer.
//
// If producer and consumer are not transformed consistently with provided
// postition, returns -1.
int skipReplay(
const TensorView* producer,
const TensorView* consumer,
int consumer_or_producer_pos,
bool consumer_pos = true) {
FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay");

const auto c2p_root_map =
PairwiseRootDomainMap(producer, consumer)
.mapConsumerToProducer(consumer->domain(), producer->domain());

// IterDomains in consumer root also in producer root
std::unordered_set<Val*> mapped_consumer_roots;
for (auto entry : c2p_root_map) {
mapped_consumer_roots.emplace(entry.first);
}

const auto consumer_domain = consumer->domain()->domain();

auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween(
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});

std::unordered_set<Val*> mapped_consumer_domain_ids(
mapped_consumer_domain_ids_vec.begin(),
mapped_consumer_domain_ids_vec.end());

const auto producer_domain = producer->domain()->domain();

auto it_consumer = consumer_domain.begin();
auto it_producer = producer_domain.begin();

auto best_effort_PasC = BestEffortReplay::replayPasC(
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer));

auto c2p_map = best_effort_PasC.getReplay();

int mismatched_consumer_pos = 0;
int mismatched_producer_pos = 0;
while (it_consumer != consumer_domain.end()) {
auto consumer_id = *it_consumer;
if (!mapped_consumer_domain_ids.count(consumer_id)) {
++it_consumer;
mismatched_consumer_pos++;
continue;
}

auto c2p_it = c2p_map.find(consumer_id);
if (c2p_it == c2p_map.end()) {
break;
}

if (it_producer == producer_domain.end()) {
break;
}

auto producer_id = *it_producer;

if (c2p_it->second == producer_id) {
++mismatched_consumer_pos;
++mismatched_producer_pos;
++it_consumer;
++it_producer;
if (consumer_pos) {
if (consumer_or_producer_pos == mismatched_consumer_pos) {
return mismatched_producer_pos;
}
} else {
if (consumer_or_producer_pos == mismatched_producer_pos) {
return mismatched_consumer_pos;
}
}
} else {
break;
}
}
return -1;
}

} // namespace

// Actually applies transformation
unsigned int ComputeAt::backwardComputeAt_impl(
TensorView* producer,
Expand Down Expand Up @@ -460,9 +370,11 @@ unsigned int ComputeAt::backwardComputeAt_impl(
max_consumer_compute_at_pos);
}

// Short cut if no replay is necessary
auto maybe_producer_pos =
skipReplay(producer, consumer, (int)consumer_compute_at_pos, true);
// Checks if producer and consumer are transformed consistently so that to
// satisfy the provided compute at position. This means no replay is actually
// necessary for the compute at requested.
auto maybe_producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
producer, consumer, consumer_compute_at_pos);
if (maybe_producer_pos >= 0) {
if (!producer->isFusionInput()) {
producer->setComputeAt((unsigned int)maybe_producer_pos);
Expand Down Expand Up @@ -536,8 +448,8 @@ unsigned int ComputeAt::forwardComputeAt_impl(
}

// Short cut if no replay is necessary
auto maybe_consumer_pos =
skipReplay(producer, consumer, (int)producer_compute_at_pos, false);
auto maybe_consumer_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(
consumer, producer, producer_compute_at_pos);
if (maybe_consumer_pos > -1) {
if (!producer->isFusionInput()) {
producer->setComputeAt(producer_compute_at_pos);
Expand Down
47 changes: 40 additions & 7 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23710,7 +23710,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) {
for (auto tensors : siblings) {
for (auto t1 : tensors) {
for (auto t2 : tensors) {
checkSiblingConsistency(t1, t2);
TORCH_CHECK(TransformReplay::fullyMatching(t1, t2));
}
}
}
Expand Down Expand Up @@ -23769,7 +23769,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) {
for (auto tensors : siblings) {
for (auto t1 : tensors) {
for (auto t2 : tensors) {
checkSiblingConsistency(t1, t2);
TORCH_CHECK(TransformReplay::fullyMatching(t1, t2));
}
}
}
Expand Down Expand Up @@ -23922,7 +23922,7 @@ TEST_F(NVFuserTest, FusionTransformPropagatorSelector) {
TORCH_CHECK(tv4->nDims() == 1);
}

TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) {
TEST_F(NVFuserTest, FusionTransformPropagatorPos_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand All @@ -23939,10 +23939,9 @@ TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) {
TransformPropagator propagator(tv1, 2);
MaxRootDomainInfoSpanningTree(tv1, 2).traverse(&propagator);

TORCH_CHECK(tv0->nDims() == 3);
TORCH_CHECK(tv0->axis(0)->extent()->evaluateInt() == 11);
TORCH_CHECK(tv0->axis(1)->extent()->evaluateInt() == 2);
TORCH_CHECK(tv0->axis(2)->extent()->evaluateInt() == 105);
auto expect = makeConcreteTensor({22, 105});
expect->split(0, 2);
TORCH_CHECK(TransformReplay::fullyMatching(expect, tv0));
}

TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) {
Expand Down Expand Up @@ -23996,6 +23995,40 @@ to: 2
TORCH_CHECK(printer2.ss.str() == expect);
}

TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

auto tv0 = makeSymbolicTensor(1);
fusion->addInput(tv0);
auto tv1 = broadcast(tv0, {true, false, true});
auto tv2 = sin(tv1);
fusion->addOutput(tv2);

tv0->split(0, 2);
tv2->split(1, 2);
tv2->split(0, 4);

MaxRootDomainInfoSpanningTree path1(tv2);
TransformPropagator propagator1(tv2);
path1.traverse(&propagator1);

MaxRootDomainInfoSpanningTree path2(tv0);
TransformPropagator propagator2(tv0);
path2.traverse(&propagator2);

TORCH_CHECK(tv1->axis(0)->isBroadcast());
TORCH_CHECK(tv1->axis(1)->isBroadcast());
TORCH_CHECK(!tv1->axis(2)->isBroadcast());
TORCH_CHECK(!tv1->axis(3)->isBroadcast());
TORCH_CHECK(tv1->axis(4)->isBroadcast());

auto expect = makeSymbolicTensor(3);
expect->split(1, 2);
expect->split(0, 4);
TORCH_CHECK(TransformReplay::fullyMatching(expect, tv1));
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
165 changes: 157 additions & 8 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,24 +644,173 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
return replayCasP(consumer, producer, compute_at_axis, root_map);
}

namespace {

int getMatchedLeafPosWithoutReplay(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments on the parameters and the return value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that these functions need a better doc. I will work on it in a new PR.

const TensorView* producer,
const TensorView* consumer,
int consumer_or_producer_pos,
bool consumer_pos = true) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We often use names like consumer_pos to indicate a position in a consumer domain, so this could be confusing. Maybe something like is_producer_as_consumer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the code has been replace by #1791, no more such issue 😉

FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay");

const auto c2p_root_map =
PairwiseRootDomainMap(producer, consumer)
.mapConsumerToProducer(consumer->domain(), producer->domain());

// IterDomains in consumer root also in producer root
std::unordered_set<Val*> mapped_consumer_roots;
for (auto entry : c2p_root_map) {
mapped_consumer_roots.emplace(entry.first);
}

const auto consumer_domain = consumer->domain()->domain();

auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween(
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});

std::unordered_set<Val*> mapped_consumer_domain_ids(
mapped_consumer_domain_ids_vec.begin(),
mapped_consumer_domain_ids_vec.end());

const auto producer_domain = producer->domain()->domain();

auto it_consumer = consumer_domain.begin();
auto it_producer = producer_domain.begin();

auto best_effort_PasC = BestEffortReplay::replayPasC(
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer));

auto c2p_map = best_effort_PasC.getReplay();

int mismatched_consumer_pos = 0;
int mismatched_producer_pos = 0;
while (it_consumer != consumer_domain.end()) {
auto consumer_id = *it_consumer;
if (!mapped_consumer_domain_ids.count(consumer_id)) {
++it_consumer;
mismatched_consumer_pos++;
continue;
}

auto c2p_it = c2p_map.find(consumer_id);
if (c2p_it == c2p_map.end()) {
break;
}

if (it_producer == producer_domain.end()) {
break;
}

auto producer_id = *it_producer;

if (c2p_it->second == producer_id) {
++mismatched_consumer_pos;
++mismatched_producer_pos;
++it_consumer;
++it_producer;
if (consumer_pos) {
if (consumer_or_producer_pos == mismatched_consumer_pos) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was what was confusing me. We're not returning the actual mismatched location, we're just returning the provided position if the mismatched position is further to the right. So this function can only return the provided position, or -1 meaning there's a mismatch before we hit the provided position.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should either return this function to just being a bool, or we should actually return the position of the mismatch and compare that position outside of this function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're just returning the provided position if the mismatched position is further to the right.

I think we are not just returning the provided position, but instead, if you provide a consumer position, we are returning the corresponding producer position of the given consumer position. In this case, we can not return the actual mismatch position, because it does not provide the information about the corresponding producer position of the given consumer position?

I copy-pasted this function from computeAt, and actually, there is one thing that I don't understand about this function. Why are we skipping unmappable dims in the consumer but not in the producer? What breaks the symmetry?

Copy link
Owner

@csarofeen csarofeen Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copy-pasted this function from computeAt

Yeah it's my fault, this function is really just returning yes or no, it shouldn't be returning an int. It could return an int, but should be rewritten to do so.

Why are we skipping unmappable dims in the consumer but not in the producer? What breaks the symmetry?
Unmappable dims (might be a bad name) come from patterns associated with reductions, and really associated with normalization patterns.

If I have:
T1 = set(T0)
T2 = sum(T1, {1})
T3 = broadcast(T2, {false, true})
T4 = add(T1, T3)

T1 cannot be fully inlined to T4. So T1's second dimension (I believe) is marked as unmappable to T4. It's not okay to fully inline T2 into T4, but it's fine to fully inline T1 with T2.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam can tell me if I'm wrong on the details of the example, but I believe the principles stand.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it shouldn't be returning an int. It could return an int, but should be rewritten to do so.

I still don't understand. Are you saying that the algorithm we used to compute the corresponding position is wrong, and we should rewrite to compute it differently? Or are you saying that the algorithm is OK, but we should split up the testing of "needs play" and "find the corresponding position" into two things? Or something else?

If I have:
T1 = set(T0)
T2 = sum(T1, {1})
T3 = broadcast(T2, {false, true})
T4 = add(T1, T3)

Hmm, I am not sure if this example is a related. I think getMatchedLeafPosWithoutReplay only looks at PairwiseRootDomainMap, which will map all the dims in T4 to T1?

TEST_F(NVFuserTest, TMP) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  auto tv0 = makeSymbolicTensor(2);
  fusion->addInput(tv0);
  auto tv1 = set(tv0);
  auto tv2 = sum(tv1, {1});
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv1, tv3);
  fusion->addOutput(tv4);

  auto producer = tv1;
  auto consumer = tv4;

  const auto c2p_root_map =
      PairwiseRootDomainMap(producer, consumer)
          .mapConsumerToProducer(consumer->domain(), producer->domain());

  // IterDomains in consumer root also in producer root
  std::unordered_set<Val*> mapped_consumer_roots;
  for (auto entry : c2p_root_map) {
    mapped_consumer_roots.emplace(entry.first);
  }

  fusion->print();
  std::cout << ir_utils::toString(mapped_consumer_roots) << std::endl;
}

This gives iS8{i1}, iS9{i2}.

I think "unmapped IDs in consumer but not in producer" here refers to new broadcasting dims, and we are saying that, if we see a leaf ID that completely comes from new broadcasting dims, then we can ignore it. But why don't we want to symmetrically do the same for reductions in the producer? For non-trivial reduction, I think this is because it could not be inlined, but for trivial reduction, it should be safe to just ignore it as well?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the algorithm is correct, and it should not be symmetric. I was considering trivial reduction, but trivial reduction could likely be ignored as well.

I was trying to say the asymmetry is due to the unmappable dims, which are related to particular graph patterns with reduction (the pattern mentioned above).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my mistake, I misread the return statements, I thought we were checking the return value was equal to the provided value to return that value. We're swapping produce/consume so they're different.

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Jun 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The algorithm is actually not correct😉: #1791

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it shouldn't be returning an int. It could return an int, but should be rewritten to do so.

I still don't understand. Are you saying that the algorithm we used to compute the corresponding position is wrong, and we should rewrite to compute it differently? Or are you saying that the algorithm is OK, but we should split up the testing of "needs play" and "find the corresponding position" into two things? Or something else?

If I have:
T1 = set(T0)
T2 = sum(T1, {1})
T3 = broadcast(T2, {false, true})
T4 = add(T1, T3)

Hmm, I am not sure if this example is a related. I think getMatchedLeafPosWithoutReplay only looks at PairwiseRootDomainMap, which will map all the dims in T4 to T1?

TEST_F(NVFuserTest, TMP) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  auto tv0 = makeSymbolicTensor(2);
  fusion->addInput(tv0);
  auto tv1 = set(tv0);
  auto tv2 = sum(tv1, {1});
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv1, tv3);
  fusion->addOutput(tv4);

  auto producer = tv1;
  auto consumer = tv4;

  const auto c2p_root_map =
      PairwiseRootDomainMap(producer, consumer)
          .mapConsumerToProducer(consumer->domain(), producer->domain());

  // IterDomains in consumer root also in producer root
  std::unordered_set<Val*> mapped_consumer_roots;
  for (auto entry : c2p_root_map) {
    mapped_consumer_roots.emplace(entry.first);
  }

  fusion->print();
  std::cout << ir_utils::toString(mapped_consumer_roots) << std::endl;
}

This gives iS8{i1}, iS9{i2}.

I think "unmapped IDs in consumer but not in producer" here refers to new broadcasting dims, and we are saying that, if we see a leaf ID that completely comes from new broadcasting dims, then we can ignore it. But why don't we want to symmetrically do the same for reductions in the producer? For non-trivial reduction, I think this is because it could not be inlined, but for trivial reduction, it should be safe to just ignore it as well?

The constraint about reductions not being able to get inlined is reflected in ComputeAtRootDomainMap. PairwiseRootDomainMap does not consider constraints due to inlining but just looks at a pair of a producer and consumer.

return mismatched_producer_pos;
}
} else {
if (consumer_or_producer_pos == mismatched_producer_pos) {
return mismatched_consumer_pos;
}
}
} else {
break;
}
}
return -1;
}

} // namespace

int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
const TensorView* producer,
const TensorView* consumer,
int consumer_pos) {
return getMatchedLeafPosWithoutReplay(producer, consumer, consumer_pos, true);
}

int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
const TensorView* consumer,
const TensorView* producer,
int producer_pos) {
return getMatchedLeafPosWithoutReplay(
producer, consumer, producer_pos, false);
}

bool TransformReplay::fullyMatching(
csarofeen marked this conversation as resolved.
Show resolved Hide resolved
const TensorView* replay,
const TensorView* target) {
auto replay_root = replay->getRootDomain();
auto replay_dom = replay->domain()->domain();
auto target_root = target->getRootDomain();
auto target_dom = target->domain()->domain();
std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
if (replay_root.size() != target_root.size()) {
return false;
}
target2replay_map.reserve(replay_root.size());
std::transform(
target_root.begin(),
target_root.end(),
replay_root.begin(),
std::inserter(target2replay_map, target2replay_map.begin()),
[](auto a, auto b) { return std::make_pair(a, b); });
BestEffortReplay replay_(replay_dom, target_dom, target2replay_map);
auto r = replay_.getReplay();
for (int64_t i = 0; i < replay_dom.size(); i++) {
auto target_id = target_dom[i];
auto replay_it = r.find(target_id);
if (replay_it == r.end() || replay_it->second != replay_dom[i]) {
return false;
}
}
return true;
}

void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) {
int pos = replayed_pos_.at(from);
auto replay = TransformReplay::replayPasC(to, from, pos);
to->setDomain(replay.first);
replayed_pos_[to] = replay.second;
// Note: [Using multiple TransformPropagators]
// There are cases that we use multiple TransformPropagators along different
// spanning trees with different references in the same fusion. Some of these
// spanning trees could overlap. In cases when there are overlapping nodes,
// TransformPropagator needs to respect the replay of others, because the
// current TransformPropagator might not contain the most amount of
// information on how to do the correct transformation. The logic below tells
// TransformPropagator to skip the replay when not necessary.
int new_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
if (new_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
to->setDomain(replay.first);
new_pos = replay.second;
}
replayed_pos_[to] = new_pos;
}

void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) {
int pos = replayed_pos_.at(from);
auto replay = TransformReplay::replayCasP(to, from, pos);
to->setDomain(replay.first);
replayed_pos_[to] = replay.second;
// See note [Using multiple TransformPropagators]
int new_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
if (new_pos < 0) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like I'm missing something here, why would new_pos be less than 0 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is by protocol. <0 means impossible to find matched position without replay, that is, replay is required. See comment:

  // Returns the leaf position in consumer that matches with `producer_pos` in
  // producer. Returns -1 if matching is impossible. This function can be used
  // to test if replay is needed for getting matching outer dims.

auto replay = TransformReplay::replayCasP(to, from, pos);
to->setDomain(replay.first);
new_pos = replay.second;
}
replayed_pos_[to] = new_pos;
}

void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) {
int pos = replayed_pos_.at(from);
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
to->setDomain(replay);
// See note [Using multiple TransformPropagators]
if (!TransformReplay::fullyMatching(to, from)) {
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
to->setDomain(replay);
}
replayed_pos_[to] = pos;
}

Expand Down
Loading