Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

View reduction devel merge #2042

Merged
merged 8 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ if(USE_CUDA)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu1.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu2.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu3.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)
Expand Down
12 changes: 2 additions & 10 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1416,16 +1416,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
}

//! Check if IterDomain is a reduction axis with size of 1, i.e.
//! a "squeeze" operator.
//!
//! NOTE: Detection of trivial reduction here is not
//! comprehensive. See detectTrivialReductionDerivedDomains for more
//! comprehensive analysis. We typically use this for root domain trivial
//! reduction checks. So we ship to the correct scheduler. It may
//! not be incredibly robust, but it makes sense to keep it for now.
bool isTrivialReduction() const {
return isReduction() && extent()->isOneInt();
}
//! a "squeeze" operator, or solely derived from such axes.
bool isTrivialReduction() const;

//! Split for stride by a given factor. It effectively does an inner
//! split by the factor and sets the inner domain as a Stride
Expand Down
37 changes: 36 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,37 @@ IterDomain* IterDomain::cloneWithoutRFactor() const {
return cloned;
}

bool IterDomain::isTrivialReduction() const {
if (!isReduction()) {
return false;
}

if (extent()->isOneInt()) {
return true;
}

// If this domain is an output of an expression, i.e., not a root
// domain, check if all root domains are trivial reductions. This is
// almost the same as the analysis done in TrivialReductionInfo, but
// is limited within a single tensor, whereas TrivialReductionInfo
// does more expensive analysis potentially traversing through
// rfactor domains
if (definition()) {
// Note: There's no const version of IterVisitor.
auto id_inputs = InputsOf::output(fusion(), const_cast<IterDomain*>(this));
if (std::all_of(
ir_utils::filterByType<IterDomain>(id_inputs).begin(),
ir_utils::filterByType<IterDomain>(id_inputs).end(),
[](IterDomain* root_id) {
return root_id->isReduction() && root_id->extent()->isOneInt();
})) {
return true;
}
}

return false;
}

std::vector<IterDomain*> IterDomain::clone(
const std::vector<IterDomain*>& domains) {
std::vector<IterDomain*> cloned_domains;
Expand All @@ -1744,7 +1775,11 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
outer->isReduction() == inner->isReduction() ||
(!outer->isReduction() && inner->isTrivialReduction()) ||
(outer->isTrivialReduction() && !inner->isReduction()),
"Merging IterDomains requires that their iteration types match.");
"Merging IterDomains requires that their iteration types match. ",
"Outer: ",
outer->toString(),
", Inner: ",
inner->toString());
TORCH_CHECK(
(outer->isGather() && inner->isGather()) ||
(!outer->isGather() && !inner->isGather()),
Expand Down
Loading