@@ -713,22 +713,28 @@ class PoolingDescriptor {
713713class AlgorithmDesc {
714714 public:
715715 typedef int64 Index;
716- AlgorithmDesc () : algo_(kDefaultAlgorithm ), tensor_ops_enabled_(true ) {}
716+ AlgorithmDesc () : algo_(kDefaultAlgorithm ), tensor_ops_enabled_(true ), scratch_size_( 0 ) {}
717717 AlgorithmDesc (Index a, bool use_tensor_ops)
718- : algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
718+ : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0 ) {}
719+ AlgorithmDesc (Index a, bool use_tensor_ops, size_t scratch_size)
720+ : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(scratch_size) {}
719721 bool is_default () const { return algo_ == kDefaultAlgorithm ; }
720722 bool tensor_ops_enabled () const { return tensor_ops_enabled_; }
721723 Index algo_id () const { return algo_; }
724+ size_t scratch_size () const { return scratch_size_; }
725+ void set_scratch_size (size_t val) { scratch_size_ = val; }
722726 bool operator ==(const AlgorithmDesc& other) const {
723727 return this ->algo_ == other.algo_ &&
724- this ->tensor_ops_enabled_ == other.tensor_ops_enabled_ ;
728+ this ->tensor_ops_enabled_ == other.tensor_ops_enabled_ &&
729+ this ->scratch_size_ == other.scratch_size_ ;
725730 }
726731 uint64 hash () const ;
727732
728733 private:
729734 enum { kDefaultAlgorithm = -1 };
730735 Index algo_;
731736 bool tensor_ops_enabled_;
737+ size_t scratch_size_;
732738};
733739
734740// Describes the result from a perf experiment.
@@ -743,15 +749,12 @@ class ProfileResult {
743749 elapsed_time_in_ms_ != std::numeric_limits<float >::max ());
744750 }
745751 AlgorithmDesc algorithm () const { return algorithm_; }
746- size_t scratch_size () const { return scratch_size_; }
747752 void set_algorithm (AlgorithmDesc val) { algorithm_ = val; }
748- void set_scratch_size (size_t val) { scratch_size_ = val; }
749753 float elapsed_time_in_ms () const { return elapsed_time_in_ms_; }
750754 void set_elapsed_time_in_ms (float val) { elapsed_time_in_ms_ = val; }
751755
752756 private:
753757 AlgorithmDesc algorithm_;
754- size_t scratch_size_ = 0 ;
755758 float elapsed_time_in_ms_ = std::numeric_limits<float >::max();
756759};
757760
@@ -764,18 +767,9 @@ class ProfileResult {
764767class AlgorithmConfig {
765768 public:
766769 AlgorithmConfig ()
767- : algorithm_(),
768- algorithm_no_scratch_ (),
769- algorithm_scratch_size_(0 ) {}
770+ : algorithm_(), algorithm_no_scratch_() {}
770771 explicit AlgorithmConfig (AlgorithmDesc algorithm)
771- : algorithm_(algorithm),
772- algorithm_no_scratch_(),
773- algorithm_scratch_size_(0 ) {}
774- AlgorithmConfig (AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch,
775- size_t algorithm_scratch_size = 0 )
776- : algorithm_(algorithm),
777- algorithm_no_scratch_(algorithm_no_scratch),
778- algorithm_scratch_size_(0 ) {}
772+ : algorithm_(algorithm), algorithm_no_scratch_() {}
779773 AlgorithmDesc algorithm () const { return algorithm_; }
780774 void set_algorithm (AlgorithmDesc val) { algorithm_ = val; }
781775 AlgorithmDesc algorithm_no_scratch () const { return algorithm_no_scratch_; }
@@ -790,17 +784,10 @@ class AlgorithmConfig {
790784 return !(*this == other);
791785 }
792786 string ToString () const ;
793- size_t algorithm_scratch_size () const {
794- return algorithm_scratch_size_;
795- }
796- void set_algorithm_scratch_size (size_t val) {
797- algorithm_scratch_size_ = val;
798- }
799787
800788 private:
801789 AlgorithmDesc algorithm_;
802790 AlgorithmDesc algorithm_no_scratch_;
803- size_t algorithm_scratch_size_;
804791};
805792
806793// Describes a local response normalization (LRN). LRN is used e.g. in
0 commit comments