From 115d1f42c148ce95a4f436f9a0e1e0e606186487 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 28 Sep 2022 18:01:42 -0700 Subject: [PATCH 1/6] Cleanup trivial reduction workarounds --- torch/csrc/jit/codegen/cuda/inlining.cpp | 25 +++--- torch/csrc/jit/codegen/cuda/inlining.h | 17 ++-- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 48 ++++------- .../cuda/scheduler/reduction_utils.cpp | 7 +- .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 86 +++---------------- torch/csrc/jit/codegen/cuda/scheduler/utils.h | 12 +-- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 64 ++++++++++++++ 7 files changed, 115 insertions(+), 144 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inlining.cpp b/torch/csrc/jit/codegen/cuda/inlining.cpp index da6d229c68f8b..621b88b842c0a 100644 --- a/torch/csrc/jit/codegen/cuda/inlining.cpp +++ b/torch/csrc/jit/codegen/cuda/inlining.cpp @@ -153,29 +153,26 @@ size_t MaxPosCalculator::getMaxPosAll( return max_pos; } -void inlineMost(const std::unordered_set& uninlinable_ids) { - inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids); +void inlineMost() { + inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion())); } -void inlineMost( - const std::vector& tvs, - const std::unordered_set& uninlinable_ids) { +void inlineMost(const std::vector& tvs) { if (tvs.empty()) { return; } - MaxPosCalculator calc(uninlinable_ids); + MaxPosCalculator calc; for (auto tv : tvs) { tv->inlineAt(-1, true, &calc); } } void inlineMost( - const std::unordered_set& tvs, - const std::unordered_set& uninlinable_ids) { + const std::unordered_set& tvs) { if (tvs.empty()) { return; } - MaxPosCalculator calc(uninlinable_ids); + MaxPosCalculator calc; for (auto tv : tvs) { tv->inlineAt(-1, true, &calc); } @@ -276,10 +273,9 @@ std::unordered_map getPositionsMappedTo( void inlineAllAt( TensorView* reference_tv, int64_t reference_pos, - bool best_effort, - const std::unordered_set& uninlinable_ids) { + bool best_effort) { auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); - MaxPosCalculator calc(uninlinable_ids); + MaxPosCalculator calc; for (auto pair : mapped_positions) { pair.first->inlineAt(pair.second, best_effort, &calc); } @@ -289,10 +285,9 @@ void inlineSelectedAt( const std::unordered_set& selected, TensorView* reference_tv, int64_t reference_pos, - bool best_effort, - const std::unordered_set& uninlinable_ids) { + bool best_effort) { auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); - MaxPosCalculator calc(uninlinable_ids); + MaxPosCalculator calc; for (auto pair : mapped_positions) { if (selected.count(pair.first) > 0) { pair.first->inlineAt(pair.second, best_effort, &calc); diff --git a/torch/csrc/jit/codegen/cuda/inlining.h b/torch/csrc/jit/codegen/cuda/inlining.h index 3b15eb23f9877..0a3ce1e8012a1 100644 --- a/torch/csrc/jit/codegen/cuda/inlining.h +++ b/torch/csrc/jit/codegen/cuda/inlining.h @@ -64,26 +64,20 @@ class MaxPosCalculator { // Inline to the right most allowed position for all tensors in the current // fusion. -TORCH_CUDA_CU_API void inlineMost( - const std::unordered_set& uninlinable_ids = {}); +TORCH_CUDA_CU_API void inlineMost(); // Inline to the right most allowed position for the selected tensors in the // current fusion. -TORCH_CUDA_CU_API void inlineMost( - const std::vector& tvs, - const std::unordered_set& uninlinable_ids = {}); +TORCH_CUDA_CU_API void inlineMost(const std::vector& tvs); // Inline to the right most allowed position for the selected tensors in the // current fusion. -TORCH_CUDA_CU_API void inlineMost( - const std::unordered_set& tvs, - const std::unordered_set& uninlinable_ids = {}); +TORCH_CUDA_CU_API void inlineMost(const std::unordered_set& tvs); // Inline to the position corresponding to the reference position in the // reference tensor for all tensors in the current fusion. TORCH_CUDA_CU_API void inlineAllAt( TensorView* reference_tv, int64_t reference_pos, - bool best_effort = false, - const std::unordered_set& uninlinable_ids = {}); + bool best_effort = false); // Inline to the position corresponding to the reference position in the // reference tensor for selected tensors in the current fusion. @@ -91,8 +85,7 @@ TORCH_CUDA_CU_API void inlineSelectedAt( const std::unordered_set& selected, TensorView* reference_tv, int64_t reference_pos, - bool best_effort = false, - const std::unordered_set& uninlinable_ids = {}); + bool best_effort = false); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 3319bf28a18a9..9b3fabd1609d4 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1606,49 +1606,37 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { TORCH_CHECK( !outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(), "Merging IterDomains with ending values that are 0 is not supported at this time."); - TORCH_CHECK( - outer->isReduction() == inner->isReduction() || - (!outer->isReduction() && inner->isTrivialReduction()) || - (outer->isTrivialReduction() && !inner->isReduction()), - "Merging IterDomains requires that their iteration types match."); - TORCH_CHECK( - (outer->isGather() && inner->isGather()) || - (!outer->isGather() && !inner->isGather()), - "Merging gather and non-gather domains is not supported."); - - TORCH_CHECK( - !outer->isStride() && !inner->isStride(), - "No support for merging stride domains"); Val* merged_id_size = mul(outer->extent(), inner->extent()); IterType itype = outer->getIterType(); - if (outer->isBroadcast() && inner->isBroadcast()) { - itype = IterType::Broadcast; + if (inner->getIterType() == itype) { + goto itype_infer_finished; } - if ((outer->isBroadcast() || inner->isBroadcast()) && - (outer->getIterType() == IterType::Iteration || - inner->getIterType() == IterType::Iteration)) { - itype = IterType::Iteration; + if (inner->isTrivialReduction()) { + goto itype_infer_finished; } - // Merging trivial reduction with iter domain, that's fine, just make it an - // iter domain. - if ((outer->isTrivialReduction() || inner->isTrivialReduction()) && - (outer->getIterType() == IterType::Iteration || - inner->getIterType() == IterType::Iteration)) { - itype = IterType::Iteration; + if (outer->isTrivialReduction()) { + itype = inner->getIterType(); + goto itype_infer_finished; } - // Merging trivial reduction with broadcasting, that's fine, just make it a - // broadcasting. - if ((outer->isTrivialReduction() || inner->isTrivialReduction()) && - (outer->isBroadcast() || inner->isBroadcast())) { - itype = IterType::Broadcast; + if (inner->isBroadcast()) { + goto itype_infer_finished; } + if (outer->isBroadcast()) { + itype = inner->getIterType(); + goto itype_infer_finished; + } + + TORCH_CHECK( + false, "Merging IterDomains requires that their iteration types match."); + +itype_infer_finished: Val* expanded_extent = nullptr; if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) { if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp index ae9ecd88bbdc3..f88c34eb3f59a 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp @@ -330,13 +330,8 @@ void multiReductionInliner( } } - // Find iter domains that are mapped to a trivial reduction, these should - // never be inlined. - std::unordered_set mapped_to_trivial_reduction = - scheduler_utils::getTrivialReductionMap(fusion); - // Inline the schedule - inlineMost(mapped_to_trivial_reduction); + inlineMost(); } namespace { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index d985da926354b..e38a899ef7b17 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -21,26 +21,20 @@ namespace scheduler_utils { // Returns number of "valid" dimensions. e.g. if tv has // [I1, R2, I3, I4, R3{1}] -// where R3{1} is in dont_merge, resulting domain should be: -// [I1, I3*I4, R2, R3{1}] with return value 3 +// resulting domain should be: +// [I1, I3*I4, R2*R3{1}] with return value 3 // // if tv has // [R1, I2, R3, I4, R4, R5{1}, R6{1}] -// where R5{1} and R6{1} are in dont_merge, resulting domain should be: -// [I2*I4, R1*R3, R4, R5{1}, R6{1}] +// resulting domain should be: +// [I2*I4, R1*R3*R5{1}*R6{1}, R4] // with return value 3 -size_t merge_3d( - TensorView* tv, - const std::unordered_set& dont_merge) { +size_t merge_3d(TensorView* tv) { bool active_is_reduction = false; bool first_dim = true; int prev_i = -1; for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (dont_merge.count(tv->axis(i))) { - continue; - } - if (first_dim) { active_is_reduction = tv->axis(i)->isReduction(); prev_i = i; @@ -67,10 +61,6 @@ size_t merge_3d( for (int i = static_cast(tv->nDims()) - 2; i >= 0; i--) { auto id = tv->axis(i); - if (dont_merge.count(id)) { - continue; - } - if (first_dim) { active_is_reduction = id->isReduction(); prev_i = i; @@ -96,10 +86,6 @@ size_t merge_3d( prev_i = -1; for (int i = static_cast(tv->nDims()) - 3; i >= 0; i--) { - if (dont_merge.count(tv->axis(i))) { - continue; - } - if (first_dim) { active_is_reduction = tv->axis(i)->isReduction(); prev_i = i; @@ -114,7 +100,7 @@ size_t merge_3d( if (prev_i == -1) { // Two dimensional, put merged dimensions first tv->reorder({{-1, 0}, {-2, 1}}); - // [outer, inner, dont_merge...] + // [outer, inner] if (tv->axis(0)->isReduction()) { // put reductions as second axis tv->reorder({{0, 1}, {1, 0}}); @@ -195,13 +181,11 @@ c10::optional mergeDims( return left; } -size_t mergeReduction( - TensorView* tv, - const std::unordered_set& dont_merge) { +size_t mergeReduction(TensorView* tv) { int prev_i = -1; size_t num_merged = 0; for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) { + if (!tv->axis(i)->isReduction()) { continue; } if (prev_i == -1) { @@ -219,16 +203,14 @@ size_t mergeReduction( return prev_i == -1 ? 0 : num_merged + 1; } -size_t mergeNonReduction( - TensorView* tv, - const std::unordered_set& dont_merge) { +size_t mergeNonReduction(TensorView* tv) { int prev_i = -1; size_t num_merged = 0; if (tv->nDims() == 0) { return 0; } for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) { + if (tv->axis(i)->isReduction()) { continue; } if (prev_i == -1) { @@ -905,63 +887,21 @@ PersistentBufferSizeReturn persistentBufferSize( return persistent_buffer_size; } -std::unordered_set getTrivialReductionMap(Fusion* fusion) { - auto all_tvs = ir_utils::allTvs(fusion); - std::unordered_set mapped_to_trivial_reduction; - for (auto tv : all_tvs) { - // root domain vs domain shouldn't matter as at this point we shouldn't have - // any transformations. - for (auto id : tv->getRootDomain()) { - if (id->isTrivialReduction()) { - mapped_to_trivial_reduction.emplace(id); - } - } - } - - if (!mapped_to_trivial_reduction.empty()) { - // Use the loop map as that is the most permissive - auto ca_map = ComputeAtMap(fusion); - // Make a copy we need to check mappings of all - auto trivial_ids = mapped_to_trivial_reduction; - for (auto tv : all_tvs) { - for (auto id : tv->getRootDomain()) { - if (!id->extent()->isOneInt()) { - continue; - } - if (std::any_of( - trivial_ids.begin(), - trivial_ids.end(), - [&ca_map, &id](IterDomain* trivial_id) { - return ca_map.areMapped( - id, trivial_id, IdMappingMode::PERMISSIVE); - })) { - mapped_to_trivial_reduction.emplace(id); - } - } - } - } - return mapped_to_trivial_reduction; -} - std::pair canonicalDimReduction( Fusion* fusion, TensorView* tv, bool schedule_3D) { - std::unordered_set mapped_to_trivial_reduction = - getTrivialReductionMap(fusion); - TORCH_INTERNAL_ASSERT(tv != nullptr); if (!schedule_3D) { // We coalesce all reduction axes to the right; - bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0; + bool has_red_axis = mergeReduction(tv) > 0; - bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0; + bool has_iter_axis = mergeNonReduction(tv) > 0; return {has_iter_axis, has_red_axis}; } else { TORCH_INTERNAL_ASSERT( - merge_3d(tv, mapped_to_trivial_reduction) == 3, - "Tried 3D merge, but result is not 3D."); + merge_3d(tv) == 3, "Tried 3D merge, but result is not 3D."); return {true, true}; } } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h index 373a879f740d5..b5dbe162f0e93 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.h @@ -78,16 +78,12 @@ TORCH_CUDA_CU_API inline c10::optional mergeDims( } // Merge all reduction to the right side and returns total number of -// reduction axes. Don't merge is typically used for trivial reductions. -size_t mergeReduction( - TensorView* tv, - const std::unordered_set& dont_merge = {}); +// reduction axes. +size_t mergeReduction(TensorView* tv); // merge all non-reduction axes to the left side and returns total number of -// iteration axes. Don't merge is typically used for trivial reductions. -size_t mergeNonReduction( - TensorView* tv, - const std::unordered_set& dont_merge = {}); +// iteration axes. +size_t mergeNonReduction(TensorView* tv); // Propagate the parallelization from the selected dimensions of the reference // tensor to their corresponding dimensions in all selected tensors in the DAG. diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index ee5e55bd592e1..9c1b013ff4867 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -7853,6 +7853,70 @@ TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { lparams); } +TEST_F(NVFuserTest, FusionReductionWithTrivial1_CUDA) { + constexpr int bid_x = 80; + constexpr int tid_x = 4096; + + std::vector> shapes = { + {-1, -1, 1}, {-1, 1, -1}, {1, -1, -1}}; + + for (auto shape : shapes) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + std::vector> reduction_dims = { + {0}, + {1}, + {2}, + {0, 1}, + {0, 2}, + {1, 2}, + {0, 1, 2}, + }; + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + for (auto rdims : reduction_dims) { + std::vector rdims_(rdims.begin(), rdims.end()); + auto tv = sum(tv0, rdims_); + fusion.addOutput(tv); + } + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto concrete_shape = shape; + std::deque concrete_values = {bid_x, tid_x}; + for (auto& s : concrete_shape) { + if (s == -1) { + s = concrete_values.front(); + concrete_values.pop_front(); + } + } + + at::Tensor aten_input = at::randn(concrete_shape, options); + std::vector aten_outputs; + for (auto rdims : reduction_dims) { + aten_outputs.push_back(aten_input.sum(rdims)); + } + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs({aten_input}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + aten_outputs, + __LINE__, + __FILE__, + ""); + } +} + // Simple reduction parallelized on a symbolic size. TEST_F(NVFuserTest, FusionSymbolicReduction_CUDA) { Fusion fusion; From 1efe33cdf3574f478b04db86140c92802c9bc68b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 28 Sep 2022 18:05:47 -0700 Subject: [PATCH 2/6] comment --- torch/csrc/jit/codegen/cuda/scheduler/utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index e38a899ef7b17..036e9c920824a 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -27,7 +27,7 @@ namespace scheduler_utils { // if tv has // [R1, I2, R3, I4, R4, R5{1}, R6{1}] // resulting domain should be: -// [I2*I4, R1*R3*R5{1}*R6{1}, R4] +// [I2*I4, R1*R3, R4*R5{1}*R6{1}] // with return value 3 size_t merge_3d(TensorView* tv) { bool active_is_reduction = false; From bb031cc5412014c48461478af4d16a1d9801be19 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 28 Sep 2022 18:06:36 -0700 Subject: [PATCH 3/6] rename --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 9c1b013ff4867..92aeea64a9bdb 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -7853,7 +7853,7 @@ TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { lparams); } -TEST_F(NVFuserTest, FusionReductionWithTrivial1_CUDA) { +TEST_F(NVFuserTest, FusionReductionWithTrivial_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; From 050b14aaf24f89831f8e8443cfca72c8f372bed5 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 28 Sep 2022 20:39:49 -0700 Subject: [PATCH 4/6] lint --- torch/csrc/jit/codegen/cuda/inlining.cpp | 3 +-- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inlining.cpp b/torch/csrc/jit/codegen/cuda/inlining.cpp index 621b88b842c0a..eb2c4b3fb5db5 100644 --- a/torch/csrc/jit/codegen/cuda/inlining.cpp +++ b/torch/csrc/jit/codegen/cuda/inlining.cpp @@ -167,8 +167,7 @@ void inlineMost(const std::vector& tvs) { } } -void inlineMost( - const std::unordered_set& tvs) { +void inlineMost(const std::unordered_set& tvs) { if (tvs.empty()) { return; } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 92aeea64a9bdb..adf79ee121541 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -7902,7 +7902,7 @@ TEST_F(NVFuserTest, FusionReductionWithTrivial_CUDA) { for (auto rdims : reduction_dims) { aten_outputs.push_back(aten_input.sum(rdims)); } - + FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs({aten_input}); From 235e78b5343f6d13ed4d33f51925a32a39cae891 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 29 Sep 2022 19:10:18 -0700 Subject: [PATCH 5/6] no goto --- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 65 ++++++++++++++---------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 9b3fabd1609d4..410f008d59cb2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1597,6 +1597,43 @@ std::vector IterDomain::clone( return cloned_domains; } +IterType inferIterType(IterDomain* i1, IterDomain* i2) { + // The itertype inference is a pattern matching of the rules below: + // + // X + X = X + // trivial reduction + X = X + // X + trivial reduction = X + // broadcasting + X = X + // X + broadcasting = X + // fail + // + // The rules are proceeded one by one in order. For each rule, we test if the + // given (outer, inner) matches the pattern. If it does, then we stop + // procceeding and get a result. If we have reached the end without finding + // any matched pattern, then it is a mistake and should be reported. + // + // Note that based on the above rule: + // broadcasting + (non-trivial) reduction = reduction + // broadcasting + trivial reduction = broadcasting + if (i1->getIterType() == i2->getIterType()) { + return i1->getIterType(); + } + if (i1->isTrivialReduction()) { + return i2->getIterType(); + } + if (i2->isTrivialReduction()) { + return i1->getIterType(); + } + if (i1->isBroadcast()) { + return i2->getIterType(); + } + if (i2->isBroadcast()) { + return i1->getIterType(); + } + TORCH_CHECK( + false, "Merging IterDomains requires that their iteration types match."); +} + // Merging does not propagate the start and stop values of the input // domains to the merged output domain. The actual range of the // domains is enforced by predicates. Note that since only root @@ -1609,34 +1646,8 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { Val* merged_id_size = mul(outer->extent(), inner->extent()); - IterType itype = outer->getIterType(); - - if (inner->getIterType() == itype) { - goto itype_infer_finished; - } - - if (inner->isTrivialReduction()) { - goto itype_infer_finished; - } - - if (outer->isTrivialReduction()) { - itype = inner->getIterType(); - goto itype_infer_finished; - } - - if (inner->isBroadcast()) { - goto itype_infer_finished; - } - - if (outer->isBroadcast()) { - itype = inner->getIterType(); - goto itype_infer_finished; - } - - TORCH_CHECK( - false, "Merging IterDomains requires that their iteration types match."); + IterType itype = inferIterType(outer, inner); -itype_infer_finished: Val* expanded_extent = nullptr; if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) { if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) { From 1229502cf18e1e7803e2501c1f92689cdc4b6272 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 29 Sep 2022 19:13:59 -0700 Subject: [PATCH 6/6] comment --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index adf79ee121541..8711154c9e732 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -7853,6 +7853,10 @@ TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { lparams); } +// This test checks if our system could correctly handles the case where both +// reduction and trivial reduction exist in the fusion. Trivial reduction +// deserve testing because trivial reduction is handled more like a broadcasting +// rather than a reduction. TEST_F(NVFuserTest, FusionReductionWithTrivial_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096;