diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp index 27e4d21928e2..3223321927a7 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp @@ -1259,7 +1259,7 @@ inline void inferGroupInputs( } } // namespace -FusionSegmentRuntime::SchedulerEntryPtr SegmentedFusion::makeSchedulerEntry( +FusionKernelRuntime::SchedulerEntryPtr SegmentedFusion::makeSchedulerEntry( SegmentedGroup* sg, ExpressionEvaluator& ee) { ExpressionEvaluator local_ee(&fusion_); @@ -1268,12 +1268,12 @@ FusionSegmentRuntime::SchedulerEntryPtr SegmentedFusion::makeSchedulerEntry( return SchedulerEntry::makeEntry(sg->heuristic(), &fusion_, local_ee); } -std::unique_ptr SegmentedFusion::makeHeuristics( +std::unique_ptr SegmentedFusion::makeHeuristics( const at::ArrayRef& inputs) { - auto ret = std::make_unique(); + auto ret = std::make_unique(); auto evaluator = executor_utils::bindFusionInputs(inputs, &fusion_); for (auto g : groups()) { - ret->emplace_back(makeSchedulerEntry(g, evaluator)); + ret->emplaceBack(makeSchedulerEntry(g, evaluator)); } return ret; } diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h index 5e9d53a4a9a5..92909178a00a 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h @@ -100,7 +100,7 @@ class TORCH_CUDA_CU_API SegmentedGroup { private: friend class SegmentCandidateFinder; friend class SegmentedFusion; - friend class FusionSegmentRuntime; + friend class FusionKernelRuntime; //! unique identifier of group in the segmented fusion int group_id_ = -1; @@ -174,23 +174,47 @@ class TORCH_CUDA_CU_API SegmentedGroup { std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group); -//! Auxiliary class for managing a list of heuristics instances for the -//! Segmented Groups -class TORCH_CUDA_CU_API SegmentHeuristics { - using SchedulerEntryPtr = std::unique_ptr; +//! Auxiliary class for storing heuristics. The managed data is either +//! a single scheduler entry for complete fusion, +//! or a vector of schedulers, one for each segment, for segmented fusion. +class TORCH_CUDA_CU_API FusionHeuristics { + using SchedulerEntryOwningPtr = std::unique_ptr; public: - explicit SegmentHeuristics() = default; - void emplace_back(SchedulerEntryPtr&& pt) { + //! Constructor for segmented fusion case. Created with empty list and + //! uses emplaceBack for inserting heuristics in order + explicit FusionHeuristics() = default; + + //! Constructor for complete fusion case, generates the scheduler entry + //! for the fusion owning the given expression + explicit FusionHeuristics( + ScheduleHeuristic schedule_heuristic, + ExpressionEvaluator& expr_eval) { + heuristics_.emplace_back(SchedulerEntry::makeEntry( + schedule_heuristic, expr_eval.fusion(), expr_eval)); + is_segmented_ = false; + } + + //! Place a scheduler entry on the list. Applies to segmented fusion only. + void emplaceBack(SchedulerEntryOwningPtr&& pt) { + TORCH_INTERNAL_ASSERT(is_segmented_); heuristics_.emplace_back(std::move(pt)); } - const std::vector& heuristics() const { + //! Returns list of schedulers for a segmneted fusion. + const std::vector& heuristicsList() const { return heuristics_; } + //! Returns the single scheduler for a complete fusion. + SchedulerEntry* singleHeuristics() { + TORCH_INTERNAL_ASSERT(!is_segmented_); + return heuristics_.begin()->get(); + } + private: - std::vector heuristics_; + std::vector heuristics_; + bool is_segmented_ = true; }; //! Exported Interface for representing segmented fusion graph @@ -237,7 +261,7 @@ class TORCH_CUDA_CU_API SegmentedFusion { std::unique_ptr makeFusion(SegmentedGroup* sg); //! Make heuristics for all groups in this segmented fusion - std::unique_ptr makeHeuristics( + std::unique_ptr makeHeuristics( const at::ArrayRef& inputs); //! Inline Debug print for segmented fusion diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index d43a5bdcd723..f1e77dab3e58 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -281,53 +281,31 @@ FusionExecutorCache::FusionExecutorCache(std::unique_ptr&& fusion) : fusion_(std::move(fusion)) { FUSER_PERF_SCOPE("FusionExecutorCache::FusionExecutorCache"); - // case of segmented fusion - // TODO: might be worthwhile re-using the SchedulerEntry infrastructure for - // single-kernel fusion as well. - const bool segmented = - !SchedulerEntry::proposeHeuristics(fusion_.get()).has_value(); + //! Try to schedule the complete fusion + const auto maybe_complete_fusion_scheduler = + SchedulerEntry::proposeHeuristics(fusion_.get()); + + //! Decide if this fusion is segmented or not + const bool segmented = !maybe_complete_fusion_scheduler.has_value(); if (segmented) { + // Segment the fusion through FusionSegmenter and + // initialize the caching for segmented heuristics fusion_segments_ = fusion_->segment(); - fusion_segment_runtime_cache_.initCache(fusion_segments_.get()); + fusion_kernel_runtime_cache_.initSegmentCache(fusion_segments_.get()); if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { fusion_segments_->print(); } - return; - } - - // In the case that the fusion isn't segmented but user - // wants segmented fusion in the debug print. Will - // print math of the composite fusion as placeholder - if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { - fusion_->printMath(); - } - - // avoid putting `has_nontrivial_reduction_` in the initializer list - has_nontrivial_reduction_ = fusion_->hasReduction(); - - if (has_nontrivial_reduction_) { - FusionGuard fg(fusion_.get()); - - // Use dependency check to find the reduction tv as it returns used values - // instead of exprs. - - // The call is relatively heavy weight, consider caching - auto all_values = DependencyCheck::getAllValsBetween( - {fusion_->inputs().begin(), fusion_->inputs().end()}, - fusion_->outputs()); - - // Separate the reduction TensorViews from the other TensorViews - // Ignore input TensorViews - for (auto tv : ir_utils::filterByType(all_values)) { - if (tv->hasReduction()) { - reduction_tv_.push_back(tv); - } + } else { + // Initialize single kernel case + fusion_kernel_runtime_cache_.initSingleKernelCache( + fusion_.get(), maybe_complete_fusion_scheduler.value()); + // In the case that the fusion isn't segmented but user + // wants segmented fusion in the debug print. Will + // print math of the composite fusion as placeholder + if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { + fusion_->printMath(); } - - TORCH_INTERNAL_ASSERT( - !reduction_tv_.empty(), - "Could not find any reduction TensorViews in the fusion."); } } @@ -335,25 +313,6 @@ std::vector FusionExecutorCache::runFusionWithInputs( const at::ArrayRef& inputs) { FUSER_PERF_SCOPE("runFusionWithInputs"); - // TODO: This seems overly conservative to send to normalization scheduler. We - // may want to check there's a "residual path" around the reduction. - auto detect_normalization_fusion = [&]() { - for (auto expr : fusion_->exprs()) { - if (expr->getExprType() == ExprType::BroadcastOp) { - auto output = expr->output(0); - auto input_def_expr = expr->input(0)->definition(); - if (!fusion_->unordered_uses(output).empty() && - input_def_expr != nullptr && - input_def_expr->getExprType() == ExprType::ReductionOp) { - return true; - } - } - } - return false; - }; - - LaunchParams launch_params; - // get unique id `unique_id` for given input set `inputs`; auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs); if (id_lookup_ret.eviction) { @@ -361,152 +320,68 @@ std::vector FusionExecutorCache::runFusionWithInputs( } const size_t unique_id = id_lookup_ret.id; - const int device_index = getCommonDeviceCUDA(inputs); - TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs"); - - // Manage Segmented Fusion through FusionSegmentRuntimeCache - if (isSegmented()) { - auto seg_runtime = fusion_segment_runtime_cache_.getRt(inputs, unique_id); - // Propagate the unique_id so the contained fusionExecutors in the runtime - // entry will cache the buffer sizes and launch params based on this id. - return seg_runtime->runWithInput(inputs, unique_id); - } - - if (code_to_fe_lookup_.count(unique_id) == 0) { - // enter when we get a new input set. We need to search for compatible - // entries in cached `FusionExecutor` or compile new one as needed. - - // caching strategy is different for pw-fusion and reduction-fusion. - if (has_nontrivial_reduction_) { - bool isNormalizationFusion = detect_normalization_fusion(); - // Generate the reduction parameters - auto reduction_params = (isNormalizationFusion) - ? getNormalizationHeuristics(fusion_.get(), inputs, reduction_tv_) - : getReductionHeuristics( - fusion_.get(), inputs, reduction_tv_.front()); - - TORCH_INTERNAL_ASSERT( - reduction_params.has_value(), - "Error getting reduction heuristics for scheduling."); - - launch_params = reduction_params.value().lparams; - - // cache based on launch parameters - auto fusion_executor = - &red_fusion_executor_cache_[device_index][reduction_params.value()]; - - if (!fusion_executor->compiled()) { - // HEURISTIC NOT COMPILED, COMPILE A KERNEL - - // We clone *fusion_ to fusion so we can leave the unscheduled - // computational graph intact for future compilation. - Fusion fusion_clone = *fusion_; - FusionGuard fg(&fusion_clone); - - // Separate the reduction TensorViews from the other TensorViews - // Ignore input TensorViews - std::vector clone_reduction_tv; - std::vector clone_other_tv; - auto all_values = DependencyCheck::getAllValsBetween( - {fusion_clone.inputs().begin(), fusion_clone.inputs().end()}, - fusion_clone.outputs()); - - for (auto tv : ir_utils::filterByType(all_values)) { - if (tv->hasReduction()) { - clone_reduction_tv.push_back(tv); - } else if (!fusion_clone.hasInput(tv)) { - clone_other_tv.push_back(tv); - } - } - - if (isNormalizationFusion) { - scheduleNormalization( - &fusion_clone, - reduction_params.value(), - clone_reduction_tv, - clone_other_tv); - } else { - auto single_reduction_tv = clone_reduction_tv.front(); - - // Heavy weight call - auto outputs_of_reduction = - DependencyCheck::getAllOutputsOf({single_reduction_tv}); - - auto tv_entries = - ir_utils::filterByType(outputs_of_reduction); - - std::vector tv_outputs_of_reduction( - tv_entries.begin(), tv_entries.end()); - - scheduleReduction( - &fusion_clone, - reduction_params.value(), - single_reduction_tv, - tv_outputs_of_reduction); - } - - // This means we have not found a previously generated kernel that is - // compatible with the new reduction params. We need to finish codegen. - CompileOptions options; - options.device = c10::Device(DeviceType::CUDA, device_index); - fusion_executor->compileFusion(&fusion_clone, options); - } - // record new short cut to `FusionExecutor` - code_to_fe_lookup_[unique_id] = fusion_executor; - - } else { - // Handle pointwise operations - if (pw_fusion_executor_cache_.count(device_index) == 0) { - pw_fusion_executor_cache_[device_index] = - std::make_unique(); - CompileOptions options; - options.device = c10::Device(DeviceType::CUDA, device_index); - // We do not need to copy fusion_ because we are not generating - // multiple kernels for point-wise operations. - auto fusion_clone = *fusion_; - scheduleFusion(&fusion_clone, inputs); - pw_fusion_executor_cache_[device_index]->compileFusion( - &fusion_clone, options); - } - // record new short cut to `FusionExecutor` - code_to_fe_lookup_[unique_id] = - pw_fusion_executor_cache_[device_index].get(); - } - } - - return code_to_fe_lookup_[unique_id]->runFusion( - inputs, launch_params, unique_id); + // Manage Segmented Fusion through FusionKernelRuntimeCache + auto fusion_kernel_runtime = + fusion_kernel_runtime_cache_.getRt(inputs, unique_id); + // Propagate the unique_id so the contained fusionExecutors in the runtime + // entry will cache the buffer sizes and launch params based on this id. + return fusion_kernel_runtime->runWithInput(inputs, unique_id); } -FusionSegmentRuntime::FusionSegmentRuntime( +FusionKernelRuntime::FusionKernelRuntime( SegmentedFusion* segmented_fusion, - std::unique_ptr& heuristics, + std::unique_ptr& heuristics, size_t input_id) : executors_(segmented_fusion->groups().size()), heuristics_(std::move(heuristics)), segmented_fusion_(segmented_fusion) {} -// Largely duplicated from FusionExecutorCache -std::vector FusionSegmentRuntime::runSegmentWithInput( - SegmentedGroup* sg, +FusionKernelRuntime::FusionKernelRuntime( + Fusion* fusion, + std::unique_ptr& heuristics, + size_t input_id) + : executors_(1), + heuristics_(std::move(heuristics)), + is_segmented_(false), + complete_fusion_(fusion) {} + +std::vector FusionKernelRuntime::runKernelWithInput( const at::ArrayRef& inputs, - size_t input_id) { - auto group_id = sg->groupId(); + size_t input_id, + SegmentedGroup* sg) { + // This function will be called once on un-segmented fusion, + // for segmented fusion, this function will be called on each segment + // In the case of segmented fusion, segmented group needs to be given so + // a kernel is compiled and run for a segmented group + // In the case of complete fusion, sg = nullptr, and the original fusion + // is complied and run + auto group_id = sg ? sg->groupId() : 0; const int device_index = getCommonDeviceCUDA(inputs); + TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs"); + LaunchParams launch_params; auto scheduler_entry = schedulers()[group_id].get(); - // Check that the heuristics are matched - TORCH_INTERNAL_ASSERT(scheduler_entry->heuristc() == sg->heuristic()); + // Check that the heuristics are matched, in the case of segmented fusion + TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristc() == sg->heuristic()); if (!executors_[group_id].compiled()) { - std::unique_ptr fusion_seg = segmented_fusion_->makeFusion(sg); + std::unique_ptr fusion_to_run; + if (sg) { + // Running a segment group as a single kernel, + // make a fusion to run from segmented fusion + fusion_to_run = segmented_fusion_->makeFusion(sg); + } else { + // Without a segmented group defaults to compiling the + // complete fusion + fusion_to_run = std::make_unique(*complete_fusion_); + } CompileOptions options; options.device = c10::Device(DeviceType::CUDA, device_index); - FusionGuard fg(fusion_seg.get()); - scheduler_entry->schedule(fusion_seg.get()); - executors_[group_id].compileFusion(fusion_seg.get(), options); + FusionGuard fg(fusion_to_run.get()); + scheduler_entry->schedule(fusion_to_run.get()); + executors_[group_id].compileFusion(fusion_to_run.get(), options); } // Load launch params for reduction and normalization kernels @@ -517,7 +392,7 @@ std::vector FusionSegmentRuntime::runSegmentWithInput( return executors_[group_id].runFusion(inputs, launch_params, input_id); } -std::vector FusionSegmentRuntime::runWithInput( +std::vector FusionKernelRuntime::runMultiKernelWithInput( const at::ArrayRef& inputs, size_t input_id) { TORCH_INTERNAL_ASSERT( @@ -584,7 +459,7 @@ std::vector FusionSegmentRuntime::runWithInput( // Run graph segment auto group_runtime_outputs = - runSegmentWithInput(group, group_runtime_inputs, input_id); + runKernelWithInput(group_runtime_inputs, input_id, group); const auto& group_outputs = group->outputs(); @@ -635,13 +510,13 @@ std::vector FusionSegmentRuntime::runWithInput( return fusion_output_tensors; } -const std::vector& -FusionSegmentRuntime::schedulers() { - return heuristics_->heuristics(); +const std::vector& FusionKernelRuntime:: + schedulers() { + return heuristics_->heuristicsList(); } namespace { -using HashType = FusionSegmentRuntime::HashType; +using HashType = FusionKernelRuntime::HashType; // Use a slightly more nontrivial combine to avoid collision // (from Boost) inline HashType combineHash(HashType a, HashType b) { @@ -652,38 +527,38 @@ inline HashType combineHash(HashType a, HashType b) { } } // namespace -FusionSegmentRuntime::HashType FusionSegmentRuntime::getHash( - SegmentHeuristics* sh) { +FusionKernelRuntime::HashType FusionKernelRuntime::getHash( + FusionHeuristics* sh) { HashType h = 0; - for (auto& se_pt : sh->heuristics()) { + for (auto& se_pt : sh->heuristicsList()) { h = combineHash(h, SchedulerEntryHash()(*se_pt)); } return h; } -FusionSegmentRuntime::HeuristicTag::HeuristicTag(SegmentHeuristics* sh) { +FusionKernelRuntime::HeuristicTag::HeuristicTag(FusionHeuristics* sh) { heuristics_ = sh; - hash_ = FusionSegmentRuntime::getHash(sh); + hash_ = FusionKernelRuntime::getHash(sh); } -bool FusionSegmentRuntime::HeuristicTag::operator==( - const FusionSegmentRuntime::HeuristicTag& other) const { - if (heuristics_->heuristics().size() != - other.heuristics_->heuristics().size()) { +bool FusionKernelRuntime::HeuristicTag::operator==( + const FusionKernelRuntime::HeuristicTag& other) const { + if (heuristics_->heuristicsList().size() != + other.heuristics_->heuristicsList().size()) { return false; } - auto& heuristics = heuristics_->heuristics(); + auto& heuristics = heuristics_->heuristicsList(); return std::equal( heuristics.begin(), heuristics.end(), - other.heuristics_->heuristics().begin(), + other.heuristics_->heuristicsList().begin(), [](const SchedulerEntryPtr& a, const SchedulerEntryPtr& b) { return a->sameAs(b.get()); }); } -void FusionSegmentRuntimeCache::evictId(size_t input_id) { +void FusionKernelRuntimeCache::evictId(size_t input_id) { TORCH_INTERNAL_ASSERT(id_to_rt_.count(input_id) != 0); // Evict the stored input tensor meta data @@ -692,7 +567,7 @@ void FusionSegmentRuntimeCache::evictId(size_t input_id) { id_to_rt_.erase(input_id); } -FusionSegmentRuntime* FusionSegmentRuntimeCache::getRt( +FusionKernelRuntime* FusionKernelRuntimeCache::getRt( const at::ArrayRef& inputs, size_t input_id) { // Look up by input_id first @@ -705,26 +580,42 @@ FusionSegmentRuntime* FusionSegmentRuntimeCache::getRt( return seg_runtime; } -FusionSegmentRuntime* FusionSegmentRuntimeCache::getRtById(size_t input_id) { +FusionKernelRuntime* FusionKernelRuntimeCache::getRtById(size_t input_id) { if (id_to_rt_.count(input_id) == 0) { return nullptr; } return id_to_rt_.at(input_id); } -FusionSegmentRuntime* FusionSegmentRuntimeCache::getRtByHeuristics( +FusionKernelRuntime* FusionKernelRuntimeCache::getRtByHeuristics( const at::ArrayRef& inputs, size_t input_id) { auto dev_id = getCommonDeviceCUDA(inputs); - auto heuristics = segmented_fusion_->makeHeuristics(inputs); + std::unique_ptr heuristics; + if (is_segmented_) { + heuristics = segmented_fusion_->makeHeuristics(inputs); + } else { + auto evaluator = executor_utils::bindFusionInputs(inputs, complete_fusion_); + heuristics = std::make_unique( + complete_fusion_heuristic_, evaluator); + } + HeuristicTag tag(heuristics.get()); auto rt = at(dev_id, tag); // Heuristics miss if (rt == nullptr) { // Construct new runtime instance - auto new_rt = std::make_unique( - segmented_fusion_, heuristics, input_id); + + std::unique_ptr new_rt; + + if (is_segmented_) { + new_rt = std::make_unique( + segmented_fusion_, heuristics, input_id); + } else { + new_rt = std::make_unique( + complete_fusion_, heuristics, input_id); + } rt = new_rt.get(); // Cache the new instance @@ -737,11 +628,20 @@ FusionSegmentRuntime* FusionSegmentRuntimeCache::getRtByHeuristics( return rt; } -void FusionSegmentRuntimeCache::initCache(SegmentedFusion* sf) { - segmented_fusion_ = sf; +void FusionKernelRuntimeCache::initSegmentCache( + SegmentedFusion* segmented_fusion) { + is_segmented_ = true; + segmented_fusion_ = segmented_fusion; +} + +void FusionKernelRuntimeCache::initSingleKernelCache( + Fusion* fusion, + ScheduleHeuristic schedule_heuristic) { + complete_fusion_ = fusion; + complete_fusion_heuristic_ = schedule_heuristic; } -FusionSegmentRuntime* FusionSegmentRuntimeCache::at( +FusionKernelRuntime* FusionKernelRuntimeCache::at( int dev_id, HeuristicTag tag) { // Get cache for the device id @@ -764,7 +664,7 @@ FusionSegmentRuntime* FusionSegmentRuntimeCache::at( return cache_entry_ptr.get(); } -void FusionSegmentRuntimeCache::insertEntry( +void FusionKernelRuntimeCache::insertEntry( int dev_id, HeuristicTag tag, SegRuntimePtr&& rt_pt) { diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index 31f1c2bbfb74..f9816b1d95ff 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -19,43 +19,63 @@ namespace fuser { namespace cuda { class SegmentedGroup; -class SegmentHeuristics; +class FusionHeuristics; -//! Implementation of a graph runtime with simple scheduling to support -//! multi-kernel fusion -class TORCH_CUDA_CU_API FusionSegmentRuntime { +//! FusionKernelRuntime is the unified interface from fusion graphs into +//! caching, compilation into kernels, and kernel launches. +//! +//! Each instance is also a cache entry tracked by FusionKernelRuntimeCache. +//! +//! Two types of instance can be created, one for complete/single-kernel fusion +//! and one for segmented/multi-kernel fusion. +//! Conceptually this is a generalization of FusionExecutor that supports both +//! single-kernel and multi-kernel caching/compiling/launching +class TORCH_CUDA_CU_API FusionKernelRuntime { public: - //! Type notations within FusionSegmentRuntime Context + //! Type notations within FusionKernelRuntime Context using HashType = size_t; using SchedulerEntryPtr = std::unique_ptr; - explicit FusionSegmentRuntime( + //! Create a runtime instance for segmented fusion + explicit FusionKernelRuntime( SegmentedFusion* segmented_fusion, - std::unique_ptr& heuristics, + std::unique_ptr& heuristics, size_t input_id); - //! FusionExecutorCache API for evicting an input id + //! Create a runtime instance for complete/single-kernel fusion + explicit FusionKernelRuntime( + Fusion* fusion, + std::unique_ptr& heuristics, + size_t input_id); + + //! Evicts internally cached parameters based on input sizes. + //! An interface used by runtime caches. void evictCache(size_t input_id) { for (auto& fe : executors_) { fe.evictCache(input_id); } } - //! FusionExecutorCache API for running the segmented fusion with given global - //! inputs + //! Unified interface to run the managed kernels with given input std::vector runWithInput( const at::ArrayRef& inputs, - size_t input_id); + size_t input_id) { + if (is_segmented_) { + return runMultiKernelWithInput(inputs, input_id); + } else { + return runKernelWithInput(inputs, input_id); + } + } //! Cache Interface: Common utility for computing hash of scheduler entires - static HashType getHash(SegmentHeuristics* sh); + static HashType getHash(FusionHeuristics* sh); //! Cache Interface: trivially copied and easily compared - //! descriptor for FusionSegmentRuntime + //! descriptor for a FusionKernelRuntime instance class HeuristicTag { public: //! Computes hash upon creation - explicit HeuristicTag(SegmentHeuristics*); + explicit HeuristicTag(FusionHeuristics*); //! Tag equal abstracts the heuristics equivalence bool operator==(const HeuristicTag& other) const; @@ -67,7 +87,7 @@ class TORCH_CUDA_CU_API FusionSegmentRuntime { private: HashType hash_; - SegmentHeuristics* heuristics_; + FusionHeuristics* heuristics_; }; class HeuristicTagHash { @@ -78,13 +98,23 @@ class TORCH_CUDA_CU_API FusionSegmentRuntime { }; private: - //! Run one segment of the segmented fusion, compiles if not done so - std::vector runSegmentWithInput( - SegmentedGroup* sg, + //! Interface to run a single kernel, either one kernel for single-kernel + //! fusions, + //! or a kernel for a segmentedGrouup in a segmented fusion. Returns the + //! kernel outputs. + std::vector runKernelWithInput( + const at::ArrayRef& inputs, + size_t input_id, + SegmentedGroup* sg = nullptr); + + //! Interface to run a the whole graph in a segmented fusion and return the + //! complete + //! fusion outputs. + std::vector runMultiKernelWithInput( const at::ArrayRef& inputs, size_t input_id); - //! Accessor class for the internal schedulers maintained in this runtime + //! Access the list of schedulers maintained in this runtime instance const std::vector& schedulers(); private: @@ -94,16 +124,27 @@ class TORCH_CUDA_CU_API FusionSegmentRuntime { std::vector executors_; //! Heuristics object holding scheduler entries for all segments - std::unique_ptr heuristics_; + std::unique_ptr heuristics_; + + // Checks if this runtime instance is for a single-kernel fusion (false) or a + // segmented fusion (true). + bool is_segmented_ = true; - // States - SegmentedFusion* segmented_fusion_; + // Maintain the original segmented fusion that this runtime is maintaining + // heuristics for. Applies only in the segmented fusion case, i.e. + // is_segmented==true + SegmentedFusion* segmented_fusion_ = nullptr; + + // Maintain the original fusion that this runtime is maintaining + // heuristics for. Applies only in the single-kernel fusion case, i.e. + // is_segmented==false + Fusion* complete_fusion_ = nullptr; }; //! Object holding cache entries for segmented fusion -class TORCH_CUDA_CU_API FusionSegmentRuntimeCache { +class TORCH_CUDA_CU_API FusionKernelRuntimeCache { public: - explicit FusionSegmentRuntimeCache() = default; + explicit FusionKernelRuntimeCache() = default; //! Evict the cacheEntry by id. //! removes ID to RT lookup and corresponding @@ -112,22 +153,28 @@ class TORCH_CUDA_CU_API FusionSegmentRuntimeCache { void evictId(size_t input_id); //! Interface for registering segmented fusion for caching heuristics - void initCache(SegmentedFusion* sf); + void initSegmentCache(SegmentedFusion* sf); + + //! Interface for registering complete fusion for caching single kernel + //! heuristics + void initSingleKernelCache( + Fusion* fusion, + ScheduleHeuristic schedule_heuristic); - //! API for collecting FusionSegmentRuntime entry from cache, + //! API for collecting FusionKernelRuntime entry from cache, //! contains a two level lookup, //! if input_id is hit -> returns cached //! if input_id miss -> lookup with heuristics -> return cached if found //! if heuristics miss -> create a new entry and return created - FusionSegmentRuntime* getRt( + FusionKernelRuntime* getRt( const at::ArrayRef& inputs, size_t input_id); private: - using HeuristicTag = FusionSegmentRuntime::HeuristicTag; - using HeuristicTagHash = FusionSegmentRuntime::HeuristicTagHash; - //! FusionSegmentRuntime cache based on HeuristicTag lookup - using SegRuntimePtr = std::unique_ptr; + using HeuristicTag = FusionKernelRuntime::HeuristicTag; + using HeuristicTagHash = FusionKernelRuntime::HeuristicTagHash; + //! FusionKernelRuntime cache based on HeuristicTag lookup + using SegRuntimePtr = std::unique_ptr; using SegRuntimeCache = std::unordered_map; //! One cache per device id @@ -138,23 +185,34 @@ class TORCH_CUDA_CU_API FusionSegmentRuntimeCache { //! Currently don't have releasing entry at this level since //! we would not release compiled kernels at this point void insertEntry(int dev_id, HeuristicTag tag, SegRuntimePtr&& rt); - FusionSegmentRuntime* at(int dev_id, HeuristicTag tag); + FusionKernelRuntime* at(int dev_id, HeuristicTag tag); private: + //! Checks if this cache is for segmented fusion or not + bool is_segmented_ = false; + + //! Store the heuristic corresponding to the complete fusion if any + ScheduleHeuristic complete_fusion_heuristic_ = ScheduleHeuristic::PointWise; + + //! Contains the complete fusion + Fusion* complete_fusion_ = nullptr; + + //! Data structure hosting the actual caches SegRuntimeCacheGroup seg_runtime_cache_group_; + //! Input_id to runtime shortcut - std::unordered_map id_to_rt_; + std::unordered_map id_to_rt_; //! Reference to the segmented fusion held in FusionExecutorCache SegmentedFusion* segmented_fusion_ = nullptr; //! In case of cache hit by input id, return pointer to that entry, //! returns nullptr if input_id miss - FusionSegmentRuntime* getRtById(size_t input_id); + FusionKernelRuntime* getRtById(size_t input_id); //! In case of input id miss, evaluate heuristics and find a hit by heuristics //! in case of heuristics miss, create a new entry - FusionSegmentRuntime* getRtByHeuristics( + FusionKernelRuntime* getRtByHeuristics( const at::ArrayRef& inputs, size_t input_id); }; @@ -313,69 +371,21 @@ class TORCH_CUDA_CU_API FusionExecutorCache { //! evict cached short cut entry in `code_to_fe_lookup_` as well as cached //! entry in `FusionExecutor` void evictCache(size_t cache_id) { - // Handling segmented fusion differently - if (isSegmented()) { - fusion_segment_runtime_cache_.evictId(cache_id); - return; - } - - auto iter = code_to_fe_lookup_.find(cache_id); - TORCH_INTERNAL_ASSERT( - iter != code_to_fe_lookup_.end(), - "evict cache failed to find an entry"); - // evict nested lookup entry in nested `FusionExecutor` - (iter->second)->evictCache(cache_id); - code_to_fe_lookup_.erase(iter); + fusion_kernel_runtime_cache_.evictId(cache_id); }; private: //! original un-scheduled `Fusion`; std::unique_ptr fusion_; - // I'm trading the const model in favor of assigning - // `has_nontrivial_reduction_` in the body of constructor, instead of the - // initializer list; Because of the move statement used in the constructor, - // it's tricky to maintain the code if we have `has_nontrivial_reduction_` as - // a const member and initizlize it in the initializer list, where the order - // of initialization is controled by the order of declaration instead of their - // order in the list - // - //! cache fusion->hasReduction() because it's expensive; - bool has_nontrivial_reduction_ = false; - - //! cache reduction_tv_ to avoid searching repetitively at runtime - std::vector reduction_tv_; - - //! TODO: ugly logic for now. We should integrate the hashing of cache for - //! different kernels. (alternatively we could do so in scheduler). - //! ugly bits now: - //! The fact that we have heuristics only for reduction, but use a general - //! kernel for all point-wise fusion ended up with this: - //! 1. For point-wise fusion, we have a single `FusionExecutor` in - //! `pw_fusion_executor_cache_` - //! 2. For reduction fusion we have a hash table with ReductionParams as entry - //! pointing to the actual `FusionExecutor` in `red_fusion_executor_cache_` - //! - //! Both cache_ key on device_index, because `FusionExecutor` is designated to - //! a single device - std::unordered_map> - pw_fusion_executor_cache_; - std::unordered_map< - int, - std::unordered_map> - red_fusion_executor_cache_; - - //! short cut to FusionExecutor for input set encoded with id; - std::unordered_map code_to_fe_lookup_; - //! inputs to unique_id lookup table; InputsIdLookup inputs_id_lookup_; - //! Multi-Kernel fusion segment caching + //! Multi-Kernel fusion segment when applies std::unique_ptr fusion_segments_ = nullptr; //! Caching for segmented fusions - FusionSegmentRuntimeCache fusion_segment_runtime_cache_; + FusionKernelRuntimeCache fusion_kernel_runtime_cache_; }; class GraphCache { diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 1065a2169553..17fe6c2e532a 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -72,7 +72,7 @@ void GpuLower::replaceSymbolicSizes() { // TODO(kir): consider a different implementation which doesn't // hijack the kir_val_map_ // Currently turn off this part for inputs of segmented fusion, - // since FusionSegmentRuntime will provide these as integer inputs + // since FusionKernelRuntime will provide these as integer inputs if (kir_val_map_.find(orig_size) == kir_val_map_.end() && !orig_size->isFusionInput()) { std::stringstream ss;