Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions aten/src/ATen/cuda/tunable/GemmHipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,13 +565,6 @@ auto GetHipBlasLtTypeStringAndOps() {
heuristic_result));
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));

// Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
std::sort(heuristic_result.begin(),
heuristic_result.end(),
[](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
});

int returned_algo_count = heuristic_result.size();
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
for (int i = 0; i < returned_algo_count; i++) {
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/cuda/tunable/GemmRocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,6 @@ auto GetRocBlasGemmTypeStringAndOps() {
rocblas_gemm_flags_none,
solutions.data(),
&solution_size));
// Sort the solutions in ascending order to make the solution vector deterministic across runs
std::sort(solutions.begin(), solutions.end());

std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
for (size_t i = 0; i < solutions.size(); ++i) {
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
Expand Down
25 changes: 24 additions & 1 deletion aten/src/ATen/cuda/tunable/StreamTimer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ StreamTimer::~StreamTimer() {
}

void StreamTimer::Start() {
AT_CUDA_CHECK(cudaDeviceSynchronize());
AT_CUDA_CHECK(cudaEventSynchronize(start_));
AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream()));
}

Expand All @@ -40,4 +40,27 @@ float StreamTimer::Duration() {
return time;
}

StreamTimerNoSync::StreamTimerNoSync() {
AT_CUDA_CHECK(cudaEventCreate(&start_));
AT_CUDA_CHECK(cudaEventCreate(&end_));
}

StreamTimerNoSync::~StreamTimerNoSync() = default;

void StreamTimerNoSync::Start() {
AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream()));
}

void StreamTimerNoSync::End() {
AT_CUDA_CHECK(cudaEventRecord(end_, at::cuda::getCurrentCUDAStream()));
}

float StreamTimerNoSync::Duration() {
auto time = std::numeric_limits<float>::quiet_NaN();
AT_CUDA_CHECK(cudaEventSynchronize(end_));
// time is in ms with a resolution of 1 us
AT_CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_));
return time;
}

} // namespace at::cuda::tunable
16 changes: 16 additions & 0 deletions aten/src/ATen/cuda/tunable/StreamTimer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,20 @@ class StreamTimer : public ITimer {
cudaEvent_t end_;
};

class StreamTimerNoSync : public ITimer {
public:
StreamTimerNoSync();
~StreamTimerNoSync() override;

void Start() override;

void End() override;

float Duration() override;

private:
cudaEvent_t start_{};
cudaEvent_t end_{};
};

} // namespace at::cuda::tunable
17 changes: 12 additions & 5 deletions aten/src/ATen/cuda/tunable/TunableGemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <ATen/cuda/tunable/GemmHipblaslt.h>
#include <ATen/cuda/tunable/GemmRocblas.h>
#endif
#include <ATen/cuda/tunable/StreamTimer.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Float8_e4m3fn.h>
Expand Down Expand Up @@ -190,7 +189,7 @@ inline std::string TypeName(c10::complex<float> v) {
}

template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
class GemmTunableOp : public TunableOp<GemmParams<T>> {
public:
GemmTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
Expand All @@ -215,6 +214,8 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
}
#endif

this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
}

std::string Signature() override {
Expand All @@ -223,7 +224,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
};

template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>> {
public:
GemmAndBiasTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
Expand All @@ -241,6 +242,8 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
}
}
#endif

this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
}

std::string Signature() override {
Expand All @@ -249,7 +252,7 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
};

template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
public:
GemmStridedBatchedTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
Expand All @@ -274,6 +277,8 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
}
#endif

this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
}

std::string Signature() override {
Expand All @@ -282,7 +287,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
};

template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>> {
public:
ScaledGemmTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
Expand All @@ -292,6 +297,8 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
this->RegisterOp(std::move(name), std::move(op));
}
#endif

this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
}

std::string Signature() override {
Expand Down
130 changes: 120 additions & 10 deletions aten/src/ATen/cuda/tunable/TunableOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#pragma once

#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/StreamTimer.h>
#include <ATen/cuda/Sleep.h>
#include <c10/cuda/CUDACachingAllocator.h>

Expand Down Expand Up @@ -38,7 +39,57 @@ class Callable {
}
};

template <typename ParamsT, typename TimerT>
namespace {

/** http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance */

class Stats {
public:
Stats() {
_n = 0UL;
_mean = 0.0;
_M2 = 0.0;
_sum = 0.0;
_min = 0.0;
_max = 0.0;
}

void sample_value(const double x) {
double delta = 0;
_sum = _sum + x;
if (0UL == _n) {
_min = x;
_max = x;
}
else {
_min = _min < x ? _min : x;
_max = _max > x ? _max : x;
}
_n = _n + 1UL;
delta = x - _mean;
_mean = _mean + delta/_n;
_M2 = _M2 + delta * (x - _mean);
}

double variance() const {
return _M2/(_n-1);
}

double stddev() const {
return std::sqrt(variance());
}

unsigned long _n;
double _mean;
double _M2;
double _sum;
double _min;
double _max;
};

} // anonymous namespace

template <typename ParamsT>
class TunableOp {
public:
TunableOp() = default;
Expand Down Expand Up @@ -99,10 +150,17 @@ class TunableOp {
}
}

static double Profile(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
static double ProfileSimple(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
TuningContext* ctx = getTuningContext();
bool do_flush = ctx->IsICacheFlushEnabled();
TimerT timer{};
StreamTimerNoSync timer{};

// Small Mandatory Warmup
// Reduces outliers
for (size_t i = 0; i < 2; i++) {
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
}

timer.Start();
for (size_t i = 0; i < num_iter; i++) {
if (do_flush) {
Expand All @@ -114,6 +172,32 @@ class TunableOp {
return timer.Duration() / num_iter;
}

static Stats ProfileStats(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
TuningContext* ctx = getTuningContext();
bool do_flush = ctx->IsICacheFlushEnabled();
std::vector<StreamTimerNoSync> timer(num_iter);

// Small Mandatory Warmup
// Reduces outliers
for (size_t i = 0; i < 2; i++) {
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
}

for (size_t i = 0; i < num_iter; i++) {
timer[i].Start();
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
timer[i].End();
if (do_flush) {
at::cuda::flush_icache();
}
}
Stats s;
for (size_t i = 0; i < num_iter; i++) {
s.sample_value(timer[i].Duration());
}
return s;
}

protected:
virtual ResultEntry FindFastest(const ParamsT* params) {
TuningContext* ctx = getTuningContext();
Expand Down Expand Up @@ -183,14 +267,25 @@ class TunableOp {
}

// collect a small profile
constexpr const int approx_num_iter = 3;
auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset);
int approx_num_iter = 3;
auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
double approx_duration = s._mean;
// bail if too slow
if (approx_duration > 2 * min_duration_ms) {
if (approx_duration > 1.5 * min_duration_ms) {
TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}

// 2nd phase skip, more aggressive
approx_num_iter = 10;
s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
approx_duration = s._mean;
// bail if too slow
if (approx_duration > 1.15 * min_duration_ms) {
TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}

// for warmup does user set max duration, max iters, or both?
// warmup is allowed to be skipped by setting either iterations or duration to 0
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
Expand Down Expand Up @@ -237,12 +332,27 @@ class TunableOp {
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
TUNABLE_LOG3("├──offset at ", offset);
WarmUp(candidate, reusable_params, warmup_iter, offset);
auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset);
if (duration_ms < min_duration_ms) {
TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
min_duration_ms = duration_ms;
s = ProfileStats(candidate, reusable_params, tuning_iter, offset);
auto s_stddev = s.stddev();
// Assume normal distribution.
// Solution with smallest mean + 2*sigma will be a better solution?
// if ((s._mean + 2*s_stddev) < (min_duration_ms + 2*min_stddev_ms)) {
if (s._mean < min_duration_ms) {
TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
" min ", s._min,
" max ", s._max,
" mean ", s._mean,
" std ", s_stddev);
min_duration_ms = s._mean;
id_name = op_names_[i];
}
else {
TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
" min ", s._min,
" max ", s._max,
" mean ", s._mean,
" std ", s_stddev);
}
}

for (size_t i = 0; i < reusable_params.size(); i++) {
Expand Down