From 3c1cff822d696ea330e528435b4f0f555f305aab Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" <165712832+naromero77amd@users.noreply.github.com> Date: Mon, 21 Apr 2025 17:36:36 -0500 Subject: [PATCH] [release/2.5][ROCm][TunableOp] Improve identification of fastest solution (#144942) (#2018) This PR addresses some stability issues with identifying the fastest solution on AMD GPUs, particularly the MI300. Changes include: - An improved timer, StreamTimerNoSync - More aggressive skipping of slow solutions - Additional statistics that can be used for diagnostics PYTORCH_TUNABLEOP_VERBOSE=3 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144942 Approved by: https://github.com/jeffdaily (cherry picked from commit fd0cd6a08f706b7bb1dedb296217b6441e4fb9ff) --- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 7 -- aten/src/ATen/cuda/tunable/GemmRocblas.h | 3 - aten/src/ATen/cuda/tunable/StreamTimer.cpp | 25 +++- aten/src/ATen/cuda/tunable/StreamTimer.h | 16 +++ aten/src/ATen/cuda/tunable/TunableGemm.h | 17 ++- aten/src/ATen/cuda/tunable/TunableOp.h | 130 +++++++++++++++++++-- 6 files changed, 172 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index b0a3d7ef6e1d..7564b37ca5d7 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -625,13 +625,6 @@ auto GetHipBlasLtTypeStringAndOps() { heuristic_result)); TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle)); - // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic. - std::sort(heuristic_result.begin(), - heuristic_result.end(), - [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) { - return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo); - }); - int returned_algo_count = heuristic_result.size(); std::vector>>> ret; for (int i = 0; i < returned_algo_count; i++) { diff --git a/aten/src/ATen/cuda/tunable/GemmRocblas.h b/aten/src/ATen/cuda/tunable/GemmRocblas.h index 026836fc73cc..182d597fe29c 100644 --- a/aten/src/ATen/cuda/tunable/GemmRocblas.h +++ b/aten/src/ATen/cuda/tunable/GemmRocblas.h @@ -192,9 +192,6 @@ auto GetRocBlasGemmTypeStringAndOps() { rocblas_gemm_flags_none, solutions.data(), &solution_size)); - // Sort the solutions in ascending order to make the solution vector deterministic across runs - std::sort(solutions.begin(), solutions.end()); - std::vector>>>> ret; for (size_t i = 0; i < solutions.size(); ++i) { auto callable = std::make_unique>(solutions[i]); diff --git a/aten/src/ATen/cuda/tunable/StreamTimer.cpp b/aten/src/ATen/cuda/tunable/StreamTimer.cpp index ed24a29d9919..8b9e6f05cbf1 100644 --- a/aten/src/ATen/cuda/tunable/StreamTimer.cpp +++ b/aten/src/ATen/cuda/tunable/StreamTimer.cpp @@ -24,7 +24,7 @@ StreamTimer::StreamTimer() { StreamTimer::~StreamTimer() = default; void StreamTimer::Start() { - AT_CUDA_CHECK(cudaDeviceSynchronize()); + AT_CUDA_CHECK(cudaEventSynchronize(start_)); AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream())); } @@ -40,4 +40,27 @@ float StreamTimer::Duration() { return time; } +StreamTimerNoSync::StreamTimerNoSync() { + AT_CUDA_CHECK(cudaEventCreate(&start_)); + AT_CUDA_CHECK(cudaEventCreate(&end_)); +} + +StreamTimerNoSync::~StreamTimerNoSync() = default; + +void StreamTimerNoSync::Start() { + AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream())); +} + +void StreamTimerNoSync::End() { + AT_CUDA_CHECK(cudaEventRecord(end_, at::cuda::getCurrentCUDAStream())); +} + +float StreamTimerNoSync::Duration() { + auto time = std::numeric_limits::quiet_NaN(); + AT_CUDA_CHECK(cudaEventSynchronize(end_)); + // time is in ms with a resolution of 1 us + AT_CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_)); + return time; +} + } // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/StreamTimer.h b/aten/src/ATen/cuda/tunable/StreamTimer.h index c83291d1b0e5..15ed5e769975 100644 --- a/aten/src/ATen/cuda/tunable/StreamTimer.h +++ b/aten/src/ATen/cuda/tunable/StreamTimer.h @@ -31,4 +31,20 @@ class StreamTimer : public ITimer { cudaEvent_t end_{}; }; +class StreamTimerNoSync : public ITimer { + public: + StreamTimerNoSync(); + ~StreamTimerNoSync() override; + + void Start() override; + + void End() override; + + float Duration() override; + + private: + cudaEvent_t start_{}; + cudaEvent_t end_{}; +}; + } // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index e4e2c0f51862..8c4190bda977 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -14,7 +14,6 @@ #include #include #endif -#include #include #include #include @@ -198,7 +197,7 @@ inline const char* TypeName(c10::complex v) { } template -class GemmTunableOp : public TunableOp, StreamTimer> { +class GemmTunableOp : public TunableOp> { public: GemmTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); @@ -223,6 +222,8 @@ class GemmTunableOp : public TunableOp, StreamTimer> { } } #endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); } std::string Signature() override { @@ -231,7 +232,7 @@ class GemmTunableOp : public TunableOp, StreamTimer> { }; template -class GemmAndBiasTunableOp : public TunableOp, StreamTimer> { +class GemmAndBiasTunableOp : public TunableOp> { public: GemmAndBiasTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); @@ -249,6 +250,8 @@ class GemmAndBiasTunableOp : public TunableOp, StreamTimer> } } #endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); } std::string Signature() override { @@ -257,7 +260,7 @@ class GemmAndBiasTunableOp : public TunableOp, StreamTimer> }; template -class GemmStridedBatchedTunableOp : public TunableOp, StreamTimer> { +class GemmStridedBatchedTunableOp : public TunableOp> { public: GemmStridedBatchedTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); @@ -282,6 +285,8 @@ class GemmStridedBatchedTunableOp : public TunableOp } } #endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); } std::string Signature() override { @@ -290,7 +295,7 @@ class GemmStridedBatchedTunableOp : public TunableOp }; template -class ScaledGemmTunableOp : public TunableOp, StreamTimer> { +class ScaledGemmTunableOp : public TunableOp> { public: ScaledGemmTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); @@ -300,6 +305,8 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> this->RegisterOp(std::move(name), std::move(op)); } #endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); } std::string Signature() override { diff --git a/aten/src/ATen/cuda/tunable/TunableOp.h b/aten/src/ATen/cuda/tunable/TunableOp.h index b1c607c72e0c..624158d9ca3c 100644 --- a/aten/src/ATen/cuda/tunable/TunableOp.h +++ b/aten/src/ATen/cuda/tunable/TunableOp.h @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include @@ -35,7 +36,57 @@ class Callable { } }; -template +namespace { + +/** http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance */ + +class Stats { + public: + Stats() { + _n = 0UL; + _mean = 0.0; + _M2 = 0.0; + _sum = 0.0; + _min = 0.0; + _max = 0.0; + } + + void sample_value(const double x) { + double delta = 0; + _sum = _sum + x; + if (0UL == _n) { + _min = x; + _max = x; + } + else { + _min = _min < x ? _min : x; + _max = _max > x ? _max : x; + } + _n = _n + 1UL; + delta = x - _mean; + _mean = _mean + delta/_n; + _M2 = _M2 + delta * (x - _mean); + } + + double variance() const { + return _M2/(_n-1); + } + + double stddev() const { + return std::sqrt(variance()); + } + + unsigned long _n; + double _mean; + double _M2; + double _sum; + double _min; + double _max; +}; + +} // anonymous namespace + +template class TunableOp { public: virtual ~TunableOp() = default; @@ -100,10 +151,17 @@ class TunableOp { } } - static double Profile(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + static double ProfileSimple(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { TuningContext* ctx = getTuningContext(); bool do_flush = ctx->IsICacheFlushEnabled(); - TimerT timer{}; + StreamTimerNoSync timer{}; + + // Small Mandatory Warmup + // Reduces outliers + for (size_t i = 0; i < 2; i++) { + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + timer.Start(); for (size_t i = 0; i < num_iter; i++) { if (do_flush) { @@ -115,6 +173,32 @@ class TunableOp { return timer.Duration() / num_iter; } + static Stats ProfileStats(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + std::vector timer(num_iter); + + // Small Mandatory Warmup + // Reduces outliers + for (size_t i = 0; i < 2; i++) { + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + + for (size_t i = 0; i < num_iter; i++) { + timer[i].Start(); + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + timer[i].End(); + if (do_flush) { + at::cuda::flush_icache(); + } + } + Stats s; + for (size_t i = 0; i < num_iter; i++) { + s.sample_value(timer[i].Duration()); + } + return s; + } + protected: virtual ResultEntry FindFastest(const ParamsT* params) { TuningContext* ctx = getTuningContext(); @@ -184,14 +268,25 @@ class TunableOp { } // collect a small profile - constexpr const int approx_num_iter = 3; - auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset); + int approx_num_iter = 3; + auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset); + double approx_duration = s._mean; // bail if too slow - if (approx_duration > 2 * min_duration_ms) { + if (approx_duration > 1.5 * min_duration_ms) { TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); continue; } + // 2nd phase skip, more aggressive + approx_num_iter = 10; + s = ProfileStats(candidate, reusable_params, approx_num_iter, offset); + approx_duration = s._mean; + // bail if too slow + if (approx_duration > 1.15 * min_duration_ms) { + TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + // for warmup does user set max duration, max iters, or both? // warmup is allowed to be skipped by setting either iterations or duration to 0 double max_warmup_duration = ctx->GetMaxWarmupDurationMs(); @@ -238,12 +333,27 @@ class TunableOp { "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]); TUNABLE_LOG3("├──offset at ", offset); WarmUp(candidate, reusable_params, warmup_iter, offset); - auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset); - if (duration_ms < min_duration_ms) { - TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]); - min_duration_ms = duration_ms; + s = ProfileStats(candidate, reusable_params, tuning_iter, offset); + auto s_stddev = s.stddev(); + // Assume normal distribution. + // Solution with smallest mean + 2*sigma will be a better solution? + // if ((s._mean + 2*s_stddev) < (min_duration_ms + 2*min_stddev_ms)) { + if (s._mean < min_duration_ms) { + TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i], + " min ", s._min, + " max ", s._max, + " mean ", s._mean, + " std ", s_stddev); + min_duration_ms = s._mean; id_name = op_names_[i]; } + else { + TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i], + " min ", s._min, + " max ", s._max, + " mean ", s._mean, + " std ", s_stddev); + } } for (size_t i = 0; i < reusable_params.size(); i++) {