diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index af224cb5be8ab..314f4b343d481 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -130,6 +130,7 @@ 'fused_dot_product_attention', 'nce', 'lars_momentum', + 'max_pool2d_v2', 'recv_v2', 'rnn_', 'row_conv', diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index 9551bfc425ebc..4a4c4707ac9a8 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -29,4 +29,5 @@ 'fused_rotary_position_embedding', 'fused_bias_dropout_residual_layer_norm', 'fused_dot_product_attention', + 'max_pool2d_v2', ] diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 53df4c25034ab..32e9ffd3a5c63 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -254,6 +254,14 @@ bool IsCompiledWithCUDA() { #endif } +bool IsCompiledWithCudnnFrontend() { +#ifndef PADDLE_WITH_CUDNN_FRONTEND + return false; +#else + return true; +#endif +} + bool IsCompiledWithDISTRIBUTE() { #if !defined(PADDLE_WITH_DISTRIBUTE) return false; @@ -2124,6 +2132,7 @@ All parameter, weight, gradient are variables in Paddle. }); m.def("is_compiled_with_avx", IsCompiledWithAVX); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); + m.def("is_compiled_with_cudnn_frontend", IsCompiledWithCudnnFrontend); m.def("is_compiled_with_rocm", IsCompiledWithROCM); m.def("is_compiled_with_custom_device", IsCompiledWithCustomDevice); m.def("is_compiled_with_ipu", IsCompiledWithIPU); diff --git a/paddle/phi/api/yaml/fused_backward.yaml b/paddle/phi/api/yaml/fused_backward.yaml index 649e427b25a34..8a2a9786a837a 100644 --- a/paddle/phi/api/yaml/fused_backward.yaml +++ b/paddle/phi/api/yaml/fused_backward.yaml @@ -51,3 +51,14 @@ func : fused_rotary_position_embedding_grad data_type : out_q_grad support_dygraph_mode : true + +- backward_op : max_pool2d_v2_grad + forward : max_pool2d_v2(Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, str data_format = "NCHW", bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(saved_idx) + args : (Tensor x, Tensor out, Tensor saved_idx, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, str data_format, bool global_pooling, bool adaptive) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : max_pool2d_v2_grad + param: [x, out, saved_idx, out_grad, kernel_size, strides, paddings, data_format, global_pooling, adaptive] diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 1b429fc958de7..235ddaaacc694 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -383,6 +383,21 @@ func : layer_norm_act_xpu data_type : x +# This op is implemented using CUDNN Frontend API, which serves as a supplement to +# legacy max pooling implementation. It shows better performance with NHWC layout and +# half precision. +- op : max_pool2d_v2 + args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, str data_format = "NCHW", bool global_pooling = false, bool adaptive = false) + output : Tensor(out), Tensor(saved_idx) + infer_meta : + func : MaxPoolV2InferMeta + param : [x, kernel_size, strides, paddings, data_format, global_pooling, adaptive] + kernel : + func : max_pool2d_v2 + param : [x, kernel_size, strides, paddings, data_format, global_pooling, adaptive] + intermediate: saved_idx + backward : max_pool2d_v2_grad + - op : multi_encoder_xpu args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx) output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index b1b06fdbfed71..39cec09e3db86 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2349,6 +2349,37 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, mask->set_dtype(phi::CppTypeToDataType::Type()); } +void MaxPoolV2InferMeta(const MetaTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + MetaTensor* out, + MetaTensor* saved_idx, + MetaConfig config) { + PADDLE_ENFORCE_EQ(adaptive, + false, + phi::errors::InvalidArgument( + "max_pool2d_v2 op does not support adaptive.")); + Pool2DInferMeta(x, + kernel_size, + strides, + paddings, + false, + false, + data_format, + "max", + global_pooling, + adaptive, + "EXPLICIT", + out, + config); + saved_idx->set_dims(out->dims()); + saved_idx->set_dtype(phi::CppTypeToDataType::Type()); +} + void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_dims(common::make_ddim({})); out->set_dtype(x.dtype()); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 0126b76754fef..aaa85da1f8524 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -350,6 +350,17 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, MetaTensor* mask, MetaConfig config = MetaConfig()); +void MaxPoolV2InferMeta(const MetaTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + MetaTensor* out, + MetaTensor* saved_idx, + MetaConfig config = MetaConfig()); + void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out); void ModeInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index f6ed266577bac..c0ef08cb0b5ef 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -222,7 +222,9 @@ if(NOT WITH_CUDNN_FRONTEND) "fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu" "fusion/gpu/fused_scale_bias_add_relu_kernel.cu" "fusion/gpu/fused_dconv_drelu_dbn_kernel.cu" - "fusion/gpu/fused_dot_product_attention_op.cu") + "fusion/gpu/fused_dot_product_attention_op.cu" + "fusion/gpu/max_pool2d_v2_grad_kernel.cu" + "fusion/gpu/max_pool2d_v2_kernel.cu") endif() set(cc_search_pattern diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index fcb9058cd0a76..0554ab526d5ee 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -61,7 +61,9 @@ enum class AlgorithmType { kDgradDreluBnBwdWeight = 16, kDbnApply = 17, kBnActWgrad = 18, - kAlgorithmCount = 19 + kPoolingForwardV8 = 19, + kPoolingBackwardV8 = 20, + kAlgorithmCount = 21 #endif }; diff --git a/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_grad_kernel.cu new file mode 100644 index 0000000000000..9cd45357f0bfb --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_grad_kernel.cu @@ -0,0 +1,255 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" +#include "paddle/phi/kernels/gpudnn/pool_gpudnn.h" + +PHI_DECLARE_bool(cudnn_exhaustive_search); + +namespace phi { + +template +void MaxPoolV2GradCUDNNKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& saved_idx, + const DenseTensor& dout, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + DenseTensor* dx) { + PADDLE_ENFORCE_GE(ctx.GetComputeCapability(), + 80, + phi::errors::PreconditionNotMet( + "This op only supports Ampere and later devices, " + "but got compute capability: %d.", + ctx.GetComputeCapability())); + // Additional options + bool exhaustive_search = FLAGS_cudnn_exhaustive_search; + bool deterministic = FLAGS_cudnn_deterministic; + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, + false, + phi::errors::InvalidArgument( + "Can't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + // Allocate output tensors + ctx.template Alloc(dx); + // Update paddings + std::vector paddings_ = paddings; + std::vector kernel_size_ = kernel_size; + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + PADDLE_ENFORCE_EQ( + channel_last, + true, + phi::errors::InvalidArgument( + "NCHW layout is currently not supported for max pooling bwd.")); + const std::string padding_algorithm = "EXPLICIT"; + + auto x_dims = x.dims(); + DDim data_dims; + if (channel_last) { + data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1); + } else { + data_dims = slice_ddim(x_dims, 2, x_dims.size()); + } + funcs::UpdatePadding(&paddings_, + global_pooling, + adaptive, + padding_algorithm, + data_dims, + strides, + kernel_size_); + + const auto data_dim = data_dims.size(); + std::vector pre_padding(data_dim, 0); + std::vector post_padding(data_dim, 0); + for (size_t i = 0; i < data_dim; ++i) { + pre_padding[i] = static_cast(paddings_[2 * i]); + post_padding[i] = static_cast(paddings_[2 * i + 1]); + } + + if (global_pooling) { + funcs::UpdateKernelSize(&kernel_size_, data_dims); + } + + using helper = CudnnFrontendConvHelper; + auto kernel_size_int64 = helper::GetInt64Array(kernel_size_); + auto strides_int64 = helper::GetInt64Array(strides); + + // Create tensor descriptors + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kPoolingBackwardV8); + + T2* saved_idx_data = const_cast(saved_idx.data()); + T1* dout_data = const_cast(dout.data()); + T1* dx_data = dx->data(); + + auto uid = [](std::string name) { + const std::map _uid = { + {"saved_idx", 0}, {"dout", 1}, {"dx", 2}}; + PADDLE_ENFORCE_GT(_uid.count(name), + 0, + phi::errors::InvalidArgument( + "The tensor name %s is unknown. " + "Should be in one of [saved_idx, dout, dx].", + name)); + return _uid.at(name); + }; + + cudnnHandle_t handle = const_cast(ctx.cudnn_handle()); + auto workspace_handle = ctx.cudnn_workspace_handle(); + + auto layout = GetLayoutFromStr(data_format); + auto layout_format = phi::backends::gpu::GetCudnnTensorFormat(layout); + auto input_dtype = phi::backends::gpu::CudnnDataType::type; + auto saved_idx_dtype = CudnnIndexType::type; + + // Create plan and execute + std::vector data_ptrs({saved_idx_data, dout_data, dx_data}); + std::vector uids({uid("saved_idx"), uid("dout"), uid("dx")}); + + // Create feature vector for plan caching + cudnn_frontend::feature_vector_t feature_vector; + auto dim_x = phi::vectorize(x.dims()); + phi::autotune::BuildFeatureVector(&feature_vector, + dim_x, + kernel_size_int64, + strides_int64, + pre_padding, + post_padding, + data_format, + input_dtype, + saved_idx_dtype); + + if (plan_cache.FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache.GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + helper::ExecutePlan(handle, + &workspace_handle, + &data_ptrs, + &uids, + cached_plan->get_raw_desc(), + workspace_size); + return; + } + + auto saved_idx_desc = + helper::GetTensorDescriptor(&saved_idx, uid("saved_idx"), layout_format); + auto dout_desc = + helper::GetTensorDescriptor(&dout, uid("dout"), layout_format); + auto dx_desc = helper::GetTensorDescriptor(dx, uid("dx"), layout_format); + + // Create maxpooling descriptor + auto const nan_opt = CUDNN_NOT_PROPAGATE_NAN; + auto const mode = cudnn_frontend::cudnnResampleMode_t::CUDNN_RESAMPLE_MAXPOOL; + auto const padding_mode = + cudnn_frontend::cudnnPaddingMode_t::CUDNN_NEG_INF_PAD; + auto pool_desc = cudnn_frontend::ResampleDescBuilder_v8() + .setComputeType(CUDNN_DATA_FLOAT) + .setNanPropagation(nan_opt) + .setResampleMode(mode) + .setPaddingMode(padding_mode) + .setSpatialDim(data_dim, kernel_size_int64.data()) + .setSpatialStride(data_dim, strides_int64.data()) + .setPrePadding(data_dim, pre_padding.data()) + .setPostPadding(data_dim, post_padding.data()) + .build(); + + // Create maxpooling bwd op + auto pool_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESAMPLE_BWD_DESCRIPTOR) + .setdxDesc(dx_desc) + .setdyDesc(dout_desc) + .setidxDesc(saved_idx_desc) + .setResampleDesc(pool_desc) + .build(); + + // Create op graph + std::array ops = {&pool_op}; + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + auto plans = helper::FindExecutionPlans(&op_graph, + exhaustive_search, + deterministic, + &data_ptrs, + &uids, + handle, + &workspace_handle); + + helper::ExecutePlansAndCache(handle, + &workspace_handle, + &data_ptrs, + &uids, + &plans, + exhaustive_search, + feature_vector, + &plan_cache); +} + +template +void MaxPool2dV2GradCUDNNKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& saved_idx, + const DenseTensor& dout, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + DenseTensor* dx) { + MaxPoolV2GradCUDNNKernel(ctx, + x, + out, + saved_idx, + dout, + kernel_size, + strides, + paddings, + data_format, + global_pooling, + adaptive, + dx); +} + +} // namespace phi + +using phi::dtype::float16; + +PD_REGISTER_KERNEL(max_pool2d_v2_grad, // cuda_only + GPU, + ALL_LAYOUT, + phi::MaxPool2dV2GradCUDNNKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(2).SetDataType(phi::CppTypeToDataType::Type()); +} diff --git a/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_kernel.cu b/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_kernel.cu new file mode 100644 index 0000000000000..46cabfe8b2d85 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_kernel.cu @@ -0,0 +1,236 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" +#include "paddle/phi/kernels/gpudnn/pool_gpudnn.h" + +PHI_DECLARE_bool(cudnn_exhaustive_search); + +namespace phi { + +template +void MaxPoolV2CUDNNKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + DenseTensor* out, + DenseTensor* saved_idx) { + PADDLE_ENFORCE_GE(ctx.GetComputeCapability(), + 80, + phi::errors::PreconditionNotMet( + "This op only supports Ampere and later devices, " + "but got compute capability: %d.", + ctx.GetComputeCapability())); + // Additional options + bool exhaustive_search = FLAGS_cudnn_exhaustive_search; + bool deterministic = FLAGS_cudnn_deterministic; + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, + false, + phi::errors::InvalidArgument( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + // Allocate output tensors + ctx.template Alloc(out); + ctx.template Alloc(saved_idx); + // Update paddings + std::vector paddings_ = paddings; + std::vector kernel_size_ = kernel_size; + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + const std::string padding_algorithm = "EXPLICIT"; + + auto x_dims = x.dims(); + DDim data_dims; + if (channel_last) { + data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1); + } else { + data_dims = slice_ddim(x_dims, 2, x_dims.size()); + } + funcs::UpdatePadding(&paddings_, + global_pooling, + adaptive, + padding_algorithm, + data_dims, + strides, + kernel_size_); + + const auto data_dim = data_dims.size(); + std::vector pre_padding(data_dim, 0); + std::vector post_padding(data_dim, 0); + for (size_t i = 0; i < data_dim; ++i) { + pre_padding[i] = static_cast(paddings_[2 * i]); + post_padding[i] = static_cast(paddings_[2 * i + 1]); + } + + if (global_pooling) { + funcs::UpdateKernelSize(&kernel_size_, data_dims); + } + + using helper = CudnnFrontendConvHelper; + auto kernel_size_int64 = helper::GetInt64Array(kernel_size_); + auto strides_int64 = helper::GetInt64Array(strides); + + // Prepare for execution + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kPoolingForwardV8); + + T1* input_data = const_cast(x.data()); + T1* output_data = out->data(); + T2* saved_idx_data = saved_idx->data(); + + cudnnHandle_t handle = const_cast(ctx.cudnn_handle()); + auto workspace_handle = ctx.cudnn_workspace_handle(); + + auto layout = GetLayoutFromStr(data_format); + auto layout_format = phi::backends::gpu::GetCudnnTensorFormat(layout); + auto input_dtype = phi::backends::gpu::CudnnDataType::type; + auto saved_idx_dtype = CudnnIndexType::type; + + // Create plan and execute + std::vector data_ptrs({input_data, output_data, saved_idx_data}); + std::vector uids({'x', 'o', 's'}); + + // Create feature vector for plan caching + cudnn_frontend::feature_vector_t feature_vector; + auto dim_x = phi::vectorize(x.dims()); + + phi::autotune::BuildFeatureVector(&feature_vector, + dim_x, + kernel_size_int64, + strides_int64, + pre_padding, + post_padding, + data_format, + input_dtype, + saved_idx_dtype); + + // Query cache and execute + if (plan_cache.FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache.GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + helper::ExecutePlan(handle, + &workspace_handle, + &data_ptrs, + &uids, + cached_plan->get_raw_desc(), + workspace_size); + return; + } + + // Create tensor descriptors + auto x_desc = helper::GetTensorDescriptor(&x, 'x', layout_format); + auto out_desc = helper::GetTensorDescriptor(out, 'o', layout_format); + auto saved_idx_desc = + helper::GetTensorDescriptor(saved_idx, 's', layout_format); + + // Create maxpooling descriptor + auto const nan_opt = CUDNN_NOT_PROPAGATE_NAN; + auto const mode = cudnn_frontend::cudnnResampleMode_t::CUDNN_RESAMPLE_MAXPOOL; + auto const padding_mode = + cudnn_frontend::cudnnPaddingMode_t::CUDNN_NEG_INF_PAD; + auto pool_desc = cudnn_frontend::ResampleDescBuilder_v8() + .setComputeType(CUDNN_DATA_FLOAT) + .setNanPropagation(nan_opt) + .setResampleMode(mode) + .setPaddingMode(padding_mode) + .setSpatialDim(data_dim, kernel_size_int64.data()) + .setSpatialStride(data_dim, strides_int64.data()) + .setPrePadding(data_dim, pre_padding.data()) + .setPostPadding(data_dim, post_padding.data()) + .build(); + + // Create maxpooling op + auto pool_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR) + .setxDesc(x_desc) + .setyDesc(out_desc) + .setidxDesc(saved_idx_desc) + .setResampleDesc(pool_desc) + .build(); + + // Create op graph + std::array ops = {&pool_op}; + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + auto plans = helper::FindExecutionPlans(&op_graph, + exhaustive_search, + deterministic, + &data_ptrs, + &uids, + handle, + &workspace_handle); + + helper::ExecutePlansAndCache(handle, + &workspace_handle, + &data_ptrs, + &uids, + &plans, + exhaustive_search, + feature_vector, + &plan_cache); +} + +template +void MaxPool2dV2CUDNNKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + DenseTensor* out, + DenseTensor* saved_idx) { + // TODO(tizheng): support int8 mask + MaxPoolV2CUDNNKernel(ctx, + x, + kernel_size, + strides, + paddings, + data_format, + global_pooling, + adaptive, + out, + saved_idx); +} + +} // namespace phi + +using phi::dtype::float16; + +PD_REGISTER_KERNEL(max_pool2d_v2, // cuda_only + GPU, + ALL_LAYOUT, + phi::MaxPool2dV2CUDNNKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType::Type()); +} diff --git a/paddle/phi/kernels/gpudnn/pool_gpudnn.h b/paddle/phi/kernels/gpudnn/pool_gpudnn.h index d830aad6b4f4f..cd2758109f28c 100644 --- a/paddle/phi/kernels/gpudnn/pool_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/pool_gpudnn.h @@ -29,6 +29,21 @@ template using ScalingParamType = typename phi::backends::gpu::CudnnDataType::ScalingParamType; +template +class CudnnIndexType; + +template <> +class CudnnIndexType { + public: + static const cudnnDataType_t type = CUDNN_DATA_INT32; +}; + +template <> +class CudnnIndexType { + public: + static const cudnnDataType_t type = CUDNN_DATA_INT8; +}; + inline GPUDNNDataLayout GetLayoutFromStr(std::string data_format) { if (data_format == "NHWC") { return GPUDNNDataLayout::kNHWC; diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index ed0f40f982d23..2d3116d5ad69b 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1082,6 +1082,7 @@ set_tests_properties( test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 220) +set_tests_properties(test_pool_max_op PROPERTIES TIMEOUT 500) set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120) set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 1000) set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250) diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index 23740d39b8ef3..f2186dee7c339 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -469,5 +469,241 @@ def test_check_grad(self): create_test_bf16_class(TestCastAdaptive2d) +def skip_unit_test(): + return ( + not core.is_compiled_with_cuda() + or not core.is_compiled_with_cudnn_frontend() + or paddle.device.cuda.get_device_capability()[0] < 8 + ) + + +@unittest.skipIf( + skip_unit_test(), + "Only support Ampere or later devices; " + "Paddle should be built with WITH_CUDNN_FRONTEND=ON.", +) +class TestMaxPool2dV2Op(OpTest): + def setUp(self): + self.init_layout() + self.init_test_case() + self.init_global() + self.init_adaptive() + self.init_dtype() + + if self.is_bfloat16_op(): + input = np.random.random(self.shape).astype(np.float32) + input = convert_uint16_to_float( + convert_float_to_uint16(np.round(input * 100.0, 2)) + ) + + else: + input = np.random.random(self.shape).astype(self.dtype) + input = np.round(input * 100.0, 2) + + output, _ = self.pool_forward_naive( + input, + self.ksize, + self.strides, + self.paddings, + self.global_pool, + self.adaptive, + ) + if self.is_bfloat16_op(): + output = output.astype(np.float32) + else: + output = output.astype(self.dtype) + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'kernel_size': self.ksize, + 'data_format': self.data_format, + 'global_pooling': self.global_pool, + 'adaptive': self.adaptive, + } + + if self.data_format == 'NHWC': + input = input.transpose((0, 2, 3, 1)) + output = output.transpose((0, 2, 3, 1)) + + saved_idx = np.zeros(shape=output.shape, dtype=np.int32) + + if self.is_bfloat16_op(): + self.inputs = { + 'x': convert_float_to_uint16( + input, data_format=self.data_format + ) + } + self.outputs = { + 'out': convert_float_to_uint16( + output, data_format=self.data_format + ), + 'saved_idx': saved_idx, + } + self.inputs_fp32 = {'x': input} + + else: + self.inputs = {'x': input} + self.outputs = {'out': output, 'saved_idx': saved_idx} + + def init_layout(self): + self.data_format = "NHWC" + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place( + place, no_check_set=['saved_idx'], check_dygraph=False + ) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + {'x'}, + ['out'], + max_relative_error=0.05, + check_dygraph=False, + ) + + def init_test_case(self): + self.op_type = "max_pool2d_v2" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + def init_global(self): + self.global_pool = True + + def init_adaptive(self): + self.adaptive = False + + +class TestCase8(TestMaxPool2dV2Op): + def init_global(self): + self.global_pool = False + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + {'x'}, + ['out'], + max_relative_error=0.5, + check_dygraph=False, + ) + + +class TestCase9(TestMaxPool2dV2Op): + def init_test_case(self): + self.op_type = "max_pool2d_v2" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + def init_global(self): + self.global_pool = True + + +class TestCase10(TestCase9): + def init_global(self): + self.global_pool = False + + +def create_test_fp16_class(parent): + class TestMaxPool2dV2FP16(parent): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place( + place, no_check_set=['saved_idx'], check_dygraph=False + ) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, {'x'}, ['out'], check_dygraph=False + ) + + cls_name = "{}_{}".format(parent.__name__, "FP16OP") + TestMaxPool2dV2FP16.__name__ = cls_name + globals()[cls_name] = TestMaxPool2dV2FP16 + + +create_test_fp16_class(TestMaxPool2dV2Op) +create_test_fp16_class(TestCase8) +create_test_fp16_class(TestCase9) +create_test_fp16_class(TestCase10) + + +def create_test_bf16_class(parent): + @unittest.skipIf( + skip_unit_test() or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", + ) + class TestMaxPool2dV2BF16(parent): + def init_dtype(self): + self.dtype = np.uint16 + + def get_numeric_grad(self, place, check_name): + scope = core.Scope() + self._check_grad_helper() + op = create_op( + scope, self.op_type, self.inputs, self.outputs, self.attrs + ) + return get_numeric_gradient( + place, + scope, + op, + self.inputs_fp32, + check_name, + ['out'], + delta=0.005, + ) + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_output_with_place( + place, no_check_set=['saved_idx'], check_dygraph=False + ) + + def test_check_grad(self): + place = core.CUDAPlace(0) + numeric_grads = self.get_numeric_grad(place, 'x') + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, + {'x'}, + ['out'], + user_defined_grads=[numeric_grads], + check_dygraph=False, + ) + + cls_name = "{}_{}".format(parent.__name__, "BF16OP") + TestMaxPool2dV2BF16.__name__ = cls_name + globals()[cls_name] = TestMaxPool2dV2BF16 + + +create_test_bf16_class(TestMaxPool2dV2Op) +create_test_bf16_class(TestCase8) +create_test_bf16_class(TestCase9) +create_test_bf16_class(TestCase10) + + if __name__ == '__main__': unittest.main() diff --git a/test/white_list/no_check_set_white_list.py b/test/white_list/no_check_set_white_list.py index 806b0891ea92e..16bf755eecf6e 100644 --- a/test/white_list/no_check_set_white_list.py +++ b/test/white_list/no_check_set_white_list.py @@ -39,4 +39,5 @@ 'rmsprop', 'rrelu', 'layer_norm', + 'max_pool2d_v2', ] diff --git a/test/white_list/op_accuracy_white_list.py b/test/white_list/op_accuracy_white_list.py index 5ad871e071ba4..3027f4960c050 100644 --- a/test/white_list/op_accuracy_white_list.py +++ b/test/white_list/op_accuracy_white_list.py @@ -41,6 +41,7 @@ 'lrn', 'match_matrix_tensor', 'matmul', + 'max_pool2d_v2', 'max_pool2d_with_index', 'max_pool3d_with_index', 'minus', diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index a482de9074eac..36675f650578a 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -110,6 +110,7 @@ parallel_list="^init_phi_test$|\ ^test_gather_nd_op$|\ ^test_index_select_op$|\ ^test_pass_base_list$|\ +^test_pool_max_op$|\ ^test_roll_op$|\ ^test_switch_autotune$|\ ^test_tcp_store$|\