From b7a206e93b4ac823c791c87f12859cf7af264a4c Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Fri, 9 Sep 2022 20:48:51 -0400 Subject: [PATCH] Move scheduler vectorize utilities into their own file (#1959) --- build_variables.bzl | 1 + .../codegen/cuda/scheduler/normalization.cpp | 2 +- .../jit/codegen/cuda/scheduler/pointwise.cpp | 2 +- .../jit/codegen/cuda/scheduler/reduction.cpp | 2 +- .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 260 ---------------- .../cuda/scheduler/vectorize_helper.cpp | 283 ++++++++++++++++++ .../codegen/cuda/scheduler/vectorize_helper.h | 14 +- 7 files changed, 291 insertions(+), 273 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp diff --git a/build_variables.bzl b/build_variables.bzl index 37840ae557274..eb058241445ec 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -730,6 +730,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp", "torch/csrc/jit/codegen/cuda/scheduler/registry.cpp", "torch/csrc/jit/codegen/cuda/scheduler/utils.cpp", + "torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp", "torch/csrc/jit/codegen/cuda/type_inference.cpp", "torch/csrc/jit/codegen/cuda/type_promotion.cpp", "torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp", diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index 0ab3b4676e9e1..459974b8d2884 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -909,7 +909,7 @@ TORCH_CUDA_CU_API std::shared_ptr getPersistentHeuristics( } // Try expanding vectorization to contig merged domains - vectorize_factor = scheduler_utils::expandVectorizationToContigMergedDomains( + vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains( fusion, runtime_info, vectorizable_inputs_outputs, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 936cd041c0d30..bd887a9a1754a 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -344,7 +344,7 @@ std::shared_ptr getPointwiseHeuristics( // TODO: This is an expensive function that shouldn't be in heuristics without // caching. auto expanded_vector_word_size = - scheduler_utils::expandVectorizationToContigMergedDomains( + vectorize_helper::expandVectorizationToContigMergedDomains( fusion, runtime_info, vectorizable_inputs_outputs, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index e6ef08b3c2568..3037f8469dad4 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -954,7 +954,7 @@ TORCH_CUDA_CU_API std::shared_ptr getReductionHeuristics( } // Try expanding vectorization to contig merged domains - vectorize_factor = scheduler_utils::expandVectorizationToContigMergedDomains( + vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains( fusion, runtime_info, vectorizable_inputs_outputs, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index b8680d3aa8dff..d985da926354b 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -1620,89 +1620,6 @@ BroadcastMultipleInformation getBroadcastMultiples( return bcast_info; } -size_t collectMaxVectorizeSizeWithContigMerge( - TensorView* tv, - IterDomain* leaf_merged_domain, - size_t max_vector_size_in_byte, - ExpressionEvaluator& expression_evaluator, - DataType index_type) { - // Maybe too conservative, but only handles fully contiguous tensors - // TODO: Relax the contiguity constraint to be similar to that in index - // computing. Just looking for all merged root domains in the right order, - // all merged root dimensions are contiguous, all merged root dimensions are - // next to eachother (exlcuding broadcast). - if (std::any_of( - tv->domain()->contiguity().begin(), - tv->domain()->contiguity().end(), - [](const auto contig) { return !contig; })) { - return 1; - } - - auto dtype_size = dataTypeSize(tv->dtype(), index_type); - const size_t max_vector_size = max_vector_size_in_byte / dtype_size; - - // Assume no halo-related expression appears in the fusion. No - // broadcast is merged, so indexability can be assumed to be true. - ContigIDs contigIds( - {leaf_merged_domain}, - tv->getMaybeRFactorDomain(), - tv->domain()->contiguity(), - {}, - {}, - true, - true); - - auto innermost_root_id = tv->getMaybeRFactorDomain().back(); - auto indexed_id = contigIds.rootToIndexedID().at(innermost_root_id); - - size_t merged_size = 1; - // If the indexed ID is a contig merged domain, i.e., it is - // different from innermost_root_id, we accumulate the extents of - // all the root domains covered by the contig indexed ID. Otherwise, - // just look at the extent of the innermost root ID. - if (indexed_id != innermost_root_id) { - const auto& within_root = contigIds.withinContigIDs().at(indexed_id); - for (auto root_id : tv->getMaybeRFactorDomain()) { - if (within_root.find(root_id) == within_root.end()) { - continue; - } - auto maybe_dimension_size = - expression_evaluator.evaluate(root_id->extent()); - TORCH_INTERNAL_ASSERT( - maybe_dimension_size.has_value(), - "Unknown extent of tv: ", - tv->toString(), - ", id: ", - root_id->toString()); - merged_size *= maybe_dimension_size->as(); - } - } else { - auto maybe_dimension_size = - expression_evaluator.evaluate(innermost_root_id->extent()); - TORCH_INTERNAL_ASSERT( - maybe_dimension_size.has_value(), - "Unknown extent of tv: ", - tv->toString(), - ", id: ", - innermost_root_id->toString()); - merged_size = maybe_dimension_size->as(); - } - - size_t vector_size = 1; - size_t next_vector_size = vector_size * 2; - - // Try until vector size exceeds the max allowed size - while (next_vector_size <= max_vector_size) { - if (merged_size % next_vector_size != 0) { - break; - } - vector_size = next_vector_size; - next_vector_size *= 2; - } - - return vector_size; -} - namespace matmul_utils { void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { @@ -2260,183 +2177,6 @@ void BoundedDirectionalTransformPropagator::bothWays( propagate(from, pos, included_tvs, *options); } -// Grab all values and expressions used to make the merged_domain and remove -// them from the fusion -void cleanUpInnermostMergedDomains( - const std::vector& root_domain, - IterDomain* merged_domain) { - TORCH_INTERNAL_ASSERT(merged_domain != nullptr); - TORCH_INTERNAL_ASSERT(!root_domain.empty()); - - std::unordered_set root_set({root_domain.begin(), root_domain.end()}); - - auto vals = DependencyCheck::getAllValsBetween(root_set, {merged_domain}); - - for (auto it = vals.rbegin(); it != vals.rend(); ++it) { - TORCH_INTERNAL_ASSERT((*it)->isA()); - auto id = (*it)->as(); - if (root_set.find(id) != root_set.end()) { - continue; - } - Fusion* fusion = id->container()->as(); - auto id_def = id->definition(); - TORCH_INTERNAL_ASSERT( - id_def->isA(), - "Invalid ID: ", - id->toString(), - ". Expected definition of a Merge expression: ", - (id_def != nullptr ? id_def->toString() : "nullptr")); - fusion->removeExpr(id_def); - fusion->removeVal(id); - } -} - -// Merge innermost domains for finding the widest vectorizable -// size. Return the merged domain or nullptr if no merge is done. -IterDomain* mergeInnermostDomains( - const std::vector& domain, - int num_merged_domains) { - const auto ndims = domain.size(); - IterDomain* merged_id = nullptr; - bool is_merge_done = false; - for (const auto i : c10::irange(num_merged_domains)) { - auto id = domain.at(ndims - 1 - i); - // broadcast and trivial reductions are ignored - if (id->isBroadcast() || id->isTrivialReduction()) { - continue; - } - if (merged_id == nullptr) { - merged_id = id; - } else { - auto id_inner = merged_id; - auto id_outer = id; - merged_id = IterDomain::merge(id_outer, id_inner); - is_merge_done = true; - } - } - return is_merge_done ? merged_id : nullptr; -} - -//! Attempt to expand vectorized domains to contig merged domains. Break point -//! identifies the point in which you can't propagate contiguous merges. For -//! example in pointwise this is the point where we want to split the -//! parallelization to take advantage of broadcast, and for reduction -//! schedulers it's the point where we switch from a reduction domain to an -//! iter domain (or vice versa). -size_t expandVectorizationToContigMergedDomains( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - const std::vector vectorizable_inputs_outputs, - TensorView* reference_tv, - int break_point, - size_t default_word_size) { - size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; - size_t common_alignment_size = - SchedulerRuntimeInfo::max_alignment_size_in_byte; - - for (auto inp_out : vectorizable_inputs_outputs) { - auto dtype_size = dataTypeSize( - inp_out->dtype(), indexModeToDtype(runtime_info.getIndexMode())); - - max_expand_size = std::min( - max_expand_size, - SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size); - max_expand_size = std::min( - max_expand_size, runtime_info.getMaxVectorizableWidth(inp_out)); - common_alignment_size = - std::min(common_alignment_size, runtime_info.getAlignmentSize(inp_out)); - } - - // If there's no possibility to increase vector size of provided tensors, - // then don't bother doing a more complex analysis to try and do so, just - // return early. - if (max_expand_size == default_word_size) { - return default_word_size; - } - - auto ca_map = ComputeAtMap(fusion); - - // Merge the domains right of the break point - const auto& ref_root = reference_tv->getMaybeRFactorDomain(); - const int num_merged_domains = - static_cast(ref_root.size()) - static_cast(break_point); - - // No expansion with no merged domain - if (num_merged_domains == 0) { - return default_word_size; - } - - // Merge the domains but don't modify TensorDomain - auto merged_domain = mergeInnermostDomains(ref_root, num_merged_domains); - - // No expansion is done if no merge is done. - if (merged_domain == nullptr) { - return default_word_size; - } - - // Find the vectorizable word size with the merged domains - size_t word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge( - reference_tv, - merged_domain, - common_alignment_size, - runtime_info.expressionEvaluator(), - indexModeToDtype(runtime_info.getIndexMode())); - - cleanUpInnermostMergedDomains(ref_root, merged_domain); - - // Stop if the reference doesn't get a larger word size. - if (word_size <= default_word_size) { - return default_word_size; - } - - // Check the other TVs and take the minimum of the valid word sizes - for (const auto tv : vectorizable_inputs_outputs) { - if (tv == reference_tv) { - continue; - } - - const auto& tv_root = tv->getMaybeRFactorDomain(); - - int tv_num_merged_domains = 0; - for (const auto i : c10::irange(num_merged_domains)) { - if (i == tv_root.size()) { - break; - } - auto ref_id = ref_root.at(ref_root.size() - 1 - i); - IterDomain* tv_id = tv_root.at(tv_root.size() - 1 - i); - // If not mapped, stop expanding. - if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT)) { - break; - } else { - ++tv_num_merged_domains; - } - } - - size_t tv_word_size = 1; - if (tv_num_merged_domains > 1) { - auto tv_merged_domain = - mergeInnermostDomains(tv_root, tv_num_merged_domains); - if (tv_merged_domain == nullptr) { - tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv); - } else { - tv_word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge( - tv, - tv_merged_domain, - common_alignment_size, - runtime_info.expressionEvaluator(), - indexModeToDtype(runtime_info.getIndexMode())); - cleanUpInnermostMergedDomains(tv_root, tv_merged_domain); - } - } else { - tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv); - } - - word_size = std::min(word_size, tv_word_size); - } - - return word_size; -} - DisjointSets disjointViewSets(Fusion* fusion) { // Start from the exact iter domain graph of the fusion IterDomainGraph id_graph(fusion); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp new file mode 100644 index 0000000000000..11a207f2ac91a --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp @@ -0,0 +1,283 @@ +#include + +#include +#include +#include +#include +#include + +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace vectorize_helper { + +// Grab all values and expressions used to make the merged_domain and remove +// them from the fusion +void cleanUpInnermostMergedDomains( + const std::vector& root_domain, + IterDomain* merged_domain) { + TORCH_INTERNAL_ASSERT(merged_domain != nullptr); + TORCH_INTERNAL_ASSERT(!root_domain.empty()); + + std::unordered_set root_set({root_domain.begin(), root_domain.end()}); + + auto vals = DependencyCheck::getAllValsBetween(root_set, {merged_domain}); + + for (auto it = vals.rbegin(); it != vals.rend(); ++it) { + TORCH_INTERNAL_ASSERT((*it)->isA()); + auto id = (*it)->as(); + if (root_set.find(id) != root_set.end()) { + continue; + } + Fusion* fusion = id->container()->as(); + auto id_def = id->definition(); + TORCH_INTERNAL_ASSERT( + id_def->isA(), + "Invalid ID: ", + id->toString(), + ". Expected definition of a Merge expression: ", + (id_def != nullptr ? id_def->toString() : "nullptr")); + fusion->removeExpr(id_def); + fusion->removeVal(id); + } +} + +// Merge innermost domains for finding the widest vectorizable +// size. Return the merged domain or nullptr if no merge is done. +IterDomain* mergeInnermostDomains( + const std::vector& domain, + int num_merged_domains) { + const auto ndims = domain.size(); + IterDomain* merged_id = nullptr; + bool is_merge_done = false; + for (const auto i : c10::irange(num_merged_domains)) { + auto id = domain.at(ndims - 1 - i); + // broadcast and trivial reductions are ignored + if (id->isBroadcast() || id->isTrivialReduction()) { + continue; + } + if (merged_id == nullptr) { + merged_id = id; + } else { + auto id_inner = merged_id; + auto id_outer = id; + merged_id = IterDomain::merge(id_outer, id_inner); + is_merge_done = true; + } + } + return is_merge_done ? merged_id : nullptr; +} + +size_t collectMaxVectorizeSizeWithContigMerge( + TensorView* tv, + IterDomain* leaf_merged_domain, + size_t max_vector_size_in_byte, + ExpressionEvaluator& expression_evaluator, + DataType index_type) { + // Maybe too conservative, but only handles fully contiguous tensors + // TODO: Relax the contiguity constraint to be similar to that in index + // computing. Just looking for all merged root domains in the right order, + // all merged root dimensions are contiguous, all merged root dimensions are + // next to eachother (exlcuding broadcast). + if (std::any_of( + tv->domain()->contiguity().begin(), + tv->domain()->contiguity().end(), + [](const auto contig) { return !contig; })) { + return 1; + } + + auto dtype_size = dataTypeSize(tv->dtype(), index_type); + const size_t max_vector_size = max_vector_size_in_byte / dtype_size; + + // Assume no halo-related expression appears in the fusion. No + // broadcast is merged, so indexability can be assumed to be true. + ContigIDs contigIds( + {leaf_merged_domain}, + tv->getMaybeRFactorDomain(), + tv->domain()->contiguity(), + {}, + {}, + true, + true); + + auto innermost_root_id = tv->getMaybeRFactorDomain().back(); + auto indexed_id = contigIds.rootToIndexedID().at(innermost_root_id); + + size_t merged_size = 1; + // If the indexed ID is a contig merged domain, i.e., it is + // different from innermost_root_id, we accumulate the extents of + // all the root domains covered by the contig indexed ID. Otherwise, + // just look at the extent of the innermost root ID. + if (indexed_id != innermost_root_id) { + const auto& within_root = contigIds.withinContigIDs().at(indexed_id); + for (auto root_id : tv->getMaybeRFactorDomain()) { + if (within_root.find(root_id) == within_root.end()) { + continue; + } + auto maybe_dimension_size = + expression_evaluator.evaluate(root_id->extent()); + TORCH_INTERNAL_ASSERT( + maybe_dimension_size.has_value(), + "Unknown extent of tv: ", + tv->toString(), + ", id: ", + root_id->toString()); + merged_size *= maybe_dimension_size->as(); + } + } else { + auto maybe_dimension_size = + expression_evaluator.evaluate(innermost_root_id->extent()); + TORCH_INTERNAL_ASSERT( + maybe_dimension_size.has_value(), + "Unknown extent of tv: ", + tv->toString(), + ", id: ", + innermost_root_id->toString()); + merged_size = maybe_dimension_size->as(); + } + + size_t vector_size = 1; + size_t next_vector_size = vector_size * 2; + + // Try until vector size exceeds the max allowed size + while (next_vector_size <= max_vector_size) { + if (merged_size % next_vector_size != 0) { + break; + } + vector_size = next_vector_size; + next_vector_size *= 2; + } + + return vector_size; +} + +//! Attempt to expand vectorized domains to contig merged domains. Break point +//! identifies the point in which you can't propagate contiguous merges. For +//! example in pointwise this is the point where we want to split the +//! parallelization to take advantage of broadcast, and for reduction +//! schedulers it's the point where we switch from a reduction domain to an +//! iter domain (or vice versa). +size_t expandVectorizationToContigMergedDomains( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + const std::vector vectorizable_inputs_outputs, + TensorView* reference_tv, + int break_point, + size_t default_word_size) { + size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; + size_t common_alignment_size = + SchedulerRuntimeInfo::max_alignment_size_in_byte; + + for (auto inp_out : vectorizable_inputs_outputs) { + auto dtype_size = dataTypeSize( + inp_out->dtype(), indexModeToDtype(runtime_info.getIndexMode())); + + max_expand_size = std::min( + max_expand_size, + SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size); + max_expand_size = std::min( + max_expand_size, runtime_info.getMaxVectorizableWidth(inp_out)); + common_alignment_size = + std::min(common_alignment_size, runtime_info.getAlignmentSize(inp_out)); + } + + // If there's no possibility to increase vector size of provided tensors, + // then don't bother doing a more complex analysis to try and do so, just + // return early. + if (max_expand_size == default_word_size) { + return default_word_size; + } + + auto ca_map = ComputeAtMap(fusion); + + // Merge the domains right of the break point + const auto& ref_root = reference_tv->getMaybeRFactorDomain(); + const int num_merged_domains = + static_cast(ref_root.size()) - static_cast(break_point); + + // No expansion with no merged domain + if (num_merged_domains == 0) { + return default_word_size; + } + + // Merge the domains but don't modify TensorDomain + auto merged_domain = mergeInnermostDomains(ref_root, num_merged_domains); + + // No expansion is done if no merge is done. + if (merged_domain == nullptr) { + return default_word_size; + } + + // Find the vectorizable word size with the merged domains + size_t word_size = collectMaxVectorizeSizeWithContigMerge( + reference_tv, + merged_domain, + common_alignment_size, + runtime_info.expressionEvaluator(), + indexModeToDtype(runtime_info.getIndexMode())); + + cleanUpInnermostMergedDomains(ref_root, merged_domain); + + // Stop if the reference doesn't get a larger word size. + if (word_size <= default_word_size) { + return default_word_size; + } + + // Check the other TVs and take the minimum of the valid word sizes + for (const auto tv : vectorizable_inputs_outputs) { + if (tv == reference_tv) { + continue; + } + + const auto& tv_root = tv->getMaybeRFactorDomain(); + + int tv_num_merged_domains = 0; + for (const auto i : c10::irange(num_merged_domains)) { + if (i == tv_root.size()) { + break; + } + auto ref_id = ref_root.at(ref_root.size() - 1 - i); + IterDomain* tv_id = tv_root.at(tv_root.size() - 1 - i); + // If not mapped, stop expanding. + if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT)) { + break; + } else { + ++tv_num_merged_domains; + } + } + + size_t tv_word_size = 1; + if (tv_num_merged_domains > 1) { + auto tv_merged_domain = + mergeInnermostDomains(tv_root, tv_num_merged_domains); + if (tv_merged_domain == nullptr) { + tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv); + } else { + tv_word_size = collectMaxVectorizeSizeWithContigMerge( + tv, + tv_merged_domain, + common_alignment_size, + runtime_info.expressionEvaluator(), + indexModeToDtype(runtime_info.getIndexMode())); + cleanUpInnermostMergedDomains(tv_root, tv_merged_domain); + } + } else { + tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv); + } + + word_size = std::min(word_size, tv_word_size); + } + + return word_size; +} + +} // namespace vectorize_helper +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h index 0a67d00618e23..a9b959b495d60 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h @@ -2,21 +2,15 @@ #include #include -#include #include -#include + +#include namespace torch { namespace jit { namespace fuser { namespace cuda { - -// TODO: Put implementations in a vectorize_helper.cpp -namespace scheduler_utils { - -// Moved the definition of these to -// torch/csrc/jit/codegen/cuda/scheduler/utils.cpp as making new CPP files is -// painful for multiple reasons. +namespace vectorize_helper { // Grab all values and expressions used to make the merged_domain and remove // them from the fusion @@ -44,7 +38,7 @@ size_t expandVectorizationToContigMergedDomains( int break_point, size_t default_word_size); -} // namespace scheduler_utils +} // namespace vectorize_helper } // namespace cuda } // namespace fuser } // namespace jit