Skip to content

Commit 0ed815f

Browse files
authored
New TransformPropagator algorithm (#1763)
I completely rewrite the `TransformPropagator`. In this new `TransformPropagator`, I explicitly keep track of the information about which root ID in the starting tensor is preserved. The `RootIDInfo` stores the information for each root ID. `view` is not treated differently from other ops. During propagation, I do Dijkstra to find the path for each tensor in the graph that preserves the most amount of information. Each tensor will only be replayed once.
1 parent 6c19520 commit 0ed815f

File tree

5 files changed

+519
-123
lines changed

5 files changed

+519
-123
lines changed

torch/csrc/jit/codegen/cuda/ir_interface_nodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <torch/csrc/jit/codegen/cuda/fusion.h>
66
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
77
#include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>
8+
#include <torch/csrc/jit/codegen/cuda/mma_type.h>
89

910
#include <torch/csrc/jit/ir/ir.h>
1011

@@ -158,6 +159,7 @@ class TransformPropagator;
158159
class TransformIter;
159160
class TransformReplay;
160161
class OptOutMutator;
162+
class TensorDomain;
161163

162164
namespace ir_utils {
163165
class TVDomainGuard;

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include <torch/csrc/jit/codegen/cuda/fusion.h>
66
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
7-
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
87
#include <torch/csrc/jit/codegen/cuda/mma_type.h>
98
#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
109

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23574,6 +23574,96 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) {
2357423574
executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
2357523575
}
2357623576

23577+
TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) {
23578+
// https://github.com/csarofeen/pytorch/issues/1760
23579+
Fusion fusion;
23580+
FusionGuard fg(&fusion);
23581+
23582+
auto tv0 = makeSymbolicTensor(2);
23583+
fusion.addInput(tv0);
23584+
23585+
auto tvs = Welford(tv0, {1});
23586+
fusion.addOutput(tvs.var_sum);
23587+
23588+
tvs.avg->split(1, 1);
23589+
tvs.avg->split(1, 2);
23590+
tvs.avg->split(1, 3);
23591+
tvs.var_sum->split(1, 1);
23592+
tvs.var_sum->split(1, 2);
23593+
tvs.var_sum->split(1, 3);
23594+
tvs.n->split(1, 1);
23595+
tvs.n->split(1, 2);
23596+
tvs.n->split(1, 3);
23597+
23598+
auto tvs2 = tvs.rFactor({1, 4});
23599+
23600+
TransformPropagator::from(tvs2.var_sum);
23601+
23602+
// check that the resulting tensors in tvs2 are identical
23603+
auto checkSiblingConsistency = [](TensorView* replay, TensorView* target) {
23604+
auto replay_root = replay->getRootDomain();
23605+
auto replay_dom = replay->domain()->domain();
23606+
auto target_root = target->getRootDomain();
23607+
auto target_dom = target->domain()->domain();
23608+
std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
23609+
TORCH_CHECK(replay_root.size() == target_root.size());
23610+
target2replay_map.reserve(replay_root.size());
23611+
std::transform(
23612+
target_root.begin(),
23613+
target_root.end(),
23614+
replay_root.begin(),
23615+
std::inserter(target2replay_map, target2replay_map.begin()),
23616+
[](auto a, auto b) { return std::make_pair(a, b); });
23617+
BestEffortReplay replay_(replay_dom, target_dom, target2replay_map);
23618+
auto r = replay_.getReplay();
23619+
for (int64_t i = 0; i < replay_dom.size(); i++) {
23620+
auto target_id = target_dom[i];
23621+
auto replay_it = r.find(target_id);
23622+
TORCH_CHECK(replay_it != r.end());
23623+
TORCH_CHECK(
23624+
replay_it->second == replay_dom[i],
23625+
"IterDomain mismatch when checking ",
23626+
replay,
23627+
" and ",
23628+
target,
23629+
" at ",
23630+
i,
23631+
", got ",
23632+
replay_it->second,
23633+
" and ",
23634+
replay_dom[i]);
23635+
}
23636+
};
23637+
std::vector<TensorView*> siblings[] = {
23638+
{tvs.avg, tvs.var_sum, tvs.n}, {tvs2.avg, tvs2.var_sum, tvs2.n}};
23639+
for (auto tensors : siblings) {
23640+
for (auto t1 : tensors) {
23641+
for (auto t2 : tensors) {
23642+
checkSiblingConsistency(t1, t2);
23643+
}
23644+
}
23645+
}
23646+
}
23647+
23648+
TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) {
23649+
Fusion fusion;
23650+
FusionGuard fg(&fusion);
23651+
23652+
auto tv0 = makeSymbolicTensor(4);
23653+
auto tv1 = makeSymbolicTensor(6);
23654+
fusion.addInput(tv0);
23655+
23656+
auto tv2 = broadcast(tv0, {false, false, true, false, false, true});
23657+
auto tv3 = add(tv1, tv2);
23658+
fusion.addOutput(tv3);
23659+
23660+
tv0->merge(2);
23661+
tv0->merge(0);
23662+
TransformPropagator::from(tv0);
23663+
23664+
TORCH_CHECK(tv1->nDims() == 4);
23665+
}
23666+
2357723667
} // namespace jit
2357823668
} // namespace torch
2357923669
#endif // #if defined(USE_CUDA)

0 commit comments

Comments
 (0)