Skip to content

Commit

Permalink
More cleanup on InlinePropagator (csarofeen#1800)
Browse files Browse the repository at this point in the history
I just realized that `InlinePropagator` can be further simplified because it no longer replays.

Since `InlinePropagator` is no longer doing replay, it is more like a "for each" problem rather than a propagation problem:

For each tensor `tv`, if we already know what is the max position of `tv` that is mapped to the reference tensor's selected outer dimensions(stored in `mapped_reference_pos_` in the code), setting the CA position is a very local operation, and is as simple as checking `tv` itself and all its consumers to determine the inline position.

`InlinePropagator` is not completely a "for each" problem only because the computation of `mapped_reference_pos_` is a propagation problem.

This cleanup reorganizes the code of `InlinePropagator` so it is clear that `InlinePropagator` is nothing but a two-step process:
Step 1: Do a propagation to find the `mapped_reference_pos_` for all tensors.
Step 2: For each tensor, check itself and its consumers to determine the CA position.

Conceptually, I would like to split step 1 with step 2. Because this split makes these concepts decoupled. Especially, this PR makes `mapped_reference_pos_` only contain info about the reference tensor, and is independent of the CA position (Currently, this is not true for best effort and most inlined computeAt without this PR). Now, in my view, `InlinePropagator` is conceptually very simple and easy to understand.

In terms of implementation, step 1 and step 2 can be interleaved, because when we don't need to know the `mapped_reference_pos_` for `tv`'s consumer in order to compute the CA position of `tv`. So a one-pass traverse could do both step 1 and step 2 altogether.
  • Loading branch information
zasdfgbnm authored Jul 5, 2022
1 parent 8d384da commit 5f375d0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 122 deletions.
155 changes: 49 additions & 106 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,129 +104,56 @@ size_t MaxPosCalculator::getMaxPosSelf(
return std::distance(dom.begin(), iter);
}

// Return the max position in consumer that producer can be inlined to
// Cannot inline:
// Reduction dimensions in producer
// Block broadcast dimensions in producer
// Vectorized dimensions in producer or consumer
// Unrolled dimensions in producer or consumer
// Dimensions derived from root dimensions that exist in both but are
// unmappable
size_t MaxPosCalculator::getMaxPosC2P(
TensorView* consumer,
TensorView* producer) const {
// Limit max position based on vectorized dims in consumer.
auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true);

auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
auto replay_PasC =
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map);
auto c2p_replay_map = replay_PasC.getReplay();

for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0;
consumer_pos--) {
auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1));
if (map_it != c2p_replay_map.end()) {
auto p_id = map_it->second;
if (!isAllowedID(p_id, producer, true, false, false)) {
max_consumer_pos = consumer_pos - 1;
}
}
}

return max_consumer_pos;
}

// Return the max position in producer that can be inlined to consumer
// Cannot inline:
// Reduction dimensions in producer
// Vectorized dimensions in producer or consumer
// Unrolled dimensions in producer or consumer
// Dimensions derived from root dimensions that exist in both but are
// unmappable
size_t MaxPosCalculator::getMaxPosP2C(
// Vectorized dimensions in consumer
// Unrolled dimensions in consumer
size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
TensorView* producer,
TensorView* consumer) const {
auto max_producer_pos = getMaxPosSelf(producer, false, false, false);

auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
auto replay_CasP =
BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map);
auto p2c_replay_map = replay_CasP.getReplay();

for (size_t producer_pos = max_producer_pos; producer_pos > 0;
producer_pos--) {
auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1));
for (size_t producer_pos = 0; producer_pos < producer->nDims();
producer_pos++) {
auto map_it = p2c_replay_map.find(producer->axis(producer_pos));
if (map_it != p2c_replay_map.end()) {
auto c_id = map_it->second;
if (!isAllowedID(c_id, consumer, true, false, true)) {
max_producer_pos = producer_pos - 1;
return producer_pos;
}
}
}

return max_producer_pos;
return producer->nDims();
}

size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false);
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
// consumers are always replayed consistently
max_pos =
std::min<size_t>(max_pos, max_pos_calc.getMaxPosP2C(tv, consumer_tv));
max_pos = std::min<size_t>(
max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv));
}
return max_pos;
}

size_t InlinePropagator::getFromPosC2P(TensorView* from, TensorView* to) {
size_t max_pos = max_pos_calc.getMaxPosC2P(from, to);
size_t pos = mapped_reference_pos_.at(from);

if (mode_ == ComputeAtMode::BestEffort) {
return std::min(pos, max_pos);
} else if (mode_ == ComputeAtMode::MostInlined) {
return max_pos;
}

TORCH_INTERNAL_ASSERT(
pos <= max_pos,
"Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: ",
from,
" to producer: ",
to,
" tried to do this at position: ",
pos,
" but max position that's allowed is ",
max_pos);
return pos;
}

size_t InlinePropagator::getFromPosP2C(TensorView* from, TensorView* to) {
size_t max_pos = max_pos_calc.getMaxPosP2C(from, to);
size_t pos = mapped_reference_pos_.at(from);

if (mode_ == ComputeAtMode::BestEffort) {
return std::min(pos, max_pos);
} else if (mode_ == ComputeAtMode::MostInlined) {
return max_pos;
}

TORCH_INTERNAL_ASSERT(
pos <= max_pos,
"Invalid compute at position detected in compute at when trying to propagate the CA position from producer: ",
from,
" to consumer: ",
to,
" tried to do this at position: ",
pos,
" but max position that's allowed is ",
max_pos);
return pos;
}

void InlinePropagator::setCAPos(TensorView* tv, size_t pos) {
void InlinePropagator::setCAPos(TensorView* tv) {
size_t pos = mapped_reference_pos_.at(tv);
if (selected_.count(tv) && !tv->isFusionInput()) {
pos = std::min<size_t>(pos, getMaxPosAll(tv));
auto max_pos = getMaxPosAll(tv);
if (mode_ == ComputeAtMode::Standard) {
TORCH_INTERNAL_ASSERT(
pos <= max_pos,
"Invalid compute at position detected in InlinePropagator when trying to set the CA position of: ",
tv,
" to ",
pos,
", max position that's allowed is ",
max_pos);
} else {
pos = std::min<size_t>(pos, max_pos);
}
// hoist inner most broadcast
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
pos--;
Expand Down Expand Up @@ -262,10 +189,16 @@ InlinePropagator::InlinePropagator(
void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
setCAPos(reference_, reference_pos_);
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
}
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
int from_pos = getFromPosC2P(from, to);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
TORCH_CHECK(
Expand All @@ -275,17 +208,24 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
" to producer ",
to,
" because this would require replay.");
setCAPos(to, to_pos);
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}

void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
setCAPos(reference_, reference_pos_);
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
}
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
int from_pos = getFromPosP2C(from, to);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
TORCH_CHECK(
Expand All @@ -295,16 +235,18 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
" to consumer ",
to,
" because this would require replay.");
setCAPos(to, to_pos);
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}

void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
setCAPos(reference_, reference_pos_);
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
}
// Step 1: find mapped_reference_pos_[to]
auto from_pos = mapped_reference_pos_.at(from);
TORCH_CHECK(
TransformReplay::fullSelfMatching(to, from),
Expand All @@ -313,8 +255,9 @@ void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
" to sibling ",
to,
" because this would require replay.");
setCAPos(to, from_pos);
mapped_reference_pos_[to] = from_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}

namespace {
Expand Down
20 changes: 4 additions & 16 deletions torch/csrc/jit/codegen/cuda/inline_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,9 @@ class MaxPosCalculator {

// Returns the maximum position producer can be inlined based on consumer
// given the set ComputeAtMode
size_t getMaxPosC2P(TensorView* from, TensorView* to) const;

// Returns the maximum position consumer can be inlined based on producer
// given the set ComputeAtMode
size_t getMaxPosP2C(TensorView* from, TensorView* to) const;
size_t getMaxProducerPosFromConsumer(
TensorView* producer,
TensorView* consumer) const;

MaxPosCalculator(ComputeAtMode mode);
};
Expand All @@ -74,16 +72,6 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
// that can be shared across both directions.
size_t getMaxPosAll(TensorView* tv);

// Returns the inline position in consumer that producer should be inlined as
// based on consumer, taking into consideration the max possible returned by
// getMaxPos{P2C, C2P}, the compute at mode type.
size_t getFromPosC2P(TensorView* from, TensorView* to);

// Returns the inline position in producer that consumer should be inlined as
// based on producer, taking into consideration the max possible returned by
// getMaxPos{P2C, C2P}, the compute at mode type.
size_t getFromPosP2C(TensorView* from, TensorView* to);

// We use mapped_reference_pos_ to keep track of the outer axes information of
// the reference tensor. That is, mapped_reference_pos_[tv] answers the
// question "What outer axes in tv are shared with the specified reference
Expand All @@ -95,7 +83,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {

// Actually set the computeAt position. This does not necessarily equal to
// mapped_reference_pos_[tv] because we don't want to inline certain things.
void setCAPos(TensorView* tv, size_t pos);
void setCAPos(TensorView* tv);

const MaxPosCalculator max_pos_calc;
std::unordered_set<TensorView*> selected_;
Expand Down

0 comments on commit 5f375d0

Please sign in to comment.