Skip to content

Commit

Permalink
validateDomain in TransformPropagator (csarofeen#1796)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 1, 2022
1 parent c077085 commit 3f2c263
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,19 @@ bool TransformReplay::fullSelfMatching(
return true;
}

namespace {

// Make sure if tv is set to new_td it doesn't violate set compute at and max
// produce at positions.
bool validateDomain(TensorView* tv, TensorDomain* new_td) {
auto first_mismatch =
BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td);
return first_mismatch >= (int)tv->getMaxProducerPosition() &&
first_mismatch >= (int)tv->getComputeAtPosition();
}

} // namespace

void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) {
int pos = replayed_pos_.at(from);
// Note: [Using multiple TransformPropagators]
Expand All @@ -849,6 +862,13 @@ void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) {
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
if (new_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay.first),
"Tried to set the domain of ",
to,
" to ",
replay.first,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay.first);
new_pos = replay.second;
}
Expand All @@ -862,6 +882,13 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) {
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
if (new_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay.first),
"Tried to set the domain of ",
to,
" to ",
replay.first,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay.first);
new_pos = replay.second;
}
Expand All @@ -873,6 +900,13 @@ void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) {
// See note [Using multiple TransformPropagators]
if (!TransformReplay::fullSelfMatching(to, from)) {
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay),
"Tried to set the domain of ",
to,
" to ",
replay,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay);
}
replayed_pos_[to] = pos;
Expand Down

0 comments on commit 3f2c263

Please sign in to comment.