Skip to content

Commit

Permalink
InlinePropagator please don't replay (csarofeen#1797)
Browse files Browse the repository at this point in the history
This PR makes `InlinePropagator` just set compute-at positions. It will not replay any tensor. If you want to replay, please use `TransformPropagator` and friends to do so.

Currently, `InlinePropagator` is already asserting no replay for standard and best effort compute at. So this PR is mostly about making most inlined compute at works as well.

This PR also does a lot of cleanups to remove the word "replay" from comments and variable and function names from `InlinePropagator`.

I also cleaned up `recordReplayedPos` and `retrieveReplayedPos`, now the logic is much easier to understand.
  • Loading branch information
zasdfgbnm authored Jul 1, 2022
1 parent 3f2c263 commit 38c7f3c
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 141 deletions.
21 changes: 17 additions & 4 deletions torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,20 @@ void ComputeAt::runAt(
auto selected = getPropagationSubgraph(producer, consumer);
InlinePropagatorSelector selector(selected);

TransformPropagator propagator(consumer, consumer_position);
InlinePropagator inline_propagator(
selector.selected(), consumer, consumer_position, mode);
MaxProducerPosUpdater updater;

MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
path.traverse(&propagator);

if (mode == ComputeAtMode::MostInlined) {
MostInlinedTransformPropagator propagator;
path.traverse(&propagator);
} else {
TransformPropagator propagator(consumer, consumer_position);
path.traverse(&propagator);
}

path.traverse(&inline_propagator);
path.traverse(&updater);
}
Expand Down Expand Up @@ -219,13 +226,19 @@ void ComputeAt::runWith(
auto selected = getPropagationSubgraph(producer, consumer);
InlinePropagatorSelector selector(selected);

TransformPropagator propagator(producer, producer_position);
InlinePropagator inline_propagator(
selector.selected(), producer, producer_position, mode);
MaxProducerPosUpdater updater;

MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);
path.traverse(&propagator);

if (mode == ComputeAtMode::MostInlined) {
MostInlinedTransformPropagator propagator;
path.traverse(&propagator);
} else {
TransformPropagator propagator(producer, producer_position);
path.traverse(&propagator);
}
path.traverse(&inline_propagator);
path.traverse(&updater);
}
Expand Down
174 changes: 55 additions & 119 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,22 +178,11 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
return max_pos;
}

size_t InlinePropagator::adjustComputeAtPos(TensorView* tv, size_t pos) {
pos = std::min<size_t>(pos, getMaxPosAll(tv));

// hoist inner most broadcast
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
pos--;
}

return pos;
}

size_t InlinePropagator::getReplayPosPasC(
size_t InlinePropagator::getFromPosPasC(
TensorView* producer,
TensorView* consumer) {
size_t max_pos = max_pos_calc.getMaxPosPasC(producer, consumer);
size_t pos = retrieveReplayedPos(consumer);
size_t pos = mapped_reference_pos_.at(consumer);

if (mode_ == ComputeAtMode::BestEffort) {
return std::min(pos, max_pos);
Expand All @@ -203,22 +192,22 @@ size_t InlinePropagator::getReplayPosPasC(

TORCH_INTERNAL_ASSERT(
pos <= max_pos,
"Invalid compute at position detected in compute at when trying to replay producer: ",
producer,
" as consumer: ",
"Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: ",
consumer,
" to producer: ",
producer,
" tried to do this at position: ",
pos,
" but max position that's allowed is ",
max_pos);
return pos;
}

size_t InlinePropagator::getReplayPosCasP(
size_t InlinePropagator::getFromPosCasP(
TensorView* consumer,
TensorView* producer) {
size_t max_pos = max_pos_calc.getMaxPosCasP(consumer, producer);
size_t pos = retrieveReplayedPos(producer);
size_t pos = mapped_reference_pos_.at(producer);

if (mode_ == ComputeAtMode::BestEffort) {
return std::min(pos, max_pos);
Expand All @@ -228,42 +217,28 @@ size_t InlinePropagator::getReplayPosCasP(

TORCH_INTERNAL_ASSERT(
pos <= max_pos,
"Invalid compute at position detected in compute at when trying to replay consumer: ",
consumer,
" as producer: ",
"Invalid compute at position detected in compute at when trying to propagate the CA position from producer: ",
producer,
" to consumer: ",
consumer,
" tried to do this at position: ",
pos,
" but max position that's allowed is ",
max_pos);
return pos;
}

void InlinePropagator::recordReplayedPos(TensorView* tv, size_t pos) {
if (selected_.count(tv)) {
auto new_pos = adjustComputeAtPos(tv, pos);
if (pos != new_pos) {
replayed_pos_[tv] = pos;
pos = new_pos;
void InlinePropagator::setCAPos(TensorView* tv, size_t pos) {
if (selected_.count(tv) && !tv->isFusionInput()) {
pos = std::min<size_t>(pos, getMaxPosAll(tv));
// hoist inner most broadcast
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
pos--;
}
if (!tv->isFusionInput()) {
tv->setComputeAt(pos);
} else {
replayed_pos_[tv] = pos;
}
} else {
replayed_pos_[tv] = pos;
tv->setComputeAt(pos);
}
}

size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) {
auto it = replayed_pos_.find(tv);
if (it != replayed_pos_.end()) {
return it->second;
}
return tv->getComputeAtPosition();
}

InlinePropagator::InlinePropagator(
std::unordered_set<TensorView*> selected,
TensorView* reference,
Expand All @@ -288,101 +263,62 @@ InlinePropagator::InlinePropagator(
".");
}

namespace {

// Make sure if tv is set to new_td it doesn't violate set compute at and max
// produce at positions.
bool validateDomain(TensorView* tv, TensorDomain* new_td) {
auto first_mismatch =
BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td);
return first_mismatch >= (int)tv->getMaxProducerPosition() &&
first_mismatch >= (int)tv->getComputeAtPosition();
}

} // namespace

void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
recordReplayedPos(reference_, reference_pos_);
setCAPos(reference_, reference_pos_);
mapped_reference_pos_[reference_] = reference_pos_;
}
int pos = getReplayPosPasC(to, from);
int from_pos = getFromPosPasC(to, from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
if (mode_ != ComputeAtMode::MostInlined) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" to producer ",
to,
" because this would require replay.");
}
if (to_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay.first),
"Tried to set the domain of ",
to,
" to ",
replay.first,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay.first);
to_pos = replay.second;
}
recordReplayedPos(to, to_pos);
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" to producer ",
to,
" because this would require replay.");
setCAPos(to, to_pos);
mapped_reference_pos_[to] = to_pos;
}

void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
recordReplayedPos(reference_, reference_pos_);
setCAPos(reference_, reference_pos_);
mapped_reference_pos_[reference_] = reference_pos_;
}
int pos = getReplayPosCasP(to, from);
int from_pos = getFromPosCasP(to, from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
if (mode_ != ComputeAtMode::MostInlined) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" to consumer ",
to,
" because this would require replay.");
}
if (to_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay.first),
"Tried to set the domain of ",
to,
" to ",
replay.first,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay.first);
to_pos = replay.second;
}
recordReplayedPos(to, to_pos);
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" to consumer ",
to,
" because this would require replay.");
setCAPos(to, to_pos);
mapped_reference_pos_[to] = to_pos;
}

void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
recordReplayedPos(reference_, reference_pos_);
}
auto from_pos = retrieveReplayedPos(from);
if (!TransformReplay::fullSelfMatching(to, from)) {
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay),
"Tried to set the domain of ",
to,
" to ",
replay,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay);
setCAPos(reference_, reference_pos_);
mapped_reference_pos_[reference_] = reference_pos_;
}
recordReplayedPos(to, from_pos);
auto from_pos = mapped_reference_pos_.at(from);
TORCH_CHECK(
TransformReplay::fullSelfMatching(to, from),
"Unable to propagate CA position from ",
from,
" to sibling ",
to,
" because this would require replay.");
setCAPos(to, from_pos);
mapped_reference_pos_[to] = from_pos;
}

namespace {
Expand Down
36 changes: 18 additions & 18 deletions torch/csrc/jit/codegen/cuda/inline_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,18 @@ class MaxPosCalculator {
bool allow_unmappable) const;

public:
// Returns the position at which tv can be relayed within.
// Returns the position at which tv can be inlined within.
size_t getMaxPosSelf(
TensorView* tv,
bool allow_reduction,
bool allow_vectorize,
bool allow_unmappable) const;

// Returns the maximum position producer can be replayed based on consumer
// Returns the maximum position producer can be inlined based on consumer
// given the set ComputeAtMode
size_t getMaxPosPasC(TensorView* producer, TensorView* consumer) const;

// Returns the maximum position consumer can be replayed based on producer
// Returns the maximum position consumer can be inlined based on producer
// given the set ComputeAtMode
size_t getMaxPosCasP(TensorView* consumer, TensorView* producer) const;

Expand All @@ -74,34 +74,34 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
// that can be shared across both directions.
size_t getMaxPosAll(TensorView* tv);

// Returns position of getMaxPosAll while also hoisting outside broadcast
// dimensions.
size_t adjustComputeAtPos(TensorView* tv, size_t pos);

// Returns the replay position in consumer that producer should be replayed as
// Returns the inline position in consumer that producer should be inlined as
// based on consumer, taking into consideration the max possible returned by
// getMaxPos{PasC, CasP}, the compute at mode type.
size_t getReplayPosPasC(TensorView* producer, TensorView* consumer);
size_t getFromPosPasC(TensorView* producer, TensorView* consumer);

// Returns the replay position in producer that consumer should be replayed as
// Returns the inline position in producer that consumer should be inlined as
// based on producer, taking into consideration the max possible returned by
// getMaxPos{PasC, CasP}, the compute at mode type.
size_t getReplayPosCasP(TensorView* consumer, TensorView* producer);
size_t getFromPosCasP(TensorView* consumer, TensorView* producer);

// Sets the compute at position of tv and records the position in
// replayed_pos_
void recordReplayedPos(TensorView* tv, size_t pos);
// We use mapped_reference_pos_ to keep track of the outer axes information of
// the reference tensor. That is, mapped_reference_pos_[tv] answers the
// question "What outer axes in tv are shared with the specified reference
// tensor's outer axes?". However, when we actually set the CA position of tv,
// we might not want to set it as mapped_reference_pos_[tv] because because we
// don't want to inline certain things (such as vectorized dimensions, inner
// most broadcasting, etc.).
std::unordered_map<TensorView*, size_t> mapped_reference_pos_;

// Returns the entry for tv in replayed_pos_ if it exists, else returns the
// compute at position of tv.
size_t retrieveReplayedPos(TensorView* tv);
// Actually set the computeAt position. This does not necessarily equal to
// mapped_reference_pos_[tv] because we don't want to inline certain things.
void setCAPos(TensorView* tv, size_t pos);

const MaxPosCalculator max_pos_calc;
std::unordered_set<TensorView*> selected_;
TensorView* reference_;
size_t reference_pos_;
ComputeAtMode mode_ = ComputeAtMode::Standard;
std::unordered_map<TensorView*, size_t> replayed_pos_;
bool is_first_ = true;

public:
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ enum class ComputeAtMode { Standard, BestEffort, MostInlined };
class InlinePropagator;
class MaxProducerPosUpdater;
class TransformPropagator;
struct MostInlinedTransformPropagator;
class TransformIter;
class TransformReplay;
class OptOutMutator;
Expand Down Expand Up @@ -457,6 +458,7 @@ class TORCH_CUDA_CU_API TensorView : public Val {
void applyMmaSwizzle(MmaOptions options);

friend TORCH_CUDA_CU_API TransformPropagator;
friend TORCH_CUDA_CU_API MostInlinedTransformPropagator;
friend TORCH_CUDA_CU_API TransformReplay;
friend TORCH_CUDA_CU_API OptOutMutator;
friend TORCH_CUDA_CU_API InlinePropagator;
Expand Down
Loading

0 comments on commit 38c7f3c

Please sign in to comment.