Skip to content

Commit

Permalink
New TransformPropagator algorithm (#1763)
Browse files Browse the repository at this point in the history
I completely rewrite the `TransformPropagator`. In this new `TransformPropagator`, I explicitly keep track of the information about which root ID in the starting tensor is preserved. The `RootIDInfo` stores the information for each root ID. `view` is not treated differently from other ops. During propagation, I do Dijkstra to find the path for each tensor in the graph that preserves the most amount of information. Each tensor will only be replayed once.
  • Loading branch information
zasdfgbnm authored Jun 21, 2022
1 parent 6c19520 commit 0ed815f
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 123 deletions.
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>
#include <torch/csrc/jit/codegen/cuda/mma_type.h>

#include <torch/csrc/jit/ir/ir.h>

Expand Down Expand Up @@ -158,6 +159,7 @@ class TransformPropagator;
class TransformIter;
class TransformReplay;
class OptOutMutator;
class TensorDomain;

namespace ir_utils {
class TVDomainGuard;
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/mma_type.h>
#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>

Expand Down
90 changes: 90 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23574,6 +23574,96 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) {
executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) {
// https://github.com/csarofeen/pytorch/issues/1760
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);

auto tvs = Welford(tv0, {1});
fusion.addOutput(tvs.var_sum);

tvs.avg->split(1, 1);
tvs.avg->split(1, 2);
tvs.avg->split(1, 3);
tvs.var_sum->split(1, 1);
tvs.var_sum->split(1, 2);
tvs.var_sum->split(1, 3);
tvs.n->split(1, 1);
tvs.n->split(1, 2);
tvs.n->split(1, 3);

auto tvs2 = tvs.rFactor({1, 4});

TransformPropagator::from(tvs2.var_sum);

// check that the resulting tensors in tvs2 are identical
auto checkSiblingConsistency = [](TensorView* replay, TensorView* target) {
auto replay_root = replay->getRootDomain();
auto replay_dom = replay->domain()->domain();
auto target_root = target->getRootDomain();
auto target_dom = target->domain()->domain();
std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
TORCH_CHECK(replay_root.size() == target_root.size());
target2replay_map.reserve(replay_root.size());
std::transform(
target_root.begin(),
target_root.end(),
replay_root.begin(),
std::inserter(target2replay_map, target2replay_map.begin()),
[](auto a, auto b) { return std::make_pair(a, b); });
BestEffortReplay replay_(replay_dom, target_dom, target2replay_map);
auto r = replay_.getReplay();
for (int64_t i = 0; i < replay_dom.size(); i++) {
auto target_id = target_dom[i];
auto replay_it = r.find(target_id);
TORCH_CHECK(replay_it != r.end());
TORCH_CHECK(
replay_it->second == replay_dom[i],
"IterDomain mismatch when checking ",
replay,
" and ",
target,
" at ",
i,
", got ",
replay_it->second,
" and ",
replay_dom[i]);
}
};
std::vector<TensorView*> siblings[] = {
{tvs.avg, tvs.var_sum, tvs.n}, {tvs2.avg, tvs2.var_sum, tvs2.n}};
for (auto tensors : siblings) {
for (auto t1 : tensors) {
for (auto t2 : tensors) {
checkSiblingConsistency(t1, t2);
}
}
}
}

TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(4);
auto tv1 = makeSymbolicTensor(6);
fusion.addInput(tv0);

auto tv2 = broadcast(tv0, {false, false, true, false, false, true});
auto tv3 = add(tv1, tv2);
fusion.addOutput(tv3);

tv0->merge(2);
tv0->merge(0);
TransformPropagator::from(tv0);

TORCH_CHECK(tv1->nDims() == 4);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
Loading

0 comments on commit 0ed815f

Please sign in to comment.