From 68425c21f79eef1ffa8de3e675b517d3c38c6d49 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sun, 24 Jul 2022 21:29:57 -0700 Subject: [PATCH 1/4] Cleanup copy-pasted code in tests --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 53 +------------- .../cuda/test/test_gpu_fused_reduction.cpp | 24 +------ .../jit/codegen/cuda/test/test_gpu_shift.cpp | 43 +----------- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 43 +----------- .../jit/codegen/cuda/test/test_gpu_view.cpp | 28 +------- torch/csrc/jit/codegen/cuda/test/test_utils.h | 69 +++++++++++++++++++ 6 files changed, 74 insertions(+), 186 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/test/test_utils.h diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 1bc86fe15f50a..cf8dd35562528 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -57,58 +58,6 @@ using namespace at::indexing; namespace { -// Make a tensor that is known to be fully contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder() - .ndims(ndims) - .dtype(dtype) - .contiguity(std::vector(ndims, true)) - .build(); -} - -// Make a tensor that is known to be non-contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); -} - -// Make a non-contiguous tensor of compile-time known sizes -TensorView* makeConcreteTensor( - std::vector shape, - DataType dtype = DataType::Float) { - return TensorViewBuilder().shape(shape).dtype(dtype).build(); -} - -TensorView* makeContigConcreteTensor( - std::vector shape, - DataType dtype = DataType::Float) { - return TensorViewBuilder() - .shape(shape) - .dtype(dtype) - .contiguity(std::vector(shape.size(), true)) - .build(); -} - -void checkIntValue( - ExpressionEvaluator& evaluator, - Val* val, - Int::ScalarType expected_value) { - TORCH_CHECK(val->isAnInt()); - const auto actual_value = evaluator.evaluate(val); - TORCH_CHECK(actual_value.has_value()); - TORCH_CHECK(actual_value.value() == expected_value); -} - -void checkIntValue( - kir::ExpressionEvaluator& evaluator, - const Val* val, - Int::ScalarType expected_value) { - const auto actual_value = evaluator.evaluate(val); - TORCH_CHECK(actual_value.has_value()); - TORCH_CHECK(actual_value.value() == expected_value); -} - TensorView* loweredTv(TensorView* tv, GpuLower& gpulw) { auto used_tvs = ir_utils::allTvs(gpulw.kernel()->as()); TensorView* matching_tv = nullptr; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 482ef5daf3251..d87927668301d 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -46,29 +47,6 @@ using namespace at::indexing; namespace { -// Make a tensor that is known to be fully contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder() - .ndims(ndims) - .dtype(dtype) - .contiguity(std::vector(ndims, true)) - .build(); -} - -// Make a tensor that is known to be non-contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); -} - -// Make a non-contiguous tensor of compile-time known sizes -TensorView* makeConcreteTensor( - std::vector shape, - DataType dtype = DataType::Float) { - return TensorViewBuilder().shape(shape).dtype(dtype).build(); -} - class KernelExprVisitor : private kir::IrVisitor { public: static std::vector getAllExprs(const kir::Kernel* kernel) { diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp index f4c60e2b1c11e..b2302013f5fd9 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -46,48 +47,6 @@ using namespace at::indexing; namespace { -// Make a tensor that is known to be fully contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder() - .ndims(ndims) - .dtype(dtype) - .contiguity(std::vector(ndims, true)) - .build(); -} - -// Make a tensor that is known to be non-contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); -} - -// Make a non-contiguous tensor of compile-time known sizes -TensorView* makeConcreteTensor( - std::vector shape, - DataType dtype = DataType::Float) { - return TensorViewBuilder().shape(shape).dtype(dtype).build(); -} - -void checkIntValue( - ExpressionEvaluator& evaluator, - Val* val, - Int::ScalarType expected_value) { - TORCH_CHECK(val->isAnInt()); - const auto actual_value = evaluator.evaluate(val); - TORCH_CHECK(actual_value.has_value()); - TORCH_CHECK(actual_value.value() == expected_value); -} - -void checkIntValue( - kir::ExpressionEvaluator& evaluator, - const Val* val, - Int::ScalarType expected_value) { - const auto actual_value = evaluator.evaluate(val); - TORCH_CHECK(actual_value.has_value()); - TORCH_CHECK(actual_value.value() == expected_value); -} - // Used to signify invalid ranges, i.e., values at offset 0 to // start_offset, and values at offset stop_offset to the end of the // domain. diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 2eb04a5b2bfd0..55bc4adf02702 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -49,48 +50,6 @@ using namespace at::indexing; namespace { -// Make a tensor that is known to be fully contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder() - .ndims(ndims) - .dtype(dtype) - .contiguity(std::vector(ndims, true)) - .build(); -} - -// Make a tensor that is known to be non-contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); -} - -// Make a non-contiguous tensor of compile-time known sizes -TensorView* makeConcreteTensor( - std::vector shape, - DataType dtype = DataType::Float) { - return TensorViewBuilder().shape(shape).dtype(dtype).build(); -} - -void checkIntValue( - ExpressionEvaluator& evaluator, - Val* val, - Int::ScalarType expected_value) { - TORCH_CHECK(val->isAnInt()); - const auto actual_value = evaluator.evaluate(val); - TORCH_CHECK(actual_value.has_value()); - TORCH_CHECK(actual_value.value() == expected_value); -} - -void checkIntValue( - kir::ExpressionEvaluator& evaluator, - const Val* val, - Int::ScalarType expected_value) { - const auto actual_value = evaluator.evaluate(val); - TORCH_CHECK(actual_value.has_value()); - TORCH_CHECK(actual_value.value() == expected_value); -} - bool cudaArchGuardShouldSkip(int required_major, int required_minor) { int capability_major = at::cuda::getCurrentDeviceProperties()->major; int capability_minor = at::cuda::getCurrentDeviceProperties()->minor; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp index 412e86ceb0e6a..c12babc65c5f3 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -49,33 +50,6 @@ namespace jit { using namespace torch::jit::fuser::cuda; using namespace at::indexing; -namespace { - -// Make a tensor that is known to be fully contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder() - .ndims(ndims) - .dtype(dtype) - .contiguity(std::vector(ndims, true)) - .build(); -} - -// Make a tensor that is known to be non-contiguous of dimensionality=ndims, -// but unknown sizes -TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); -} - -// Make a non-contiguous tensor of compile-time known sizes -TensorView* makeConcreteTensor( - std::vector shape, - DataType dtype = DataType::Float) { - return TensorViewBuilder().shape(shape).dtype(dtype).build(); -} - -} // namespace - TEST_F(NVFuserTest, FusionViewDtypeSameSizeOutput_CUDA) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/torch/csrc/jit/codegen/cuda/test/test_utils.h b/torch/csrc/jit/codegen/cuda/test/test_utils.h new file mode 100644 index 0000000000000..4886785e355c6 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/test/test_utils.h @@ -0,0 +1,69 @@ +#pragma once + +#include + +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; + +namespace { + +// Make a tensor that is known to be fully contiguous of dimensionality=ndims, +// but unknown sizes +TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { + return TensorViewBuilder() + .ndims(ndims) + .dtype(dtype) + .contiguity(std::vector(ndims, true)) + .build(); +} + +// Make a tensor that is known to be non-contiguous of dimensionality=ndims, +// but unknown sizes +TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { + return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); +} + +// Make a non-contiguous tensor of compile-time known sizes +TensorView* makeConcreteTensor( + std::vector shape, + DataType dtype = DataType::Float) { + return TensorViewBuilder().shape(shape).dtype(dtype).build(); +} + +TensorView* makeContigConcreteTensor( + std::vector shape, + DataType dtype = DataType::Float) { + return TensorViewBuilder() + .shape(shape) + .dtype(dtype) + .contiguity(std::vector(shape.size(), true)) + .build(); +} + +void checkIntValue( + ExpressionEvaluator& evaluator, + Val* val, + Int::ScalarType expected_value) { + TORCH_CHECK(val->isAnInt()); + const auto actual_value = evaluator.evaluate(val); + TORCH_CHECK(actual_value.has_value()); + TORCH_CHECK(actual_value.value() == expected_value); +} + +void checkIntValue( + kir::ExpressionEvaluator& evaluator, + const Val* val, + Int::ScalarType expected_value) { + const auto actual_value = evaluator.evaluate(val); + TORCH_CHECK(actual_value.has_value()); + TORCH_CHECK(actual_value.value() == expected_value); +} + +} // namespace +} // namespace jit +} // namespace torch From 524497e60d17aeb94e4ebece90a6756944d3a297 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 25 Jul 2022 00:18:06 -0700 Subject: [PATCH 2/4] Refactor heuristics params to make it more extensible --- benchmarks/cpp/nvfuser/bert.cpp | 36 ++--- benchmarks/cpp/nvfuser/broadcast.cpp | 3 +- benchmarks/cpp/nvfuser/reduction.cpp | 3 +- benchmarks/cpp/nvfuser/scale_bias_relu.cpp | 6 +- benchmarks/cpp/nvfuser/utils.cpp | 30 ++-- benchmarks/cpp/nvfuser/utils.h | 1 + torch/csrc/jit/codegen/cuda/kernel_cache.cpp | 37 +---- torch/csrc/jit/codegen/cuda/kernel_cache.h | 3 +- .../jit/codegen/cuda/scheduler/heuristic.h | 37 +++++ .../codegen/cuda/scheduler/normalization.cpp | 105 ++++++------- .../codegen/cuda/scheduler/normalization.h | 4 +- .../jit/codegen/cuda/scheduler/pointwise.cpp | 44 +++--- .../jit/codegen/cuda/scheduler/pointwise.h | 4 +- .../cuda/scheduler/pointwise_heuristic.h | 41 ++--- .../jit/codegen/cuda/scheduler/reduction.cpp | 140 +++++++++--------- .../jit/codegen/cuda/scheduler/reduction.h | 4 +- .../cuda/scheduler/reduction_heuristic.h | 75 +++++----- .../jit/codegen/cuda/scheduler/registry.cpp | 46 ++---- .../jit/codegen/cuda/scheduler/registry.h | 50 +++---- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 96 ++++++------ 20 files changed, 378 insertions(+), 387 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/scheduler/heuristic.h diff --git a/benchmarks/cpp/nvfuser/bert.cpp b/benchmarks/cpp/nvfuser/bert.cpp index f105cfe4a4e35..008785c8cf043 100644 --- a/benchmarks/cpp/nvfuser/bert.cpp +++ b/benchmarks/cpp/nvfuser/bert.cpp @@ -133,8 +133,8 @@ static void MagicScheduler_DivMaxSoftDropFwd( std::vector cg_outputs; auto norm_params = getPersistentHeuristics(&fusion, at_inputs); - TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!"); - schedulePersistentKernel(&fusion, norm_params.value()); + TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!"); + schedulePersistentKernel(&fusion, *norm_params); FusionExecutor fe; fe.compileFusion(&fusion); @@ -143,7 +143,7 @@ static void MagicScheduler_DivMaxSoftDropFwd( cudaDeviceSynchronize(); for (auto _ : benchmark_state) { CudaKernelTimer timer; - cg_outputs = fe.runFusion({t0, t1}, norm_params.value().lparams); + cg_outputs = fe.runFusion({t0, t1}, norm_params->lparams); benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0); } // Sync everything up before we're finished, don't want to run ahead on the @@ -193,8 +193,8 @@ static void MagicScheduler_DivMaxSoftDropBwd( std::vector cg_outputs; auto norm_params = getPersistentHeuristics(&fusion, at_inputs); - TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!"); - schedulePersistentKernel(&fusion, norm_params.value()); + TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!"); + schedulePersistentKernel(&fusion, *norm_params); FusionExecutor fe; fe.compileFusion(&fusion); @@ -203,7 +203,7 @@ static void MagicScheduler_DivMaxSoftDropBwd( cudaDeviceSynchronize(); for (auto _ : benchmark_state) { CudaKernelTimer timer; - cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params.value().lparams); + cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params->lparams); benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0); } // Sync everything up before we're finished, don't want to run ahead on the @@ -308,8 +308,8 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd( std::vector cg_outputs; auto norm_params = getPersistentHeuristics(&fusion, at_inputs); - TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!"); - schedulePersistentKernel(&fusion, norm_params.value()); + TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!"); + schedulePersistentKernel(&fusion, *norm_params); FusionExecutor fe; fe.compileFusion(&fusion); @@ -319,7 +319,7 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd( cudaDeviceSynchronize(); for (auto _ : benchmark_state) { CudaKernelTimer timer; - cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams); + cg_outputs = fe.runFusion(at_inputs, norm_params->lparams); benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0); } // Sync everything up before we're finished, don't want to run ahead on the @@ -423,8 +423,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1( std::vector cg_outputs; auto norm_params = getReductionHeuristics(&fusion, at_inputs); - TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!"); - scheduleReduction(&fusion, norm_params.value()); + TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!"); + scheduleReduction(&fusion, *norm_params); FusionExecutor fe; fe.compileFusion(&fusion); @@ -434,7 +434,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1( cudaDeviceSynchronize(); for (auto _ : benchmark_state) { clearL2Cache(); - cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams); + cg_outputs = fe.runFusion(at_inputs, norm_params->lparams); benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0); } // Sync everything up before we're finished, don't want to run ahead on the @@ -534,8 +534,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2( std::vector cg_outputs; auto norm_params = getPersistentHeuristics(&fusion, at_inputs); - TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!"); - schedulePersistentKernel(&fusion, norm_params.value()); + TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!"); + schedulePersistentKernel(&fusion, *norm_params); FusionExecutor fe; fe.compileFusion(&fusion); @@ -545,7 +545,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2( cudaDeviceSynchronize(); for (auto _ : benchmark_state) { CudaKernelTimer timer; - cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams); + cg_outputs = fe.runFusion(at_inputs, norm_params->lparams); benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0); } // Sync everything up before we're finished, don't want to run ahead on the @@ -625,8 +625,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3( std::vector cg_outputs; auto norm_params = getReductionHeuristics(&fusion, at_inputs); - TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!"); - scheduleReduction(&fusion, norm_params.value()); + TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!"); + scheduleReduction(&fusion, *norm_params); FusionExecutor fe; fe.compileFusion(&fusion); @@ -636,7 +636,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3( cudaDeviceSynchronize(); for (auto _ : benchmark_state) { CudaKernelTimer timer; - cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams); + cg_outputs = fe.runFusion(at_inputs, norm_params->lparams); benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0); } // Sync everything up before we're finished, don't want to run ahead on the diff --git a/benchmarks/cpp/nvfuser/broadcast.cpp b/benchmarks/cpp/nvfuser/broadcast.cpp index 8411444ca96a2..05e8e052f4b26 100644 --- a/benchmarks/cpp/nvfuser/broadcast.cpp +++ b/benchmarks/cpp/nvfuser/broadcast.cpp @@ -69,8 +69,7 @@ static void NvFuserScheduler_Broadcast( auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo(); auto executor_instance = compile_log.fusion_executor; - TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value()); - auto params = toString(compile_log.pointwise_params.value()); + auto params = toString(compile_log.params); auto lparams = toString(compile_log.fusion_executor->lastLaunchParams()); benchmark_state.SetLabel(params + lparams); diff --git a/benchmarks/cpp/nvfuser/reduction.cpp b/benchmarks/cpp/nvfuser/reduction.cpp index 3fd1bcb59dfc6..d6fc0ca327ae7 100644 --- a/benchmarks/cpp/nvfuser/reduction.cpp +++ b/benchmarks/cpp/nvfuser/reduction.cpp @@ -65,8 +65,7 @@ static void NvFuserScheduler_Reduction( auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo(); auto executor_instance = compile_log.fusion_executor; - TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value()); - auto rparams = toString(compile_log.reduction_params.value()); + auto rparams = toString(compile_log.params); auto lparams = toString(compile_log.fusion_executor->lastLaunchParams()); benchmark_state.SetLabel(rparams + lparams); diff --git a/benchmarks/cpp/nvfuser/scale_bias_relu.cpp b/benchmarks/cpp/nvfuser/scale_bias_relu.cpp index 6bb7fc18aa0b0..74dbb5324cbab 100644 --- a/benchmarks/cpp/nvfuser/scale_bias_relu.cpp +++ b/benchmarks/cpp/nvfuser/scale_bias_relu.cpp @@ -135,8 +135,7 @@ static void NvFuserScheduler_SBR( auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo(); auto executor_instance = compile_log.fusion_executor; - TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value()); - auto params = toString(compile_log.pointwise_params.value()); + auto params = toString(compile_log.params); auto lparams = toString(compile_log.fusion_executor->lastLaunchParams()); benchmark_state.SetLabel(params + lparams); @@ -238,8 +237,7 @@ static void NvFuserScheduler_SBR_Norm( auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo(); auto executor_instance = compile_log.fusion_executor; - TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value()); - auto params = toString(compile_log.pointwise_params.value()); + auto params = toString(compile_log.params); auto lparams = toString(compile_log.fusion_executor->lastLaunchParams()); benchmark_state.SetLabel(params + lparams); diff --git a/benchmarks/cpp/nvfuser/utils.cpp b/benchmarks/cpp/nvfuser/utils.cpp index c15248bce71d7..3915f7d652989 100644 --- a/benchmarks/cpp/nvfuser/utils.cpp +++ b/benchmarks/cpp/nvfuser/utils.cpp @@ -89,6 +89,20 @@ std::string toString(PointwiseParams params) { return ss.str(); } +std::string toString(const std::shared_ptr& params) { + auto rparams = std::dynamic_pointer_cast(params); + if (rparams) { + return toString(*rparams); + } + auto pparams = std::dynamic_pointer_cast(params); + if (pparams) { + return toString(*pparams); + } + TORCH_INTERNAL_ASSERT( + false, + "Unknown heuristic parameter type. Did you just added a new heuristic parameter type but forget to update here?"); +} + std::string toString(LaunchParams lparams) { std::stringstream ss; lparams.toString(); @@ -123,9 +137,7 @@ TensorView* makeContigTensor(size_t ndims, DataType dtype) { .build(); } -TensorView* makeConcreteTensor( - std::vector shape, - DataType dtype) { +TensorView* makeConcreteTensor(std::vector shape, DataType dtype) { return TensorViewBuilder().shape(shape).dtype(dtype).build(); } @@ -157,15 +169,9 @@ void runBenchmarkIterations( auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo(); auto executor_instance = compile_log.fusion_executor; - if (compile_log.reduction_params.has_value()) { - auto rparams = toString(compile_log.reduction_params.value()); - auto lparams = toString(compile_log.fusion_executor->lastLaunchParams()); - benchmark_state.SetLabel(rparams + lparams); - } else if (compile_log.pointwise_params.has_value()){ - auto pparams = toString(compile_log.pointwise_params.value()); - auto lparams = toString(compile_log.fusion_executor->lastLaunchParams()); - benchmark_state.SetLabel(pparams + lparams); - } + auto params = toString(compile_log.params); + auto lparams = toString(compile_log.fusion_executor->lastLaunchParams()); + benchmark_state.SetLabel(params + lparams); executor_instance->setMeasureKernelTimeFlag(true); diff --git a/benchmarks/cpp/nvfuser/utils.h b/benchmarks/cpp/nvfuser/utils.h index 176290fd76f34..e24fdfb127dab 100644 --- a/benchmarks/cpp/nvfuser/utils.h +++ b/benchmarks/cpp/nvfuser/utils.h @@ -38,6 +38,7 @@ TensorView* makeContigConcreteTensor( std::string toString(ReductionParams rparams); std::string toString(PointwiseParams params); +std::string toString(const std::shared_ptr& params); std::string toString(LaunchParams lparams); // Run benchmark iterations with provided inputs. If not segmented, report diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index 1e43807cb2a1e..e1ed1d56c496d 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -341,32 +341,16 @@ std::vector FusionKernelRuntime::runKernelWithInput( options.index_mode = scheduler_entry->indexMode(); FusionGuard fg(fusion_to_run.get()); scheduler_entry->schedule(fusion_to_run.get()); - // Load launch params for reduction and normalization kernels - if (scheduler_entry->hasReductionParam()) { - launch_params = scheduler_entry->reductionParams().lparams; - } else { - launch_params = scheduler_entry->pointwiseParams().lparams; - } + launch_params = scheduler_entry->params()->lparams; executors_[group_id].compileFusion( fusion_to_run.get(), inputs, launch_params, options); } else { - // Load launch params for reduction and normalization kernels - if (scheduler_entry->hasReductionParam()) { - launch_params = scheduler_entry->reductionParams().lparams; - } else { - launch_params = scheduler_entry->pointwiseParams().lparams; - } + launch_params = scheduler_entry->params()->lparams; } if (profiling_) { most_recent_executor_log_.fusion_executor = &executors_[group_id]; - if (scheduler_entry->hasReductionParam()) { - most_recent_executor_log_.reduction_params = - scheduler_entry->reductionParams(); - } else { - most_recent_executor_log_.pointwise_params = - scheduler_entry->pointwiseParams(); - } + most_recent_executor_log_.params = scheduler_entry->params()->clone(); } auto& executor = executors_[group_id]; @@ -395,11 +379,7 @@ std::vector FusionKernelRuntime::runKernelWithInput( } } std::cout << "Compiler log: " << executor.compilerLog() << "\n"; - if (scheduler_entry->hasReductionParam()) { - std::cout << scheduler_entry->reductionParams().toString() << "\n"; - } else { - std::cout << scheduler_entry->pointwiseParams().toString() << "\n"; - } + std::cout << scheduler_entry->params()->toString() << "\n"; std::cout << "With arguments: " << executor.lastLaunchParams().toString(); std::cout << executor.kernelName() << " " << executor.bytesProcessed() << " bytes/ " << std::setprecision(3) << executor.kernelTimeMs() @@ -604,13 +584,8 @@ void FusionKernelRuntime::updateHeuristicsLaunchParams( update_heuristics->heuristicsList().size() == scheduler_list_length); for (const auto i : c10::irange(scheduler_list_length)) { auto& schedulerPtr = heuristics_->heuristicsList()[i]; - if (schedulerPtr->hasReductionParam()) { - schedulerPtr->updateLaunchConstraint( - update_heuristics->heuristicsList()[i]->reductionParams().lparams); - } else { - schedulerPtr->updateLaunchConstraint( - update_heuristics->heuristicsList()[i]->pointwiseParams().lparams); - } + schedulerPtr->updateLaunchConstraint( + update_heuristics->heuristicsList()[i]->params()->lparams); } } diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index f3faaec18b323..f67742d10f3f4 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -25,8 +25,7 @@ class SchedulerRuntimeInfo; // Utilities for benchmarking and profiling struct ExecutorLog { - c10::optional reduction_params = c10::nullopt; - c10::optional pointwise_params = c10::nullopt; + std::shared_ptr params = nullptr; FusionExecutor* fusion_executor = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/heuristic.h new file mode 100644 index 0000000000000..058c72e592ad1 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler/heuristic.h @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class HeuristicParams { + public: + std::string tag = ""; + + LaunchParams lparams; + + virtual std::string toString() const { + return "Undefined Heuristic Params"; + } + + virtual size_t hash() const = 0; + + virtual ~HeuristicParams() = default; + + virtual bool sameAs(const std::shared_ptr& other) const = 0; + + virtual std::shared_ptr clone() const = 0; + + HeuristicParams() = default; + HeuristicParams(const std::string& tag) : tag(tag) {} +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index bea1767a80c52..d8a3dc9345000 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -33,7 +33,7 @@ int64_t roundUpPow2Or8(const int64_t x) { // Copied from reduction scheduler, should generalize. Simply needed to take out // grid reductions. -ReductionParams innerPersistentHeuristic( +std::shared_ptr innerPersistentHeuristic( const int64_t total_reduction_numel, const int64_t total_iteration_numel, const int64_t inner_most_dimension_numel, @@ -409,51 +409,52 @@ ReductionParams innerPersistentHeuristic( int64_t gdimy = LaunchParams::UNINITIALIZED_VAL; int64_t gdimz = LaunchParams::UNINITIALIZED_VAL; - ReductionParams rparams; + auto rparams = std::make_shared(); - rparams.persistent_kernel = true; - rparams.fastest_dim = true; + rparams->persistent_kernel = true; + rparams->fastest_dim = true; // Inner reduction domain - rparams.cross_block_inner_reduction = true; - rparams.block_dim_inner_reduction = ParallelType::TIDx; - rparams.pad_inner_reduction_to_warp = pad_bdimx; - rparams.batches_per_block_inner_reduction = batches_per_block_inner_reduction; + rparams->cross_block_inner_reduction = true; + rparams->block_dim_inner_reduction = ParallelType::TIDx; + rparams->pad_inner_reduction_to_warp = pad_bdimx; + rparams->batches_per_block_inner_reduction = + batches_per_block_inner_reduction; // For persistent schedules always have to mark the reduction unrolled // otherwise rfactor can fail - rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor; - rparams.vectorize_inner_reduction = vectorize; + rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor; + rparams->vectorize_inner_reduction = vectorize; // Iter domain - rparams.multiple_reds_per_blk = bdimy > 1; - if (rparams.multiple_reds_per_blk) { - rparams.block_dim_iter_dom = ParallelType::TIDy; + rparams->multiple_reds_per_blk = bdimy > 1; + if (rparams->multiple_reds_per_blk) { + rparams->block_dim_iter_dom = ParallelType::TIDy; } if (godim > 1) { - rparams.grid_dim_iter_dom = ParallelType::BIDx; + rparams->grid_dim_iter_dom = ParallelType::BIDx; if (godim > scheduler_utils::x_grid_limit) { - rparams.split_grid_dim_iter_dom = true; + rparams->split_grid_dim_iter_dom = true; gdimx = scheduler_utils::x_grid_limit; } } if (iter_unroll_factor > 1) { - rparams.unroll_factor_iter_dom = iter_unroll_factor; + rparams->unroll_factor_iter_dom = iter_unroll_factor; } // Outer reduction domain - rparams.schedule_3D = total_reduction_numel != inner_most_dimension_numel; - if (rparams.schedule_3D) { - rparams.batches_per_block_outer_reduction = + rparams->schedule_3D = total_reduction_numel != inner_most_dimension_numel; + if (rparams->schedule_3D) { + rparams->batches_per_block_outer_reduction = batches_per_block_outer_reduction; - rparams.block_dim_outer_reduction = ParallelType::TIDz; - rparams.cross_block_outer_reduction = true; - rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor; + rparams->block_dim_outer_reduction = ParallelType::TIDz; + rparams->cross_block_outer_reduction = true; + rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor; } - rparams.lparams = LaunchParams( + rparams->lparams = LaunchParams( gdimx, gdimy, gdimz, @@ -461,7 +462,7 @@ ReductionParams innerPersistentHeuristic( bdimy, LaunchParams::UNINITIALIZED_VAL); - rparams.tag = "Inner Persistent Heuristic.\n"; + rparams->tag = "Inner Persistent Heuristic.\n"; if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { std::cerr << "\n===== Reduction Stats ========\n" @@ -478,7 +479,7 @@ ReductionParams innerPersistentHeuristic( << "\n" << "block(" << (pad_bdimx ? padded_bdimx : bdimx) << ", " << bdimy << ", " << bdimz << ")"; - std::cerr << rparams.toString() << std::endl; + std::cerr << rparams->toString() << std::endl; } return rparams; @@ -487,7 +488,7 @@ ReductionParams innerPersistentHeuristic( // Copied from reduction scheduler, should generalize. Simply needed to take out // grid reductions. // TODO: Check adding iteration domain unrolling -ReductionParams outerPersistentHeuristic( +std::shared_ptr outerPersistentHeuristic( const int64_t total_reduction_numel, const int64_t total_iteration_numel, const int64_t n_tensor_inputs, @@ -695,47 +696,47 @@ ReductionParams outerPersistentHeuristic( gdimx = ceilDiv(total_iteration_numel, bdimx); - ReductionParams rparams; - rparams.batches_per_block_inner_reduction = batches_per_block; - rparams.persistent_kernel = true; + auto rparams = std::make_shared(); + rparams->batches_per_block_inner_reduction = batches_per_block; + rparams->persistent_kernel = true; - rparams.fastest_dim = false; - rparams.cross_block_inner_reduction = true; - rparams.cross_grid_inner_reduction = false; - rparams.multiple_reds_per_blk = bdimx > 1; + rparams->fastest_dim = false; + rparams->cross_block_inner_reduction = true; + rparams->cross_grid_inner_reduction = false; + rparams->multiple_reds_per_blk = bdimx > 1; - if (rparams.multiple_reds_per_blk) { - rparams.block_dim_iter_dom = ParallelType::TIDx; + if (rparams->multiple_reds_per_blk) { + rparams->block_dim_iter_dom = ParallelType::TIDx; } - rparams.grid_dim_iter_dom = ParallelType::BIDx; - rparams.split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit; + rparams->grid_dim_iter_dom = ParallelType::BIDx; + rparams->split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit; - if (rparams.block_dim_iter_dom == ParallelType::TIDx) { - rparams.block_dim_inner_reduction = ParallelType::TIDy; + if (rparams->block_dim_iter_dom == ParallelType::TIDx) { + rparams->block_dim_inner_reduction = ParallelType::TIDy; } else { - rparams.block_dim_inner_reduction = ParallelType::TIDx; + rparams->block_dim_inner_reduction = ParallelType::TIDx; } // Always need to mark inner reduction unroll for rfactor in outer persitent // kernels - rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor; + rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor; - rparams.unroll_factor_iter_dom = iter_unroll_factor; + rparams->unroll_factor_iter_dom = iter_unroll_factor; if (iter_unroll_factor > 1) { - rparams.vectorize_iter_dom = vectorize; + rparams->vectorize_iter_dom = vectorize; } - rparams.lparams = LaunchParams( + rparams->lparams = LaunchParams( LaunchParams::UNINITIALIZED_VAL, LaunchParams::UNINITIALIZED_VAL, LaunchParams::UNINITIALIZED_VAL, - rparams.multiple_reds_per_blk ? bdimx : bdimy, + rparams->multiple_reds_per_blk ? bdimx : bdimy, LaunchParams::UNINITIALIZED_VAL, LaunchParams::UNINITIALIZED_VAL); - rparams.tag = "Outer persistent kernel heuristic.\n"; + rparams->tag = "Outer persistent kernel heuristic.\n"; if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { std::cerr << "\n===== Reduction Stats ========\n" @@ -749,7 +750,7 @@ ReductionParams outerPersistentHeuristic( << "max_multi_reduction_factor: " << max_multi_reduction_factor << "\n" << "block(" << bdimx << ", " << bdimy << ", 1)" << std::endl; - std::cerr << rparams.toString() << std::endl; + std::cerr << rparams->toString() << std::endl; } return rparams; @@ -757,7 +758,7 @@ ReductionParams outerPersistentHeuristic( } // namespace -ReductionParams persistentHeuristic( +std::shared_ptr persistentHeuristic( const int64_t total_reduction_numel, const int64_t total_iteration_numel, const int64_t inner_most_dimension_numel, @@ -767,7 +768,7 @@ ReductionParams persistentHeuristic( const int64_t max_persistent_buffer_size, size_t vectorize_factor, bool project_persistent_buffers) { - ReductionParams rparams; + std::shared_ptr rparams; if (fastest_dim_reduction) { rparams = innerPersistentHeuristic( total_reduction_numel, @@ -786,11 +787,11 @@ ReductionParams persistentHeuristic( max_persistent_buffer_size, vectorize_factor); } - rparams.project_persistent_buffers = project_persistent_buffers; + rparams->project_persistent_buffers = project_persistent_buffers; return rparams; } -TORCH_CUDA_CU_API c10::optional getPersistentHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getPersistentHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { @@ -946,7 +947,7 @@ TORCH_CUDA_CU_API c10::optional getPersistentHeuristics( project_persistent_buffers); } -TORCH_CUDA_CU_API c10::optional getPersistentHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getPersistentHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.h b/torch/csrc/jit/codegen/cuda/scheduler/normalization.h index 298e94cdb8eb3..dbf2eb895f0ff 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.h @@ -18,12 +18,12 @@ namespace cuda { class SchedulerRuntimeInfo; class HeuristicSummary; -TORCH_CUDA_CU_API c10::optional getPersistentHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getPersistentHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache = nullptr); -TORCH_CUDA_CU_API c10::optional getPersistentHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getPersistentHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 03a2cb7bedaa7..afdedb503d589 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -186,7 +186,7 @@ class DomainMap { } // namespace -c10::optional getPointwiseHeuristics( +std::shared_ptr getPointwiseHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { @@ -194,7 +194,7 @@ c10::optional getPointwiseHeuristics( return getPointwiseHeuristics(fusion, runtime_info, data_cache); } -c10::optional getPointwiseHeuristics( +std::shared_ptr getPointwiseHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { @@ -257,10 +257,7 @@ c10::optional getPointwiseHeuristics( std::vector>(); }); broadcast_byte_multiples_entry.get(); - - PointwiseParams params; - params.tag = "Pointwise heuristics"; - return params; + return std::make_shared("Pointwise heuristics"); } // Find all vectorizable inputs/outputs @@ -298,8 +295,7 @@ c10::optional getPointwiseHeuristics( max_unroll_factor = 1; } - PointwiseParams params; - params.tag = "Pointwise heuristics"; + auto params = std::make_shared("Pointwise heuristics"); /* * 2D pointwise scheduling logic. What is expected is there's some @@ -467,7 +463,7 @@ c10::optional getPointwiseHeuristics( // Vectorizing innermost domains // Don't try to vectorize if it's not recommended - params.unroll_factor = 1; + params->unroll_factor = 1; // Compute maximum vectorize factor that can be used size_t vectorize_factor = max_unroll_factor; @@ -497,27 +493,27 @@ c10::optional getPointwiseHeuristics( } if (vectorize_factor == 1) { - params.vectorize = false; - params.unroll_factor = max_unroll_factor; + params->vectorize = false; + params->unroll_factor = max_unroll_factor; } else { - params.vectorize = true; - params.unroll_factor = vectorize_factor; + params->vectorize = true; + params->unroll_factor = vectorize_factor; } TORCH_INTERNAL_ASSERT(right_elem_count > 0 || break_point == 0); TORCH_INTERNAL_ASSERT(!(bdimy > 1 && gdim_right > 1)); - params.break_point = break_point; - params.flip_grid_binding = flip_grid_binding; - params.split_block = bdimy > 1; + params->break_point = break_point; + params->flip_grid_binding = flip_grid_binding; + params->split_block = bdimy > 1; - params.lparams.bind(bdimx, ParallelType::TIDx); - if (params.split_block) { - params.lparams.bind(bdimy, ParallelType::TIDy); + params->lparams.bind(bdimx, ParallelType::TIDx); + if (params->split_block) { + params->lparams.bind(bdimy, ParallelType::TIDy); } if ((flip_grid_binding && gdim_right > 65535) || (!flip_grid_binding && gdim_left > 65535)) { - params.split_grid_y_dim = true; + params->split_grid_y_dim = true; } if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { @@ -535,7 +531,7 @@ c10::optional getPointwiseHeuristics( << (right_elem_count > 0 ? n_elems / right_elem_count : 0) << " RHS elems: " << right_elem_count << std::endl; std::cerr << std::endl; - std::cerr << params.toString() << std::endl; + std::cerr << params->toString() << std::endl; } return params; @@ -548,9 +544,9 @@ LaunchParams schedulePointwise( FUSER_PERF_SCOPE("scheduleFusion"); auto params = getPointwiseHeuristics(fusion, runtime_inputs); TORCH_INTERNAL_ASSERT( - params.has_value(), "Could not schedule pointwise operation."); - schedulePointwise(fusion, params.value()); - return params.value().lparams; + params != nullptr, "Could not schedule pointwise operation."); + schedulePointwise(fusion, *params); + return params->lparams; } bool hasReferenceTensorView(Fusion* fusion) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h index 57b77bb20cc9c..6cba29cd6b4b9 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h @@ -13,12 +13,12 @@ namespace cuda { class SchedulerRuntimeInfo; class HeuristicSummary; -TORCH_CUDA_CU_API c10::optional getPointwiseHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getPointwiseHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache = nullptr); -TORCH_CUDA_CU_API c10::optional getPointwiseHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getPointwiseHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h index 73d49bb985ad2..b63576f08e3f0 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include @@ -13,7 +13,7 @@ namespace cuda { // schedule. Warning: equal operator is intended for use in caching the kernel // associated with these reduction parameters. It does not check if the launch // parameters are equivelent! -class PointwiseParams { +class PointwiseParams : public HeuristicParams { public: // vectorize if true, otherwise unroll bool vectorize = false; @@ -39,12 +39,16 @@ class PointwiseParams { // Unroll or vectorization factor size_t unroll_factor = 1; - std::string tag = ""; - - LaunchParams lparams; + using HeuristicParams::HeuristicParams; // Warning: Does not check launch parameters! - bool operator==(const PointwiseParams& other) const { + bool sameAs( + const std::shared_ptr& other_base) const override { + auto other_casted = std::dynamic_pointer_cast(other_base); + if (other_casted == nullptr) { + return false; + } + const PointwiseParams& other = *other_casted; bool attr_equal = other.vectorize == vectorize && other.break_point == break_point && other.split_block == split_block && other.split_grid_y_dim == split_grid_y_dim && @@ -53,7 +57,7 @@ class PointwiseParams { return attr_equal; } - std::string toString() const { + std::string toString() const override { std::stringstream ss; ss << "\n===== Pointwise Parameters ========\n" << (tag == "" ? "" : "Tag: ") << tag << " Pointwise Characteristics:\n" @@ -82,20 +86,21 @@ class PointwiseParams { ss << "====================================\n"; return ss.str(); } -}; -// Warning: Hash is not based on launch parameters! -class PointwiseParamsHash { - public: - size_t operator()(const PointwiseParams& pp) const { - size_t attr_hash = static_cast(pp.vectorize) ^ - static_cast(pp.break_point) << 4 ^ - static_cast(pp.split_block) << 5 ^ - static_cast(pp.split_grid_y_dim) << 6 ^ - static_cast(pp.unroll_factor) << 9 ^ - static_cast(pp.flip_grid_binding) << 10; + // Warning: Hash is not based on launch parameters! + size_t hash() const override { + size_t attr_hash = static_cast(vectorize) ^ + static_cast(break_point) << 4 ^ + static_cast(split_block) << 5 ^ + static_cast(split_grid_y_dim) << 6 ^ + static_cast(unroll_factor) << 9 ^ + static_cast(flip_grid_binding) << 10; return attr_hash; } + + std::shared_ptr clone() const override { + return std::make_shared(*this); + } }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index 4bf85152f5881..e8781b5f7f6c7 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -62,7 +62,7 @@ void reduceProductTo(int64_t& z, int64_t& y, int64_t& x, const int64_t max) { } } -ReductionParams innerReductionHeuristic( +std::shared_ptr innerReductionHeuristic( const int64_t total_reduction_numel, const int64_t total_iteration_numel, const int64_t inner_most_dimension_numel, @@ -377,21 +377,21 @@ ReductionParams innerReductionHeuristic( // require iterating over this entire function. } - ReductionParams rparams; - rparams.fastest_dim = true; - rparams.cross_block_inner_reduction = true; - rparams.block_dim_inner_reduction = ParallelType::TIDx; - rparams.cross_grid_inner_reduction = gridim > 1; - rparams.multiple_reds_per_blk = bdimy > 1; + auto rparams = std::make_shared(); + rparams->fastest_dim = true; + rparams->cross_block_inner_reduction = true; + rparams->block_dim_inner_reduction = ParallelType::TIDx; + rparams->cross_grid_inner_reduction = gridim > 1; + rparams->multiple_reds_per_blk = bdimy > 1; bool pad_bdimx = bdimx > 16 && bdimx * bdimy < (int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; // If barely just covering reduction dim, don't pad to the next warp pad_bdimx = pad_bdimx && bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel; - rparams.pad_inner_reduction_to_warp = pad_bdimx; + rparams->pad_inner_reduction_to_warp = pad_bdimx; - if (rparams.pad_inner_reduction_to_warp) { + if (rparams->pad_inner_reduction_to_warp) { // Adjust bdimx based on padding auto min_warp_size = (int64_t)at::cuda::getCurrentDeviceProperties()->warpSize; @@ -400,24 +400,24 @@ ReductionParams innerReductionHeuristic( : bdimx + min_warp_size - bdimx % min_warp_size; } - rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor; - rparams.vectorize_inner_reduction = vectorize; + rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor; + rparams->vectorize_inner_reduction = vectorize; - if (rparams.multiple_reds_per_blk) { - rparams.block_dim_iter_dom = ParallelType::TIDy; + if (rparams->multiple_reds_per_blk) { + rparams->block_dim_iter_dom = ParallelType::TIDy; } - rparams.unroll_factor_iter_dom = iter_unroll_factor; + rparams->unroll_factor_iter_dom = iter_unroll_factor; - rparams.schedule_3D = total_reduction_numel != inner_most_dimension_numel; + rparams->schedule_3D = total_reduction_numel != inner_most_dimension_numel; // Outer reduction domain - if (rparams.schedule_3D) { - rparams.cross_grid_outer_reduction = grodim > 1; + if (rparams->schedule_3D) { + rparams->cross_grid_outer_reduction = grodim > 1; if (bdimz > 1) { - rparams.block_dim_outer_reduction = ParallelType::TIDz; - rparams.cross_block_outer_reduction = true; + rparams->block_dim_outer_reduction = ParallelType::TIDz; + rparams->cross_block_outer_reduction = true; } - rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor; + rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor; } int64_t gdimx = LaunchParams::UNINITIALIZED_VAL; @@ -428,38 +428,38 @@ ReductionParams innerReductionHeuristic( // gdimx assigned to grdim. Otherwise it's helpful to pull godim into gdimx in // case it's larger than gdimy can hold, as not doing so can thrash the cache. - if (rparams.cross_grid_inner_reduction) { - rparams.grid_dim_inner_reduction = ParallelType::BIDx; - rparams.split_grid_dim_inner_reduction = true; + if (rparams->cross_grid_inner_reduction) { + rparams->grid_dim_inner_reduction = ParallelType::BIDx; + rparams->split_grid_dim_inner_reduction = true; gdimx = std::min(gridim, scheduler_utils::x_grid_limit); - rparams.grid_dim_iter_dom = ParallelType::BIDy; + rparams->grid_dim_iter_dom = ParallelType::BIDy; if (godim > scheduler_utils::y_grid_limit) { - rparams.split_grid_dim_iter_dom = true; + rparams->split_grid_dim_iter_dom = true; gdimy = std::min(godim, scheduler_utils::y_grid_limit); } } else { - rparams.grid_dim_iter_dom = ParallelType::BIDx; + rparams->grid_dim_iter_dom = ParallelType::BIDx; if (gdimx > scheduler_utils::x_grid_limit) { - rparams.split_grid_dim_iter_dom = true; + rparams->split_grid_dim_iter_dom = true; gdimx = godim; } } - if (rparams.cross_grid_outer_reduction) { - if (rparams.cross_block_inner_reduction) { - rparams.grid_dim_outer_reduction = ParallelType::BIDz; + if (rparams->cross_grid_outer_reduction) { + if (rparams->cross_block_inner_reduction) { + rparams->grid_dim_outer_reduction = ParallelType::BIDz; gdimz = std::min(grodim, scheduler_utils::z_grid_limit); - rparams.split_grid_dim_outer_reduction = true; + rparams->split_grid_dim_outer_reduction = true; } else { - rparams.grid_dim_outer_reduction = ParallelType::BIDy; + rparams->grid_dim_outer_reduction = ParallelType::BIDy; gdimy = std::min(grodim, scheduler_utils::y_grid_limit); - rparams.split_grid_dim_outer_reduction = true; + rparams->split_grid_dim_outer_reduction = true; } } - rparams.lparams = LaunchParams( + rparams->lparams = LaunchParams( gdimx, gdimy, gdimz, @@ -478,20 +478,20 @@ ReductionParams innerReductionHeuristic( << "max_input_dtype_size: " << max_input_dtype_size << "\n" << "block(" << bdimx << ", " << bdimy << ", " << bdimz << ")" << std::endl; - std::cerr << rparams.toString() << std::endl; + std::cerr << rparams->toString() << std::endl; } // If 3d, check if it's supported by the scheduler, otherwise force 1D // schedule - if (rparams.schedule_3D) { - if (rparams.multiple_reds_per_blk && - (rparams.cross_grid_inner_reduction || - rparams.cross_grid_outer_reduction)) { + if (rparams->schedule_3D) { + if (rparams->multiple_reds_per_blk && + (rparams->cross_grid_inner_reduction || + rparams->cross_grid_outer_reduction)) { if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { std::cerr << "\n===== UNSUPPORTED REDUCTION HEURISTIC ========\n"; - std::cerr << rparams.multiple_reds_per_blk << ", " - << (rparams.unroll_factor_inner_reduction > 1) << ", " - << rparams.cross_grid_inner_reduction << std::endl; + std::cerr << rparams->multiple_reds_per_blk << ", " + << (rparams->unroll_factor_inner_reduction > 1) << ", " + << rparams->cross_grid_inner_reduction << std::endl; } return innerReductionHeuristic( total_reduction_numel, @@ -506,7 +506,7 @@ ReductionParams innerReductionHeuristic( return rparams; } -ReductionParams OuterReductionHeuristic( +std::shared_ptr outerReductionHeuristic( const int64_t total_reduction_numel, const int64_t total_iteration_numel, const int64_t n_tensor_inputs, @@ -768,13 +768,13 @@ ReductionParams OuterReductionHeuristic( // Always disabled for now. // bool flip_grid = gidim > 1 && gidim < 8; const bool flip_grid = false; - ReductionParams rparams; + auto rparams = std::make_shared(); // cross grid implies cross block - rparams.cross_block_inner_reduction = bdimy > 1 || grdim > 1; - rparams.cross_grid_inner_reduction = grdim > 1; - if (rparams.cross_grid_inner_reduction) { - rparams.split_grid_dim_inner_reduction = true; - rparams.grid_dim_inner_reduction = + rparams->cross_block_inner_reduction = bdimy > 1 || grdim > 1; + rparams->cross_grid_inner_reduction = grdim > 1; + if (rparams->cross_grid_inner_reduction) { + rparams->split_grid_dim_inner_reduction = true; + rparams->grid_dim_inner_reduction = flip_grid ? ParallelType::BIDx : ParallelType::BIDy; if (flip_grid) { gdimx = std::min(grdim, scheduler_utils::x_grid_limit); @@ -782,17 +782,17 @@ ReductionParams OuterReductionHeuristic( gdimy = std::min(grdim, scheduler_utils::y_grid_limit); } } - rparams.multiple_reds_per_blk = bdimx > 1 || iter_unroll_factor > 1; + rparams->multiple_reds_per_blk = bdimx > 1 || iter_unroll_factor > 1; - if (rparams.multiple_reds_per_blk) { - rparams.block_dim_iter_dom = ParallelType::TIDx; + if (rparams->multiple_reds_per_blk) { + rparams->block_dim_iter_dom = ParallelType::TIDx; } - rparams.grid_dim_iter_dom = + rparams->grid_dim_iter_dom = flip_grid ? ParallelType::BIDy : ParallelType::BIDx; if (gidim > (flip_grid ? scheduler_utils::y_grid_limit : scheduler_utils::x_grid_limit)) { - rparams.split_grid_dim_iter_dom = true; + rparams->split_grid_dim_iter_dom = true; if (flip_grid) { gdimy = scheduler_utils::y_grid_limit; } else { @@ -800,29 +800,29 @@ ReductionParams OuterReductionHeuristic( } } - rparams.flip_grid = flip_grid; + rparams->flip_grid = flip_grid; - if (rparams.cross_block_inner_reduction) { - if (rparams.block_dim_iter_dom == ParallelType::TIDx) { - rparams.block_dim_inner_reduction = ParallelType::TIDy; + if (rparams->cross_block_inner_reduction) { + if (rparams->block_dim_iter_dom == ParallelType::TIDx) { + rparams->block_dim_inner_reduction = ParallelType::TIDy; } else { - rparams.block_dim_inner_reduction = ParallelType::TIDx; + rparams->block_dim_inner_reduction = ParallelType::TIDx; } } - rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor; + rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor; - rparams.unroll_factor_iter_dom = iter_unroll_factor; + rparams->unroll_factor_iter_dom = iter_unroll_factor; if (iter_unroll_factor > 1) { - rparams.vectorize_iter_dom = vectorize; + rparams->vectorize_iter_dom = vectorize; } - rparams.lparams = LaunchParams( + rparams->lparams = LaunchParams( gdimx, gdimy, LaunchParams::UNINITIALIZED_VAL, - rparams.multiple_reds_per_blk ? bdimx : bdimy, - rparams.multiple_reds_per_blk ? bdimy : LaunchParams::UNINITIALIZED_VAL, + rparams->multiple_reds_per_blk ? bdimx : bdimy, + rparams->multiple_reds_per_blk ? bdimy : LaunchParams::UNINITIALIZED_VAL, LaunchParams::UNINITIALIZED_VAL); if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { @@ -833,14 +833,14 @@ ReductionParams OuterReductionHeuristic( << "n_tensor_inputs: " << n_tensor_inputs << "\n" << "max_input_dtype_size: " << max_input_dtype_size << "\n" << "block(" << bdimx << ", " << bdimy << ", 1)" << std::endl; - std::cerr << rparams.toString() << std::endl; + std::cerr << rparams->toString() << std::endl; } return rparams; } } // namespace -ReductionParams reductionHeuristic( +std::shared_ptr reductionHeuristic( const int64_t total_reduction_numel, const int64_t total_iteration_numel, const int64_t inner_most_dimension_numel, @@ -858,7 +858,7 @@ ReductionParams reductionHeuristic( vectorize_factor); } else { // 3D schedules not enabled for outer reductions - return OuterReductionHeuristic( + return outerReductionHeuristic( total_reduction_numel, total_iteration_numel, n_tensor_inputs, @@ -867,7 +867,7 @@ ReductionParams reductionHeuristic( } } -TORCH_CUDA_CU_API c10::optional getReductionHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getReductionHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { @@ -878,7 +878,7 @@ TORCH_CUDA_CU_API c10::optional getReductionHeuristics( return getReductionHeuristics(fusion, runtime_info, data_cache); } -TORCH_CUDA_CU_API c10::optional getReductionHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getReductionHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.h b/torch/csrc/jit/codegen/cuda/scheduler/reduction.h index 7e517b1c75aaf..c09608e74b07b 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.h @@ -13,12 +13,12 @@ namespace cuda { class SchedulerRuntimeInfo; class HeuristicSummary; -TORCH_CUDA_CU_API c10::optional getReductionHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getReductionHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache = nullptr); -TORCH_CUDA_CU_API c10::optional getReductionHeuristics( +TORCH_CUDA_CU_API std::shared_ptr getReductionHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h index 4df5e288eadc5..55e17b4ef6487 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include @@ -13,7 +13,7 @@ namespace cuda { // schedule. Warning: equal operator is intended for use in caching the kernel // associated with these reduction parameters. It does not check if the launch // parameters are equivelent! -class ReductionParams { +class ReductionParams : public HeuristicParams { public: // Reducing inner most dimension? bool fastest_dim = false; @@ -100,18 +100,22 @@ class ReductionParams { // parameters, not used for equivalence/hashing. ParallelType grid_dim_outer_reduction = ParallelType::Serial; - std::string tag = ""; - - LaunchParams lparams; - bool isUnrolled() const { return unroll_factor_inner_reduction > 1 || unroll_factor_iter_dom > 1 || unroll_factor_outer_reduction > 1; } public: + using HeuristicParams::HeuristicParams; + // Warning: Does not check launch parameters! - bool operator==(const ReductionParams& other) const { + bool sameAs( + const std::shared_ptr& other_base) const override { + auto other_casted = std::dynamic_pointer_cast(other_base); + if (other_casted == nullptr) { + return false; + } + const ReductionParams& other = *other_casted; bool attr_equal = other.fastest_dim == fastest_dim && other.persistent_kernel == persistent_kernel && other.project_persistent_buffers == project_persistent_buffers && @@ -139,7 +143,7 @@ class ReductionParams { return attr_equal; } - std::string toString() const { + std::string toString() const override { std::stringstream ss; ss << "\n===== Reduction Parameters ========\n" << (tag == "" ? "" : "Tag: ") << tag << "\n" @@ -216,38 +220,37 @@ class ReductionParams { ss << "====================================\n"; return ss.str(); } -}; -// Warning: Hash is not based on launch parameters! -class ReductionParamsHash { - public: - size_t operator()(const ReductionParams& rp) const { + // Warning: Hash is not based on launch parameters! + size_t hash() const override { constexpr size_t bits = sizeof(std::size_t) * 8; - size_t attr_hash = static_cast(rp.fastest_dim) << (bits - 1) ^ - static_cast(rp.persistent_kernel) << (bits - 2) ^ - static_cast(rp.project_persistent_buffers) << (bits - 3) ^ - static_cast(rp.schedule_3D) << (bits - 4) ^ - static_cast(rp.flip_grid) << (bits - 5) ^ - static_cast(rp.cross_block_inner_reduction) << (bits - 6) ^ - static_cast(rp.cross_grid_inner_reduction) << (bits - 7) ^ - static_cast(rp.unroll_factor_inner_reduction) << (bits - 8) ^ - static_cast(rp.vectorize_inner_reduction) << (bits - 9) ^ - static_cast(rp.split_grid_dim_inner_reduction) << (bits - 10) ^ - static_cast(rp.pad_inner_reduction_to_warp) << (bits - 11) ^ - static_cast(rp.batches_per_block_inner_reduction) - << (bits - 12) ^ - static_cast(rp.multiple_reds_per_blk) << (bits - 13) ^ - static_cast(rp.unroll_factor_iter_dom) << (bits - 14) ^ - static_cast(rp.vectorize_iter_dom) << (bits - 15) ^ - static_cast(rp.split_grid_dim_iter_dom) << (bits - 16) ^ - static_cast(rp.cross_block_outer_reduction) << (bits - 17) ^ - static_cast(rp.cross_grid_outer_reduction) << (bits - 18) ^ - static_cast(rp.split_grid_dim_outer_reduction) << (bits - 19) ^ - static_cast(rp.batches_per_block_outer_reduction) - << (bits - 20) ^ - static_cast(rp.unroll_factor_outer_reduction) << (bits - 21); + size_t attr_hash = static_cast(fastest_dim) << (bits - 1) ^ + static_cast(persistent_kernel) << (bits - 2) ^ + static_cast(project_persistent_buffers) << (bits - 3) ^ + static_cast(schedule_3D) << (bits - 4) ^ + static_cast(flip_grid) << (bits - 5) ^ + static_cast(cross_block_inner_reduction) << (bits - 6) ^ + static_cast(cross_grid_inner_reduction) << (bits - 7) ^ + static_cast(unroll_factor_inner_reduction) << (bits - 8) ^ + static_cast(vectorize_inner_reduction) << (bits - 9) ^ + static_cast(split_grid_dim_inner_reduction) << (bits - 10) ^ + static_cast(pad_inner_reduction_to_warp) << (bits - 11) ^ + static_cast(batches_per_block_inner_reduction) << (bits - 12) ^ + static_cast(multiple_reds_per_blk) << (bits - 13) ^ + static_cast(unroll_factor_iter_dom) << (bits - 14) ^ + static_cast(vectorize_iter_dom) << (bits - 15) ^ + static_cast(split_grid_dim_iter_dom) << (bits - 16) ^ + static_cast(cross_block_outer_reduction) << (bits - 17) ^ + static_cast(cross_grid_outer_reduction) << (bits - 18) ^ + static_cast(split_grid_dim_outer_reduction) << (bits - 19) ^ + static_cast(batches_per_block_outer_reduction) << (bits - 20) ^ + static_cast(unroll_factor_outer_reduction) << (bits - 21); return attr_hash; } + + std::shared_ptr clone() const override { + return std::make_shared(*this); + } }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index c0ce9eb011a85..75a41278d767f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -721,18 +721,7 @@ bool SchedulerEntry::sameAs(const SchedulerEntry* other) { if (index_mode_ != other->index_mode_) { return false; } - // Heuristic equal should imply has_reduction_param_ equal, - // need to double check if it is the case before removing - // the below one. - if (has_reduction_param_ != other->has_reduction_param_) { - return false; - } - if (has_reduction_param_) { - return rparams_ == other->rparams_; - } else { - return pparams_ == other->pparams_; - } - return true; + return params_->sameAs(other->params_); } namespace { @@ -834,7 +823,7 @@ class ReductionScheduler : public SchedulerEntry { Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) - : SchedulerEntry(ScheduleHeuristic::Reduction, true) { + : SchedulerEntry(ScheduleHeuristic::Reduction) { computeHeuristics(fusion, runtime_info, data_cache); } @@ -964,7 +953,7 @@ class ReductionScheduler : public SchedulerEntry { void schedule(Fusion* fusion) override { FUSER_PERF_SCOPE("Schedule Single Reduction"); - scheduleReduction(fusion, rparams()); + scheduleReduction(fusion, reductionParams()); } private: @@ -972,9 +961,8 @@ class ReductionScheduler : public SchedulerEntry { Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) { - auto param = getReductionHeuristics(fusion, runtime_info, data_cache); - TORCH_INTERNAL_ASSERT(param.has_value()); - rparams() = param.value(); + params_ = getReductionHeuristics(fusion, runtime_info, data_cache); + TORCH_INTERNAL_ASSERT(params_ != nullptr); } }; @@ -984,7 +972,7 @@ class PointWiseScheduler : public SchedulerEntry { Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) - : SchedulerEntry(ScheduleHeuristic::PointWise, false) { + : SchedulerEntry(ScheduleHeuristic::PointWise) { computeHeuristics(fusion, runtime_info, data_cache); } @@ -1027,16 +1015,15 @@ class PointWiseScheduler : public SchedulerEntry { void schedule(Fusion* fusion) override { FUSER_PERF_SCOPE("Schedule PointWise Fusion"); - schedulePointwise(fusion, pparams()); + schedulePointwise(fusion, pointwiseParams()); } void computeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) { - auto pparam = getPointwiseHeuristics(fusion, runtime_info, data_cache); - TORCH_INTERNAL_ASSERT(pparam.has_value()); - pparams() = pparam.value(); + params_ = getPointwiseHeuristics(fusion, runtime_info, data_cache); + TORCH_INTERNAL_ASSERT(params_ != nullptr); } }; @@ -1046,13 +1033,13 @@ class PersistentKernelScheduler : public SchedulerEntry { Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) - : SchedulerEntry(ScheduleHeuristic::Persistent, true) { + : SchedulerEntry(ScheduleHeuristic::Persistent) { computeHeuristics(fusion, runtime_info, data_cache); } void schedule(Fusion* fusion) override { FUSER_PERF_SCOPE("Schedule Persistent Fusion"); - schedulePersistentKernel(fusion, rparams()); + schedulePersistentKernel(fusion, reductionParams()); } static bool canScheduleCompileTime(Fusion* fusion) { @@ -1263,9 +1250,8 @@ class PersistentKernelScheduler : public SchedulerEntry { Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) { - auto param = getPersistentHeuristics(fusion, runtime_info, data_cache); - TORCH_INTERNAL_ASSERT(param.has_value()); - rparams() = param.value(); + params_ = getPersistentHeuristics(fusion, runtime_info, data_cache); + TORCH_INTERNAL_ASSERT(params_ != nullptr); } }; @@ -1367,11 +1353,7 @@ c10::optional SchedulerEntry::proposeHeuristics( } size_t SchedulerEntryHash::operator()(const SchedulerEntry& se) const { - if (se.hasReductionParam()) { - return ReductionParamsHash()(se.reductionParams()); - } else { - return PointwiseParamsHash()(se.pointwiseParams()); - } + return se.params()->hash(); } std::string toString(ScheduleHeuristic sh) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.h b/torch/csrc/jit/codegen/cuda/scheduler/registry.h index dcfee08ec08ba..0af9163518063 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.h @@ -2,9 +2,14 @@ #include #include #include +#include +#include +#include #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -158,10 +163,6 @@ class TORCH_CUDA_CU_API SchedulerEntry { //! Heuristic comparison bool sameAs(const SchedulerEntry* other); - bool hasReductionParam() const { - return has_reduction_param_; - } - ScheduleHeuristic heuristic() const { return heuristc_; } @@ -170,51 +171,38 @@ class TORCH_CUDA_CU_API SchedulerEntry { return index_mode_; } + const std::shared_ptr& params() const { + return params_; + } + const ReductionParams& reductionParams() const { + auto rparams = std::dynamic_pointer_cast(params_); TORCH_INTERNAL_ASSERT( - has_reduction_param_, "This schedule heuristic is not reduction."); - return rparams_; + rparams != nullptr, "Heuristic parameter is not a reduction parameter"); + return *rparams; } const PointwiseParams& pointwiseParams() const { + auto pparams = std::dynamic_pointer_cast(params_); TORCH_INTERNAL_ASSERT( - !has_reduction_param_, "This schedule heuristic is not pointwise."); - return pparams_; + pparams != nullptr, "Heuristic parameter is not a pointwise parameter"); + return *pparams; } void updateLaunchConstraint(const LaunchParams& launch_params) { - if (hasReductionParam()) { - rparams_.lparams = launch_params; - } else { - pparams_.lparams = launch_params; - } + params_->lparams = launch_params; } protected: - explicit SchedulerEntry(ScheduleHeuristic heuristic, bool has_reduction_param) - : heuristc_(heuristic), has_reduction_param_(has_reduction_param) {} + explicit SchedulerEntry(ScheduleHeuristic heuristic) : heuristc_(heuristic) {} - ReductionParams& rparams() { - return rparams_; - } - - PointwiseParams& pparams() { - return pparams_; - } + //! Heuristic parameters if applicable + std::shared_ptr params_ = nullptr; private: //! What kind of heuristics does this entry have? const ScheduleHeuristic heuristc_; - //! Has reduction params if true, else has pointwise params - const bool has_reduction_param_; - - //! Reduction parameters if applicable - ReductionParams rparams_; - - //! Pointwise parameters if applicable - PointwiseParams pparams_; - //! Kernel Index Mode KernelIndexMode index_mode_ = KernelIndexMode::INT64; }; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index cf8dd35562528..33be276bb715b 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -5714,12 +5714,11 @@ TEST_F(NVFuserTest, FusionAdvancedIndexing6_CUDA) { std::vector reduction_axes{0, 1}; auto reduction_params = getReductionHeuristics(&fusion, {input0, input1}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); + scheduleReduction(&fusion, *reduction_params); FusionExecutor fe; - fe.compileFusion(&fusion, {input0, input1}, reduction_params.value().lparams); - auto cg_outputs = - fe.runFusion({input0, input1}, reduction_params.value().lparams); + fe.compileFusion(&fusion, {input0, input1}, reduction_params->lparams); + auto cg_outputs = fe.runFusion({input0, input1}, reduction_params->lparams); auto aten_output = input0.add(input1).to(at::kDouble).sum(reduction_axes); @@ -5731,7 +5730,7 @@ TEST_F(NVFuserTest, FusionAdvancedIndexing6_CUDA) { __LINE__, __FILE__, "", - reduction_params.value().lparams); + reduction_params->lparams); } TEST_F(NVFuserTest, FusionAdvancedIndexing7_CUDA) { @@ -7661,9 +7660,9 @@ TEST_F(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { // Apply reduction heuristic auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); + scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -7790,9 +7789,9 @@ TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { // Apply reduction heuristic auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); + scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -7897,8 +7896,8 @@ TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { // Apply reduction heuristic auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); - auto lparams = reduction_params.value().lparams; + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -7940,8 +7939,8 @@ TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); - auto lparams = reduction_params.value().lparams; + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -8015,9 +8014,9 @@ TEST_F(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { auto aten_output = aten_input.to(at::kDouble).sum({0}); auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!"); - scheduleReduction(&fusion, reduction_params.value()); - auto lparams = reduction_params.value().lparams; + TORCH_CHECK(reduction_params != nullptr, "Reduction is not found!"); + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -8097,9 +8096,9 @@ TEST_F(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { : at::randn({rdim, odim}, options)); auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!"); - scheduleReduction(&fusion, reduction_params.value()); - auto lparams = reduction_params.value().lparams; + TORCH_CHECK(reduction_params != nullptr, "Reduction is not found!"); + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -8757,9 +8756,9 @@ TEST_F(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { auto reduction_params = getPersistentHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - schedulePersistentKernel(&fusion, reduction_params.value()); + schedulePersistentKernel(&fusion, *reduction_params); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -8812,9 +8811,9 @@ TEST_F(NVFuserTest, FusionTestMaskSoftmax_CUDA) { getPersistentHeuristics(&fusion, {aten_input, aten_mask}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - schedulePersistentKernel(&fusion, reduction_params.value()); + schedulePersistentKernel(&fusion, *reduction_params); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion, {aten_input, aten_mask}, lparams); @@ -10870,11 +10869,11 @@ TEST_F(NVFuserTest, FusionReductionHalf_CUDA) { auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); + scheduleReduction(&fusion, *reduction_params); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -10942,8 +10941,8 @@ TEST_F(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { // Apply reduction heuristic auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); - auto lparams = reduction_params.value().lparams; + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -10989,8 +10988,8 @@ TEST_F(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); - auto lparams = reduction_params.value().lparams; + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -11035,8 +11034,8 @@ TEST_F(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { // Apply reduction heuristic auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); - auto lparams = reduction_params.value().lparams; + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -12949,9 +12948,9 @@ TEST_F(NVFuserTest, FusionWelfordSchedule_CUDA) { at::Tensor t0 = at::randn({M, N}, options); // TODO: Why do we use launch params from here, but not scheduling??? auto reduction_params = getReductionHeuristics(&fusion, {t0}); - scheduleReduction(&fusion, reduction_params.value()); + scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {t0}, lparams); auto outputs = fe.runFusion({t0}, lparams); @@ -12971,7 +12970,7 @@ TEST_F(NVFuserTest, FusionWelfordSchedule_CUDA) { __LINE__, __FILE__, "validate welford", - reduction_params.value().lparams); + reduction_params->lparams); } namespace { @@ -13025,9 +13024,9 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) { } auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - scheduleReduction(&fusion, reduction_params.value()); + scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {aten_input}, lparams); @@ -13053,7 +13052,7 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) { __LINE__, __FILE__, "validate welford", - reduction_params.value().lparams); + reduction_params->lparams); } } // namespace @@ -13764,8 +13763,9 @@ TEST_F(NVFuserTest, FusionMultipleVectorize_CUDA) { auto outputs = executor_cache.runFusionWithInputs({t0, t1}); auto runtime1 = executor_cache.getMostRecentKernelRuntime(); - auto log1 = executor_cache.getMostRecentExecutorInfo().pointwise_params; - TORCH_CHECK(log1.has_value()); + auto log1 = std::dynamic_pointer_cast( + executor_cache.getMostRecentExecutorInfo().params); + TORCH_CHECK(log1 != nullptr); TORCH_CHECK(log1->vectorize); testValidate( @@ -13777,8 +13777,9 @@ TEST_F(NVFuserTest, FusionMultipleVectorize_CUDA) { outputs = executor_cache.runFusionWithInputs({t0, t1}); auto runtime2 = executor_cache.getMostRecentKernelRuntime(); - auto log2 = executor_cache.getMostRecentExecutorInfo().pointwise_params; - TORCH_CHECK(log2.has_value()); + auto log2 = std::dynamic_pointer_cast( + executor_cache.getMostRecentExecutorInfo().params); + TORCH_CHECK(log2 != nullptr); TORCH_CHECK(log2->vectorize); testValidate( @@ -13790,8 +13791,9 @@ TEST_F(NVFuserTest, FusionMultipleVectorize_CUDA) { outputs = executor_cache.runFusionWithInputs({t0, t1}); auto runtime3 = executor_cache.getMostRecentKernelRuntime(); - auto log3 = executor_cache.getMostRecentExecutorInfo().pointwise_params; - TORCH_CHECK(log3.has_value()); + auto log3 = std::dynamic_pointer_cast( + executor_cache.getMostRecentExecutorInfo().params); + TORCH_CHECK(log3 != nullptr); TORCH_CHECK(log3->vectorize); testValidate( @@ -16186,10 +16188,10 @@ TEST_F(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { auto reduction_params = getReductionHeuristics(&fusion, {input0, input1}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, reduction_params.value()); + scheduleReduction(&fusion, *reduction_params); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {input0, input1}, lparams); auto cg_outputs = fe.runFusion({input0, input1}, lparams); @@ -16234,9 +16236,9 @@ TEST_F(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { auto reduction_params = getPersistentHeuristics(&fusion, {input0, input1}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - schedulePersistentKernel(&fusion, reduction_params.value()); + schedulePersistentKernel(&fusion, *reduction_params); - auto lparams = reduction_params.value().lparams; + auto lparams = reduction_params->lparams; FusionExecutor fe; fe.compileFusion(&fusion, {input0, input1}, lparams); auto cg_outputs = fe.runFusion({input0, input1}, lparams); From eb0078101fd1de53c5e762b20d729b220dc4e5a6 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 25 Jul 2022 00:33:13 -0700 Subject: [PATCH 3/4] cleanup --- torch/csrc/jit/codegen/cuda/scheduler/registry.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.h b/torch/csrc/jit/codegen/cuda/scheduler/registry.h index 0af9163518063..7d2af85bfad0e 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.h @@ -8,8 +8,6 @@ #include #include -#include - namespace torch { namespace jit { namespace fuser { From 4eedbe77d35da81c7b1b74d8b2811c5f094835c0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 25 Jul 2022 00:44:23 -0700 Subject: [PATCH 4/4] cleanup english --- .../csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h | 8 ++++---- .../csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h index b63576f08e3f0..3d2cb5ee9521f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h @@ -9,10 +9,10 @@ namespace jit { namespace fuser { namespace cuda { -// Parameters the Reduction Heuristic Generates to describe the optimial -// schedule. Warning: equal operator is intended for use in caching the kernel -// associated with these reduction parameters. It does not check if the launch -// parameters are equivelent! +// Parameters of the pointwise heuristic to describe the optimial schedule. +// Warning: equal operator is intended for use in caching the kernel associated +// with these pointwise parameters. It does not check if the launch parameters +// are equivelent! class PointwiseParams : public HeuristicParams { public: // vectorize if true, otherwise unroll diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h index 55e17b4ef6487..5349b64aeaffc 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h @@ -9,10 +9,10 @@ namespace jit { namespace fuser { namespace cuda { -// Parameters the Reduction Heuristic Generates to describe the optimial -// schedule. Warning: equal operator is intended for use in caching the kernel -// associated with these reduction parameters. It does not check if the launch -// parameters are equivelent! +// Parameters of the reduction heuristic to describe the optimial schedule. +// Warning: equal operator is intended for use in caching the kernel associated +// with these reduction parameters. It does not check if the launch parameters +// are equivelent! class ReductionParams : public HeuristicParams { public: // Reducing inner most dimension?