Skip to content

Commit

Permalink
Some further cleanup for the new computeAt interface (#1793)
Browse files Browse the repository at this point in the history
Revert MaxProducerPosUpdater to old algo.
  • Loading branch information
zasdfgbnm authored Jul 1, 2022
1 parent 45f5203 commit d0d0908
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 18 deletions.
10 changes: 9 additions & 1 deletion torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void ComputeAt::runAt(
TensorView* consumer,
unsigned int consumer_position,
ComputeAtMode mode) {
FUSER_PERF_SCOPE("ComputeAt::run");
FUSER_PERF_SCOPE("ComputeAt::runAt");

// Make sure the correct fusion is setup between this and consumer.
TORCH_CHECK(
Expand All @@ -175,6 +175,10 @@ void ComputeAt::runAt(
consumer,
" are not in the same fusion.");

if (mode == ComputeAtMode::MostInlined) {
consumer_position = consumer->nDims();
}

FusionGuard fg(producer->fusion());

auto selected = getPropagationSubgraph(producer, consumer);
Expand Down Expand Up @@ -206,6 +210,10 @@ void ComputeAt::runWith(
consumer,
" are not in the same fusion.");

if (mode == ComputeAtMode::MostInlined) {
producer_position = producer->nDims();
}

FusionGuard fg(producer->fusion());

auto selected = getPropagationSubgraph(producer, consumer);
Expand Down
94 changes: 77 additions & 17 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,15 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
int pos = getReplayPosPasC(to, from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
// TODO: Can we make TransformPropagator do the transformation, and
// InlinePropagator only set the CA positions?
// TORCH_CHECK(to_pos >= 0);
if (mode_ != ComputeAtMode::MostInlined) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" to producer ",
to,
" because this would require replay.");
}
if (to_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
TORCH_INTERNAL_ASSERT(
Expand All @@ -335,9 +341,15 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
int pos = getReplayPosCasP(to, from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
// TODO: Can we make TransformPropagator do the transformation, and
// InlinePropagator only set the CA positions?
// TORCH_CHECK(to_pos >= 0);
if (mode_ != ComputeAtMode::MostInlined) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" to consumer ",
to,
" because this would require replay.");
}
if (to_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -373,23 +385,71 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) {
recordReplayedPos(to, from_pos);
}

namespace {

// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain.
void MaxProducerPosUpdater::handle(TensorView* consumer) {
// compute at position of producer domain. Used in computeAt pass only. No
// checking on actual producer-consumer relationship.
unsigned int getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer) {
// Locate consumer's position that aligns with
// the producer's new compute at axis. We need broadcast axes forwarded so we
// need to replay PasC as CasP will not forward braodcast dims. For example
// if we have:
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
// NVFuserTest.FusionComplexBCast1_CUDA

auto c2p_map =
BestEffortReplay::replayPasC(
producer,
consumer,
-1,
// Compute at root domain may not be valid here, as all
// producers don't have to be able to map into consumer at
// max producer position. Since computeAt should be valid
// and this mechanism is only intended to lower produce
// position of consumer, we can simply use the pairwise map.
PairwiseRootDomainMap(producer, consumer))
.getReplay();

// Find the innermost position of consumer that has
// been mapped within the producer ca axis.
unsigned int consumer_pos = consumer->nDims();
while (consumer_pos > 0) {
for (auto producer : ir_utils::producerTvsOf(consumer)) {
auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
producer, consumer, consumer_pos);
if (producer_pos >= 0 &&
producer_pos <= producer->getComputeAtPosition()) {
goto finished;
}
auto consumer_id = consumer->axis((int)consumer_pos - 1);
auto p_dom = producer->domain()->domain();
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer->getComputeAtPosition(),
[&consumer_id, &c2p_map](IterDomain* p_id) {
auto c_id_it = c2p_map.find(consumer_id);
if (c_id_it != c2p_map.end()) {
return c_id_it->second == p_id;
}
return false;
})) {
break;
}
consumer_pos--;
}
finished:
consumer->setMaxProducer(consumer_pos, true);

return consumer_pos;
}

} // namespace

// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain.
void MaxProducerPosUpdater::handle(TensorView* consumer) {
unsigned int consumer_pos = 0;
for (auto producer : ir_utils::producerTvsOf(consumer)) {
consumer_pos = std::max(
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
}
consumer->setMaxProducer(consumer_pos);
}

void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {
Expand Down

0 comments on commit d0d0908

Please sign in to comment.