diff --git a/paddle/fluid/operators/conv_base_helper.h b/paddle/fluid/operators/conv_base_helper.h index 8425dcb521ab6..b52936c197218 100644 --- a/paddle/fluid/operators/conv_base_helper.h +++ b/paddle/fluid/operators/conv_base_helper.h @@ -44,7 +44,13 @@ struct SearchAlgorithm {}; template struct SearchResult { SearchResult() {} + explicit SearchResult(const phi::autotune::DnnNode& node) + : algo(static_cast(node.algo)), + workspace_size(node.workspace_size) {} + explicit SearchResult(AlgoT a) : algo(a) {} + explicit SearchResult(AlgoT a, float t, size_t size) + : algo(a), time(t), workspace_size(size) {} AlgoT algo = static_cast(0); float time = -1.f; @@ -76,28 +82,50 @@ struct ConvArgsBase { // dilations std::vector d; + // groups + int group; + + // data foramt + DataLayout data_layout; + ConvArgsBase(const framework::Tensor* x, const framework::Tensor* w, const framework::Tensor* o, const std::vector s, const std::vector p, const std::vector d, - DataT dtype) - : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} + DataT dtype, + int g, + DataLayout layout) + : x(x), + w(w), + o(o), + s(s), + p(p), + d(d), + cudnn_dtype(dtype), + group(g), + data_layout(layout) {} template - size_t GetCacheKey() const { + phi::autotune::ConvCacheKey Convert2ConvCacheKey() const { auto x_shape = phi::vectorize(x->dims()); auto w_shape = phi::vectorize(w->dims()); VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape - << ", strides=" << s << ", paddings=" << p << ", dilations=" << d; - return phi::autotune::ConvKey( + << ", strides=" << s << ", paddings=" << p << ", dilations=" << d + << ",data= " << paddle::experimental::CppTypeToDataType::Type() + << ", group=" << group + << ", data layout=" << static_cast(data_layout); + + return phi::autotune::ConvCacheKey( x_shape, w_shape, p, s, d, - paddle::experimental::CppTypeToDataType::Type()); + paddle::experimental::CppTypeToDataType::Type(), + group, + static_cast(data_layout)); } }; diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 1b8d421d133f1..e6fcf2be286ec 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -191,32 +191,36 @@ struct SearchAlgorithm { SetConvMathType(ctx, dtype, args.cdesc); if (deterministic) { - result = FindAlgoDeterministic(); + result = FindAlgoDeterministic(args); } else { // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. // 2. Once turning on auto-tune, runn heuristic search(default) before // auto-tune process, run exhaustive_search during mentioned process. // 3. After auto-tune process, run cached algorithm if cached, run // default mode for the rest. - size_t key = args.GetCacheKey(); + auto key = args.Convert2ConvCacheKey(); auto& cache = phi::autotune::AutoTuneCache::Instance().GetConvForward(); if (cache.Find(key)) { - result.algo = static_cast(cache.Get(key)); + auto t = cache.Get(key); + result.algo = static_cast(t.algo); + result.workspace_size = t.workspace_size; } else { bool use_autotune = phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); if (exhaustive_search || use_autotune) { result = FindAlgoExhaustiveSearch(args, ctx); - cache.Set(key, static_cast(result.algo)); } else { result = FindAlgoHeuristic(args, ctx); } + phi::autotune::DnnNode node(static_cast(result.algo), + result.workspace_size); + cache.Set(key, node); } } VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search << ", deterministic=" << deterministic - << ", choose algo=" << result.algo << ", workspace=" - << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; + << ", choose algo=" << result.algo + << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB"; return result; } @@ -236,8 +240,9 @@ struct SearchAlgorithm { } private: - static SearchResult FindAlgoDeterministic() { - return SearchResult(static_cast(1)); + static SearchResult FindAlgoDeterministic(const ConvArgs& args) { + auto workspace_size = GetWorkspaceSize(args, static_cast(1)); + return SearchResult(static_cast(1), -1.0, workspace_size); } // Heuristic search mode, calling the cudnnGetXxxAlgorithm. @@ -298,6 +303,7 @@ struct SearchAlgorithm { workspace_size_limit, &(result.algo))); #endif + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -343,6 +349,7 @@ struct SearchAlgorithm { ChooseAlgoByWorkspace( perf_results, workspace_size_limit, &result); + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -394,33 +401,37 @@ struct SearchAlgorithm { SetConvMathType(ctx, dtype, args.cdesc); if (deterministic) { - result = FindAlgoDeterministic(); + result = FindAlgoDeterministic(args); } else { // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. // 2. Once turning on auto-tune, runn heuristic search(default) before // auto-tune process, run exhaustive_search during mentioned process. // 3. After auto-tune process, run cached algorithm if cached, run // default mode for the rest. - size_t key = args.GetCacheKey(); + auto key = args.Convert2ConvCacheKey(); auto& cache = phi::autotune::AutoTuneCache::Instance().GetConvBackwardData(); if (cache.Find(key)) { - result.algo = static_cast(cache.Get(key)); + auto t = cache.Get(key); + result.algo = static_cast(t.algo); + result.workspace_size = t.workspace_size; } else { bool use_autotune = phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); if (exhaustive_search || use_autotune) { result = FindAlgoExhaustiveSearch(args, ctx); - cache.Set(key, static_cast(result.algo)); } else { result = FindAlgoHeuristic(args, ctx); } + phi::autotune::DnnNode node(static_cast(result.algo), + result.workspace_size); + cache.Set(key, node); } } VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search << ", deterministic=" << deterministic - << ", choose algo=" << result.algo << ", workspace=" - << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; + << ", choose algo=" << result.algo + << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB"; return result; } @@ -440,8 +451,11 @@ struct SearchAlgorithm { } private: - static SearchResult FindAlgoDeterministic() { - return SearchResult(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1); + static SearchResult FindAlgoDeterministic(const ConvArgs& args) { + auto workspace_size = + GetWorkspaceSize(args, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1); + return SearchResult( + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, -1.0, workspace_size); } static SearchResult FindAlgoHeuristic(const ConvArgs& args, @@ -513,7 +527,7 @@ struct SearchAlgorithm { workspace_size_limit, &(result.algo))); #endif - + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -559,6 +573,7 @@ struct SearchAlgorithm { ChooseAlgoByWorkspace( perf_results, workspace_size_limit, &result); + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -609,33 +624,37 @@ struct SearchAlgorithm { SetConvMathType(ctx, dtype, args.cdesc); if (deterministic) { - result = FindAlgoDeterministic(); + result = FindAlgoDeterministic(args); } else { // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. // 2. Once turning on auto-tune, runn heuristic search(default) before // auto-tune process, run exhaustive_search during mentioned process. // 3. After auto-tune process, run cached algorithm if cached, run // default mode for the rest. - size_t key = args.GetCacheKey(); + auto key = args.Convert2ConvCacheKey(); auto& cache = phi::autotune::AutoTuneCache::Instance().GetConvBackwardFilter(); if (cache.Find(key)) { - result.algo = static_cast(cache.Get(key)); + auto t = cache.Get(key); + result.algo = static_cast(t.algo); + result.workspace_size = t.workspace_size; } else { bool use_autotune = phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); if (exhaustive_search || use_autotune) { result = FindAlgoExhaustiveSearch(args, ctx); - cache.Set(key, static_cast(result.algo)); } else { result = FindAlgoHeuristic(args, ctx); } + phi::autotune::DnnNode node(static_cast(result.algo), + result.workspace_size); + cache.Set(key, node); } } VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search << ", deterministic=" << deterministic - << ", choose algo=" << result.algo << ", workspace=" - << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; + << ", choose algo=" << result.algo + << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB"; return result; } @@ -656,8 +675,11 @@ struct SearchAlgorithm { } private: - static SearchResult FindAlgoDeterministic() { - return SearchResult(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1); + static SearchResult FindAlgoDeterministic(const ConvArgs& args) { + auto workspace_size = + GetWorkspaceSize(args, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1); + return SearchResult( + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, -1.0, workspace_size); } static SearchResult FindAlgoHeuristic(const ConvArgs& args, @@ -718,6 +740,7 @@ struct SearchAlgorithm { &(result.algo))); #endif + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -786,6 +809,7 @@ struct SearchAlgorithm { ChooseAlgo(perf_results, workspace_size_limit, &result); } + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 54063e1192a61..28dddc1fbebdd 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -984,6 +984,17 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); */ PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune."); +/** + * Conv Search cache max number related FLAG + * Name: FLAGS_search_cache_max_number + * Since Version: 2.3.0 + * Value Range: int32, default=1000000 + * Example: + */ +PADDLE_DEFINE_EXPORTED_int32(search_cache_max_number, + 1000000, + "search_cache_max_number."); + /** * Preformance related FLAG * Name: einsum_opt diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc index 838f2dd265eb3..ad7a2b134a20c 100644 --- a/paddle/phi/kernels/autotune/cache.cc +++ b/paddle/phi/kernels/autotune/cache.cc @@ -21,21 +21,6 @@ namespace phi { namespace autotune { -// Define the cache key of operator -size_t ConvKey(const std::vector& x_dims, - const std::vector& w_dims, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations, - phi::DataType dtype) { - return GetKey(x_dims, - w_dims, - strides, - paddings, - dilations, - static_cast(dtype)); -} - size_t TransposeKey(const std::vector& x_dims, const std::vector& perm, phi::DataType dtype) { @@ -73,6 +58,19 @@ void AutoTuneCache::UpdateStatus() { cache_hits += v.second.CacheHits(); cache_misses += v.second.CacheMisses(); } + + for (auto& v : cudnn_auto_tune_map_) { + VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width) + << AlgorithmTypeString(v.first) + << " Cache Size: " << v.second.Size() + << " Hits: " << v.second.CacheHits() + << " Misses: " << v.second.CacheMisses() + << " Hit Rate: " << v.second.CacheHitRate(); + size += v.second.Size(); + cache_hits += v.second.CacheHits(); + cache_misses += v.second.CacheMisses(); + } + total_size_ = size; total_cache_hits_ = cache_hits; total_cache_misses_ = cache_misses; diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 1263cf40e567e..aacebc66570cb 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -24,6 +24,8 @@ #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" +DECLARE_int32(search_cache_max_number); + inline void HashCombine(std::size_t* seed) {} // combine hash value @@ -32,6 +34,7 @@ template inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) { std::hash hasher; *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + *seed *= 0x00000100000001B3; HashCombine(seed, rest...); } @@ -41,7 +44,7 @@ namespace std { template struct hash> { std::size_t operator()(std::vector const& vec) const noexcept { - std::size_t seed = 0; + std::size_t seed = 0xcbf29ce484222325; for (auto val : vec) { HashCombine(&seed, val); } @@ -53,6 +56,14 @@ struct hash> { namespace phi { namespace autotune { +struct DnnNode { + DnnNode() {} + explicit DnnNode(int64_t a, size_t size) : algo(a), workspace_size(size) {} + + int64_t algo; + size_t workspace_size = 0; +}; + template size_t GetKey(Args&&... args) { size_t seed = 0; @@ -60,13 +71,130 @@ size_t GetKey(Args&&... args) { return seed; } -// Define the cache key of operator -size_t ConvKey(const std::vector& x_dims, - const std::vector& w_dims, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations, - phi::DataType dtype); +struct ConvCacheKey { + ConvCacheKey() {} + explicit ConvCacheKey(const std::vector& x_dims, + const std::vector& w_dims, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + phi::DataType dtype, + int groups, + int64_t data_layout) + : x_dims_(x_dims), + w_dims_(w_dims), + strides_(strides), + paddings_(paddings), + dilations_(dilations), + dtype_(dtype), + groups_(groups), + data_layout_(data_layout) {} + size_t hash_value() const { + return GetKey(x_dims_, + w_dims_, + strides_, + paddings_, + dilations_, + static_cast(dtype_), + groups_, + data_layout_); + } + std::vector x_dims_; + std::vector w_dims_; + std::vector strides_; + std::vector paddings_; + std::vector dilations_; + phi::DataType dtype_; + int groups_; + int64_t data_layout_; +}; + +struct ConvCacheKeyHash { + size_t operator()(const ConvCacheKey& cache) const { + return cache.hash_value(); + } +}; + +struct ConvCacheKeyEqual { + size_t operator()(const ConvCacheKey& first, + const ConvCacheKey& second) const { + if (first.x_dims_ != second.x_dims_) return false; + if (first.w_dims_ != second.w_dims_) return false; + if (first.strides_ != second.strides_) return false; + if (first.paddings_ != second.paddings_) return false; + if (first.dilations_ != second.dilations_) return false; + if (first.dtype_ != second.dtype_) return false; + if (first.groups_ != second.groups_) return false; + if (first.data_layout_ != second.data_layout_) return false; + + return true; + } +}; + +class CudnnAlgorithmsCacheMap { + public: + CudnnAlgorithmsCacheMap() : cache_mutex_(new std::mutex()) { hash_.clear(); } + + DnnNode Get(const ConvCacheKey& key) { + std::lock_guard lock(*cache_mutex_); + PADDLE_ENFORCE_NE( + hash_.find(key), + hash_.end(), + phi::errors::PreconditionNotMet("The key does not exist.")); + return hash_[key]; + } + + bool Find(const ConvCacheKey& key) { + bool ret = false; + std::lock_guard lock(*cache_mutex_); + if (hash_.find(key) != hash_.end()) { + cache_hits_++; + ret = true; + } else { + cache_misses_++; + } + return ret; + } + + void Clean() { + std::lock_guard lock(*cache_mutex_); + hash_.clear(); + cache_hits_ = 0; + cache_misses_ = 0; + } + + void Set(const ConvCacheKey& key, DnnNode algo) { + std::lock_guard lock(*cache_mutex_); + if (hash_.size() > static_cast(FLAGS_search_cache_max_number)) { + hash_.clear(); + } + hash_[key] = algo; + } + + int64_t CacheMisses() const { return cache_misses_; } + + int64_t CacheHits() const { return cache_hits_; } + + float CacheHitRate() const { + int64_t num_accesses = cache_hits_ + cache_misses_; + float cache_hit_rate = 0.; + if (num_accesses != 0) { + cache_hit_rate = + static_cast(cache_hits_) / static_cast(num_accesses); + } + return cache_hit_rate; + } + + int64_t Size() const { return hash_.size(); } + + private: + std::unordered_map + hash_; + std::shared_ptr cache_mutex_; + + int64_t cache_hits_{0}; + int64_t cache_misses_{0}; +}; size_t TransposeKey(const std::vector& x_dims, const std::vector& perm, @@ -77,7 +205,7 @@ class AlgorithmsCache { public: AlgorithmsCache() : cache_mutex_(new std::mutex()) { hash_.clear(); } - AlgorithmT Get(size_t key) { + AlgorithmT Get(const size_t& key) { std::lock_guard lock(*cache_mutex_); PADDLE_ENFORCE_NE( hash_.find(key), @@ -86,7 +214,7 @@ class AlgorithmsCache { return hash_[key]; } - bool Find(size_t key) { + bool Find(const size_t& key) { bool ret = false; std::lock_guard lock(*cache_mutex_); if (hash_.find(key) != hash_.end()) { @@ -105,7 +233,7 @@ class AlgorithmsCache { cache_misses_ = 0; } - void Set(size_t key, AlgorithmT algo) { + void Set(const size_t& key, AlgorithmT algo) { std::lock_guard lock(*cache_mutex_); hash_[key] = algo; } @@ -143,9 +271,12 @@ enum class AlgorithmType { }; // AlgorithmsConfigKey -> AlgorithmsID +// (todo. hong) use cudnnConvolutionFwdAlgo_t using AlgorithmsCacheMap = AlgorithmsCache; // AlgorithmType -> AlgorithmsCache using AlgorithmsTypeMap = std::unordered_map; +using CudnnAlgorithmsTypeMap = + std::unordered_map; class AutoTuneCache { public: @@ -158,16 +289,19 @@ class AutoTuneCache { return auto_tune_map_[static_cast(algo_type)]; } - AlgorithmsCacheMap& GetConvForward() { - return Get(AlgorithmType::kConvForward); + CudnnAlgorithmsCacheMap& GetConvForward() { + return cudnn_auto_tune_map_[static_cast( + AlgorithmType::kConvForward)]; } - AlgorithmsCacheMap& GetConvBackwardData() { - return Get(AlgorithmType::kConvBackwardData); + CudnnAlgorithmsCacheMap& GetConvBackwardData() { + return cudnn_auto_tune_map_[static_cast( + AlgorithmType::kConvBackwardData)]; } - AlgorithmsCacheMap& GetConvBackwardFilter() { - return Get(AlgorithmType::kConvBackwardFilter); + CudnnAlgorithmsCacheMap& GetConvBackwardFilter() { + return cudnn_auto_tune_map_[static_cast( + AlgorithmType::kConvBackwardFilter)]; } AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); } @@ -176,6 +310,10 @@ class AutoTuneCache { for (auto& v : auto_tune_map_) { v.second.Clean(); } + + for (auto& v : cudnn_auto_tune_map_) { + v.second.Clean(); + } } void UpdateStatus(); @@ -206,14 +344,25 @@ class AutoTuneCache { void Register(const AlgorithmType& algo_type) { std::lock_guard lock(*autotune_cache_mutex_); - int64_t key = static_cast(algo_type); - if (auto_tune_map_.find(key) == auto_tune_map_.end()) { - AlgorithmsCacheMap cache; - auto_tune_map_[key] = cache; + if (algo_type == AlgorithmType::kConvForward || + algo_type == AlgorithmType::kConvBackwardData || + algo_type == AlgorithmType::kConvBackwardFilter) { + int64_t key = static_cast(algo_type); + if (auto_tune_map_.find(key) == auto_tune_map_.end()) { + CudnnAlgorithmsCacheMap cache; + cudnn_auto_tune_map_[key] = cache; + } + } else { + int64_t key = static_cast(algo_type); + if (auto_tune_map_.find(key) == auto_tune_map_.end()) { + AlgorithmsCacheMap cache; + auto_tune_map_[key] = cache; + } } } AlgorithmsTypeMap auto_tune_map_; + CudnnAlgorithmsTypeMap cudnn_auto_tune_map_; std::shared_ptr autotune_cache_mutex_; int64_t total_cache_hits_{0}; int64_t total_cache_misses_{0}; diff --git a/paddle/phi/kernels/autotune/cache_test.cc b/paddle/phi/kernels/autotune/cache_test.cc index 53574c3d0c9ac..29affd45f0f5c 100644 --- a/paddle/phi/kernels/autotune/cache_test.cc +++ b/paddle/phi/kernels/autotune/cache_test.cc @@ -34,20 +34,23 @@ TEST(AlgosCache, AlgosCache) { std::vector dilations = {1, 1}; phi::DataType dtype = paddle::experimental::CppTypeToDataType::Type(); - auto key = phi::autotune::ConvKey( - x_shape, w_shape, paddings, strides, dilations, dtype); + phi::autotune::ConvCacheKey key( + x_shape, w_shape, paddings, strides, dilations, dtype, 0, 0); EXPECT_EQ(cache.Find(key), false); - cache.Set(key, ConvAlgos::GEMMKernel); + phi::autotune::DnnNode node(static_cast(ConvAlgos::GEMMKernel), 0); + cache.Set(key, node); EXPECT_EQ(cache.Size(), 1); EXPECT_EQ(cache.Find(key), true); auto algo = cache.Get(key); - EXPECT_EQ(algo, ConvAlgos::GEMMKernel); + EXPECT_EQ(algo.algo, ConvAlgos::GEMMKernel); x_shape = {4, 128, 128, 3}; - key = phi::autotune::ConvKey( - x_shape, w_shape, paddings, strides, dilations, dtype); - EXPECT_EQ(cache.Find(key), false); - cache.Set(key, ConvAlgos::CuDNNKernel_1); + phi::autotune::ConvCacheKey key1( + x_shape, w_shape, paddings, strides, dilations, dtype, 0, 1); + EXPECT_EQ(cache.Find(key1), false); + phi::autotune::DnnNode node1(static_cast(ConvAlgos::CuDNNKernel_1), + 0); + cache.Set(key1, node1); EXPECT_EQ(cache.Size(), 2); EXPECT_EQ(cache.CacheHits(), 1); EXPECT_EQ(cache.CacheMisses(), 2); diff --git a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu index ef70907b59a61..fb9580427e1f4 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu @@ -254,6 +254,8 @@ void ConvCudnnGradGradKernel( auto dtype = paddle::platform::CudnnDataType::type; auto handle = ctx.cudnn_handle(); + auto layout = paddle::platform::GetCudnnTensorFormat( + paddle::platform::DataLayout::kNCHW); paddle::operators::ConvArgs args1{&transformed_ddX, W, @@ -261,28 +263,36 @@ void ConvCudnnGradGradKernel( strides, padding_common, dilations, - dtype}; + dtype, + groups, + paddle::platform::DataLayout::kNCHW}; paddle::operators::ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + paddle::platform::DataLayout::kNCHW}; paddle::operators::ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + paddle::platform::DataLayout::kNCHW}; paddle::operators::ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + paddle::platform::DataLayout::kNCHW}; #ifdef PADDLE_WITH_HIP paddle::operators::SearchResult fwd_result1; @@ -298,9 +308,6 @@ void ConvCudnnGradGradKernel( filter_result; #endif - auto layout = paddle::platform::GetCudnnTensorFormat( - paddle::platform::DataLayout::kNCHW); - // ddo = conv(ddI, W) + conv(I, ddW) size_t workspace_size = 0; diff --git a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu index 4e9c37879c002..bc7a8b4f37840 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu @@ -251,27 +251,33 @@ void ConvCudnnGradKernel(const Context& ctx, T* input_grad_data = nullptr; T* transformed_input_grad_data = nullptr; + paddle::platform::DataLayout layout = + compute_format == paddle::platform::DataLayout::kNHWC + ? paddle::platform::DataLayout::kNHWC + : paddle::platform::DataLayout::kNCHW; + paddle::operators::ConvArgs args1{&transformed_input_grad, &transformed_filter_channel, &transformed_output_grad_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + layout}; paddle::operators::ConvArgs args2{&transformed_input, &transformed_filter_grad_channel, &transformed_output_grad_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + layout}; auto handle = ctx.cudnn_handle(); // TODO(phlrain): replace paddle::platform::DataLaytout to phi::DataLayout - paddle::platform::DataLayout layout = - compute_format == paddle::platform::DataLayout::kNHWC - ? paddle::platform::DataLayout::kNHWC - : paddle::platform::DataLayout::kNCHW; + if (transformed_input.dims().size() == 5) { layout = compute_format == paddle::platform::DataLayout::kNHWC ? paddle::platform::DataLayout::kNDHWC @@ -368,8 +374,7 @@ void ConvCudnnGradKernel(const Context& ctx, using search1 = paddle::operators::SearchAlgorithm; bwd_result = search1::Find(args1, exhaustive_search, deterministic, ctx); - workspace_size_d = std::max( - workspace_size_d, search1::GetWorkspaceSize(args1, bwd_result.algo)); + workspace_size_d = std::max(workspace_size_d, bwd_result.workspace_size); #endif } @@ -400,8 +405,7 @@ void ConvCudnnGradKernel(const Context& ctx, search2::Find(args2, exhaustive_search, deterministic, ctx); VLOG(3) << "filter algo: " << filter_result.algo << ", time " << filter_result.time; - workspace_size_w = std::max( - workspace_size_w, search2::GetWorkspaceSize(args2, filter_result.algo)); + workspace_size_w = std::max(workspace_size_w, filter_result.workspace_size); #endif } diff --git a/paddle/phi/kernels/gpudnn/conv_kernel.cu b/paddle/phi/kernels/gpudnn/conv_kernel.cu index bd95a32bc724f..aa591a34a4399 100644 --- a/paddle/phi/kernels/gpudnn/conv_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_kernel.cu @@ -213,7 +213,9 @@ void ConvCudnnKernel(const Context& ctx, strides, padding_common, dilations, - dtype}; + dtype, + groups, + compute_format}; auto handle = ctx.cudnn_handle(); auto workspace_handle = ctx.cudnn_workspace_handle(); @@ -314,7 +316,7 @@ void ConvCudnnKernel(const Context& ctx, using search = paddle::operators::SearchAlgorithm; fwd_result = search::Find(args, exhaustive_search, deterministic, ctx); - workspace_size = search::GetWorkspaceSize(args, fwd_result.algo); + workspace_size = fwd_result.workspace_size; #endif #if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1) diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu index 0ce16f66becfa..626ef79d56483 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu @@ -179,14 +179,18 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + layout}; paddle::operators::ConvArgs args2{&transformed_dout, &filter, &x_transpose, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + layout}; #ifdef PADDLE_WITH_HIP paddle::operators::SearchResult fwd_result; @@ -625,6 +629,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( auto dtype = paddle::platform::CudnnDataType::type; auto handle = ctx.cudnn_handle(); + auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW); paddle::operators::ConvArgs args1{&transformed_ddout_channel, &filter, @@ -632,14 +637,18 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( strides, padding_common, dilations_, - dtype}; + dtype, + groups, + GPUDNNDataLayout::kNCHW}; paddle::operators::ConvArgs args2{&transformed_ddout_channel, &ddfilter, &transformed_x, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + GPUDNNDataLayout::kNCHW}; paddle::operators::ConvArgs args3{&transformed_dout, dfilter, @@ -647,14 +656,18 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( strides, padding_common, dilations_, - dtype}; + dtype, + groups, + GPUDNNDataLayout::kNCHW}; paddle::operators::ConvArgs args4{&transformed_dout, &ddfilter, &transformed_dx_channel, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + GPUDNNDataLayout::kNCHW}; #ifdef PADDLE_WITH_HIP paddle::operators::SearchResult bwd_result1; paddle::operators::SearchResult bwd_result2; @@ -669,8 +682,6 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( paddle::operators::SearchResult fwd_result; #endif - auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW); - // ddo = conv(ddI, filter) + conv(I, ddfilter) size_t workspace_size = 0; diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu index 58ead4c3287f8..2289541da1659 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu @@ -205,7 +205,9 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + data_layout}; args.handle = handle; args.idesc.set(transformed_out, iwo_groups); args.wdesc.set(filter, layout_tensor, iwo_groups); diff --git a/python/paddle/fluid/tests/unittests/test_switch_autotune.py b/python/paddle/fluid/tests/unittests/test_switch_autotune.py index a22df61ace8c7..3daf8752d8c08 100644 --- a/python/paddle/fluid/tests/unittests/test_switch_autotune.py +++ b/python/paddle/fluid/tests/unittests/test_switch_autotune.py @@ -71,11 +71,8 @@ def get_expected_res(self, step_id, enable_autotune): } if paddle.is_compiled_with_cuda(): # Total 3 * num_iters cache accesses, only iter 2 hits the cache. - if enable_autotune and step_id >= 1: - expected_res["cache_size"] = 3 - if enable_autotune and step_id == 2: - expected_res["cache_hit_rate"] = np.round( - float(3) / float(9), 5) + expected_res["cache_size"] = 3 + expected_res["cache_hit_rate"] = (step_id + 0.0) / (step_id + 1.0) return expected_res def test_autotune(self): @@ -91,11 +88,11 @@ def test_autotune(self): def check_status(self, expected_res): status = paddle.fluid.core.autotune_status() for key in status.keys(): + v = status[key] if key == "cache_hit_rate": - v = np.round(status[key], 5) + self.assertTrue(np.allclose(v, expected_res[key])) else: - v = status[key] - self.assertEqual(v, expected_res[key]) + self.assertEqual(v, expected_res[key]) class TestDygraphAutoTuneStatus(TestAutoTune):