Skip to content

Commit

Permalink
Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 11, 2022
1 parent de6b7ca commit 4413c8f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
67 changes: 67 additions & 0 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,12 @@ void TransformPropagator::propagateC2P(TensorView* from, TensorView* to) {
// TransformPropagator to skip the replay when not necessary.
int new_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
if (debug) {
std::cout << "TransformPropagator::propagateC2P" << std::endl;
std::cout << " from: " << from << " @ " << pos << std::endl;
std::cout << " to: " << to << std::endl;
}
if (new_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
TORCH_INTERNAL_ASSERT(
Expand All @@ -874,6 +880,11 @@ void TransformPropagator::propagateC2P(TensorView* from, TensorView* to) {
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay.first);
new_pos = replay.second;
if (debug) {
std::cout << " replayed: " << to << " @ " << new_pos << std::endl;
}
} else if (debug) {
std::cout << " replay skipped. result position: " << new_pos << std::endl;
}
replayed_pos_[to] = new_pos;
}
Expand All @@ -883,6 +894,12 @@ void TransformPropagator::propagateP2C(TensorView* from, TensorView* to) {
// See note [Using multiple TransformPropagators]
int new_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
if (debug) {
std::cout << "TransformPropagator::propagateP2C" << std::endl;
std::cout << " from: " << from << " @ " << pos << std::endl;
std::cout << " to: " << to << std::endl;
}
if (new_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
TORCH_INTERNAL_ASSERT(
Expand All @@ -894,13 +911,24 @@ void TransformPropagator::propagateP2C(TensorView* from, TensorView* to) {
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay.first);
new_pos = replay.second;
if (debug) {
std::cout << " replayed: " << to << " @ " << new_pos << std::endl;
}
} else if (debug) {
std::cout << " replay skipped. result position: " << new_pos << std::endl;
}
replayed_pos_[to] = new_pos;
}

void TransformPropagator::propagateSibling(TensorView* from, TensorView* to) {
int pos = replayed_pos_.at(from);
// See note [Using multiple TransformPropagators]
bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
if (debug) {
std::cout << "TransformPropagator::propagateSibling" << std::endl;
std::cout << " from: " << from << " @ " << pos << std::endl;
std::cout << " to: " << to << std::endl;
}
if (!TransformReplay::fullSelfMatching(to, from)) {
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
TORCH_INTERNAL_ASSERT(
Expand All @@ -911,6 +939,11 @@ void TransformPropagator::propagateSibling(TensorView* from, TensorView* to) {
replay,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay);
if (debug) {
std::cout << " replayed: " << to << " @ " << pos << std::endl;
}
} else if (debug) {
std::cout << " replay skipped. result position: " << pos << std::endl;
}
replayed_pos_[to] = pos;
}
Expand All @@ -932,6 +965,12 @@ void MostInlinedTransformPropagator::propagateC2P(
// See note [Using multiple TransformPropagators]
int new_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
if (debug) {
std::cout << "MostInlinedTransformPropagator::propagateC2P" << std::endl;
std::cout << " from: " << from << std::endl;
std::cout << " to: " << to << std::endl;
}
if (new_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
TORCH_INTERNAL_ASSERT(
Expand All @@ -942,6 +981,11 @@ void MostInlinedTransformPropagator::propagateC2P(
replay.first,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay.first);
if (debug) {
std::cout << " replayed: " << to << std::endl;
}
} else if (debug) {
std::cout << " replay skipped" << std::endl;
}
}

Expand All @@ -952,6 +996,12 @@ void MostInlinedTransformPropagator::propagateP2C(
// See note [Using multiple TransformPropagators]
int new_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
if (debug) {
std::cout << "MostInlinedTransformPropagator::propagateP2C" << std::endl;
std::cout << " from: " << from << std::endl;
std::cout << " to: " << to << std::endl;
}
if (new_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
TORCH_INTERNAL_ASSERT(
Expand All @@ -962,13 +1012,25 @@ void MostInlinedTransformPropagator::propagateP2C(
replay.first,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay.first);
if (debug) {
std::cout << " replayed: " << to << std::endl;
}
} else if (debug) {
std::cout << " replay skipped" << std::endl;
}
}

void MostInlinedTransformPropagator::propagateSibling(
TensorView* from,
TensorView* to) {
// See note [Using multiple TransformPropagators]
bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
if (debug) {
std::cout << "MostInlinedTransformPropagator::propagateSibling"
<< std::endl;
std::cout << " from: " << from << std::endl;
std::cout << " to: " << to << std::endl;
}
if (!TransformReplay::fullSelfMatching(to, from)) {
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
TORCH_INTERNAL_ASSERT(
Expand All @@ -979,6 +1041,11 @@ void MostInlinedTransformPropagator::propagateSibling(
replay,
" but that would invalidate previously compute at position or max producer position.");
to->setDomain(replay);
if (debug) {
std::cout << " replayed: " << to << std::endl;
}
} else if (debug) {
std::cout << " replay skipped" << std::endl;
}
}

Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/codegen/cuda/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ auto parseDebugDumpOptions() {
{DebugDumpOption::SchedulerDebug, false},
{DebugDumpOption::ParallelDimensions, false},
{DebugDumpOption::Halo, false},
{DebugDumpOption::PerfDebugVerbose, false}};
{DebugDumpOption::PerfDebugVerbose, false},
{DebugDumpOption::TransformPropagator, false}};

if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) {
c10::string_view options_view(dump_options);
Expand Down Expand Up @@ -82,6 +83,8 @@ auto parseDebugDumpOptions() {
options_map[DebugDumpOption::Halo] = true;
} else if (token == "perf_debug_verbose") {
options_map[DebugDumpOption::PerfDebugVerbose] = true;
} else if (token == "transform_propagator") {
options_map[DebugDumpOption::TransformPropagator] = true;
} else {
TORCH_CHECK(
false,
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/jit/codegen/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ enum class DebugDumpOption {
SchedulerDebug, //! Dump scheduler heuristic parameters
ParallelDimensions, //!< Dump known parallel dimensions
Halo, //! Halo information of tensors
PerfDebugVerbose //! When running kernels, print verbose information
//! associated with what's running
PerfDebugVerbose, //! When running kernels, print verbose information
//! associated with what's running
TransformPropagator, //! When running TransformPropagator, print propagation
//! path and replay result
};

TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);
Expand Down

0 comments on commit 4413c8f

Please sign in to comment.