diff --git a/benchmarks/cpp/nvfuser/heuristic_lookup.cpp b/benchmarks/cpp/nvfuser/heuristic_lookup.cpp index 64b1ecfb756d4..3bd4ec0b1607d 100644 --- a/benchmarks/cpp/nvfuser/heuristic_lookup.cpp +++ b/benchmarks/cpp/nvfuser/heuristic_lookup.cpp @@ -99,12 +99,15 @@ static void LayerNormBackward_HeuristicLookup( auto runtime = getLayerBackwardNormRuntime( std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape); + + KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + TORCH_INTERNAL_ASSERT( - runtime->getMaybeHeuristicsFor(aten_inputs).has_value()); + runtime->getMaybeHeuristicsFor(args).has_value()); for (auto _ : benchmark_state) { // Setup (not included in the measurement) - runtime->getMaybeHeuristicsFor(aten_inputs); + runtime->getMaybeHeuristicsFor(args); } } @@ -152,12 +155,15 @@ static void LayerNormForward_HeuristicLookup( auto runtime = getLayerForwardNormRuntime( std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape); + + KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + TORCH_INTERNAL_ASSERT( - runtime->getMaybeHeuristicsFor(aten_inputs).has_value()); + runtime->getMaybeHeuristicsFor(args).has_value()); for (auto _ : benchmark_state) { // Setup (not included in the measurement) - runtime->getMaybeHeuristicsFor(aten_inputs); + runtime->getMaybeHeuristicsFor(args); } } diff --git a/benchmarks/cpp/nvfuser/shape_inference.cpp b/benchmarks/cpp/nvfuser/shape_inference.cpp index 2e5e23ed7442e..fd628a163abce 100644 --- a/benchmarks/cpp/nvfuser/shape_inference.cpp +++ b/benchmarks/cpp/nvfuser/shape_inference.cpp @@ -100,8 +100,11 @@ void LayerNormBackward_ShapeInference_Base( auto runtime = getLayerBackwardNormRuntime( std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape); + + KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + TORCH_INTERNAL_ASSERT( - runtime->getMaybeHeuristicsFor(aten_inputs).has_value()); + runtime->getMaybeHeuristicsFor(args).has_value()); fec->profile(true); fec->disableKernelLaunch(); @@ -172,8 +175,10 @@ void LayerNormForward_ShapeInferenceBase( auto runtime = getLayerForwardNormRuntime( std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape); + KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + TORCH_INTERNAL_ASSERT( - runtime->getMaybeHeuristicsFor(aten_inputs).has_value()); + runtime->getMaybeHeuristicsFor(args).has_value()); fec->profile(true); fec->disableKernelLaunch(); diff --git a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp index 83107569dc54b..f6d84cbb208c2 100644 --- a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp +++ b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp @@ -383,19 +383,19 @@ KernelPrecomputedIntegers::KernelPrecomputedIntegers(kir::Kernel* kernel) { initializeIntegerMachine(); } +// TODO: put this to base class void KernelPrecomputedIntegers::bindTensorMetaData( TensorView* tv, - const at::Tensor& at_tensor) { - std::vector> ret; + const TensorArgAbstract* tensor_arg_abstract) { const auto root_domain = TensorDomain::noReductions(tv->domain()->getMaybeRFactorDomain()); TORCH_INTERNAL_ASSERT( - at_tensor.ndimension() == static_cast(root_domain.size()), + tensor_arg_abstract->getRank() == static_cast(root_domain.size()), "Something went wrong configuring launch. Inputs do not match."); for (const auto dim : c10::irange(root_domain.size())) { auto extent = root_domain[dim]->extent(); - auto value = at_tensor.sizes()[dim]; + auto value = tensor_arg_abstract->getSize(dim); bindValue(extent->evaluatorIndex(), value); } } @@ -434,22 +434,37 @@ void KernelPrecomputedIntegers::initializeNamedScalars() { } } +// TODO: merge this one with above. void KernelPrecomputedIntegers::bindKernelInputs( kir::Kernel* kernel, - const at::ArrayRef& aten_inputs) { + const KernelArgumentHolder& args) { if (hasValidValues()) { invalidate(); } const auto& inputs = kernel->inputs(); + TORCH_INTERNAL_ASSERT( + args.size() == inputs.size(), "kernel inputs size does not match args"); for (const auto i : c10::irange(inputs.size())) { + auto arg = args[i]; const auto input = inputs[i]; if (auto tensor_input = dynamic_cast(input)) { - const auto aten_tensor = aten_inputs[i].toTensor(); - bindTensorMetaData(tensor_input, aten_tensor); + if (const auto& tensor_arg_abstract = + dynamic_cast(arg)) { + bindTensorMetaData(tensor_input, tensor_arg_abstract); + } else { + // TODO: cpu scalar of int type should be bound as scalar int as well + TORCH_CHECK( + arg->isType(ArgType::CpuScalarTensor), + "binding input to TensorView expects input arg to be of tensor type"); + } } else if (input->isScalar() && input->dtype() == DataType::Int) { - bindValue(input->evaluatorIndex(), aten_inputs[i].toInt()); + TORCH_CHECK( + arg->isType(ArgType::Long), + "binding input to integer type expects input arg to be a scalar of Long type"); + precomputedIntegersBaseType::bindValue( + input->evaluatorIndex(), *static_cast(arg->arg())); } } } @@ -489,38 +504,51 @@ FusionPrecomputedIntegers::FusionPrecomputedIntegers(Fusion* fusion) initializeIntegerMachine(); } +// TODO: put this to base class void FusionPrecomputedIntegers::bindTensorMetaData( TensorView* tv, - const at::Tensor& at_tensor) { + const TensorArgAbstract* tensor_arg_abstract) { const auto root_domain = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); TORCH_INTERNAL_ASSERT( - at_tensor.ndimension() == static_cast(root_domain.size()), + tensor_arg_abstract->getRank() == static_cast(root_domain.size()), "Something went wrong configuring launch. Inputs do not match."); for (const auto dim : c10::irange(root_domain.size())) { auto extent = root_domain[dim]->extent(); - auto value = at_tensor.sizes()[dim]; + auto value = tensor_arg_abstract->getSize(dim); precomputedIntegersBaseType::bindValue(extent->evaluatorIndex(), value); } } void FusionPrecomputedIntegers::bindFusionInputs( - const at::ArrayRef& aten_inputs) { + const KernelArgumentHolder& args) { if (hasValidValues()) { precomputedIntegersBaseType::invalidate(); } const auto& inputs = fusion_->inputs(); + TORCH_INTERNAL_ASSERT( + args.size() == inputs.size(), "kernel inputs size does not match args"); for (const auto i : c10::irange(inputs.size())) { const auto input = inputs[i]; + const ArgAbstract* arg = args[i]; if (auto tensor_input = dynamic_cast(input)) { - const auto aten_tensor = aten_inputs[i].toTensor(); - bindTensorMetaData(tensor_input, aten_tensor); + if (const auto& tensor_arg_abstract = + dynamic_cast(arg)) { + bindTensorMetaData(tensor_input, tensor_arg_abstract); + } else { + TORCH_CHECK( + arg->isType(ArgType::CpuScalarTensor), + "binding input to TensorView expects input arg to be of tensor type"); + } } else if (input->isScalar() && input->getDataType() == DataType::Int) { + TORCH_CHECK( + arg->isType(ArgType::Long), + "binding input to integer type expects input arg to be a scalar of Long type"); precomputedIntegersBaseType::bindValue( - input->evaluatorIndex(), aten_inputs[i].toInt()); + input->evaluatorIndex(), *static_cast(arg->arg())); } } } diff --git a/torch/csrc/jit/codegen/cuda/evaluator_common.h b/torch/csrc/jit/codegen/cuda/evaluator_common.h index 7cbe37c602b9e..34e57a124fcae 100644 --- a/torch/csrc/jit/codegen/cuda/evaluator_common.h +++ b/torch/csrc/jit/codegen/cuda/evaluator_common.h @@ -279,10 +279,12 @@ class FusionPrecomputedIntegers FusionPrecomputedIntegers(Fusion* fusion); //! Bind concrete values from fusion runtime inputs - void bindFusionInputs(const at::ArrayRef& aten_inputs); + void bindFusionInputs(const KernelArgumentHolder& args); private: - void bindTensorMetaData(TensorView* tv, const at::Tensor& at_tensor); + void bindTensorMetaData( + TensorView* tv, + const TensorArgAbstract* tensor_arg_abstract); private: Fusion* fusion_ = nullptr; @@ -302,9 +304,7 @@ class KernelPrecomputedIntegers KernelPrecomputedIntegers(kir::Kernel* kernel); //! Bind concrete values from fusion runtime inputs - void bindKernelInputs( - kir::Kernel* kernel, - const at::ArrayRef& aten_inputs); + void bindKernelInputs(kir::Kernel* kernel, const KernelArgumentHolder& args); //! Bind concrete values from launch constraints void bindParallelExtents( @@ -317,7 +317,9 @@ class KernelPrecomputedIntegers void bindConcreteParallelTypeValue(ParallelType pt, int64_t value); private: - void bindTensorMetaData(TensorView* tv, const at::Tensor& at_tensor); + void bindTensorMetaData( + TensorView* tv, + const TensorArgAbstract* tensor_arg_abstract); //! Iterate through all the named scalars corresponding //! to thread sizes and pre-group them by their parallel diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 743feafcb4351..358be90876689 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -161,9 +161,8 @@ void FusionExecutor::debugCompileFusionFromStr( void FusionExecutor::compileFusion( Fusion* fusion, - const at::ArrayRef& inputs, - const LaunchParams& launch_constraints, - CompileOptions options) { + const KernelArgumentHolder& args, + const LaunchParams& launch_constraints) { FUSER_PERF_SCOPE("compileFusion"); TORCH_INTERNAL_ASSERT( @@ -181,7 +180,9 @@ void FusionExecutor::compileFusion( fusion->printMath(); } - options_ = options; + // TODO: refactor the options_ passed through + options_.device = c10::Device(c10::DeviceType::CUDA, args.getDeviceIndex()); + options_.index_mode = args.getIndexMode(); c10::DeviceGuard dg(options_.device); TORCH_INTERNAL_ASSERT( @@ -240,8 +241,8 @@ void FusionExecutor::compileFusion( // TODO: pass block_size here; c10::optional block_size = c10::nullopt; - if (!inputs.empty()) { - auto expr_eval = executor_utils::bindKernelInputs(inputs, kernel); + if (!args.empty()) { + auto expr_eval = executor_utils::bindKernelInputs(args, kernel); auto launch_params = computeLaunchParams(launch_constraints, expr_eval, warp_size_); block_size = launch_params.nThreads(); @@ -249,8 +250,15 @@ void FusionExecutor::compileFusion( block_size > 0, "launch param inferred block size < 0"); } - block_size_high_water_mark = - block_size.has_value() ? block_size.value() : block_size_high_water_mark; + // TODO: high water mark should be computed via occupancy API after + // compilation. + + // Basically setting high water martk as 1 when we don't provide args for + // compilation, it will just generate a kernel that gets ditched at the first + // run - not great. We should have better heuristics. + block_size_high_water_mark = std::max( + (block_size.has_value() ? block_size.value() : 1), + block_size_high_water_mark); std::tie(compiled_kernel_, last_compiler_log_) = executor_utils::nvrtcCompile( structured_code, (kernelNamespace() + "::" + kernelName()).c_str(), @@ -327,22 +335,20 @@ at::Tensor inferAndAlloc( } const auto at_type = data_type_to_aten(tv->dtype()); + const auto tensor_options = + at::TensorOptions().dtype(at_type).device(options.device); + c10::IntArrayRef isizes(inferred_sizes); if (zero_init) { - const auto tensor_options = - at::TensorOptions().dtype(at_type).device(options.device); - c10::IntArrayRef isizes(inferred_sizes); auto zeros = at::zeros(isizes, tensor_options); if (expanded_dim) { return zeros.expand(expanded_sizes); } return zeros; } else { - c10::IntArrayRef isizes(inferred_sizes); // Non Variable type guard for empty_cuda call at::AutoDispatchBelowADInplaceOrView non_variable_type_mode; - auto empty = at::native::empty_cuda( - isizes, at_type, c10::nullopt, options.device, c10::nullopt); + auto empty = at::empty(isizes, tensor_options); if (expanded_dim) { return empty.expand(expanded_sizes); } @@ -656,25 +662,82 @@ FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals( } std::vector FusionExecutor::allocOutputs( - const at::ArrayRef& inputs, kir::ExpressionEvaluator& expr_eval, const std::unordered_set& alias_indices) { FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs"); const auto kernel = lowered_->kernel(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector outputs; + for (const auto out_i : c10::irange(kernel->outputs().size())) { + // TODO: FIX this short-cut where we trivially forward inputs to outputs + if (kernel->outputs()[out_i]->isFusionInput()) { + TORCH_INTERNAL_ASSERT(false, "trivial input forwarding NOT IMPLEMENTED"); + // for (auto inp_i : c10::irange(kernel->inputs().size())) { + // if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) { + // TORCH_INTERNAL_ASSERT( + // inp_i < inputs.size(), + // "Issue with an input showing up as output, couldn't find + // input."); + // TORCH_INTERNAL_ASSERT( + // inputs[inp_i].isTensor(), + // "Cannot register a scalar as an output in a fusion."); + // outputs.push_back(inputs[inp_i].toTensor()); + // break; + // } + // } + } else { + TORCH_INTERNAL_ASSERT( + kernel->outputs()[out_i]->isA(), + "Cannot allocate outputs that are not tensors."); + auto output = kernel->outputs()[out_i]->as(); + if (alias_indices.count(out_i) != 0) { + // aliasing to inputs, no need to allocate real output, just push empty + // tensor here. + outputs.emplace_back(); + } else { + outputs.push_back( + inferAndAllocOutput(output, expr_eval, options_, false)); + } + } + } + return outputs; +} + +void FusionExecutor::setUsedTVs() { + auto used_vals = fusion_->usedMathVals(); + auto used_tvs = ir_utils::filterByType(used_vals); + used_tvs_.clear(); + used_tvs_.insert(used_tvs_.begin(), used_tvs.begin(), used_tvs.end()); +} + +KernelArgumentHolder FusionExecutor::evaluateOutputSizes( + const KernelArgumentHolder& args, + kir::ExpressionEvaluator& expr_eval, + const std::unordered_set& alias_indices) { + FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs"); + const auto kernel = lowered_->kernel(); + + KernelArgumentHolder ret(args.getIndexMode()); + ret.setDeviceIndex(args.getDeviceIndex()); + + CompileOptions meta_options = options_; + meta_options.device = c10::Device(DeviceType::Meta, 0); + for (const auto out_i : c10::irange(kernel->outputs().size())) { // If the output is just trivially the input, just "copy" it over. if (kernel->outputs()[out_i]->isFusionInput()) { for (auto inp_i : c10::irange(kernel->inputs().size())) { if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) { TORCH_INTERNAL_ASSERT( - inp_i < inputs.size(), + inp_i < args.size(), "Issue with an input showing up as output, couldn't find input."); + + auto tensor_arg_abstract = + dynamic_cast(args[inp_i]); TORCH_INTERNAL_ASSERT( - inputs[inp_i].isTensor(), + tensor_arg_abstract, "Cannot register a scalar as an output in a fusion."); - outputs.push_back(inputs[inp_i].toTensor()); + ret.push(tensor_arg_abstract); break; } } @@ -685,51 +748,109 @@ std::vector FusionExecutor::allocOutputs( auto output = kernel->outputs()[out_i]->as(); if (alias_indices.count(out_i) != 0) { // aliasing to inputs, no need to allocate real output - outputs.push_back( - inferAndAlloc(output, {}, expr_eval, {}, options_, false)); + // but we still need to push an entry here. + ret.push(int64_t(0)); } else { - // Allocate a real output - outputs.push_back( - inferAndAllocOutput(output, expr_eval, options_, false)); + // TODO: we are using meta here, which is bad since it doesn't account + // for devices. Switch to fake tensor instead + ret.push(inferAndAllocOutput(output, expr_eval, meta_options, false)); } } } - return outputs; + return ret; } -void FusionExecutor::setUsedTVs() { - auto used_vals = fusion_->usedMathVals(); - auto used_tvs = ir_utils::filterByType(used_vals); - used_tvs_.clear(); +KernelArgumentHolder FusionExecutor::inferOutputSizes( + const KernelArgumentHolder& args, + const LaunchParams& launch_constraints) { + FUSER_PERF_SCOPE("FusionExecutor::RunFusion"); + + ExecutorEntry* executor_entry = nullptr; + c10::optional opt_code = args.getCacheId(); + if (opt_code.has_value()) { + executor_entry = &executor_entry_lookup_[*opt_code]; + } + + executor_utils::initializeCudaContext(); + TORCH_INTERNAL_ASSERT(lowered_); + + TORCH_INTERNAL_ASSERT( + !executor_entry || !executor_entry->init, + "compile kernel shouldn't hit a pre-existing cache"); + FUSER_PERF_SCOPE("ExecutorRunFusion::ValidateAndInitialize"); + // TODO: validate kernel inputs currently won't be happy, since our fusion + // args are mapped with `meta` tensor instead of `cuda` tensor, check if this + // would be resolved with FakeTensor + // executor_utils::validateKernelInputs(fusion_, args, options_.device); + + if (!evaluator_precomputed_integers_) { + evaluator_precomputed_integers_ = + std::make_unique(lowered_->kernel()); + } + + kir::ExpressionEvaluator expr_eval; + evaluator_precomputed_integers_->bindKernelInputs(lowered_->kernel(), args); + expr_eval.precomputedIntegers() = evaluator_precomputed_integers_.get(); + + // I think this binds something to expr_eval, so even though we are not using + // launch_params_, we still need this in order to infer output shapes. + launch_params_ = + computeLaunchParams(launch_constraints, expr_eval, warp_size_); - for (auto tv : used_tvs) - used_tvs_.push_back(tv); + executor_utils::validateVectorizedTensors( + lowered_.get()->kernel(), args, {}, compileTimeDataCache(), expr_eval); + + auto alias_indices_entry = executor_utils::caching::ExecutorCompileTimeEntry< + executor_utils::caching::InputAliasIndices>( + compileTimeDataCache(), [&]() { + return std::make_unique>>( + fusion_->getInputAliasIndices()); + }); + + auto& alias_indices = alias_indices_entry.get(); + + // NOLINTNEXTLINE(bugprone-branch-clone) + auto output_alias_indices_entry = + executor_utils::caching::ExecutorCompileTimeEntry< + executor_utils::caching::OutputAliasIndices>( + compileTimeDataCache(), [&]() { + return std::make_unique>( + fusion_->getOutputAliasIndices()); + }); + + auto& output_alias_indices = output_alias_indices_entry.get(); + + auto ret = evaluateOutputSizes(args, expr_eval, output_alias_indices); + + for (const auto& entry : alias_indices) { + auto aliased_output_index = entry.first; + auto aliased_input_index = entry.second; + TORCH_INTERNAL_ASSERT( + args[aliased_input_index]->isType(ArgType::Tensor), + "alias io only supports tensor"); + ret.swap(aliased_output_index, args[aliased_input_index]); + } + + return ret; } std::vector FusionExecutor::runFusion( - const at::ArrayRef& inputs, - const std::vector& outputs, + KernelArgumentHolder& args, const LaunchParams& launch_constraints, - const c10::optional& opt_code) { + const std::vector& outputs) { FUSER_PERF_SCOPE("FusionExecutor::RunFusion"); TORCH_INTERNAL_ASSERT(compiled()); TORCH_INTERNAL_ASSERT( fusion_id_ > 0, "Cannot run fusion, it was not compiled."); TORCH_INTERNAL_ASSERT( - !opt_code.has_value() || outputs.empty(), + !args.getCacheId().has_value() || outputs.empty(), "short cut input cache is not compatible with pre-allocated output"); if (isDebugDumpEnabled(DebugDumpOption::FusionArgs)) { std::cout << "Arguments for fusion" << fusion_id_ << ":" << std::endl << "Inputs:" << std::endl; - for (const auto& input : inputs) { - if (input.isTensor()) { - const auto& input_tensor = input.toTensor(); - std::cout << " " << input_tensor.scalar_type() << " " - << input.toTensor().sizes() - << " (strides = " << input.toTensor().strides() << ")" - << std::endl; - } + for (auto i : c10::irange(args.size())) { + args[i]->print(); } std::cout << "Outputs:" << std::endl; for (const auto& output : outputs) { @@ -740,8 +861,8 @@ std::vector FusionExecutor::runFusion( } ExecutorEntry* executor_entry = nullptr; - if (opt_code.has_value()) { - executor_entry = &executor_entry_lookup_[*opt_code]; + if (args.getCacheId().has_value()) { + executor_entry = &executor_entry_lookup_[*args.getCacheId()]; } c10::DeviceGuard dg(options_.device); @@ -750,7 +871,7 @@ std::vector FusionExecutor::runFusion( TORCH_INTERNAL_ASSERT(lowered_); launch_params_ = LaunchParams(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector allocated_outputs = outputs; + std::vector allocated_outputs; GlobalBuffers global_buffers; uint64_t rand_offset = 0; @@ -772,17 +893,28 @@ std::vector FusionExecutor::runFusion( options_.device, c10::nullopt)); } + // Note: aliased output is not returned as output. But we still need it + // for kernel execution, so would need to push them to args for (const auto& entry : executor_entry->io_alias_indices) { + auto aliased_output_index = entry.first; + auto aliased_input_index = entry.second; + auto tensor_arg_abstract = + dynamic_cast(args[aliased_input_index]); TORCH_INTERNAL_ASSERT( - inputs[entry.second].isTensor(), "alias io only supports tensor"); - allocated_outputs[entry.first] = inputs[entry.second].toTensor(); + tensor_arg_abstract, "alias io only supports tensor"); + allocated_outputs[aliased_output_index] = + tensor_arg_abstract->getTensor(); } + args.push(allocated_outputs); } else { TORCH_INTERNAL_ASSERT( outputs.size() == fusion_->outputs().size(), __func__, " provided number of outputs does match fusion output"); + allocated_outputs = outputs; + args.push(outputs); } + { FUSER_PERF_SCOPE("ExecutorRunFusion::IntermediateBufferAlloc"); for (const auto i : c10::irange(executor_entry->buffer_sizes.size())) { @@ -811,7 +943,7 @@ std::vector FusionExecutor::runFusion( // code path to take when either: // 1. no opt_code is provided or // 2. `executor_entry` is not initialized - executor_utils::validateKernelInputs(fusion_, inputs, options_.device); + executor_utils::validateKernelInputs(fusion_, args, options_.device); if (!evaluator_precomputed_integers_) { evaluator_precomputed_integers_ = @@ -819,8 +951,7 @@ std::vector FusionExecutor::runFusion( } kir::ExpressionEvaluator expr_eval; - evaluator_precomputed_integers_->bindKernelInputs( - lowered_->kernel(), inputs); + evaluator_precomputed_integers_->bindKernelInputs(lowered_->kernel(), args); expr_eval.precomputedIntegers() = evaluator_precomputed_integers_.get(); launch_params_ = @@ -879,7 +1010,7 @@ std::vector FusionExecutor::runFusion( executor_utils::validateVectorizedTensors( lowered_.get()->kernel(), - inputs, + args, outputs, compileTimeDataCache(), expr_eval); @@ -894,7 +1025,6 @@ std::vector FusionExecutor::runFusion( auto& alias_indices = alias_indices_entry.get(); - // ditch pre-allocated outputs if the number doesn't match. // NOLINTNEXTLINE(bugprone-branch-clone) if (outputs.empty()) { auto output_alias_indices_entry = @@ -907,15 +1037,22 @@ std::vector FusionExecutor::runFusion( auto& output_alias_indices = output_alias_indices_entry.get(); - allocated_outputs = allocOutputs(inputs, expr_eval, output_alias_indices); + allocated_outputs = allocOutputs(expr_eval, output_alias_indices); for (const auto& entry : alias_indices) { + auto aliased_output_index = entry.first; + auto aliased_input_index = entry.second; + auto tensor_arg_abstract = + dynamic_cast(args[aliased_input_index]); TORCH_INTERNAL_ASSERT( - inputs[entry.second].isTensor(), "alias io only supports tensor"); - allocated_outputs[entry.first] = inputs[entry.second].toTensor(); + tensor_arg_abstract, "alias io only supports tensor"); + allocated_outputs[aliased_output_index] = + tensor_arg_abstract->getTensor(); } + args.push(allocated_outputs); } else { - // TODO: Update for aliasing, validate the outputs are the right sizes. + allocated_outputs = outputs; + args.push(outputs); executor_utils::validateKernelOutputs( fusion_, allocated_outputs, options_.device); } @@ -957,15 +1094,12 @@ std::vector FusionExecutor::runFusion( } } - KernelArgumentHolder kernel_arguments(options_.index_mode); - { - FUSER_PERF_SCOPE("ExecutorRunFusion::FillKernelArgStructure"); - kernel_arguments.push(inputs); - kernel_arguments.push(allocated_outputs); - kernel_arguments.push(global_buffers.buffers); - if (lowered_->kernel()->summary().max_rng_offsets >= 0) { - kernel_arguments.appendPhiloxRNGSeed(rand_offset); - } + // push back global buffers + args.push(global_buffers.buffers); + + // push back RNG state if needed + if (lowered_->kernel()->summary().max_rng_offsets >= 0) { + args.appendPhiloxRNGSeed(rand_offset); } if (isDebugDumpEnabled(DebugDumpOption::LaunchParam)) { @@ -975,17 +1109,11 @@ std::vector FusionExecutor::runFusion( if (isDebugDumpEnabled(DebugDumpOption::KernelArgs)) { std::cout << "Arguments for kernel" << fusion_id_ << ":" << std::endl << "Inputs:" << std::endl; - for (const auto& input : inputs) { - if (input.isTensor()) { - const auto& input_tensor = input.toTensor(); - std::cout << " " << input_tensor.scalar_type() << " " - << input.toTensor().sizes() - << " (strides = " << input.toTensor().strides() - << ", address = " << input.toTensor().data_ptr() << ")" - << std::endl; - } + for (auto i : c10::irange(args.size())) { + args[i]->print(); } std::cout << "Outputs:" << std::endl; + // note: add aliased outputs here. for (const auto& output : allocated_outputs) { std::cout << " " << output.scalar_type() << " " << output.sizes() << " (strides = " << output.strides() @@ -1040,7 +1168,7 @@ std::vector FusionExecutor::runFusion( launch_params_.bdimz(), launch_params_.smem(), stream, - kernel_arguments.getBuffer(), + args.getBuffer(), nullptr)); } else { #ifndef __HIP_PLATFORM_HCC__ @@ -1056,7 +1184,7 @@ std::vector FusionExecutor::runFusion( launch_params_.bdimz(), launch_params_.smem(), stream, - kernel_arguments.getBuffer())); + args.getBuffer())); #else TORCH_INTERNAL_ASSERT( false, "Cross grid communication not supported with HIP."); @@ -1076,10 +1204,11 @@ std::vector FusionExecutor::runFusion( bytes_processed_ = 0; // Figure how many bytes are inputs, outputs, and temporary buffers - for (auto input : inputs) { - if (input.isTensor()) { - bytes_processed_ += input.toTensor().numel() * - dataTypeSize(aten_to_data_type(input.toTensor().scalar_type())); + for (auto i : c10::irange(args.size())) { + if (auto tensor_arg_abstract = + dynamic_cast(args[i])) { + bytes_processed_ += tensor_arg_abstract->numel() * + dataTypeSize(tensor_arg_abstract->getDataType()); } } for (const auto& output : allocated_outputs) { diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h index 7ff6b7da3aaad..4a7c9d61e3e81 100644 --- a/torch/csrc/jit/codegen/cuda/executor.h +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -33,17 +33,46 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { int id, CompileOptions options = CompileOptions()); + //! infers output sizes via returning non-allocated KernelArgumentHolder. + //! this function is useful for async compilation for segmented fusion + KernelArgumentHolder inferOutputSizes( + const KernelArgumentHolder& args, + const LaunchParams& launch_constraints); + + void compileFusion( + Fusion* fusion, + const KernelArgumentHolder& args, + const LaunchParams& launch_constraints = LaunchParams()); + + // TODO: merge it with the overload above. + //! This API is merely here so we don't have to go back and update all cpp + //! tests. void compileFusion( Fusion* fusion, const at::ArrayRef& inputs = {}, + const LaunchParams& launch_constraints = LaunchParams()) { + KernelArgumentHolder args = + KernelArgumentHolder::createKernelArgumentHolder(inputs); + compileFusion(fusion, args, launch_constraints); + } + + std::vector runFusion( + KernelArgumentHolder& args, const LaunchParams& launch_constraints = LaunchParams(), - CompileOptions options = CompileOptions()); + const std::vector& outputs = {}); std::vector runFusion( const at::ArrayRef& inputs, const std::vector& outputs, const LaunchParams& launch_constraints = LaunchParams(), - const c10::optional& opt_code = c10::nullopt); + const c10::optional& opt_code = c10::nullopt) { + KernelArgumentHolder args = + KernelArgumentHolder::createKernelArgumentHolder(inputs); + if (opt_code.has_value()) { + args.setCacheId(*opt_code); + } + return runFusion(args, launch_constraints, outputs); + } std::vector runFusion( const at::ArrayRef& inputs, @@ -188,7 +217,6 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { // skip allocating real storage for those, but still maintain its spot to // maintain the indexing from output aliases to inputs std::vector allocOutputs( - const at::ArrayRef& inputs, kir::ExpressionEvaluator& expr_eval, const std::unordered_set& alias_indices = {}); @@ -202,6 +230,15 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { return &compile_time_info_cache_; } + //! returns KernelArgumentHolder representing the output sizes from kernel + //! execution. Note: 1. this API would ignoring aliased outputs and instead + //! pushing scalar int 0 as a place holder; 2. this API doesn't actually + //! allocate output in memory, but rather is used just to infer output sizes. + KernelArgumentHolder evaluateOutputSizes( + const KernelArgumentHolder& args, + kir::ExpressionEvaluator& expr_eval, + const std::unordered_set& alias_indices = {}); + private: CompileOptions options_; diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp index dc2e4d1fa49f9..bc1ce2a4b7bc2 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp @@ -120,6 +120,24 @@ std::unique_ptr getTensorArg( } // namespace +KernelArgumentHolder KernelArgumentHolder::createKernelArgumentHolder( + const c10::ArrayRef& inputs) { + if (inputs.empty()) { + // default to int32 on device 0 + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + return args; + } + auto device_index = getCommonDeviceCUDA(inputs); + auto index_mode = collectIndexMode(inputs); + + KernelArgumentHolder args(index_mode); + args.setDeviceIndex(device_index); + args.push(inputs); + + return args; +} + // Push a tensor to the arguments void KernelArgumentHolder::push(const at::Tensor& tensor) { changed_ = true; @@ -188,7 +206,9 @@ void KernelArgumentHolder::push(const at::Tensor& tensor) { c10::ScalarType dtype = tensor.scalar_type(); std::unique_ptr tensor_arg = getTensorArg(dtype, nDims, index_mode_); + tensor_arg->setTensor(tensor); tensor_arg->setPointer(tensor.data_ptr()); + tensor_arg->setDataType(aten_to_data_type(dtype)); for (const auto i : c10::irange(nDims)) { tensor_arg->setSize(i, tensor.sizes()[i]); tensor_arg->setStride(i, tensor.strides()[i]); @@ -230,6 +250,10 @@ void KernelArgumentHolder::push(const IValue& val) { " Tried to create argument to send to a fused kernel, but got a non-scalar type."); } +void KernelArgumentHolder::push(int64_t val) { + arguments_.push_back(std::make_unique(val)); +} + void KernelArgumentHolder::push(const at::PhiloxCudaState& val) { arguments_.push_back(std::make_unique(val)); } @@ -266,6 +290,17 @@ void KernelArgumentHolder::push(const std::vector& tensors) { } } +void KernelArgumentHolder::push(const ArgAbstract* arg) { + changed_ = true; + arguments_.emplace_back(arg->copy_unique_ptr()); +} + +void KernelArgumentHolder::swap(int i, const ArgAbstract* arg) { + changed_ = true; + auto holder = arg->copy_unique_ptr(); + arguments_[i].swap(holder); +} + void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) { at::PhiloxCudaState philox_engine_inputs; auto gen = at::cuda::detail::getDefaultCUDAGenerator(); diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h index c135328a3acc1..ddf4379373269 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -21,7 +22,7 @@ struct TensorArgCodegen { T* data; std::array size; std::array stride; - constexpr int nDims() { + constexpr int nDims() const { return N; } void setSize(int i, nvfuser_index_t s) { @@ -30,6 +31,12 @@ struct TensorArgCodegen { void setStride(int i, nvfuser_index_t s) { stride[i] = s; } + nvfuser_index_t getSize(int i) const { + return size[i]; + } + nvfuser_index_t getStride(int i) const { + return stride[i]; + } }; // 0-Dim GPU based tensor @@ -40,7 +47,7 @@ struct TensorArgCodegen { }; T* data; - constexpr int nDims() { + constexpr int nDims() const { return 0; } void setSize(int, nvfuser_index_t) { @@ -49,6 +56,12 @@ struct TensorArgCodegen { void setStride(int, nvfuser_index_t) { TORCH_INTERNAL_ASSERT(false, "Tried to set stride of a 0-dim tensor"); } + nvfuser_index_t getSize(int i) const { + TORCH_INTERNAL_ASSERT(false, "Tried to get size of a 0-dim tensor"); + } + nvfuser_index_t getStride(int i) const { + TORCH_INTERNAL_ASSERT(false, "Tried to get stride of a 0-dim tensor"); + } }; // Specialization for 0-dim case that's easy to pass in a CPU based tensor @@ -62,62 +75,151 @@ struct CpuScalarTensorCodegen { T data; }; +// TODO: macro this and the printer below +enum class ArgType { + PhiloxCudaState, + Long, + Double, + ComplexDouble, + Bool, + Tensor, + CpuScalarTensor +}; + +inline std::string argTypeToString(ArgType type) { + std::string ret; + switch (type) { + case ArgType::PhiloxCudaState: + ret = "PhiloxCudaState"; + break; + case ArgType::Long: + ret = "Long"; + break; + case ArgType::Double: + ret = "Double"; + break; + case ArgType::ComplexDouble: + ret = "ComplexDouble"; + break; + case ArgType::Bool: + ret = "Bool"; + break; + case ArgType::Tensor: + ret = "Tensor"; + break; + case ArgType::CpuScalarTensor: + ret = "CpuScalarTensor"; + break; + } + return ret; +} + struct ArgAbstract { virtual ~ArgAbstract() = default; + virtual const void* arg() const = 0; virtual void* arg() = 0; + virtual bool isType(ArgType type) const = 0; + virtual ArgType type() const = 0; + virtual std::unique_ptr copy_unique_ptr() const = 0; + virtual void print() const { + printf("input type: %s\n", argTypeToString(type()).c_str()); + }; }; +#define DEF_HELPEE_FUNC(TARGET_TYPE, ARG_NAME) \ + bool isType(ArgType type) const override { \ + return ArgType::TARGET_TYPE == type; \ + } \ + ArgType type() const override { \ + return ArgType::TARGET_TYPE; \ + } \ + const void* arg() const override { \ + return &ARG_NAME; \ + } \ + void* arg() override { \ + return &ARG_NAME; \ + } \ + std::unique_ptr copy_unique_ptr() const override { \ + return std::make_unique(*this); \ + } + +#define DEF_PRINT_FUNC \ + void print() const override { \ + std::cout << val_ << std::endl; \ + } + struct PhiloxCudaStateArg : public ArgAbstract { at::PhiloxCudaState val_; PhiloxCudaStateArg(at::PhiloxCudaState _val) : val_(_val){}; - void* arg() override { - return &val_; - } + DEF_HELPEE_FUNC(PhiloxCudaState, val_) }; struct LongArg : public ArgAbstract { int64_t val_; explicit LongArg(int64_t _val) : val_(_val) {} - void* arg() override { - return &val_; - } + DEF_HELPEE_FUNC(Long, val_) + DEF_PRINT_FUNC }; struct DoubleArg : public ArgAbstract { double val_; explicit DoubleArg(double _val) : val_(_val) {} - void* arg() override { - return &val_; - } + DEF_HELPEE_FUNC(Double, val_) + DEF_PRINT_FUNC }; struct ComplexDoubleArg : public ArgAbstract { c10::complex val_; explicit ComplexDoubleArg(c10::complex _val) : val_(_val) {} - void* arg() override { - return &val_; - } + DEF_HELPEE_FUNC(ComplexDouble, val_) + DEF_PRINT_FUNC }; struct BoolArg : public ArgAbstract { bool val_; explicit BoolArg(bool _val) : val_(_val) {} - void* arg() override { - return &val_; - } + DEF_HELPEE_FUNC(Bool, val_) + DEF_PRINT_FUNC }; struct TensorArgAbstract : ArgAbstract { virtual void setSize(int i, int64_t size) = 0; virtual void setStride(int i, int64_t stride) = 0; virtual void setPointer(void* ptr) = 0; + virtual void setDataType(DataType data_type) = 0; + virtual void setTensor(at::Tensor tensor) = 0; + + virtual int64_t getRank() const = 0; + virtual int64_t getSize(int i) const = 0; + virtual int64_t getStride(int i) const = 0; + virtual void* getPointer() const = 0; + virtual DataType getDataType() const = 0; + virtual int64_t numel() const = 0; + virtual at::Tensor getTensor() const = 0; + + // TODO: clean it up and also print out dtype + void print() const override { + auto rank = getRank(); + std::cout << "tensor dtype: " << getDataType() << " sizes: ("; + for (auto i = 0; i < rank; i++) { + std::cout << getSize(i) << ", "; + } + std::cout << ") stride: ("; + for (auto i = 0; i < rank; i++) { + std::cout << getStride(i) << ", "; + } + std::cout << ") pointer: " << getPointer() << std::endl; + } }; -// This should match the tensor used in the code generation (almost exactly) template // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct TensorArg : public TensorArgAbstract { TENSOR_TYPE instance_; + // TODO: this is ugly, we should be extracting data type from `instance_` + // instead + DataType data_type_ = DataType::Null; + at::Tensor tensor_; void setSize(int i, int64_t size) override { instance_.setSize(i, (nvfuser_index_t)size); @@ -128,10 +230,40 @@ struct TensorArg : public TensorArgAbstract { void setPointer(void* ptr) override { instance_.data = static_cast(ptr); } + void setDataType(DataType data_type) override { + data_type_ = data_type; + } + void setTensor(at::Tensor tensor) override { + tensor_ = tensor; + } - void* arg() override { - return &instance_; + int64_t getSize(int i) const override { + return instance_.getSize(i); + } + int64_t getStride(int i) const override { + return instance_.getStride(i); + } + int64_t getRank() const override { + return instance_.nDims(); + } + void* getPointer() const override { + return instance_.data; + } + DataType getDataType() const override { + return data_type_; + } + at::Tensor getTensor() const override { + return tensor_; + } + int64_t numel() const override { + int64_t ret = 1; + for (auto i : c10::irange(instance_.nDims())) { + ret *= instance_.getSize(i); + } + return ret; } + + DEF_HELPEE_FUNC(Tensor, instance_) }; template @@ -144,16 +276,37 @@ struct CpuScalarTensorArg : public ArgAbstract { instance_.data = _data; } - void* arg() override { - return &instance_; - } + DEF_HELPEE_FUNC(CpuScalarTensor, instance_) }; -class KernelArgumentHolder { +// TODO: This class needs some further clean up and refactor +//! KernelArgumentHolder copies meta information from kernel inputs, including +//! tensor sizes/shapes/dtype/memory_ptr and copies scalar inputs. It is used +//! for both compilation as well as kernel execution. The important thing is to +//! strip ownership of tensor from KernelArgumentHolder, so that during async +//! compilation, we are not unnecessarily holding memory that is not needed. +class TORCH_CUDA_CU_API KernelArgumentHolder { public: + //! create KernelArgumentHolder from c10 inputs. Note that we we not taking + //! the ownership of the memory from the original inputs, but just recording + //! its meta data for kernel execution/compilation. + static KernelArgumentHolder createKernelArgumentHolder( + const c10::ArrayRef& inputs); + + KernelIndexMode getIndexMode() const { + return index_mode_; + } + explicit KernelArgumentHolder(KernelIndexMode index_mode) : index_mode_(index_mode) {} + KernelArgumentHolder(const KernelArgumentHolder& self) + : device_index_(self.getDeviceIndex()), index_mode_(self.getIndexMode()) { + for (const auto& arg : self.arguments_) { + push(arg.get()); + } + } + // Push a tensor to the arguments void push(const at::Tensor& tensor); @@ -170,12 +323,60 @@ class KernelArgumentHolder { void push(const std::vector& tensors); + void push(const ArgAbstract* arg); + + void swap(int i, const ArgAbstract* arg); + + // push int64 + void push(int64_t val); + + const ArgAbstract* back() const { + return arguments_.back().get(); + } + void appendPhiloxRNGSeed(uint64_t rand_offset); + const ArgAbstract* operator[](int ind) const { + return arguments_.at(ind).get(); + }; + + size_t size() const { + return arguments_.size(); + } + + bool empty() const { + return arguments_.empty(); + } + + void setDeviceIndex(int index) { + device_index_ = index; + } + + int getDeviceIndex() const { + return device_index_; + } + + void setCacheId(size_t id) { + cache_id_ = id; + } + + c10::optional getCacheId() const { + return cache_id_; + } + + void print() const { + for (const auto& arg : arguments_) { + arg->print(); + } + } + private: std::vector> arguments_; std::vector void_ptrs_; bool changed_ = true; + + int device_index_ = 0; + c10::optional cache_id_ = c10::nullopt; KernelIndexMode index_mode_ = KernelIndexMode::INT64; }; diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index b0f35962c5ddb..e94e7afd53b8b 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -160,6 +160,7 @@ bool validateKernelArgTensor( at::ScalarType arg_data_type = arg.scalar_type(); DataType param_data_type = *param->getDataType(); bool match = false; + // TODO: remove this switch with `aten_to_data_type` switch (arg_data_type) { case at::ScalarType::Double: match = param_data_type == DataType::Double; @@ -201,36 +202,36 @@ bool validateKernelArgTensor( // Return false if arg_type doesn't match the type in param bool validateKernelArgScalar( - const c10::IValue& arg, + const ArgAbstract* arg, const Val* param, std::stringstream& msg) { - if (!arg.isScalar()) { - msg << "Argument is a scalar, but the parameter is not." - << "\n"; - return false; - } + TORCH_INTERNAL_ASSERT( + param->getDataType().has_value(), "kernel param should have data type"); DataType param_type = *param->getDataType(); bool match = false; - switch (arg.toScalar().type()) { - case c10::ScalarType::Long: + switch (arg->type()) { + case ArgType::Long: match = param_type == DataType::Int || param_type == DataType::Int32; break; - case c10::ScalarType::ComplexDouble: - match = param_type == DataType::ComplexDouble || - param_type == DataType::ComplexFloat; - break; - case c10::ScalarType::Double: + case ArgType::Double: match = param_type == DataType::Double || param_type == DataType::Float || param_type == DataType::Half || param_type == DataType::BFloat16; break; - case c10::ScalarType::Bool: + case ArgType::Bool: match = param_type == DataType::Bool; break; + case ArgType::ComplexDouble: + match = param_type == DataType::ComplexDouble || + param_type == DataType::ComplexFloat; + break; default: - match = false; + // TODO: We need to verify that param is actually a scalar + msg << "Argument is not a scalar, but the parameter is." + << "\n"; + return false; } if (!match) { - msg << "Argument type is " << arg.toScalar().type() + msg << "Argument type is " << argTypeToString(arg->type()) << ", but the parameter is " << param_type << "\n"; } return match; @@ -239,12 +240,23 @@ bool validateKernelArgScalar( // Return false if arg and param don't match up and if arg's device (if a // tensor) doesn't match provided device bool validateKernelArg( - const c10::IValue& arg, + const ArgAbstract* arg, const Val* param, const c10::Device& device, std::stringstream& msg) { - if (arg.isTensor()) { - return validateKernelArgTensor(arg.toTensor(), param, device, msg); + if (auto tensor_arg_abstract = dynamic_cast(arg)) { + // TODO: don't use get tensor here. We would want to remove tensor reference + // for async compilation + return validateKernelArgTensor( + tensor_arg_abstract->getTensor(), param, device, msg); + } else if (arg->isType(ArgType::CpuScalarTensor)) { + // TODO: merge this one with above + // TODO: we need to check cpu scalar dtyp matches param + bool match = param->as()->isCpuScalar(); + if (!match) { + msg << "Argument is scalar type, but kernel parameter is not\n"; + } + return match; } else { return validateKernelArgScalar(arg, param, msg); } @@ -332,7 +344,7 @@ bool checkValidMisalignedTensors( void validateKernelInputs( Fusion* fusion, - const at::ArrayRef& inputs, + const KernelArgumentHolder& args, const c10::Device& device) { FUSER_PERF_SCOPE("executor_utils::ValidateKernelInputs"); @@ -340,13 +352,12 @@ void validateKernelInputs( FusionGuard fg(fusion); // Check inputs TORCH_INTERNAL_ASSERT( - inputs.size() == fusion->inputs().size(), - "Wrong number of kernel inputs."); + args.size() == fusion->inputs().size(), "Wrong number of kernel inputs."); std::stringstream msg; bool mismatch = false; - for (const auto i : c10::irange(inputs.size())) { - const IValue& arg = inputs[i]; + for (const auto i : c10::irange(args.size())) { + const ArgAbstract* arg = args[i]; const Val* param = fusion->inputs()[i]; mismatch = !validateKernelArg(arg, param, device, msg) || mismatch; } @@ -373,7 +384,7 @@ void validateKernelOutputs( for (const auto i : c10::irange(outputs.size())) { const at::Tensor& arg = outputs[i]; const Val* param = fusion->outputs()[i]; - mismatch = !validateKernelArg(arg, param, device, msg) || mismatch; + mismatch = !validateKernelArgTensor(arg, param, device, msg) || mismatch; } TORCH_INTERNAL_ASSERT( !mismatch, "Found one or more invalid arguments: ", msg.str()); @@ -557,13 +568,9 @@ void validateAlignedVectorizeExtents( } void validateAlignedVectorizedFusionInputOutput( - const IValue& aten_val, + const at::Tensor& aten_tensor, int word_size, TensorView* tv) { - TORCH_INTERNAL_ASSERT(aten_val.isTensor()); - - const auto& aten_tensor = aten_val.toTensor(); - TORCH_INTERNAL_ASSERT( reinterpret_cast(aten_tensor.data_ptr()) % (word_size * aten_tensor.dtype().itemsize()) == @@ -614,7 +621,7 @@ void validateAlignedVectorizedFusionInputOutput( void validateAlignedVectorizedTensors( kir::Kernel* kernel, - const at::ArrayRef& inputs, + const KernelArgumentHolder& args, const std::vector& outputs, caching::ExecutorCompileTimeInfoCache* data_cache, kir::ExpressionEvaluator& expr_eval) { @@ -639,9 +646,12 @@ void validateAlignedVectorizedTensors( .aligned_vectorized_inp_tensor_pos) { auto tv = kernel->inputs().at(pos)->as(); auto word_size = kernel->summary().vectorized_accesses.at(tv); - validateAlignedVectorizedFusionInputOutput(inputs[pos], word_size, tv); + auto tensor_arg_abstract = + dynamic_cast(args[pos]); + TORCH_INTERNAL_ASSERT(tensor_arg_abstract, "alias io only supports tensor"); + validateAlignedVectorizedFusionInputOutput( + tensor_arg_abstract->getTensor(), word_size, tv); } - if (!outputs.empty()) { for (auto pos : tensor_vectorization_validation_entry.get() .aligned_vectorized_out_tensor_pos) { @@ -657,7 +667,7 @@ void validateAlignedVectorizedTensors( // could be improved to include shared memory. void validateMisalignedVectorizedTensors( kir::Kernel* kernel, - const at::ArrayRef& inputs, + const KernelArgumentHolder& args, const std::vector& outputs, caching::ExecutorCompileTimeInfoCache* data_cache, kir::ExpressionEvaluator& expr_eval) { @@ -678,7 +688,13 @@ void validateMisalignedVectorizedTensors( inp_misaligned_tensors_pos.begin(), inp_misaligned_tensors_pos.end(), std::back_inserter(inp_misaligned_tensors), - [&inputs](int idx) { return inputs[idx]; }); + [&args](int idx) { + auto tensor_arg_abstract = + dynamic_cast(args[idx]); + TORCH_INTERNAL_ASSERT( + tensor_arg_abstract, "alias io only supports tensor"); + return tensor_arg_abstract->getTensor(); + }); const auto& out_misaligned_tensors_pos = tensor_vectorization_validation_entry.get().out_misaligned_tensors_pos; @@ -732,61 +748,62 @@ void validateVectorizedSplits( void validateVectorizedTensors( kir::Kernel* kernel, - const at::ArrayRef& inputs, + const KernelArgumentHolder& args, const std::vector& outputs, caching::ExecutorCompileTimeInfoCache* data_cache, kir::ExpressionEvaluator& expr_eval) { FUSER_PERF_SCOPE("FusionExecutor::validateVectorizedTensors"); validateAlignedVectorizedTensors( - kernel, inputs, outputs, data_cache, expr_eval); + kernel, args, outputs, data_cache, expr_eval); validateMisalignedVectorizedTensors( - kernel, inputs, outputs, data_cache, expr_eval); + kernel, args, outputs, data_cache, expr_eval); validateVectorizedSplits(kernel, expr_eval); } -kir::ExpressionEvaluator bindKernelInputs( - const at::ArrayRef& aten_inputs, - kir::Kernel* kernel, - bool check_consistency) { - FUSER_PERF_SCOPE("executor_utils::BindKernelInputs"); - - TORCH_INTERNAL_ASSERT( - kernel->inputs().size() == aten_inputs.size(), - "Something went wrong configuring launch. Inputs no longer match."); - - kir::ExpressionEvaluator expr_eval; - const auto& inputs = kernel->inputs(); - - for (const auto i : c10::irange(inputs.size())) { - const auto input = inputs[i]; +namespace { - if (auto tensor_input = dynamic_cast(input)) { +template +void bindInputForExprEvaluation( + Val* val, + const ArgAbstract* arg, + bool check_consistency, + EXPR_EVALUATOR& expr_eval) { + if (val->getValType() == ValType::TensorView) { + TensorView* cg_tensor = val->as(); + auto root_domain = + TensorDomain::noReductions(cg_tensor->getMaybeRFactorDomain()); + + if (root_domain.size() == 0) { + TORCH_INTERNAL_ASSERT( + arg->isType(ArgType::CpuScalarTensor) || + (arg->isType(ArgType::Tensor) && + dynamic_cast(arg)->getRank() == 0), + "Something went wrong configuring launch. Inputs is not rank 0 tensor"); + } else { TORCH_INTERNAL_ASSERT( - aten_inputs[i].isTensor(), - "Something went wrong configuring launch. Inputs no longer match at index:", - i); + arg->isType(ArgType::Tensor), + "Something went wrong configuring launch. Inputs do not match."); - const auto aten_tensor = aten_inputs[i].toTensor(); - const auto root_domain = TensorDomain::noReductions( - tensor_input->domain()->getMaybeRFactorDomain()); + auto tensor_arg_abstract = dynamic_cast(arg); TORCH_INTERNAL_ASSERT( - aten_tensor.ndimension() == static_cast(root_domain.size()), - "Something went wrong configuring launch. Inputs no longer match."); + tensor_arg_abstract && + tensor_arg_abstract->getRank() == (int64_t)root_domain.size(), + "Something went wrong configuring launch. Inputs rank does not match."); for (const auto dim : c10::irange(root_domain.size())) { + const auto tensor_arg_size = tensor_arg_abstract->getSize(dim); + const auto tensor_arg_stride = tensor_arg_abstract->getStride(dim); const auto extent = root_domain[dim]->extent(); if (root_domain[dim]->hasExpandedExtent()) { TORCH_INTERNAL_ASSERT( - aten_tensor.strides()[dim] == 0, - "Execting an expanded dimension on ", - inputs[i]->toString(), - " dimension ", + tensor_arg_stride == 0, + "Execting an expanded dimension on dimension ", dim, " but found stride ", - aten_tensor.strides()[dim]); + tensor_arg_stride); // Could support dynamic size on expanded dimension, so may not have // an inferable expanded extent here. This check might be better to do // once all values are bound. @@ -794,18 +811,17 @@ kir::ExpressionEvaluator bindKernelInputs( expr_eval.evaluate(root_domain[dim]->expandedExtent()); if (maybe_expanded_size.has_value()) { TORCH_CHECK( - *maybe_expanded_size == aten_tensor.sizes()[dim], + *maybe_expanded_size == tensor_arg_size, "Expecting expanded extent of ", *maybe_expanded_size, " but recieved value of ", - aten_tensor.sizes()[dim]); + tensor_arg_size); } } - const auto value = root_domain[dim]->hasExpandedExtent() - ? 1 - : aten_tensor.sizes()[dim]; - if (value == 0 && tensor_input->uses().empty()) { + const auto value = + root_domain[dim]->hasExpandedExtent() ? 1 : tensor_arg_size; + if (value == 0 && cg_tensor->uses().empty()) { // If there's no uses, ignore there's a size-0 dimension. continue; } @@ -829,104 +845,56 @@ kir::ExpressionEvaluator bindKernelInputs( expr_eval.bind(extent, value); } } - // NOLINTNEXTLINE: https://bugs.llvm.org/show_bug.cgi?id=48525 - } else if (input->isScalar() && input->dtype() == DataType::Int) { - TORCH_INTERNAL_ASSERT( - aten_inputs[i].type()->kind() == c10::TypeKind::IntType, - "kernel expected Scalar Int inputs, but found", - aten_inputs[i].type()->str()); - expr_eval.bind(input, aten_inputs[i].toInt()); } + } else if ( + val->getValType().value() == ValType::Scalar && + val->getDataType().value() == DataType::Int) { + TORCH_INTERNAL_ASSERT( + arg->isType(ArgType::Long), + "fusion expected Scalar Int inputs, but found", + argTypeToString(arg->type())); + expr_eval.bind(val, *static_cast(arg->arg())); } +} +} // namespace + +kir::ExpressionEvaluator bindKernelInputs( + const KernelArgumentHolder& args, + kir::Kernel* kernel, + bool check_consistency) { + FUSER_PERF_SCOPE("executor_utils::BindKernelInputs"); + + TORCH_INTERNAL_ASSERT( + kernel->inputs().size() == args.size(), + "Something went wrong configuring launch. Inputs no longer match."); + + kir::ExpressionEvaluator expr_eval; + const auto& inputs = kernel->inputs(); + + for (const auto i : c10::irange(inputs.size())) { + bindInputForExprEvaluation( + inputs[i], args[i], check_consistency, expr_eval); + } return expr_eval; } ExpressionEvaluator bindFusionInputs( - const at::ArrayRef& aten_inputs, + const KernelArgumentHolder& args, Fusion* fusion) { FUSER_PERF_SCOPE("executor_utils::BindFusionInputs"); + auto inputs = fusion->inputs(); TORCH_INTERNAL_ASSERT( - fusion->inputs().size() == aten_inputs.size(), + inputs.size() == args.size(), "Something went wrong configuring launch. Inputs do not match."); ExpressionEvaluator expr_eval(fusion); - auto inputs = fusion->inputs(); // This should probably move to EvaluationContext as we may want to bind // input values frequently. Bind fusion input values to runtime values. - for (const auto i : c10::irange(fusion->inputs().size())) { - if (inputs[i]->getValType() == ValType::TensorView) { - TensorView* cg_tensor = inputs[i]->as(); - - TORCH_INTERNAL_ASSERT( - aten_inputs[i].isTensor(), - "Something went wrong configuring launch. Inputs do not match."); - - auto aten_tensor = aten_inputs[i].toTensor(); - auto root_domain = - TensorDomain::noReductions(cg_tensor->getMaybeRFactorDomain()); - TORCH_INTERNAL_ASSERT( - aten_tensor.ndimension() == (int64_t)root_domain.size(), - "Something went wrong configuring launch. Inputs do not match."); - for (const auto dim : c10::irange(root_domain.size())) { - const auto extent = root_domain[dim]->extent(); - if (root_domain[dim]->hasExpandedExtent()) { - TORCH_INTERNAL_ASSERT( - aten_tensor.strides()[dim] == 0, - "Execting an expanded dimension on ", - inputs[i]->toString(), - " dimension ", - dim, - " but found stride ", - aten_tensor.strides()[dim]); - // Could support dynamic size on expanded dimension, so may not have - // an inferable expanded extent here. This check might be better to do - // once all values are bound. - auto maybe_expanded_size = - expr_eval.evaluate(root_domain[dim]->expandedExtent()); - if (maybe_expanded_size.has_value()) { - TORCH_CHECK( - *maybe_expanded_size == aten_tensor.sizes()[dim], - "Expecting expanded extent of ", - *maybe_expanded_size, - " but recieved value of ", - aten_tensor.sizes()[dim]); - } - } - - const auto value = root_domain[dim]->hasExpandedExtent() - ? 1 - : aten_tensor.sizes()[dim]; - if (value == 0 && cg_tensor->uses().empty()) { - // If there's no uses, ignore there's a size-0 dimension. - continue; - } - TORCH_INTERNAL_ASSERT(value != 0, "Cannot handle size-0 dimensions"); - const auto prev_value = expr_eval.evaluate(extent); - if (prev_value.has_value()) { - TORCH_CHECK( - *prev_value == value, - "Attempting to bind ", - extent, - " to ", - value, - " but it's already set to ", - *prev_value); - } else { - expr_eval.bind(extent, value); - } - } - } else if ( - inputs[i]->getValType().value() == ValType::Scalar && - inputs[i]->getDataType().value() == DataType::Int) { - TORCH_INTERNAL_ASSERT( - aten_inputs[i].type()->kind() == c10::TypeKind::IntType, - "fusion expected Scalar Int inputs, but found", - aten_inputs[i].type()->str()); - expr_eval.bind(inputs[i], aten_inputs[i].toInt()); - } + for (const auto i : c10::irange(inputs.size())) { + bindInputForExprEvaluation(inputs[i], args[i], true, expr_eval); } return expr_eval; } diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.h b/torch/csrc/jit/codegen/cuda/executor_utils.h index 37817838f3869..af3b4d9372d41 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.h +++ b/torch/csrc/jit/codegen/cuda/executor_utils.h @@ -9,6 +9,7 @@ #include +#include #include #include #include @@ -30,7 +31,7 @@ std::string kernelPreamble(); void validateKernelInputs( Fusion* fusion, - const at::ArrayRef& inputs, + const KernelArgumentHolder& args, const c10::Device& device); void validateKernelOutputs( @@ -40,13 +41,13 @@ void validateKernelOutputs( //! Bind kernel input values to runtime values kir::ExpressionEvaluator bindKernelInputs( - const at::ArrayRef& aten_inputs, + const KernelArgumentHolder& args, kir::Kernel* kernel, bool check_consistency = true); //! Bind fusion input values to runtime values TORCH_CUDA_CU_API ExpressionEvaluator -bindFusionInputs(const at::ArrayRef& aten_inputs, Fusion* fusion); +bindFusionInputs(const KernelArgumentHolder& args, Fusion* fusion); struct NvrtcFunction { CUmodule module = CUmodule(); @@ -303,7 +304,7 @@ std::unique_ptr getWarpPaddedExtentsInfo( void validateVectorizedTensors( kir::Kernel* kernel, - const at::ArrayRef& inputs, + const KernelArgumentHolder& args, const std::vector& outputs, caching::ExecutorCompileTimeInfoCache* data_cache, kir::ExpressionEvaluator& expr_eval); diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 595ff6168433c..2a7fbf5b2827d 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -51,9 +51,9 @@ void swap(Fusion& a, Fusion& b) noexcept { } std::unique_ptr Fusion::segment( - const at::ArrayRef& inputs) { + const KernelArgumentHolder& args) { FUSER_PERF_SCOPE("Segment Fusion"); - return SegmentCandidateFinder::segment(this, inputs); + return SegmentCandidateFinder::segment(this, args); } IrCloner Fusion::copy(const Fusion* from, Fusion* to) { diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index cf7b035e971f5..626d2ab7be6c0 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -168,8 +169,7 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer { bool isStochastic(); //! Run fusion segmentation algorithm to create a segmented fusion - std::unique_ptr segment( - const at::ArrayRef& inputs); + std::unique_ptr segment(const KernelArgumentHolder& args); const auto& inputs() const { return inputs_; diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp index 4e76bffe665b4..88d3e679c0d75 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp @@ -1722,7 +1722,7 @@ class TranslateApplicableWelford { //! returns true if any welford has been translated static bool run( SegmentedFusion* segmented_fusion, - const at::ArrayRef& runtime_inputs) { + const KernelArgumentHolder& runtime_inputs) { TranslateApplicableWelford translate_welford( segmented_fusion, runtime_inputs); return translate_welford.translated_any_welford_; @@ -1730,7 +1730,7 @@ class TranslateApplicableWelford { //! Try translation on complete fusion, //! returns true if any welford has been translated - static bool run(Fusion* fusion, const at::ArrayRef& runtime_inputs) { + static bool run(Fusion* fusion, const KernelArgumentHolder& runtime_inputs) { TranslateApplicableWelford translate_welford(fusion, runtime_inputs); return translate_welford.translated_any_welford_; } @@ -1738,11 +1738,11 @@ class TranslateApplicableWelford { private: explicit TranslateApplicableWelford( SegmentedFusion* segmented_fusion, - const at::ArrayRef& runtime_inputs); + const KernelArgumentHolder& runtime_inputs); explicit TranslateApplicableWelford( Fusion* fusion, - const at::ArrayRef& runtime_inputs); + const KernelArgumentHolder& runtime_inputs); //! Given vector of welford ops from the same fusion, //! checks if translating all of them result in a @@ -1774,7 +1774,7 @@ class TranslateApplicableWelford { bool translated_any_welford_ = false; //! a reference to global fusion runtime inputs - const at::ArrayRef& runtime_inputs_; + const KernelArgumentHolder& runtime_inputs_; //! For translation within group only, //! group boundary at test copy @@ -1785,7 +1785,7 @@ class TranslateApplicableWelford { TranslateApplicableWelford::TranslateApplicableWelford( Fusion* fusion, - const at::ArrayRef& runtime_inputs) + const KernelArgumentHolder& runtime_inputs) : runtime_inputs_(runtime_inputs) { auto exprs = fusion->exprs(); std::vector orignal_welfords( @@ -1802,7 +1802,7 @@ TranslateApplicableWelford::TranslateApplicableWelford( TranslateApplicableWelford::TranslateApplicableWelford( SegmentedFusion* segmented_fusion, - const at::ArrayRef& runtime_inputs) + const KernelArgumentHolder& runtime_inputs) : runtime_inputs_(runtime_inputs) { std::vector translated_groups; std::vector welford_to_translate; @@ -2046,7 +2046,7 @@ void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) { bool SegmentCandidateFinder::TranslateWelfordInFusion( Fusion* fusion, - const at::ArrayRef& runtime_inputs) { + const KernelArgumentHolder& runtime_inputs) { return TranslateApplicableWelford::run(fusion, runtime_inputs); } @@ -2616,7 +2616,7 @@ ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic( SegmentCandidateFinder::SegmentCandidateFinder( std::unique_ptr fusion, - const at::ArrayRef& inputs, + const KernelArgumentHolder& inputs, SegmentCandidateFinderOptions options) : options_(options), runtime_info_(fusion.get(), inputs, true), @@ -3100,7 +3100,7 @@ FusionKernelRuntime::SchedulerEntryPtr SegmentedFusion:: } std::unique_ptr SegmentedFusion::makeInitialHeuristics( - const at::ArrayRef& inputs) { + const KernelArgumentHolder& inputs) { auto ret = std::make_unique(); SchedulerRuntimeInfo runtime_info(completeFusion(), inputs, true); for (auto g : groups()) { diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h index 1af3374a58748..abdd4b7e7c42d 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h @@ -307,7 +307,7 @@ class TORCH_CUDA_CU_API SegmentedFusion { //! Make heuristics for all groups in this segmented fusion std::unique_ptr makeInitialHeuristics( - const at::ArrayRef& inputs); + const KernelArgumentHolder& inputs); //! Inline Debug print for segmented fusion std::string toString(int verbosity) const; @@ -445,7 +445,7 @@ class TORCH_CUDA_CU_API SegmentCandidateFinder { // Perform segmentation on a copy of the given fusion static std::unique_ptr segment( const Fusion* fusion, - const at::ArrayRef& inputs, + const KernelArgumentHolder& inputs, SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) { auto fusion_copy = std::make_unique(*fusion); if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { @@ -460,7 +460,7 @@ class TORCH_CUDA_CU_API SegmentCandidateFinder { // Perform segmentation on and take ownership of the given fusion static std::unique_ptr segment( std::unique_ptr fusion, - const at::ArrayRef& inputs, + const KernelArgumentHolder& inputs, SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) { SegmentCandidateFinder scf(std::move(fusion), inputs, options); if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { @@ -473,13 +473,13 @@ class TORCH_CUDA_CU_API SegmentCandidateFinder { static bool TranslateWelfordInFusion( Fusion* fusion, - const at::ArrayRef& runtime_inputs); + const KernelArgumentHolder& runtime_inputs); private: // Perform segmentation on and take ownership of the given fusion SegmentCandidateFinder( std::unique_ptr fusion, - const at::ArrayRef& inputs, + const KernelArgumentHolder& inputs, SegmentCandidateFinderOptions options); void resetTraversal(); @@ -612,7 +612,7 @@ class TORCH_CUDA_CU_API SegmentCandidateFinder { //! TODO: //! implement the expression evaluator transfer and //! remove runtime_inputs_ in a follow up. - const at::ArrayRef& runtime_inputs_; + const KernelArgumentHolder& runtime_inputs_; }; // TODO: Make as member functions on classes instead of global scope diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index e1ed1d56c496d..924a0538667c7 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -8,6 +8,8 @@ #include #include +#include +#include #include #include @@ -18,28 +20,12 @@ namespace cuda { namespace { -// Check device of TensorType in all inputs ensure all tensors are on cuda -// devices. -// return common device index (or -1 if device differs). -int getCommonDeviceCUDA(const at::ArrayRef& inputs) { - int index = -1; - for (const auto& input : inputs) { - if (!input.isTensor()) { - continue; - } - const auto& device = input.toTensor().device(); - // skip cpu scalar tensor as they'll be promoted to scalar later - if (device.is_cpu() && is_cpu_scalar(input.toTensor())) { - continue; - } - TORCH_CHECK(device.is_cuda(), "nvfuser only supports cuda device"); - auto cur_index = device.index(); - if (index != -1 && index != cur_index) { - return -1; - } - index = (int)cur_index; // NOLINT - } - return index; +#define THREAD_POOL_SIZE 10 + +// TODO: clean this up with some knobs +c10::ThreadPool* getThreadPool() { + static c10::ThreadPool pool(THREAD_POOL_SIZE); + return &pool; } void encodeBuffer(size_t value, std::string& buffer) { @@ -53,8 +39,7 @@ void encodeBuffer(size_t value, std::string& buffer) { } // namespace InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( - const at::ArrayRef& inputs, - const SchedulerRuntimeInfo* additional_info) { + const at::ArrayRef& inputs) { IdLookupReturn ret; // lock mutex_ because we are touching encoding_ @@ -123,6 +108,42 @@ FusionExecutorCache::FusionExecutorCache(std::unique_ptr fusion) } } +KernelArgumentHolder FusionExecutorCache::prepareInputs( + const at::ArrayRef& inputs) { + FUSER_PERF_SCOPE("FusionExecutorCache::prepareInputs"); + + KernelArgumentHolder args = + KernelArgumentHolder::createKernelArgumentHolder(inputs); + + // TODO: move InputsIdLookup inside KernelArgumentHolder; + auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs); + if (id_lookup_ret.eviction) { + evictCache(id_lookup_ret.evict_id); + } + + args.setCacheId(id_lookup_ret.id); + return args; +} + +bool FusionExecutorCache::isCompiled(const at::ArrayRef& inputs) { + FUSER_PERF_SCOPE("FusionExecutorCache::isCompiled"); + + // Access kernels associated with the common device id + KernelArgumentHolder args = prepareInputs(inputs); + + return getKernelRuntimeFor(args)->isCompiled(); +} + +void FusionExecutorCache::compileFusionAsync( + const at::ArrayRef& inputs) { + FUSER_PERF_SCOPE("FusionExecutorCache::compileFusionAsync"); + + KernelArgumentHolder args = prepareInputs(inputs); + auto kernel_runtime = getKernelRuntimeFor(args); + + kernel_runtime->startAsyncCompile(args); +} + // Note [ Permutation support in nvfuser ] // // Background: @@ -171,25 +192,23 @@ std::vector FusionExecutorCache::runFusionWithInputs( perm_inputs = inputs_vec; } - SchedulerRuntimeInfo runtime_info(fusion(), perm_inputs); + KernelArgumentHolder args = prepareInputs(perm_inputs); - auto id_lookup_ret = inputs_id_lookup_.lookupId(perm_inputs, &runtime_info); - if (id_lookup_ret.eviction) { - evictCache(id_lookup_ret.evict_id); - } - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - const size_t unique_id = id_lookup_ret.id; - auto kernel_runtime = getKernelRuntimeFor(perm_inputs, unique_id); + auto kernel_runtime = getKernelRuntimeFor(args); most_recent_runtime_ = kernel_runtime; - auto outputs = kernel_runtime->runWithInput(perm_inputs, unique_id); + auto outputs = kernel_runtime->runWithInput(args); // permute output tensor returned by kernel execution. See Part_3 in Note [ // Permutation support in nvfuser ] for (const auto& pair : fusion_->getPermutationOutputMap()) { - outputs[pair.first] = outputs[pair.first].permute(pair.second); + if (pair.first < outputs.size()) { + outputs[pair.first] = outputs[pair.first].permute(pair.second); + } } + // removing aliased outputs, since those are only used by input tensor update + // by fusion. It is not semantically correct to actually return them as + // outputs from fusion. int offset = 0; for (const auto& v : aliased_output_indices_) { outputs.erase(outputs.begin() + v - offset); @@ -207,18 +226,16 @@ void FusionExecutorCache::evictCache(size_t cache_id) { } FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( - const at::ArrayRef& inputs, - size_t unique_id) { + const KernelArgumentHolder& args) { // Check for id hit case + auto unique_id = *args.getCacheId(); auto id_it = id_to_kernel_runtime_.find(unique_id); if (id_it != id_to_kernel_runtime_.end()) { return id_it->second; } // Access kernels associated with the common device id - auto device_index = getCommonDeviceCUDA(inputs); - TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs"); - auto& kernel_runtimes = kernel_runtimes_[device_index]; + auto& kernel_runtimes = kernel_runtimes_[args.getDeviceIndex()]; // Check for re-use hit case // a kernel runtime is re-usable if all the compiled @@ -228,8 +245,8 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( auto reuse_it = std::find_if( kernel_runtimes.begin(), kernel_runtimes.end(), - [&inputs, &new_heuristics](auto& kernel_runtime) { - auto maybe_heuristics = kernel_runtime->getMaybeHeuristicsFor(inputs); + [&args, &new_heuristics](auto& kernel_runtime) { + auto maybe_heuristics = kernel_runtime->getMaybeHeuristicsFor(args); if (!maybe_heuristics.has_value()) { return false; } @@ -244,7 +261,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( } else { // graph miss, need to re-build an optimized graph for this case kernel_runtimes.emplace_back( - std::make_unique(fusion_.get(), inputs)); + std::make_unique(fusion_.get(), args)); kernel_runtime = kernel_runtimes.back().get(); if (profiling_) { kernel_runtime->profile(true); @@ -257,7 +274,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( FusionKernelRuntime::FusionKernelRuntime( Fusion* fusion, - const at::ArrayRef& inputs) { + const KernelArgumentHolder& args) { FUSER_PERF_SCOPE("FusionKernelRuntime::FusionKernelRuntime"); // Make a copy of fusion and do segmentation and translation @@ -265,7 +282,7 @@ FusionKernelRuntime::FusionKernelRuntime( auto fusion_copy = std::make_unique(*fusion); // Run segmentation on the copied fusion - SchedulerRuntimeInfo runtime_info(fusion_copy.get(), inputs, true); + SchedulerRuntimeInfo runtime_info(fusion_copy.get(), args, true); // Initialize the evaluator simplifer precomputed_integers_ = @@ -284,13 +301,13 @@ FusionKernelRuntime::FusionKernelRuntime( if (segmented) { // Take ownership and segment transformed fusion segmented_fusion_ = - SegmentCandidateFinder::segment(std::move(fusion_copy), inputs); + SegmentCandidateFinder::segment(std::move(fusion_copy), args); } else { segmented_fusion_ = SegmentedFusion::fromCompleteFusion( std::move(fusion_copy), maybe_complete_fusion_heuristic.value()); } - heuristics_ = segmented_fusion_->makeInitialHeuristics(inputs); + heuristics_ = segmented_fusion_->makeInitialHeuristics(args); executors_ = std::vector(segmented_fusion_->groups().size()); if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { segmented_fusion_->print(); @@ -307,10 +324,10 @@ FusionKernelRuntime::FusionKernelRuntime( } std::vector FusionKernelRuntime::runKernelWithInput( - const at::ArrayRef& inputs, - size_t input_id, + KernelArgumentHolder& args, SegmentedGroup* sg) { FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput"); + std::lock_guard guard(mutex_); // This function will be called once on un-segmented fusion, // for segmented fusion, this function will be called on each segment // In the case of segmented fusion, segmented group needs to be given so @@ -319,8 +336,6 @@ std::vector FusionKernelRuntime::runKernelWithInput( // is complied and run TORCH_INTERNAL_ASSERT(sg, "runKernelWithInput: need valid group to run"); auto group_id = sg->groupId(); - const int device_index = getCommonDeviceCUDA(inputs); - TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs"); LaunchParams launch_params; @@ -336,14 +351,11 @@ std::vector FusionKernelRuntime::runKernelWithInput( // Running a segment group as a single kernel, // make a fusion to run from segmented fusion fusion_to_run = segmented_fusion_->makeFusion(sg); - CompileOptions options; - options.device = c10::Device(DeviceType::CUDA, device_index); - options.index_mode = scheduler_entry->indexMode(); FusionGuard fg(fusion_to_run.get()); scheduler_entry->schedule(fusion_to_run.get()); launch_params = scheduler_entry->params()->lparams; executors_[group_id].compileFusion( - fusion_to_run.get(), inputs, launch_params, options); + fusion_to_run.get(), args, launch_params); } else { launch_params = scheduler_entry->params()->lparams; } @@ -358,7 +370,7 @@ std::vector FusionKernelRuntime::runKernelWithInput( executor.setMeasureKernelTimeFlag(true); } - auto outputs = executor.runFusion(inputs, launch_params, input_id); + auto outputs = executor.runFusion(args, launch_params); // Print relevant information all at once for easy debuging of perf if (isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) { @@ -369,14 +381,8 @@ std::vector FusionKernelRuntime::runKernelWithInput( segmented_fusion_->completeFusion()->printMath(); } std::cout << "With inputs:\n"; - for (auto inp : inputs) { - if (inp.isTensor()) { - auto inp_tensor = inp.toTensor(); - std::cout << " " << inp_tensor.dtype() << " " << inp_tensor.sizes() - << " " << inp_tensor.strides() << "\n"; - } else { - std::cout << " " << inp << "\n"; - } + for (auto i : c10::irange(args.size())) { + args[i]->print(); } std::cout << "Compiler log: " << executor.compilerLog() << "\n"; std::cout << scheduler_entry->params()->toString() << "\n"; @@ -451,77 +457,211 @@ void FusionKernelRuntime::prepareRuntimeOrder() { } } -std::vector FusionKernelRuntime::runWithInput( - const at::ArrayRef& inputs, - size_t input_id) { - FUSER_PERF_SCOPE("FusionKernelRuntime::runMultiKernelWithInput"); +// passing args by value, since we will be modify this +void FusionKernelRuntime::startAsyncCompile(KernelArgumentHolder& args_old) { + // only single compilation is supported at this moment. + std::unique_lock unique_lock(mutex_, std::try_to_lock); + TORCH_CHECK( + unique_lock.owns_lock(), + "Calling startAsyncCompile on a FusionKernelRuntime that's already starting a compilation thread is not supported"); + std::unique_lock unique_lock2(compiling_, std::try_to_lock); + TORCH_CHECK( + unique_lock2.owns_lock(), + "Calling startAsyncCompile on a FusionKernelRuntime that's already starting a compilation thread is not supported 2"); + + // for some reason I can't seem to move unique_lock and it keeps using copy. + // auto compile_fusion = [args = std::move(args_old), lock = + // std::move(unique_lock), this] () mutable { + auto compile_fusion = [args = std::move(args_old), this]() mutable { + std::lock_guard guard(compiling_); + + // locking mutex_ since we are touching executors_ during compilation. + // c10::DeviceGuard dg(c10::Device(DeviceType::CUDA, + // args.getDeviceIndex())); CUDAGuard uses runtime API directly, which is + // thread safe. + c10::cuda::CUDAGuard dg(args.getDeviceIndex()); + + FUSER_PERF_SCOPE("FusionKernelRuntime::startAsyncCompile"); - TORCH_INTERNAL_ASSERT( - inputs.size() == segmented_fusion_->inputs().size(), - "Inputs were not set up correctly, recieved ", - inputs.size(), - " inputs but expecting ", - segmented_fusion_->inputs().size()); + TORCH_INTERNAL_ASSERT( + args.size() == segmented_fusion_->inputs().size(), + "Inputs were not set up correctly, recieved ", + args.size(), + " inputs but expecting ", + segmented_fusion_->inputs().size()); + + c10::Device device(c10::DeviceType::CUDA, args.getDeviceIndex()); + std::unordered_map tensor_map; + mapFusionInputsToArgs(tensor_map, args); + + // TODO: compilation can happen in parallel! We can have output sizes + // inferred on un-compiled kernel and setup all tensor_map prior to + // compilation. + for (auto group_to_run : runtime_workspace_.group_run_order) { + // TODO: index mode should be updated per segmented kernel + // Prepare input vector + KernelArgumentHolder group_runtime_inputs(args.getIndexMode()); + group_runtime_inputs.setDeviceIndex(args.getDeviceIndex()); + for (auto input : group_to_run->inputs()) { + group_runtime_inputs.push(tensor_map.at(input)); + } + + // Run graph segment + KernelArgumentHolder group_runtime_outputs = + compileKernel(group_runtime_inputs, group_to_run); + + // map output args to tensor map + const auto& group_outputs = group_to_run->outputs(); + for (const size_t group_out_i : c10::irange(group_outputs.size())) { + args.push(group_runtime_outputs[group_out_i]); + tensor_map.emplace(group_outputs[group_out_i], args.back()); + } + } + }; + + getThreadPool()->run(compile_fusion); +} + +// TODO: replace the boilerplate in runKernelWithInput +KernelArgumentHolder FusionKernelRuntime::compileKernel( + const KernelArgumentHolder& args, + SegmentedGroup* sg) { + FUSER_PERF_SCOPE("FusionKernelRuntime::compileKernel"); + // This function will be called once on un-segmented fusion, + // for segmented fusion, this function will be called on each segment + // In the case of segmented fusion, segmented group needs to be given so + // a kernel is compiled and run for a segmented group + // In the case of complete fusion, sg = nullptr, and the original fusion + // is complied and run + TORCH_INTERNAL_ASSERT(sg, "compileKernel: need valid group to run"); + auto group_id = sg->groupId(); + + LaunchParams launch_params; + + auto scheduler_entry = schedulers()[group_id].get(); + + // Check that the heuristics are matched, in the case of segmented fusion + TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristic() == sg->heuristic()); + + if (!executors_[group_id].compiled()) { + FUSER_PERF_SCOPE("FusionKernelRuntime::compileKernel::Compile"); + std::unique_ptr fusion_to_run; + + // Running a segment group as a single kernel, + // make a fusion to run from segmented fusion + fusion_to_run = segmented_fusion_->makeFusion(sg); + FusionGuard fg(fusion_to_run.get()); + scheduler_entry->schedule(fusion_to_run.get()); + launch_params = scheduler_entry->params()->lparams; + + executors_[group_id].compileFusion( + fusion_to_run.get(), args, launch_params); + } else { + // TODO: this is a false negative assert, since we could be compiling + // something for elevated high water mark on block size. + TORCH_CHECK(false, "compiling an already compiled kernel"); + } - c10::Device device(c10::DeviceType::CUDA, 0); - int extent_index_ = 0; - // Bind input in the tensor_map - for (const auto i : c10::irange(inputs.size())) { - runtime_workspace_.tensor_map.emplace( - segmented_fusion_->inputs()[i], inputs[i]); + auto& executor = executors_[group_id]; + auto outputs = executor.inferOutputSizes(args, launch_params); + return outputs; +} + +void FusionKernelRuntime::mapFusionInputsToArgs( + std::unordered_map& tensor_map, + KernelArgumentHolder& args) { + int extent_index = 0; + auto original_args_size = args.size(); + // Bind args in the tensor_map + for (const auto i : c10::irange(original_args_size)) { + tensor_map.emplace(segmented_fusion_->inputs()[i], args[i]); // Bind tensorview inputs values in case some segmented group // needs it down the road. // TODO: we probably have done this already up to this point // should consider caching the expression evaluators, both // more convenient and safer than replication - if (inputs[i].isTensor()) { - auto aten_tensor = inputs[i].toTensor(); - device = aten_tensor.device(); - for (auto dim_size : aten_tensor.sizes()) { - runtime_workspace_.tensor_map.emplace( - runtime_workspace_.group_extent_binding_order[extent_index_++], - dim_size); + if (auto tensor_arg_abstract = + dynamic_cast(args[i])) { + // Note this is very ugly way. We are pushing every single extent to args, + // because we don't have a better place to hold them. + auto rank = tensor_arg_abstract->getRank(); + for (const auto dim : c10::irange(rank)) { + args.push(tensor_arg_abstract->getSize(dim)); + tensor_map.emplace( + runtime_workspace_.group_extent_binding_order[extent_index++], + args.back()); } } } +} + +std::vector FusionKernelRuntime::runWithInput( + KernelArgumentHolder& args) { + FUSER_PERF_SCOPE("FusionKernelRuntime::runWithInput"); + + TORCH_INTERNAL_ASSERT( + args.size() == segmented_fusion_->inputs().size(), + "Inputs were not set up correctly, recieved ", + args.size(), + " inputs but expecting ", + segmented_fusion_->inputs().size()); + + c10::Device device(c10::DeviceType::CUDA, args.getDeviceIndex()); + + std::unordered_map tensor_map; + mapFusionInputsToArgs(tensor_map, args); + + // TODO: we don't need this any more, since TensorArgAbstract already holds a + // reference to tensor + std::unordered_map output_holder; if (isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) { std::cout << "=================RUNNING FUSION SEGMENTS=================" << std::endl; } + for (auto group_to_run : runtime_workspace_.group_run_order) { + // TODO: index mode should be updated per segmented kernel // Prepare input vector + KernelArgumentHolder group_runtime_inputs(args.getIndexMode()); + group_runtime_inputs.setDeviceIndex(args.getDeviceIndex()); for (auto input : group_to_run->inputs()) { - runtime_workspace_.group_runtime_inputs.push_back( - runtime_workspace_.tensor_map.at(input)); + group_runtime_inputs.push(tensor_map.at(input)); } + + // TODO: currently we are still outputing PyTorch tensors, instead of + // something abstract. This is quite unsatisfying. Prepare input vector + // Run graph segment - runtime_workspace_.group_runtime_outputs = runKernelWithInput( - runtime_workspace_.group_runtime_inputs, input_id, group_to_run); + std::vector group_runtime_outputs = + runKernelWithInput(group_runtime_inputs, group_to_run); const auto& group_outputs = group_to_run->outputs(); // Insert graph segment output to tensor map - for (unsigned int group_out_i = 0; group_out_i < group_outputs.size(); - group_out_i++) { - runtime_workspace_.tensor_map.emplace( - group_outputs[group_out_i], - runtime_workspace_.group_runtime_outputs[group_out_i]); + TORCH_INTERNAL_ASSERT( + group_outputs.size() == group_runtime_outputs.size(), + "output size does not match"); + for (const size_t group_out_i : c10::irange(group_outputs.size())) { + output_holder[group_outputs[group_out_i]] = + group_runtime_outputs[group_out_i]; + + args.push(group_runtime_outputs[group_out_i]); + tensor_map.emplace(group_outputs[group_out_i], args.back()); } - runtime_workspace_.group_runtime_inputs.clear(); - runtime_workspace_.group_runtime_outputs.clear(); } if (isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) { std::cout << "=============FINISHED RUNNING FUSION SEGMENTS============" << std::endl; } + // Produce final global output - std::vector fusion_outputs; + std::vector fusion_outputs; for (auto output : segmented_fusion_->outputs()) { - const auto iter = runtime_workspace_.tensor_map.find(output); - if (iter != runtime_workspace_.tensor_map.end()) { + const auto iter = output_holder.find(output); + if (iter != output_holder.end()) { fusion_outputs.push_back(iter->second); } else { bool empty_type_check = output->getDataType().has_value() && @@ -555,20 +695,7 @@ std::vector FusionKernelRuntime::runWithInput( fusion_outputs.emplace_back(at::empty({0}, tensor_options)); } } - - std::vector fusion_output_tensors; - std::transform( - fusion_outputs.begin(), - fusion_outputs.end(), - std::back_inserter(fusion_output_tensors), - [](IValue ival) { - TORCH_INTERNAL_ASSERT( - ival.isTensor(), "Cannot output non-tensor objects from a fusion."); - return ival.toTensor(); - }); - - runtime_workspace_.tensor_map.clear(); - return fusion_output_tensors; + return fusion_outputs; } const std::vector& FusionKernelRuntime:: @@ -590,11 +717,11 @@ void FusionKernelRuntime::updateHeuristicsLaunchParams( } c10::optional FusionKernelRuntime:: - getMaybeHeuristicsFor(const at::ArrayRef& inputs) { + getMaybeHeuristicsFor(const KernelArgumentHolder& args) { FUSER_PERF_SCOPE("FusionKernelRuntime::getMaybeHeuristicsFor"); auto complete_fusion = segmented_fusion_->completeFusion(); - SchedulerRuntimeInfo runtime_info(complete_fusion, inputs); - precomputed_integers_->bindFusionInputs(inputs); + SchedulerRuntimeInfo runtime_info(complete_fusion, args); + precomputed_integers_->bindFusionInputs(args); precomputed_integers_->evaluate(); runtime_info.expressionEvaluator().bindPrecomputedIntegers( precomputed_integers_.get()); diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index f67742d10f3f4..915e319131061 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -42,7 +42,7 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { public: explicit FusionKernelRuntime( Fusion* fusion, - const at::ArrayRef& inputs); + const KernelArgumentHolder& inputs); //! Type notations within FusionKernelRuntime Context using HashType = size_t; @@ -56,10 +56,34 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { } } + //! query if we already have a compiled kernel for execution + bool isCompiled() { + std::unique_lock lock0(mutex_, std::try_to_lock); + std::unique_lock lock1(compiling_, std::try_to_lock); + if (!lock0.owns_lock() || !lock1.owns_lock()) { + // compilation in progress + return false; + } + + return std::all_of( + executors_.begin(), executors_.end(), [](const auto& executor) { + return executor.compiled(); + }); + } + + //! starts compilation async + void startAsyncCompile(KernelArgumentHolder& inputs); + + //! maps entries in `args` to fusion inputs. + //! Note that this function also pushes extra bits like dimension extent into + //! `args` for expression evaluator binding. So consider your `args` polluted + //! after this function and use it with caution. + void mapFusionInputsToArgs( + std::unordered_map& tensor_map, + KernelArgumentHolder& args); + //! Unified interface to run the managed kernels with given input - std::vector runWithInput( - const at::ArrayRef& inputs, - size_t input_id); + std::vector runWithInput(KernelArgumentHolder& args); //! Turn On/Off profiling void profile(bool to_profile = true) { @@ -110,7 +134,7 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { // any segment cannot be scheduled or the parameters don't match using HeuristicsPtr = std::unique_ptr; c10::optional getMaybeHeuristicsFor( - const at::ArrayRef& inputs); + const KernelArgumentHolder& args); //! Copy the launch params given in the parameter heuristics to prepare //! for kernel launch for a new input dimension but same heuristics @@ -121,8 +145,14 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { //! fusions, or a kernel for a segmentedGrouup in a segmented fusion. Returns //! the kernel outputs. std::vector runKernelWithInput( - const at::ArrayRef& inputs, - size_t input_id, + KernelArgumentHolder& args, + SegmentedGroup* sg); + + //! Interface to compile a single kernel, either one kernel for single-kernel + //! fusions, or a kernel for a segmentedGrouup in a segmented fusion. Returns + //! the kernel outputs with tensor that doesn't own memory. + KernelArgumentHolder compileKernel( + const KernelArgumentHolder& args, SegmentedGroup* sg); //! Interface to run a the whole graph in a segmented fusion and return the @@ -154,18 +184,11 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { //! Pre-allocated runtime workspace to speed up kernel launch preparation. struct RuntimeWorkSpace { - //! Temporary space to save intermediate tensors for segmented fusion - std::unordered_map tensor_map; - //! Pre-determined order to run the segmented groups std::vector group_run_order; //! Pre-determined order to bind tensor input meta data std::vector group_extent_binding_order; - - //! Pre-allocated workspace to hold group inputs and outputs - std::vector group_runtime_inputs; - std::vector group_runtime_outputs; } runtime_workspace_; //! Utility to speed up integer evaluation at runtime @@ -174,6 +197,12 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { // States for profiling support bool profiling_ = false; + std::mutex mutex_; + // TODO: remove `compiling_` mutex and rely on `mutex_` only. + // we don't need the second mutex, if only I could figure out how to pass + // unique_lock into lambda + std::mutex compiling_; + // The heuristics and executor for most recent kernel launch ExecutorLog most_recent_executor_log_; }; @@ -208,9 +237,7 @@ class TORCH_CUDA_CU_API InputsIdLookup : public NonCopyable { //! within the lookup cache. This is needed because lookup shortcut is also //! cached in nested `GraphCache`, `FusionExecutorCache` and `FusionExecutor`. //! see [ Note -- 2 level cache implementation ] - IdLookupReturn lookupId( - const at::ArrayRef& inputs, - const SchedulerRuntimeInfo* additional_info = nullptr); + IdLookupReturn lookupId(const at::ArrayRef& inputs); //! debugging API that returns the size of lookup table size_t size() const { @@ -304,11 +331,13 @@ class TORCH_CUDA_CU_API InputsIdLookup : public NonCopyable { class TORCH_CUDA_CU_API FusionExecutorCache { public: //! create new fusion executor cache at a given device to handle kernel - //! generation of dynamic sizes; - //! fusion executor is taking the ownership of `fusion`; + //! generation of dynamic sizes + //! fusion executor is taking the ownership of `fusion` explicit FusionExecutorCache(std::unique_ptr fusion); - //! Execute fusion graph with given inputs, create `FusionExecutor` as needed; + //! Execute fusion graph with given inputs, create `FusionExecutor` as needed + //! Note this function also handles permutation & input update outside of + //! codegen. std::vector runFusionWithInputs( const at::ArrayRef& inputs); @@ -359,14 +388,25 @@ class TORCH_CUDA_CU_API FusionExecutorCache { } } + //! converts inputs from IValue to KernelArgumentHolder, also handles cache + //! lookup + KernelArgumentHolder prepareInputs(const at::ArrayRef& inputs); + + //! query if there's a kernel ready to go for given inputs + bool isCompiled(const at::ArrayRef& inputs); + + //! compile a kernel executor for given inputs. Note: the compilation is + //! async, there's some restriction on the user side. e.g. don't overlap + //! compilation and execution for the same FusionExecutor entry. This is + //! experimental at this moment, please use with extra caution. + void compileFusionAsync(const at::ArrayRef& inputs); + private: //! evict cached short cut entry in `code_to_fe_lookup_` as well as cached //! entry in `FusionExecutor` void evictCache(size_t cache_id); - FusionKernelRuntime* getKernelRuntimeFor( - const at::ArrayRef& inputs, - size_t id); + FusionKernelRuntime* getKernelRuntimeFor(const KernelArgumentHolder& inputs); private: //! original un-scheduled `Fusion`; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index ac7f66836a87c..3ceb75dbda11c 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -405,20 +405,20 @@ bool isConnectedFusionGraph(Fusion* fusion) { } // namespace -SchedulerRuntimeInfo::SchedulerRuntimeInfo( - Fusion* complete_fusion, - const at::ArrayRef& inputs, - bool create_expr_evaluator) - : complete_fusion_(complete_fusion) { +void SchedulerRuntimeInfo::initialize( + const KernelArgumentHolder& args, + bool create_expr_evaluator) { TORCH_INTERNAL_ASSERT( - complete_fusion_->inputs().size() == inputs.size(), + complete_fusion_->inputs().size() == args.size(), "Invalid number of arguments passed in for provided fusion group."); - for (auto inp_i : c10::irange(inputs.size())) { - auto aten_inp = inputs[inp_i]; - if (aten_inp.isTensor()) { + for (auto inp_i : c10::irange(args.size())) { + auto kernel_arg = args[inp_i]; + // Note: we are skipping CpuScalar tensor here + if (auto tensor_arg_abstract = + dynamic_cast(kernel_arg)) { auto fusion_inp = complete_fusion_->inputs()[inp_i]; - auto data_ptr = aten_inp.toTensor().data_ptr(); + auto data_ptr = tensor_arg_abstract->getPointer(); input_ptrs_[fusion_inp] = (size_t)data_ptr; } } @@ -426,9 +426,28 @@ SchedulerRuntimeInfo::SchedulerRuntimeInfo( expression_evaluator_ = std::make_unique(complete_fusion_); if (create_expr_evaluator) { - initializeExpressionEvaluator(inputs); + initializeExpressionEvaluator(args); } - collectIndexModeInfo(inputs); + index_mode_ = args.getIndexMode(); +} + +SchedulerRuntimeInfo::SchedulerRuntimeInfo( + Fusion* complete_fusion, + const KernelArgumentHolder& args, + bool create_expr_evaluator) + : complete_fusion_(complete_fusion) { + initialize(args, create_expr_evaluator); +} + +// TODO: remove this one +SchedulerRuntimeInfo::SchedulerRuntimeInfo( + Fusion* complete_fusion, + const at::ArrayRef& aten_inputs, + bool create_expr_evaluator) + : complete_fusion_(complete_fusion) { + KernelArgumentHolder args = + KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + initialize(args, create_expr_evaluator); } // TODO: Output tensors could have an alignment that is not 16 Bytes passed in @@ -441,11 +460,11 @@ size_t SchedulerRuntimeInfo::ptrOf(TensorView* tv) { } void SchedulerRuntimeInfo::initializeExpressionEvaluator( - const at::ArrayRef& inputs) { + const KernelArgumentHolder& args) { // TODO: refactor bindFusionInputs to better support this // use case, i.e. support construct and bind input. *expression_evaluator_ = - executor_utils::bindFusionInputs(inputs, complete_fusion_); + executor_utils::bindFusionInputs(args, complete_fusion_); } size_t SchedulerRuntimeInfo::computeAlignmentSize(size_t ptr_address) { @@ -666,54 +685,6 @@ size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) { return vector_size; } -void SchedulerRuntimeInfo::collectIndexModeInfo( - const at::ArrayRef& inputs) { - // TODO: Need to check the output sizes as well. - - // Save 1 more bit besides the sign bit to be conservative - constexpr int64_t most_positive_int32_index = - std::numeric_limits::max() / 2; - constexpr int64_t most_negative_int32_index = - std::numeric_limits::min() / 2; - - // Start by setting index mode to int32 - index_mode_ = KernelIndexMode::INT32; - - // Check all runtime inputs, and if any one of - // the input's index exceeds max_int32 will - // fall back to int64 indexing - for (auto ivalue_input : inputs) { - if (ivalue_input.isTensor()) { - auto tensor_input = ivalue_input.toTensor(); - int64_t tensor_most_positive_index = 0; - int64_t tensor_most_negative_index = 0; - for (auto dim_i = 0; dim_i < tensor_input.ndimension(); dim_i++) { - // Ignore broadcast dimensions - if (tensor_input.size(dim_i) > 1) { - // accumulate based on the sign of stride - if (tensor_input.stride(dim_i) > 0) { - // Acuumulate positive stride - tensor_most_positive_index += - (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i); - } else { - // Acuumulate negative stride - tensor_most_negative_index += - (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i); - } - } - } - - // Fall back to int64 if it can be either too positive - // or too negative. - if (tensor_most_positive_index > most_positive_int32_index || - tensor_most_negative_index < most_negative_int32_index) { - index_mode_ = KernelIndexMode::INT64; - return; - } - } - } -} - bool SchedulerEntry::sameAs(const SchedulerEntry* other) { if (heuristc_ != other->heuristc_) { return false; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.h b/torch/csrc/jit/codegen/cuda/scheduler/registry.h index dd8caf63ccdae..7ed8474935c01 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -33,13 +34,19 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { static constexpr size_t max_alignment_size_in_byte = 16; //! Create runtime info for given fusion and input. Creating and binding - //! evaluator is optional. The evaluator is used to manage intermediate + //! evaluator is optional. The evaluator is used to manage intermediate //! integers in the fusion. We need them for segmenter and schedulers, //! but we don't need them when we are just using this class to provide //! additional encoding for kernel cache lookup. SchedulerRuntimeInfo( Fusion* complete_fusion, - const at::ArrayRef& inputs, + const KernelArgumentHolder& inputs, + bool create_expr_evaluator = false); + + // TODO: Remove this guy below. Everything needs to go into the other ctor + SchedulerRuntimeInfo( + Fusion* complete_fusion, + const at::ArrayRef& aten_inputs, bool create_expr_evaluator = false); //! Lookup for the alignment sizes of the given tv. Currently only returns @@ -78,12 +85,11 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { private: // Bind full fusion inputs to the internal expression evaluator - void initializeExpressionEvaluator(const at::ArrayRef& inputs); + void initializeExpressionEvaluator(const KernelArgumentHolder& inputs); - // check if input is compatible with 32b index mode - void collectIndexModeInfo(const at::ArrayRef& inputs); + // Initialize SchedulerRuntimeInfo + void initialize(const KernelArgumentHolder& args, bool create_expr_evaluator); - private: bool isInputTv(TensorView* tv) { return std::find( complete_fusion_->inputs().begin(), @@ -91,6 +97,7 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { tv) != complete_fusion_->inputs().end(); } + private: // Returns the offset of tv in the inputs ignoring non tensor views. Used to // access input_sizes, input_strides, input_ptr int offsetTensorPos(TensorView* tv); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 4f72bf93ba36e..f6a6b87a93009 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -15197,7 +15197,13 @@ TEST_F(NVFuserTest, FusionDAGMerging_CUDA) { at::Tensor t0 = at::randn({2, 2, 2, 2, 2}, options); at::Tensor t1 = at::randn({2}, options); - auto fusion_segments = fusion.segment({t0, t1}); + std::vector aten_inputs = {t0, t1}; + + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(aten_inputs); + + auto fusion_segments = fusion.segment(args); TORCH_CHECK(fusion_segments->groups().size() <= 4); } @@ -15567,8 +15573,12 @@ TEST_F(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 2, 2}, options); + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(t0); + auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), {t0}, segment_options); + SegmentCandidateFinder::segment(fusion.get(), args, segment_options); TORCH_CHECK(segmented_fusion->groups().size() == 2); } @@ -15607,8 +15617,14 @@ TEST_F(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 2, 2}, options); + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(t0); + c10::IValue scalar = 1.0; + args.push(scalar); + auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), {t0, 1.0}, segment_options); + SegmentCandidateFinder::segment(fusion.get(), args, segment_options); TORCH_CHECK(segmented_fusion->groups().size() == 2); } @@ -15646,8 +15662,12 @@ TEST_F(NVFuserTest, FusionSegmentMixReduction_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 2, 2}, options); + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(t0); + auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), {t0}, segment_options); + SegmentCandidateFinder::segment(fusion.get(), args, segment_options); TORCH_CHECK(segmented_fusion->groups().size() <= 2); } @@ -18215,20 +18235,19 @@ TEST_F(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { at::Tensor at_t17 = at::randn({128, 64, 1024}, options_half); double at_d56 = 1.1111; - std::vector aten_inputs = { - at_t0, - at_t1, - at_t3, - at_t5, - at_t7, - at_t11, - at_t13, - at_t15, - at_t17, - at_d56}; + std::vector aten_inputs = { + at_t0, at_t1, at_t3, at_t5, at_t7, at_t11, at_t13, at_t15, at_t17}; + + c10::IValue val = at_d56; + + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(aten_inputs); + args.push(val); + for (auto _ : c10::irange(5)) { auto segmented_fusion = - SegmentCandidateFinder::segment(fusion_ptr.get(), aten_inputs); + SegmentCandidateFinder::segment(fusion_ptr.get(), args); } } @@ -23193,11 +23212,8 @@ TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) { tv->axis(-1)->parallelize(ParallelType::Serial); } - CompileOptions co; - co.index_mode = KernelIndexMode::INT32; - FusionExecutor fe; - fe.compileFusion(&fusion, {}, LaunchParams(), co); + fe.compileFusion(&fusion, {}, LaunchParams()); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); at::Tensor t0 = at::randn({X, Y, Y, Z}, options); @@ -25555,6 +25571,65 @@ TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) { testValidate(fusion, {out}, {t0}, {t0}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, AsyncCompilation_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(1); + TensorView* tv2 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 + TensorView* tv4 = + max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) + TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, + // keeps normalization scheduler away) + TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce) + + fusion->addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({8, 5}, options); + at::Tensor t1 = at::randn({5}, options); + at::Tensor t2 = at::randn({8, 5}, options); + + auto t3 = t0.add(1.0); + auto t4 = std::get<0>(at::max(t3, 0)); + auto t5 = t4.add(t1); + auto t6 = t5.add(t2); + + FusionExecutorCache executor_cache(std::move(fusion)); + + std::vector aten_inputs = {t0, t1, t2}; + + executor_cache.compileFusionAsync(aten_inputs); + + while (!executor_cache.isCompiled(aten_inputs)) { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + printf("."); + } + + auto outputs = executor_cache.runFusionWithInputs(aten_inputs); + + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation didn't happen"); + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime() + ->fusionSegments() + ->groups() + .size() == 2, + "segmentation didn't happen as expected"); + + testValidate( + executor_cache.fusion(), outputs, aten_inputs, {t6}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h b/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h index 0247c33c8a726..3bed835838a5c 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h @@ -283,7 +283,11 @@ ExpressionEvaluator bindInputsAndLaunchParams( Fusion* fusion, const at::ArrayRef& aten_inputs, const LaunchParams& launch_constraints) { - auto expr_eval = executor_utils::bindFusionInputs(aten_inputs, fusion); + // index_mode is not important here + KernelArgumentHolder argument_holder(KernelIndexMode::INT64); + argument_holder.push(aten_inputs); + + auto expr_eval = executor_utils::bindFusionInputs(argument_holder, fusion); for (auto val : fusion->vals()) { if (!val->isA()) { continue; diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index 5e82014c0c388..c2a9a1a52c59d 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -288,6 +288,73 @@ bool is_cpu_scalar(const c10::TensorType& tensor_type) { opt_numel.value() == 1; } +// Check device of TensorType in all inputs ensure all tensors are on cuda +// devices. +// return common device index (or -1 if device differs). +int getCommonDeviceCUDA(const at::ArrayRef& inputs) { + int index = -1; + for (const auto& input : inputs) { + if (!input.isTensor()) { + continue; + } + const auto& device = input.toTensor().device(); + // skip cpu scalar tensor as they'll be promoted to scalar later + if (device.is_cpu() && is_cpu_scalar(input.toTensor())) { + continue; + } + TORCH_CHECK(device.is_cuda(), "nvfuser only supports cuda device"); + auto cur_index = device.index(); + if (index != -1 && index != cur_index) { + return -1; + } + index = (int)cur_index; // NOLINT + } + return index; +} + +KernelIndexMode collectIndexMode(const at::ArrayRef& inputs) { + // Save 1 more bit besides the sign bit to be conservative + constexpr int64_t most_positive_int32_index = + std::numeric_limits::max() / 2; + constexpr int64_t most_negative_int32_index = + std::numeric_limits::min() / 2; + + // Check all runtime inputs, and if any one of + // the input's index exceeds max_int32 will + // fall back to int64 indexing + for (auto ivalue_input : inputs) { + if (ivalue_input.isTensor()) { + auto tensor_input = ivalue_input.toTensor(); + int64_t tensor_most_positive_index = 0; + int64_t tensor_most_negative_index = 0; + for (auto dim_i = 0; dim_i < tensor_input.ndimension(); dim_i++) { + // Ignore broadcast dimensions + if (tensor_input.size(dim_i) > 1) { + // accumulate based on the sign of stride + if (tensor_input.stride(dim_i) > 0) { + // Acuumulate positive stride + tensor_most_positive_index += + (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i); + } else { + // Acuumulate negative stride + tensor_most_negative_index += + (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i); + } + } + } + + // Fall back to int64 if it can be either too positive + // or too negative. + if (tensor_most_positive_index > most_positive_int32_index || + tensor_most_negative_index < most_negative_int32_index) { + return KernelIndexMode::INT64; + } + } + } + // return index mode as int32 + return KernelIndexMode::INT32; +} + bool isDebugDumpEnabled(DebugDumpOption option) { const static auto dump_options = parseDebugDumpOptions(); return dump_options.at(option); diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index 679776b383af0..43b4358cf59b5 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -2,6 +2,7 @@ #include #include +#include #include namespace torch { @@ -17,6 +18,11 @@ bool is_zero_sized_tensor(const std::shared_ptr& tensor_type); bool is_cpu_scalar(const at::Tensor& tensor); bool is_cpu_scalar(const c10::TensorType& tensor_type); +// TODO: merge these two +// check if input is compatible with 32b index mode +int getCommonDeviceCUDA(const at::ArrayRef& inputs); +KernelIndexMode collectIndexMode(const at::ArrayRef& inputs); + //! Types of debug print-outs //! //! These can be set through the `PYTORCH_NVFUSER_DUMP` environment variable