Skip to content

Commit

Permalink
Coding style cleanups (csarofeen#1798)
Browse files Browse the repository at this point in the history
Per offline discussion with @csarofeen, this PR does many renaming for better coding style: For all propagation-related things, I am now using the names `P2C` and `C2P` instead of `CasP` and `PasC`. Because "A as B" somewhat implies we want to replay A the same as B, but "B to A" sounds more general and is a better word for this case. Also, I modified the order of function arguments to match the order in its name. For example `PasC` should have `(producer, consumer)` or `(to, from)`, but not `(consumer, producer)` or `(from, to)`, and `C2P` should have `(consumer, producer)` or `(from, to)`, but not `(producer, consumer)` or `(to, from)`.
  • Loading branch information
zasdfgbnm committed Jul 1, 2022
1 parent 38c7f3c commit ef04f6c
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 103 deletions.
66 changes: 30 additions & 36 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ namespace jit {
namespace fuser {
namespace cuda {

bool InlinePropagatorSelector::allowPasC(TensorView* from, TensorView* to) {
bool InlinePropagatorSelector::allowC2P(TensorView* from, TensorView* to) {
return selected_.count(to) > 0;
}

bool InlinePropagatorSelector::allowCasP(TensorView* from, TensorView* to) {
bool InlinePropagatorSelector::allowP2C(TensorView* from, TensorView* to) {
// If the producer is in the selected set, then the consumer must also be
// replayed to obtain a compatible loop structure so that this producer
// can be consumed in this loop.
Expand Down Expand Up @@ -112,9 +112,9 @@ size_t MaxPosCalculator::getMaxPosSelf(
// Unrolled dimensions in producer or consumer
// Dimensions derived from root dimensions that exist in both but are
// unmappable
size_t MaxPosCalculator::getMaxPosPasC(
TensorView* producer,
TensorView* consumer) const {
size_t MaxPosCalculator::getMaxPosC2P(
TensorView* consumer,
TensorView* producer) const {
// Limit max position based on vectorized dims in consumer.
auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true);

Expand Down Expand Up @@ -144,9 +144,9 @@ size_t MaxPosCalculator::getMaxPosPasC(
// Unrolled dimensions in producer or consumer
// Dimensions derived from root dimensions that exist in both but are
// unmappable
size_t MaxPosCalculator::getMaxPosCasP(
TensorView* consumer,
TensorView* producer) const {
size_t MaxPosCalculator::getMaxPosP2C(
TensorView* producer,
TensorView* consumer) const {
auto max_producer_pos = getMaxPosSelf(producer, false, false, false);

auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
Expand All @@ -173,16 +173,14 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
// consumers are always replayed consistently
max_pos =
std::min<size_t>(max_pos, max_pos_calc.getMaxPosCasP(consumer_tv, tv));
std::min<size_t>(max_pos, max_pos_calc.getMaxPosP2C(tv, consumer_tv));
}
return max_pos;
}

size_t InlinePropagator::getFromPosPasC(
TensorView* producer,
TensorView* consumer) {
size_t max_pos = max_pos_calc.getMaxPosPasC(producer, consumer);
size_t pos = mapped_reference_pos_.at(consumer);
size_t InlinePropagator::getFromPosC2P(TensorView* from, TensorView* to) {
size_t max_pos = max_pos_calc.getMaxPosC2P(from, to);
size_t pos = mapped_reference_pos_.at(from);

if (mode_ == ComputeAtMode::BestEffort) {
return std::min(pos, max_pos);
Expand All @@ -193,21 +191,19 @@ size_t InlinePropagator::getFromPosPasC(
TORCH_INTERNAL_ASSERT(
pos <= max_pos,
"Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: ",
consumer,
from,
" to producer: ",
producer,
to,
" tried to do this at position: ",
pos,
" but max position that's allowed is ",
max_pos);
return pos;
}

size_t InlinePropagator::getFromPosCasP(
TensorView* consumer,
TensorView* producer) {
size_t max_pos = max_pos_calc.getMaxPosCasP(consumer, producer);
size_t pos = mapped_reference_pos_.at(producer);
size_t InlinePropagator::getFromPosP2C(TensorView* from, TensorView* to) {
size_t max_pos = max_pos_calc.getMaxPosP2C(from, to);
size_t pos = mapped_reference_pos_.at(from);

if (mode_ == ComputeAtMode::BestEffort) {
return std::min(pos, max_pos);
Expand All @@ -218,9 +214,9 @@ size_t InlinePropagator::getFromPosCasP(
TORCH_INTERNAL_ASSERT(
pos <= max_pos,
"Invalid compute at position detected in compute at when trying to propagate the CA position from producer: ",
producer,
from,
" to consumer: ",
consumer,
to,
" tried to do this at position: ",
pos,
" but max position that's allowed is ",
Expand Down Expand Up @@ -263,13 +259,13 @@ InlinePropagator::InlinePropagator(
".");
}

void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
setCAPos(reference_, reference_pos_);
mapped_reference_pos_[reference_] = reference_pos_;
}
int from_pos = getFromPosPasC(to, from);
int from_pos = getFromPosC2P(from, to);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
TORCH_CHECK(
Expand All @@ -283,13 +279,13 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
mapped_reference_pos_[to] = to_pos;
}

void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
setCAPos(reference_, reference_pos_);
mapped_reference_pos_[reference_] = reference_pos_;
}
int from_pos = getFromPosCasP(to, from);
int from_pos = getFromPosP2C(from, to);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
TORCH_CHECK(
Expand All @@ -303,7 +299,7 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
mapped_reference_pos_[to] = to_pos;
}

void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) {
void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
setCAPos(reference_, reference_pos_);
Expand Down Expand Up @@ -388,11 +384,11 @@ void MaxProducerPosUpdater::handle(TensorView* consumer) {
consumer->setMaxProducer(consumer_pos);
}

void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {
void MaxProducerPosUpdater::propagateC2P(TensorView* from, TensorView* to) {
if (updated_.empty()) {
// handle the reference tensor
updated_.insert(nullptr);
propagateTvPasC(nullptr, from);
propagateC2P(nullptr, from);
}
for (auto consumer_tv : ir_utils::consumerTvsOf(to)) {
if (updated_.count(consumer_tv) > 0) {
Expand All @@ -403,14 +399,12 @@ void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {
}
}

void MaxProducerPosUpdater::propagateTvCasP(TensorView* from, TensorView* to) {
propagateTvPasC(from, to);
void MaxProducerPosUpdater::propagateP2C(TensorView* from, TensorView* to) {
propagateC2P(from, to);
}

void MaxProducerPosUpdater::propagateTvSibling(
TensorView* from,
TensorView* to) {
propagateTvPasC(from, to);
void MaxProducerPosUpdater::propagateSibling(TensorView* from, TensorView* to) {
propagateC2P(from, to);
}

} // namespace cuda
Expand Down
28 changes: 14 additions & 14 deletions torch/csrc/jit/codegen/cuda/inline_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector {
std::unordered_set<TensorView*> selected_;

public:
virtual bool allowPasC(TensorView* from, TensorView* to) override;
virtual bool allowCasP(TensorView* from, TensorView* to) override;
virtual bool allowC2P(TensorView* from, TensorView* to) override;
virtual bool allowP2C(TensorView* from, TensorView* to) override;
virtual bool allowSibling(TensorView* from, TensorView* to) override;

InlinePropagatorSelector(std::unordered_set<TensorView*> selected)
Expand Down Expand Up @@ -60,11 +60,11 @@ class MaxPosCalculator {

// Returns the maximum position producer can be inlined based on consumer
// given the set ComputeAtMode
size_t getMaxPosPasC(TensorView* producer, TensorView* consumer) const;
size_t getMaxPosC2P(TensorView* from, TensorView* to) const;

// Returns the maximum position consumer can be inlined based on producer
// given the set ComputeAtMode
size_t getMaxPosCasP(TensorView* consumer, TensorView* producer) const;
size_t getMaxPosP2C(TensorView* from, TensorView* to) const;

MaxPosCalculator(ComputeAtMode mode);
};
Expand All @@ -76,13 +76,13 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {

// Returns the inline position in consumer that producer should be inlined as
// based on consumer, taking into consideration the max possible returned by
// getMaxPos{PasC, CasP}, the compute at mode type.
size_t getFromPosPasC(TensorView* producer, TensorView* consumer);
// getMaxPos{P2C, C2P}, the compute at mode type.
size_t getFromPosC2P(TensorView* from, TensorView* to);

// Returns the inline position in producer that consumer should be inlined as
// based on producer, taking into consideration the max possible returned by
// getMaxPos{PasC, CasP}, the compute at mode type.
size_t getFromPosCasP(TensorView* consumer, TensorView* producer);
// getMaxPos{P2C, C2P}, the compute at mode type.
size_t getFromPosP2C(TensorView* from, TensorView* to);

// We use mapped_reference_pos_ to keep track of the outer axes information of
// the reference tensor. That is, mapped_reference_pos_[tv] answers the
Expand Down Expand Up @@ -115,9 +115,9 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {

// Actually propagate the transformations for the inlining pass. Uses the
// functions above to figure out what position to do the propagation at.
virtual void propagateTvPasC(TensorView* from, TensorView* to) override;
virtual void propagateTvCasP(TensorView* from, TensorView* to) override;
virtual void propagateTvSibling(TensorView* from, TensorView* to) override;
virtual void propagateC2P(TensorView* from, TensorView* to) override;
virtual void propagateP2C(TensorView* from, TensorView* to) override;
virtual void propagateSibling(TensorView* from, TensorView* to) override;
};

// This is actually not a propagation, it only sets the max producer position of
Expand All @@ -129,9 +129,9 @@ class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator {
void handle(TensorView* tv);

public:
virtual void propagateTvPasC(TensorView* from, TensorView* to) override;
virtual void propagateTvCasP(TensorView* from, TensorView* to) override;
virtual void propagateTvSibling(TensorView* from, TensorView* to) override;
virtual void propagateC2P(TensorView* from, TensorView* to) override;
virtual void propagateP2C(TensorView* from, TensorView* to) override;
virtual void propagateSibling(TensorView* from, TensorView* to) override;
};

} // namespace cuda
Expand Down
30 changes: 15 additions & 15 deletions torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,18 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
}
};

auto allowPasC = [this](TensorView* from, TensorView* to) {
auto allowC2P = [this](TensorView* from, TensorView* to) {
if (selector_ == nullptr) {
return true;
}
return selector_->allowPasC(from, to);
return selector_->allowC2P(from, to);
};

auto allowCasP = [this](TensorView* from, TensorView* to) {
auto allowP2C = [this](TensorView* from, TensorView* to) {
if (selector_ == nullptr) {
return true;
}
return selector_->allowCasP(from, to);
return selector_->allowP2C(from, to);
};

auto allowSibling = [this](TensorView* from, TensorView* to) {
Expand Down Expand Up @@ -114,7 +114,7 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
}

for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) {
if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) {
if (replayed.count(consumer_tv) || !allowP2C(next_hop.to, consumer_tv)) {
continue;
}
insertNextHop(
Expand All @@ -128,7 +128,7 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
}

for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) {
if (replayed.count(producer_tv) || !allowPasC(next_hop.to, producer_tv)) {
if (replayed.count(producer_tv) || !allowC2P(next_hop.to, producer_tv)) {
continue;
}
insertNextHop(
Expand All @@ -150,13 +150,13 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
for (const auto& next_hop : path_) {
switch (next_hop.type) {
case NextHopType::SIBLING:
propagator->propagateTvSibling(next_hop.from, next_hop.to);
propagator->propagateSibling(next_hop.from, next_hop.to);
break;
case NextHopType::C_AS_P:
propagator->propagateTvCasP(next_hop.from, next_hop.to);
propagator->propagateP2C(next_hop.from, next_hop.to);
break;
case NextHopType::P_AS_C:
propagator->propagateTvPasC(next_hop.from, next_hop.to);
propagator->propagateC2P(next_hop.from, next_hop.to);
break;
}
}
Expand Down Expand Up @@ -416,20 +416,20 @@ std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree:
return from_info;
}

void SpanningTreePrinter::propagateTvPasC(TensorView* from, TensorView* to) {
stream_ << "propagateTvPasC" << std::endl;
void SpanningTreePrinter::propagateC2P(TensorView* from, TensorView* to) {
stream_ << "propagateC2P" << std::endl;
stream_ << " from: " << from->toString() << std::endl;
stream_ << " to: " << to->toString() << std::endl;
}

void SpanningTreePrinter::propagateTvCasP(TensorView* from, TensorView* to) {
stream_ << "propagateTvCasP" << std::endl;
void SpanningTreePrinter::propagateP2C(TensorView* from, TensorView* to) {
stream_ << "propagateP2C" << std::endl;
stream_ << " from: " << from->toString() << std::endl;
stream_ << " to: " << to->toString() << std::endl;
}

void SpanningTreePrinter::propagateTvSibling(TensorView* from, TensorView* to) {
stream_ << "propagateTvSibling" << std::endl;
void SpanningTreePrinter::propagateSibling(TensorView* from, TensorView* to) {
stream_ << "propagateSibling" << std::endl;
stream_ << " from: " << from->toString() << std::endl;
stream_ << " to: " << to->toString() << std::endl;
}
Expand Down
16 changes: 8 additions & 8 deletions torch/csrc/jit/codegen/cuda/maxinfo_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
// Class to subclass in order to stop traversal, by which limits the nodes in
// the spanning tree.
struct Selector {
virtual bool allowPasC(TensorView* from, TensorView* to) = 0;
virtual bool allowCasP(TensorView* from, TensorView* to) = 0;
virtual bool allowC2P(TensorView* from, TensorView* to) = 0;
virtual bool allowP2C(TensorView* from, TensorView* to) = 0;
virtual bool allowSibling(TensorView* from, TensorView* to) = 0;
};

// This is the interface to implement the actual propagation
struct Propagator {
virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0;
virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0;
virtual void propagateTvSibling(TensorView* from, TensorView* to) = 0;
virtual void propagateC2P(TensorView* from, TensorView* to) = 0;
virtual void propagateP2C(TensorView* from, TensorView* to) = 0;
virtual void propagateSibling(TensorView* from, TensorView* to) = 0;
};

// This is the interface that specifies the structure of information used to
Expand Down Expand Up @@ -237,9 +237,9 @@ class TORCH_CUDA_CU_API SpanningTreePrinter
std::ostream& stream_;

public:
virtual void propagateTvPasC(TensorView* from, TensorView* to) override;
virtual void propagateTvCasP(TensorView* from, TensorView* to) override;
virtual void propagateTvSibling(TensorView* from, TensorView* to) override;
virtual void propagateC2P(TensorView* from, TensorView* to) override;
virtual void propagateP2C(TensorView* from, TensorView* to) override;
virtual void propagateSibling(TensorView* from, TensorView* to) override;
SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {}
};

Expand Down
Loading

0 comments on commit ef04f6c

Please sign in to comment.