@@ -23574,6 +23574,96 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) {
23574
23574
executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
23575
23575
}
23576
23576
23577
+ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) {
23578
+ // https://github.com/csarofeen/pytorch/issues/1760
23579
+ Fusion fusion;
23580
+ FusionGuard fg(&fusion);
23581
+
23582
+ auto tv0 = makeSymbolicTensor(2);
23583
+ fusion.addInput(tv0);
23584
+
23585
+ auto tvs = Welford(tv0, {1});
23586
+ fusion.addOutput(tvs.var_sum);
23587
+
23588
+ tvs.avg->split(1, 1);
23589
+ tvs.avg->split(1, 2);
23590
+ tvs.avg->split(1, 3);
23591
+ tvs.var_sum->split(1, 1);
23592
+ tvs.var_sum->split(1, 2);
23593
+ tvs.var_sum->split(1, 3);
23594
+ tvs.n->split(1, 1);
23595
+ tvs.n->split(1, 2);
23596
+ tvs.n->split(1, 3);
23597
+
23598
+ auto tvs2 = tvs.rFactor({1, 4});
23599
+
23600
+ TransformPropagator::from(tvs2.var_sum);
23601
+
23602
+ // check that the resulting tensors in tvs2 are identical
23603
+ auto checkSiblingConsistency = [](TensorView* replay, TensorView* target) {
23604
+ auto replay_root = replay->getRootDomain();
23605
+ auto replay_dom = replay->domain()->domain();
23606
+ auto target_root = target->getRootDomain();
23607
+ auto target_dom = target->domain()->domain();
23608
+ std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
23609
+ TORCH_CHECK(replay_root.size() == target_root.size());
23610
+ target2replay_map.reserve(replay_root.size());
23611
+ std::transform(
23612
+ target_root.begin(),
23613
+ target_root.end(),
23614
+ replay_root.begin(),
23615
+ std::inserter(target2replay_map, target2replay_map.begin()),
23616
+ [](auto a, auto b) { return std::make_pair(a, b); });
23617
+ BestEffortReplay replay_(replay_dom, target_dom, target2replay_map);
23618
+ auto r = replay_.getReplay();
23619
+ for (int64_t i = 0; i < replay_dom.size(); i++) {
23620
+ auto target_id = target_dom[i];
23621
+ auto replay_it = r.find(target_id);
23622
+ TORCH_CHECK(replay_it != r.end());
23623
+ TORCH_CHECK(
23624
+ replay_it->second == replay_dom[i],
23625
+ "IterDomain mismatch when checking ",
23626
+ replay,
23627
+ " and ",
23628
+ target,
23629
+ " at ",
23630
+ i,
23631
+ ", got ",
23632
+ replay_it->second,
23633
+ " and ",
23634
+ replay_dom[i]);
23635
+ }
23636
+ };
23637
+ std::vector<TensorView*> siblings[] = {
23638
+ {tvs.avg, tvs.var_sum, tvs.n}, {tvs2.avg, tvs2.var_sum, tvs2.n}};
23639
+ for (auto tensors : siblings) {
23640
+ for (auto t1 : tensors) {
23641
+ for (auto t2 : tensors) {
23642
+ checkSiblingConsistency(t1, t2);
23643
+ }
23644
+ }
23645
+ }
23646
+ }
23647
+
23648
+ TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) {
23649
+ Fusion fusion;
23650
+ FusionGuard fg(&fusion);
23651
+
23652
+ auto tv0 = makeSymbolicTensor(4);
23653
+ auto tv1 = makeSymbolicTensor(6);
23654
+ fusion.addInput(tv0);
23655
+
23656
+ auto tv2 = broadcast(tv0, {false, false, true, false, false, true});
23657
+ auto tv3 = add(tv1, tv2);
23658
+ fusion.addOutput(tv3);
23659
+
23660
+ tv0->merge(2);
23661
+ tv0->merge(0);
23662
+ TransformPropagator::from(tv0);
23663
+
23664
+ TORCH_CHECK(tv1->nDims() == 4);
23665
+ }
23666
+
23577
23667
} // namespace jit
23578
23668
} // namespace torch
23579
23669
#endif // #if defined(USE_CUDA)
0 commit comments