diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index f775703dd120..7ac8bab92e51 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -358,6 +358,16 @@ LaunchParams FusionExecutor::computeLaunchParams( }); auto& parallel_iter_extents = parallel_iter_extent_entry.get(); + auto simplified_parallel_iter_extent_entry = + executor_utils::caching::ExecutorCompileTimeEntry< + executor_utils::caching::SimplifiedParallelIterExtentMap>( + data_cache, [¶llel_binding_ids, &lower]() { + return executor_utils::getSimplifiedParallelIterExtents( + lower, parallel_binding_ids); + }); + auto& simplified_parallel_iter_extents = + simplified_parallel_iter_extent_entry.get(); + auto warp_padded_parallel_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::WarpPaddedParallelExtents>( @@ -409,7 +419,7 @@ LaunchParams FusionExecutor::computeLaunchParams( } // Run through the rest of the parallel IterDomains and infer their size - for (auto& entry : parallel_iter_extents) { + for (auto& entry : simplified_parallel_iter_extents) { FUSER_PERF_SCOPE("FusionExecutor::ParallelBindingResolution"); auto p_type = entry.first; auto parallel_extents = entry.second; diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 08ed39ad2aa7..dcd779aaf779 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -966,6 +966,7 @@ ExecutorCompileTimeEntry::ExecutorCompileTimeEntry( // Template instantiation template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; +template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; @@ -986,20 +987,55 @@ std::vector getParallelBindingsIterDomains( return parallel_ids; } +void insertParallelExtent( + GpuLower& lower, + IterDomain* binding_id, + const std::unique_ptr& parallel_iter_extents_ptr) { + auto kir_extent = lower.lowerValue(binding_id->extent()); + const auto it = + parallel_iter_extents_ptr->find(binding_id->getParallelType()); + if (it != parallel_iter_extents_ptr->end()) { + it->second.push_back(kir_extent); + } else { + parallel_iter_extents_ptr->operator[](binding_id->getParallelType()) = { + kir_extent}; + } +} + std::unique_ptr getParallelIterExtents( GpuLower& lower, std::vector& parallel_binding_ids) { auto parallel_iter_extents_ptr = std::make_unique(); for (auto id : parallel_binding_ids) { - // TODO(kir): we should rewrite this logic based on the Kernel object - auto kir_extent = lower.lowerValue(id->extent()); - const auto it = parallel_iter_extents_ptr->find(id->getParallelType()); - if (it != parallel_iter_extents_ptr->end()) { - it->second.push_back(kir_extent); - } else { - parallel_iter_extents_ptr->operator[](id->getParallelType()) = { - kir_extent}; + insertParallelExtent(lower, id, parallel_iter_extents_ptr); + } + + return parallel_iter_extents_ptr; +} + +std::unique_ptr getSimplifiedParallelIterExtents( + GpuLower& lower, + std::vector& parallel_binding_ids) { + auto parallel_iter_extents_ptr = std::make_unique(); + auto& parallel_map = lower.caParallelMap(); + std::vector mapped; + bool is_tidx_warp_padded = lower.getWarpPaddedParallelInfo().is_tidx_padded; + + for (auto id : parallel_binding_ids) { + if (std::any_of( + mapped.begin(), + mapped.end(), + [id, ¶llel_map](IterDomain* mapped_id) { + return parallel_map.areMapped(mapped_id, id); + })) { + if (id->getParallelType() != ParallelType::TIDx || !is_tidx_warp_padded) { + continue; + } } + + insertParallelExtent( + lower, parallel_map.getConcreteMappedID(id), parallel_iter_extents_ptr); + mapped.push_back(id); } return parallel_iter_extents_ptr; diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.h b/torch/csrc/jit/codegen/cuda/executor_utils.h index 9ed457dd6d9c..f29da30af3eb 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.h +++ b/torch/csrc/jit/codegen/cuda/executor_utils.h @@ -80,6 +80,7 @@ namespace caching { enum class CompileTimeEntryType { PARALLEL_BINDING_ITERDOMAINS, PARALLEL_ITER_EXTENT_MAP, + SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP, WARP_PADDED_PARALLEL_EXTENTS, VECTORIZED_TENSOR_VALIDATION, INPUT_ALIAS_INDICES, @@ -114,6 +115,27 @@ class ParallelIterExtentMap { CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP; }; +//! Compile-time info to be cached in each FusionExecutor: +//! SimplifiedParallelIterExtentMap +//! This entry type is a simplified version of ParallelIterExtentMap. +//! +//! For launch parameter binding we only need the most concrete iterdomain +//! in each disjoint set stored in CaParallelMap. This entry stores the +//! remaining list of extents for binding after this simplification. +//! +//! We still need ParallelIterExtentMap since we want to bind the concrete +//! values to the extents of all parallelized iterdomains. We would be +//! able to save these bindings if the integer machine has a notion of +//! equality and could be configured compile time. But that'd be a longer +//! term target. +class SimplifiedParallelIterExtentMap { + public: + using DataType = + std::unordered_map, TypeHash>; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP; +}; + //! WarpPaddedExtentsInfo: //! Auxiliary data type for entry class WarpPaddedParallelExtents struct WarpPaddedExtentsInfo { @@ -269,6 +291,12 @@ std::unique_ptr getParallelIterExtents( GpuLower& lower, std::vector& parallel_binding_ids); +//! Returns the simplified set of extents necessary for launch parameter +//! binding. +std::unique_ptr getSimplifiedParallelIterExtents( + GpuLower& lower, + std::vector& parallel_binding_ids); + //! Returns the symbolic or constant extetns of warp padded parallel //! iterdomains in the given vector. std::unique_ptr getWarpPaddedExtentsInfo( diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 1d7c452d4cdb..0f4b523b6ba0 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -30,7 +30,6 @@ c10::optional ExpressionEvaluator::evaluate(Val* value) { if (evaluator_precomputed_integers_ != nullptr) { return evaluator_precomputed_integers_->getMaybeValueFor(value); } else { - FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate"); auto maybe_concrete_value = getValue(value); if (!maybe_concrete_value.has_value()) { if (value->definition() != nullptr) { diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index dfdf69740900..cfa88d0760bb 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -433,6 +433,10 @@ FusionKernelRuntime::FusionKernelRuntime( } is_segmented_ = segmented; + + if (is_segmented_) { + prepareRuntimeOrder(); + } } std::vector FusionKernelRuntime::runKernelWithInput( @@ -483,7 +487,6 @@ std::vector FusionKernelRuntime::runKernelWithInput( executors_[group_id].compileFusion( fusion_to_run.get(), options, inputs, launch_params); } else { - FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::FetchFromCache"); // Load launch params for reduction and normalization kernels if (scheduler_entry->hasReductionParam()) { launch_params = scheduler_entry->reductionParams().lparams; @@ -493,7 +496,6 @@ std::vector FusionKernelRuntime::runKernelWithInput( } if (profiling_) { - FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::profiling_"); most_recent_executor_log_.fusion_executor = &executors_[group_id]; most_recent_executor_log_.launch_constraints = launch_params; if (scheduler_entry->hasReductionParam()) { @@ -508,40 +510,21 @@ std::vector FusionKernelRuntime::runKernelWithInput( return executors_[group_id].runFusion(inputs, launch_params, input_id); } -std::vector FusionKernelRuntime::runMultiKernelWithInput( - const at::ArrayRef& inputs, - size_t input_id) { - FUSER_PERF_SCOPE("FusionKernelRuntime::runMultiKernelWithInput"); +void FusionKernelRuntime::prepareRuntimeOrder() { + // Setup group run order: + std::unordered_set available_input; - TORCH_INTERNAL_ASSERT( - inputs.size() == segmented_fusion_->inputs().size(), - "Inputs were not set up correctly, recieved ", - inputs.size(), - " inputs but expecting ", - segmented_fusion_->inputs().size()); - - // Map to keep track of currently available tensors - std::unordered_map tensor_map; - - // Bind input in the tensor_map - for (size_t i = 0; i < inputs.size(); i++) { - tensor_map.emplace(segmented_fusion_->inputs()[i], inputs[i]); - - // Bind tensorview inputs values in case some segmented group - // needs it down the road. - // TODO: we probably have done this already up to this point - // should consider caching the expression evaluators, both - // more convenient and safer than replication - if (inputs[i].isTensor()) { - auto aten_tensor = inputs[i].toTensor(); - TORCH_INTERNAL_ASSERT( - segmented_fusion_->inputs()[i]->getValType() == ValType::TensorView); - auto input_tv = segmented_fusion_->inputs()[i]->as(); + // setup the order tensor dimensions are bound + for (size_t i : c10::irange(segmented_fusion_->inputs().size())) { + auto input_val = segmented_fusion_->inputs()[i]; + available_input.insert(input_val); + + if (auto input_tv = dynamic_cast(input_val)) { auto root_dom = TensorDomain::noReductions(input_tv->getRootDomain()); - for (size_t dim = 0; dim < root_dom.size(); dim++) { + for (size_t dim : c10::irange(root_dom.size())) { const auto extent = root_dom[dim]->extent(); - const auto value = aten_tensor.sizes()[dim]; - tensor_map.emplace(extent, value); + available_input.insert(extent); + runtime_workspace_.group_extent_binding_order.push_back(extent); } } } @@ -554,38 +537,24 @@ std::vector FusionKernelRuntime::runMultiKernelWithInput( bool one_ran = false; // Find the first segment with all inputs available to run - for (size_t group_i = 0; group_i < segmented_fusion_->groups().size(); - group_i++) { + for (size_t group_i : c10::irange(segmented_fusion_->groups().size())) { auto& group = segmented_fusion_->groups()[group_i]; if (group_ran[group_i]) { continue; } const auto& group_inputs = group->inputs(); bool ready_to_run = std::all_of( - group_inputs.begin(), group_inputs.end(), [&tensor_map](Val* val) { - return tensor_map.find(val) != tensor_map.end(); - }); + group_inputs.begin(), + group_inputs.end(), + [&available_input](Val* val) { return available_input.count(val); }); if (ready_to_run) { - std::vector group_runtime_inputs; - group_runtime_inputs.reserve(group_inputs.size()); - - // Prepare input vector - for (auto input : group_inputs) { - group_runtime_inputs.push_back(tensor_map.at(input)); - } - - // Run graph segment - auto group_runtime_outputs = - runKernelWithInput(group_runtime_inputs, input_id, group); - + runtime_workspace_.group_run_order.push_back(group); const auto& group_outputs = group->outputs(); // Insert graph segment output to tensor map - for (size_t group_out_i = 0; group_out_i < group_outputs.size(); - group_out_i++) { - tensor_map.emplace( - group_outputs[group_out_i], group_runtime_outputs[group_out_i]); + for (size_t group_out_i : c10::irange(group_outputs.size())) { + available_input.insert(group_outputs[group_out_i]); } group_ran[group_i] = true; one_ran = true; @@ -595,37 +564,100 @@ std::vector FusionKernelRuntime::runMultiKernelWithInput( one_ran, "Couldn't run all groups, something must have gone wrong in segmentation."); } +} - // Produce final global output - std::vector fusion_outputs; - for (auto output : segmented_fusion_->outputs()) { - const auto iter = tensor_map.find(output); - if (iter != tensor_map.end()) { - fusion_outputs.push_back(iter->second); - } else { - // This is the check for an empty tensor; - TORCH_INTERNAL_ASSERT( - output->as()->nDims() == 0 && - output->getDataType().has_value() && - output->getDataType().value() == DataType::Float, - "Non empty tensor cannot be found at tensor_map in ", - __FUNCTION__); - fusion_outputs.emplace_back(at::Tensor()); +std::vector FusionKernelRuntime::runWithInput( + const at::ArrayRef& inputs, + size_t input_id) { + if (is_segmented_) { + FUSER_PERF_SCOPE("FusionKernelRuntime::runMultiKernelWithInput"); + + TORCH_INTERNAL_ASSERT( + inputs.size() == segmented_fusion_->inputs().size(), + "Inputs were not set up correctly, recieved ", + inputs.size(), + " inputs but expecting ", + segmented_fusion_->inputs().size()); + + int extent_index_ = 0; + // Bind input in the tensor_map + for (size_t i = 0; i < inputs.size(); i++) { + runtime_workspace_.tensor_map.emplace( + segmented_fusion_->inputs()[i], inputs[i]); + + // Bind tensorview inputs values in case some segmented group + // needs it down the road. + // TODO: we probably have done this already up to this point + // should consider caching the expression evaluators, both + // more convenient and safer than replication + if (inputs[i].isTensor()) { + auto aten_tensor = inputs[i].toTensor(); + for (auto dim_size : aten_tensor.sizes()) { + runtime_workspace_.tensor_map.emplace( + runtime_workspace_.group_extent_binding_order[extent_index_++], + dim_size); + } + } + } + + for (auto group_to_run : runtime_workspace_.group_run_order) { + // Prepare input vector + for (auto input : group_to_run->inputs()) { + runtime_workspace_.group_runtime_inputs.push_back( + runtime_workspace_.tensor_map.at(input)); + } + // Run graph segment + runtime_workspace_.group_runtime_outputs = runKernelWithInput( + runtime_workspace_.group_runtime_inputs, input_id, group_to_run); + + const auto& group_outputs = group_to_run->outputs(); + + // Insert graph segment output to tensor map + for (unsigned int group_out_i = 0; group_out_i < group_outputs.size(); + group_out_i++) { + runtime_workspace_.tensor_map.emplace( + group_outputs[group_out_i], + runtime_workspace_.group_runtime_outputs[group_out_i]); + } + runtime_workspace_.group_runtime_inputs.clear(); + runtime_workspace_.group_runtime_outputs.clear(); } - } - std::vector fusion_output_tensors; - std::transform( - fusion_outputs.begin(), - fusion_outputs.end(), - std::back_inserter(fusion_output_tensors), - [](IValue ival) { + // Produce final global output + std::vector fusion_outputs; + for (auto output : segmented_fusion_->outputs()) { + const auto iter = runtime_workspace_.tensor_map.find(output); + if (iter != runtime_workspace_.tensor_map.end()) { + fusion_outputs.push_back(iter->second); + } else { + // This is the check for an empty tensor; TORCH_INTERNAL_ASSERT( - ival.isTensor(), "Cannot output non-tensor objects from a fusion."); - return ival.toTensor(); - }); + output->as()->nDims() == 0 && + output->getDataType().has_value() && + output->getDataType().value() == DataType::Float, + "Non empty tensor cannot be found at tensor_map in ", + __FUNCTION__); + fusion_outputs.emplace_back(at::Tensor()); + } + } - return fusion_output_tensors; + std::vector fusion_output_tensors; + std::transform( + fusion_outputs.begin(), + fusion_outputs.end(), + std::back_inserter(fusion_output_tensors), + [](IValue ival) { + TORCH_INTERNAL_ASSERT( + ival.isTensor(), + "Cannot output non-tensor objects from a fusion."); + return ival.toTensor(); + }); + + runtime_workspace_.tensor_map.clear(); + return fusion_output_tensors; + } else { + return runKernelWithInput(inputs, input_id); + } } const std::vector& FusionKernelRuntime:: diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index dec29181628d..fc8c2a65497c 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -61,13 +61,7 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { //! Unified interface to run the managed kernels with given input std::vector runWithInput( const at::ArrayRef& inputs, - size_t input_id) { - if (is_segmented_) { - return runMultiKernelWithInput(inputs, input_id); - } else { - return runKernelWithInput(inputs, input_id); - } - } + size_t input_id); //! Turn On/Off profiling void profile(bool to_profile = true) { @@ -151,6 +145,8 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { //! Access the list of schedulers maintained in this runtime instance const std::vector& schedulers(); + void prepareRuntimeOrder(); + private: //! Entries indexed by groupID: //! Executors holding compiled kernels @@ -174,6 +170,22 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { //! TODO: unify the segmented and un-segmented code-path std::unique_ptr single_kernel_fusion_data_cache_ = nullptr; + //! Pre-allocated runtime workspace to speed up kernel launch preparation. + struct RuntimeWorkSpace { + //! Temporary space to save intermediate tensors for segmented fusion + std::unordered_map tensor_map; + + //! Pre-determined order to run the segmented groups + std::vector group_run_order; + + //! Pre-determined order to bind tensor input meta data + std::vector group_extent_binding_order; + + //! Pre-allocated workspace to hold group inputs and outputs + std::vector group_runtime_inputs; + std::vector group_runtime_outputs; + } runtime_workspace_; + //! Utility to speed up integer evaluation at runtime std::unique_ptr precomputed_integers_;