Skip to content

Commit 008b9db

Browse files
authored
Merge pull request #71 from ROCmSoftwarePlatform/refactor-algorithmconfig-profileresult
Refactor AlgorithmConfig / ProfileResult and put scratch size in AlgorithmDesc
2 parents f42b528 + fc45b89 commit 008b9db

File tree

5 files changed

+22
-38
lines changed

5 files changed

+22
-38
lines changed

tensorflow/core/kernels/conv_grad_filter_ops.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,6 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
987987
!profile_result.algorithm().is_default(),
988988
errors::NotFound("Failed to find backward filter algorithm!"));
989989
algorithm_config.set_algorithm(profile_result.algorithm());
990-
algorithm_config.set_algorithm_scratch_size(profile_result.scratch_size());
991990
algorithm_config.set_algorithm_no_scratch(profile_result.algorithm());
992991
#endif
993992
AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,

tensorflow/core/kernels/conv_grad_input_ops.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,6 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
10371037

10381038

10391039
algorithm_config.set_algorithm(profile_result.algorithm());
1040-
algorithm_config.set_algorithm_scratch_size(profile_result.scratch_size());
10411040
// TODO - Add support for no-scratch algorithm
10421041
algorithm_config.set_algorithm_no_scratch(AlgorithmDesc());
10431042
#endif

tensorflow/core/kernels/conv_ops.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,6 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
792792
errors::NotFound("Failed to find conv algorithm!"));
793793

794794
algorithm_config.set_algorithm(profile_result.algorithm());
795-
algorithm_config.set_algorithm_scratch_size(profile_result.scratch_size());
796795
// TODO - Add support for no-scratch algorithm
797796
algorithm_config.set_algorithm_no_scratch(AlgorithmDesc());
798797
#endif

tensorflow/stream_executor/dnn.h

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -713,22 +713,28 @@ class PoolingDescriptor {
713713
class AlgorithmDesc {
714714
public:
715715
typedef int64 Index;
716-
AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {}
716+
AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {}
717717
AlgorithmDesc(Index a, bool use_tensor_ops)
718-
: algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
718+
: algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {}
719+
AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size)
720+
: algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(scratch_size) {}
719721
bool is_default() const { return algo_ == kDefaultAlgorithm; }
720722
bool tensor_ops_enabled() const { return tensor_ops_enabled_; }
721723
Index algo_id() const { return algo_; }
724+
size_t scratch_size() const { return scratch_size_; }
725+
void set_scratch_size(size_t val) { scratch_size_ = val; }
722726
bool operator==(const AlgorithmDesc& other) const {
723727
return this->algo_ == other.algo_ &&
724-
this->tensor_ops_enabled_ == other.tensor_ops_enabled_;
728+
this->tensor_ops_enabled_ == other.tensor_ops_enabled_ &&
729+
this->scratch_size_ == other.scratch_size_;
725730
}
726731
uint64 hash() const;
727732

728733
private:
729734
enum { kDefaultAlgorithm = -1 };
730735
Index algo_;
731736
bool tensor_ops_enabled_;
737+
size_t scratch_size_;
732738
};
733739

734740
// Describes the result from a perf experiment.
@@ -743,15 +749,12 @@ class ProfileResult {
743749
elapsed_time_in_ms_ != std::numeric_limits<float>::max());
744750
}
745751
AlgorithmDesc algorithm() const { return algorithm_; }
746-
size_t scratch_size() const { return scratch_size_; }
747752
void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
748-
void set_scratch_size(size_t val) { scratch_size_ = val; }
749753
float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
750754
void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
751755

752756
private:
753757
AlgorithmDesc algorithm_;
754-
size_t scratch_size_ = 0;
755758
float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
756759
};
757760

@@ -764,18 +767,9 @@ class ProfileResult {
764767
class AlgorithmConfig {
765768
public:
766769
AlgorithmConfig()
767-
: algorithm_(),
768-
algorithm_no_scratch_(),
769-
algorithm_scratch_size_(0) {}
770+
: algorithm_(), algorithm_no_scratch_() {}
770771
explicit AlgorithmConfig(AlgorithmDesc algorithm)
771-
: algorithm_(algorithm),
772-
algorithm_no_scratch_(),
773-
algorithm_scratch_size_(0) {}
774-
AlgorithmConfig(AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch,
775-
size_t algorithm_scratch_size = 0)
776-
: algorithm_(algorithm),
777-
algorithm_no_scratch_(algorithm_no_scratch),
778-
algorithm_scratch_size_(0) {}
772+
: algorithm_(algorithm), algorithm_no_scratch_() {}
779773
AlgorithmDesc algorithm() const { return algorithm_; }
780774
void set_algorithm(AlgorithmDesc val) { algorithm_ = val; }
781775
AlgorithmDesc algorithm_no_scratch() const { return algorithm_no_scratch_; }
@@ -790,17 +784,10 @@ class AlgorithmConfig {
790784
return !(*this == other);
791785
}
792786
string ToString() const;
793-
size_t algorithm_scratch_size() const {
794-
return algorithm_scratch_size_;
795-
}
796-
void set_algorithm_scratch_size(size_t val) {
797-
algorithm_scratch_size_ = val;
798-
}
799787

800788
private:
801789
AlgorithmDesc algorithm_;
802790
AlgorithmDesc algorithm_no_scratch_;
803-
size_t algorithm_scratch_size_;
804791
};
805792

806793
// Describes a local response normalization (LRN). LRN is used e.g. in

tensorflow/stream_executor/rocm/rocm_dnn.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,8 +1746,9 @@ bool MIOpenSupport::DoConvolveImpl(
17461746

17471747
} else {
17481748
// An algorithm has been specified.
1749-
algo_sz.first = ToConvForwardAlgo(algorithm_config.algorithm());
1750-
algo_sz.second = algorithm_config.algorithm_scratch_size();
1749+
dnn::AlgorithmDesc algo = algorithm_config.algorithm();
1750+
algo_sz.first = ToConvForwardAlgo(algo);
1751+
algo_sz.second = algo.scratch_size();
17511752

17521753
size_t size_in_bytes = algo_sz.second;
17531754
if (size_in_bytes != 0) {
@@ -1801,9 +1802,8 @@ bool MIOpenSupport::DoConvolveImpl(
18011802
return false;
18021803
}
18031804
if (status == miopenStatusSuccess) {
1804-
dnn::AlgorithmDesc algotype(algo_sz.first, false);
1805+
dnn::AlgorithmDesc algotype(algo_sz.first, false, algo_sz.second);
18051806
output_profile_result->set_algorithm(algotype);
1806-
output_profile_result->set_scratch_size(algo_sz.second);
18071807
output_profile_result->set_elapsed_time_in_ms(
18081808
timer->GetElapsedMilliseconds());
18091809
}
@@ -2307,8 +2307,9 @@ bool MIOpenSupport::DoConvolveBackwardDataImpl(
23072307

23082308
} else {
23092309
// An algorithm has been specified.
2310-
algo_sz.first = ToConvBackwardDataAlgo(algorithm_config.algorithm());
2311-
algo_sz.second = algorithm_config.algorithm_scratch_size();
2310+
dnn::AlgorithmDesc algo = algorithm_config.algorithm();
2311+
algo_sz.first = ToConvBackwardDataAlgo(algo);
2312+
algo_sz.second = algo.scratch_size();
23122313

23132314
size_t size_in_bytes = algo_sz.second;
23142315
if (size_in_bytes != 0) {
@@ -2362,9 +2363,8 @@ bool MIOpenSupport::DoConvolveBackwardDataImpl(
23622363
if (is_profiling) {
23632364
timer->Stop(AsROCMStream(stream));
23642365
if (status == miopenStatusSuccess) {
2365-
dnn::AlgorithmDesc algotype(algo_sz.first, false);
2366+
dnn::AlgorithmDesc algotype(algo_sz.first, false, algo_sz.second);
23662367
output_profile_result->set_algorithm(algotype);
2367-
output_profile_result->set_scratch_size(algo_sz.second);
23682368
output_profile_result->set_elapsed_time_in_ms(
23692369
timer->GetElapsedMilliseconds());
23702370
}
@@ -2530,8 +2530,9 @@ bool MIOpenSupport::DoConvolveBackwardFilterImpl(
25302530

25312531
} else {
25322532
// An algorithm has been specified.
2533+
dnn::AlgorithmDesc algo = algorithm_config.algorithm();
25332534
algo_sz.first = ToConvBackwardFilterAlgo(algorithm_config.algorithm());
2534-
algo_sz.second = algorithm_config.algorithm_scratch_size();
2535+
algo_sz.second = algo.scratch_size();
25352536

25362537
size_t size_in_bytes = algo_sz.second;
25372538

@@ -2585,9 +2586,8 @@ bool MIOpenSupport::DoConvolveBackwardFilterImpl(
25852586
if (is_profiling) {
25862587
timer->Stop(AsROCMStream(stream));
25872588
if (status == miopenStatusSuccess) {
2588-
dnn::AlgorithmDesc algotype(algo_sz.first, false);
2589+
dnn::AlgorithmDesc algotype(algo_sz.first, false, algo_sz.second);
25892590
output_profile_result->set_algorithm(algotype);
2590-
output_profile_result->set_scratch_size(algo_sz.second);
25912591
output_profile_result->set_elapsed_time_in_ms(
25922592
timer->GetElapsedMilliseconds());
25932593
}

0 commit comments

Comments
 (0)