Skip to content

Commit

Permalink
Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 17, 2022
1 parent 9135a96 commit 63630f1
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 125 deletions.
4 changes: 0 additions & 4 deletions torch/csrc/jit/codegen/cuda/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ void ComputeAt::runAt(

InlinePropagator inline_propagator(
consumer, consumer_position, mode, selector.selected());
MaxProducerPosUpdater updater;

MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);

Expand All @@ -199,7 +198,6 @@ void ComputeAt::runAt(
}

path.traverse(&inline_propagator);
path.traverse(&updater);
}

void ComputeAt::runWith(
Expand Down Expand Up @@ -228,7 +226,6 @@ void ComputeAt::runWith(

InlinePropagator inline_propagator(
producer, producer_position, mode, selector.selected());
MaxProducerPosUpdater updater;

MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);

Expand All @@ -240,7 +237,6 @@ void ComputeAt::runWith(
path.traverse(&propagator);
}
path.traverse(&inline_propagator);
path.traverse(&updater);
}

} // namespace cuda
Expand Down
175 changes: 74 additions & 101 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ void InlinePropagator::setCAPos(TensorView* tv) {
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
pos--;
}
tv->setComputeAt(pos);
auto current_ca_pos = tv->getComputeAtPosition();
if (pos > current_ca_pos) {
tv->setComputeAt(pos);
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
needs_update_max_producer_.insert(consumer_tv);
}
}
}
}

Expand Down Expand Up @@ -194,82 +200,9 @@ InlinePropagator::InlinePropagator(
reference_pos_ = reference_pos;
}

void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
}
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" at ",
from_pos,
" to producer ",
to,
" because this would require replay.");
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}

void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
}
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" at ",
from_pos,
" to consumer ",
to,
" because this would require replay.");
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}

void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
}
// Step 1: find mapped_reference_pos_[to]
auto from_pos = mapped_reference_pos_.at(from);
TORCH_CHECK(
TransformReplay::fullSelfMatching(to, from),
"Unable to propagate CA position from ",
from,
" to sibling ",
to,
" because this would require replay.");
mapped_reference_pos_[to] = from_pos;
// Step 2: set CA position of `to`
setCAPos(to);
void InlinePropagator::setUp() {
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
}

namespace {
Expand Down Expand Up @@ -328,38 +261,78 @@ unsigned int getConsumerPosAlignedToProducerCA(

} // namespace

// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain.
void MaxProducerPosUpdater::handle(TensorView* consumer) {
unsigned int consumer_pos = 0;
for (auto producer : ir_utils::producerTvsOf(consumer)) {
consumer_pos = std::max(
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
void InlinePropagator::tearDown() {
for (auto consumer : needs_update_max_producer_) {
unsigned int consumer_pos = 0;
for (auto producer : ir_utils::producerTvsOf(consumer)) {
consumer_pos = std::max(
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
}
consumer->setMaxProducer(consumer_pos);
}
consumer->setMaxProducer(consumer_pos);
}

void MaxProducerPosUpdater::propagateC2P(TensorView* from, TensorView* to) {
if (updated_.empty()) {
// handle the reference tensor
updated_.insert(nullptr);
propagateC2P(nullptr, from);
}
for (auto consumer_tv : ir_utils::consumerTvsOf(to)) {
if (updated_.count(consumer_tv) > 0) {
continue;
}
handle(consumer_tv);
updated_.insert(consumer_tv);
void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" at ",
from_pos,
" to producer ",
to,
" because this would require replay.");
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}

void MaxProducerPosUpdater::propagateP2C(TensorView* from, TensorView* to) {
propagateC2P(from, to);
void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" at ",
from_pos,
" to consumer ",
to,
" because this would require replay.");
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}

void MaxProducerPosUpdater::propagateSibling(TensorView* from, TensorView* to) {
propagateC2P(from, to);
void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
// Step 1: find mapped_reference_pos_[to]
auto from_pos = mapped_reference_pos_.at(from);
TORCH_CHECK(
TransformReplay::fullSelfMatching(to, from),
"Unable to propagate CA position from ",
from,
" to sibling ",
to,
" because this would require replay.");
mapped_reference_pos_[to] = from_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}

} // namespace cuda
Expand Down
19 changes: 3 additions & 16 deletions torch/csrc/jit/codegen/cuda/inline_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ class TORCH_CUDA_CU_API InlinePropagator

const MaxPosCalculator max_pos_calc;
std::unordered_set<TensorView*> selected_;
std::unordered_set<TensorView*> needs_update_max_producer_;
TensorView* reference_;
size_t reference_pos_;
ComputeAtMode mode_ = ComputeAtMode::Standard;
bool is_first_ = true;

public:
InlinePropagator(
Expand All @@ -117,24 +117,11 @@ class TORCH_CUDA_CU_API InlinePropagator

// Actually propagate the transformations for the inlining pass. Uses the
// functions above to figure out what position to do the propagation at.
virtual void setUp() override;
virtual void propagateC2P(TensorView* from, TensorView* to) override;
virtual void propagateP2C(TensorView* from, TensorView* to) override;
virtual void propagateSibling(TensorView* from, TensorView* to) override;
};

// This is actually not a propagation, it only sets the max producer position of
// the tensors, and it is not needed to compute the max producer position in a
// specific order. But MaxInfoSpanningTree provides a very convenient API to
// visit the tensors, so I just use it for cleaner code.
class TORCH_CUDA_CU_API MaxProducerPosUpdater
: public MaxInfoSpanningTree::Propagator {
std::unordered_set<TensorView*> updated_;
void handle(TensorView* tv);

public:
virtual void propagateC2P(TensorView* from, TensorView* to) override;
virtual void propagateP2C(TensorView* from, TensorView* to) override;
virtual void propagateSibling(TensorView* from, TensorView* to) override;
virtual void tearDown() override;
};

} // namespace cuda
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
if (path_.empty()) {
compute_spanning_tree();
}
propagator->setUp();
for (const auto& next_hop : path_) {
switch (next_hop.type) {
case NextHopType::SIBLING:
Expand All @@ -160,6 +161,7 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
break;
}
}
propagator->tearDown();
}

MaxRootDomainInfoSpanningTree::RootDomainInfo::operator bool() const {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/maxinfo_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {

// This is the interface to implement the actual propagation
struct Propagator {
virtual void setUp() {}
virtual void tearDown() {}
virtual void propagateC2P(TensorView* from, TensorView* to) = 0;
virtual void propagateP2C(TensorView* from, TensorView* to) = 0;
virtual void propagateSibling(TensorView* from, TensorView* to) = 0;
Expand Down
4 changes: 0 additions & 4 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,10 +871,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
InlinePropagator inline_inner_most(
reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors);
spanning_tree.traverse(&inline_inner_most);

// Fix max producer position
MaxProducerPosUpdater updater;
spanning_tree.traverse(&updater);
}

} // namespace cuda
Expand Down

0 comments on commit 63630f1

Please sign in to comment.