From 246287c62e8629e76dc0cf440e4295e1e1890511 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 19 Oct 2020 09:37:18 +0000 Subject: [PATCH 1/5] refine gpu kernel config for Paddle --- paddle/fluid/operators/bce_loss_op.cu | 16 +-- paddle/fluid/operators/bilateral_slice_op.cu | 30 +++-- paddle/fluid/operators/cumsum_op.cu | 2 +- paddle/fluid/operators/interpolate_op.cu | 42 ++++--- paddle/fluid/operators/interpolate_v2_op.cu | 42 ++++--- .../fluid/operators/math/segment_pooling.cu | 2 +- paddle/fluid/operators/mish_op.cu | 20 +-- paddle/fluid/operators/mv_op.cu | 2 +- paddle/fluid/operators/segment_pool_op.cu | 2 +- paddle/fluid/operators/stack_op.cu | 2 +- paddle/fluid/operators/transpose_op.cu | 2 +- paddle/fluid/operators/where_op.cu | 2 +- paddle/fluid/platform/gpu_launch_config.h | 114 +++++++++++++----- .../fluid/platform/gpu_launch_param_config.h | 103 ---------------- .../unittests/test_bilinear_interp_op.py | 2 + 15 files changed, 175 insertions(+), 208 deletions(-) mode change 100644 => 100755 paddle/fluid/platform/gpu_launch_config.h delete mode 100755 paddle/fluid/platform/gpu_launch_param_config.h diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu index 16db4f05e31d3..1a967c57385a0 100644 --- a/paddle/fluid/operators/bce_loss_op.cu +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -70,7 +70,7 @@ class BCELossCUDAKernel : public framework::OpKernel { auto x_numel = x->numel(); platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(x_numel, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), x_numel); Tensor x_cpu; framework::TensorCopy(*x, platform::CPUPlace(), &x_cpu); @@ -89,9 +89,9 @@ class BCELossCUDAKernel : public framework::OpKernel { auto& dev_ctx = ctx.cuda_device_context(); - GPUBCELossForward< - T><<>>( - x_data, labels->data(), out_data, x_numel); + GPUBCELossForward<<>>(x_data, labels->data(), + out_data, x_numel); } }; @@ -106,12 +106,12 @@ class BCELossGradCUDAKernel : public framework::OpKernel { auto dx_data = dx->mutable_data(ctx.GetPlace()); int x_numel = x->numel(); - platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(x_numel, ctx); auto& dev_ctx = ctx.cuda_device_context(); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, x_numel); - GPUBCELossBackward< - T><<>>( + GPUBCELossBackward<<>>( x->data(), labels->data(), dout->data(), dx_data, x_numel); } }; diff --git a/paddle/fluid/operators/bilateral_slice_op.cu b/paddle/fluid/operators/bilateral_slice_op.cu index e46950f61887d..e56a4be53d149 100644 --- a/paddle/fluid/operators/bilateral_slice_op.cu +++ b/paddle/fluid/operators/bilateral_slice_op.cu @@ -165,10 +165,11 @@ class BilateralSliceOpCUDAKernel : public framework::OpKernel { int total_count = batch_size * h * w * output_dims[1]; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(total_count, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), total_count); - BilateralSliceCudaForwardKernel<<>>( + BilateralSliceCudaForwardKernel< + T><<>>( output_data, grid_data, guide_data, input_data, grid_sizes, has_offset, total_count, output_dims[1]); } @@ -472,24 +473,29 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel { grid_sizes.input_chans = input_chans; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(grid_count, ctx, 512); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), grid_count); - BilateralSliceCudaGridGradKernel<<>>( + BilateralSliceCudaGridGradKernel< + T><<>>( grid_grad_data, output_grad_data, guide_data, input_data, grid_sizes, has_offset, grid_count, output_chans); - config = platform::getGpuLaunchConfig(guide_count, ctx, 512); + config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), guide_count); - BilateralSliceCudaGuideGradKernel<<< - config.blocks, config.threads, 0, ctx.cuda_device_context().stream()>>>( + BilateralSliceCudaGuideGradKernel< + T><<>>( guide_grad_data, output_grad_data, grid_data, guide_data, input_data, grid_sizes, has_offset, guide_count, output_chans); - config = platform::getGpuLaunchConfig(input_count, ctx, 512); + config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_count); - BilateralSliceCudaInputGradKernel<<< - config.blocks, config.threads, 0, ctx.cuda_device_context().stream()>>>( + BilateralSliceCudaInputGradKernel< + T><<>>( input_grad_data, output_grad_data, grid_data, guide_data, grid_sizes, has_offset, input_count, output_chans); } diff --git a/paddle/fluid/operators/cumsum_op.cu b/paddle/fluid/operators/cumsum_op.cu index cff0a101e03d5..85cbf444a564e 100644 --- a/paddle/fluid/operators/cumsum_op.cu +++ b/paddle/fluid/operators/cumsum_op.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/cum_op.h" -#include "paddle/fluid/platform/gpu_launch_param_config.h" +#include "paddle/fluid/platform/gpu_launch_config.h" using Tensor = paddle::framework::Tensor; using LoDTensor = paddle::framework::LoDTensor; diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 5f17f2960573c..9be385683652a 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -887,10 +887,10 @@ static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_cw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("linear" == interp_method) { - KeLinearInterpFw<<<<>>( input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w, align_corners, align_mode, data_layout); @@ -981,21 +981,22 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_chw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("nearest" == interp_method) { - KeNearestNeighborInterpFw<<>>( + KeNearestNeighborInterpFw< + T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } else if ("bilinear" == interp_method) { - KeBilinearInterpFw<<<<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpFw< - T><<>>( + KeBicubicInterpFw<<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } @@ -1097,10 +1098,10 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_cdhw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("trilinear" == interp_method) { - KeTrilinearInterpFw<<<<>>( input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, @@ -1176,10 +1177,10 @@ static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_cw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("linear" == interp_method) { - KeLinearInterpBw<<<<>>( input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c, ratio_w, align_corners, align_mode, data_layout); @@ -1267,22 +1268,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_chw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("nearest" == interp_method) { - KeNearestNeighborInterpBw<<>>( + KeNearestNeighborInterpBw< + T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } else if ("bilinear" == interp_method) { - KeBilinearInterpBw<<<<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpBw< - T><<>>( + KeBicubicInterpBw<<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } @@ -1378,10 +1380,10 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_cdhw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("trilinear" == interp_method) { - KeTrilinearInterpBw<<<<>>( input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, diff --git a/paddle/fluid/operators/interpolate_v2_op.cu b/paddle/fluid/operators/interpolate_v2_op.cu index 816539c3b5fdb..88de1dbb3c1dd 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cu +++ b/paddle/fluid/operators/interpolate_v2_op.cu @@ -899,10 +899,10 @@ static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_cw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("linear" == interp_method) { - KeLinearInterpFw<<<<>>( input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w, align_corners, align_mode, data_layout); @@ -1018,21 +1018,22 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_chw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("nearest" == interp_method) { - KeNearestNeighborInterpFw<<>>( + KeNearestNeighborInterpFw< + T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } else if ("bilinear" == interp_method) { - KeBilinearInterpFw<<<<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpFw< - T><<>>( + KeBicubicInterpFw<<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } @@ -1167,10 +1168,10 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_cdhw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("trilinear" == interp_method) { - KeTrilinearInterpFw<<<<>>( input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, @@ -1259,10 +1260,10 @@ static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_cw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("linear" == interp_method) { - KeLinearInterpBw<<<<>>( input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c, ratio_w, align_corners, align_mode, data_layout); @@ -1376,22 +1377,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_chw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("nearest" == interp_method) { - KeNearestNeighborInterpBw<<>>( + KeNearestNeighborInterpBw< + T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } else if ("bilinear" == interp_method) { - KeBilinearInterpBw<<<<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpBw< - T><<>>( + KeBicubicInterpBw<<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } @@ -1520,10 +1522,10 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx, int pixelNum = n * out_cdhw; platform::GpuLaunchConfig config = - platform::getGpuLaunchConfig(pixelNum, ctx); + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("trilinear" == interp_method) { - KeTrilinearInterpBw<<<<>>( input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, diff --git a/paddle/fluid/operators/math/segment_pooling.cu b/paddle/fluid/operators/math/segment_pooling.cu index 37155fa184e23..0b615cefac4ee 100644 --- a/paddle/fluid/operators/math/segment_pooling.cu +++ b/paddle/fluid/operators/math/segment_pooling.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/segment_pooling.h" #include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/gpu_launch_param_config.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/mish_op.cu b/paddle/fluid/operators/mish_op.cu index 77817e526e13d..6513e5d95e4ac 100644 --- a/paddle/fluid/operators/mish_op.cu +++ b/paddle/fluid/operators/mish_op.cu @@ -87,8 +87,9 @@ class MishCUDAKernel : public framework::OpKernel { const int numel = x->numel(); - platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); - KeMishFw<<<<>>(x_data, out_data, numel, threshold); } @@ -108,8 +109,9 @@ class MishFP32CUDAKernel : public framework::OpKernel { const int numel = x->numel(); - platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); - KeMishFwFP32<<>>(x_data, out_data, numel, threshold); } @@ -131,8 +133,9 @@ class MishGradCUDAKernel : public framework::OpKernel { const int numel = x->numel(); - platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); - KeMishBw<<<<>>( x_data, dout_data, dx_data, numel, threshold); } @@ -154,8 +157,9 @@ class MishGradFP32CUDAKernel : public framework::OpKernel { const int numel = x->numel(); - platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); - KeMishBwFP32<<>>( x_data, dout_data, dx_data, numel, threshold); } diff --git a/paddle/fluid/operators/mv_op.cu b/paddle/fluid/operators/mv_op.cu index b6d829392e3ba..ee638ede22b64 100644 --- a/paddle/fluid/operators/mv_op.cu +++ b/paddle/fluid/operators/mv_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mv_op.h" -#include "paddle/fluid/platform/gpu_launch_param_config.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/segment_pool_op.cu b/paddle/fluid/operators/segment_pool_op.cu index dc92d7fcc3a87..379a07a26dd5c 100644 --- a/paddle/fluid/operators/segment_pool_op.cu +++ b/paddle/fluid/operators/segment_pool_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/segment_pool_op.h" #include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/gpu_launch_param_config.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/paddle/fluid/operators/stack_op.cu b/paddle/fluid/operators/stack_op.cu index b3784d3c032d9..4800f5f9eb533 100644 --- a/paddle/fluid/operators/stack_op.cu +++ b/paddle/fluid/operators/stack_op.cu @@ -16,7 +16,7 @@ #include #include #include "paddle/fluid/operators/stack_op.h" -#include "paddle/fluid/platform/gpu_launch_param_config.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace plat = paddle::platform; namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/transpose_op.cu b/paddle/fluid/operators/transpose_op.cu index 79dd29ebc691c..0679668cf1b5a 100644 --- a/paddle/fluid/operators/transpose_op.cu +++ b/paddle/fluid/operators/transpose_op.cu @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/gpu_launch_param_config.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_op.cu index 39983e7de03b5..721c6e5390e85 100644 --- a/paddle/fluid/operators/where_op.cu +++ b/paddle/fluid/operators/where_op.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/where_op.h" -#include "paddle/fluid/platform/gpu_launch_param_config.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace platform = paddle::platform; diff --git a/paddle/fluid/platform/gpu_launch_config.h b/paddle/fluid/platform/gpu_launch_config.h old mode 100644 new mode 100755 index fd6e80527caf6..4ed20743bb0e9 --- a/paddle/fluid/platform/gpu_launch_config.h +++ b/paddle/fluid/platform/gpu_launch_config.h @@ -1,49 +1,103 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Used for compute gpu launch parameter #pragma once -#include +#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/cuda_primitives.h" +#include +#include +#include +#include +#include namespace paddle { namespace platform { -struct GpuLaunchConfig { - // Number of threads per block. - int threads; - // Number of blocks for GPU kernel launch. - int blocks; +inline int DivUp(int a, int b) { return (a + b - 1) / b; } - GpuLaunchConfig(int threads, int blocks) : threads(threads), blocks(blocks) {} +struct GpuLaunchConfig { + dim3 theory_thread_count = dim3(0, 0, 0); + dim3 thread_per_block = dim3(0, 0, 0); + dim3 block_per_grid = dim3(0, 0, 0); }; -inline GpuLaunchConfig getGpuLaunchConfig( - const int N, const framework::ExecutionContext& ctx, - int max_threads = 1024) { - int threads = - std::min(max_threads, ctx.cuda_device_context().GetMaxThreadsPerBlock()); - int physical_thread_count = - std::min(ctx.cuda_device_context().GetMaxPhysicalThreadCount(), N); - int blocks = std::min((physical_thread_count + threads - 1) / threads, - ctx.cuda_device_context().GetSMCount()); +inline GpuLaunchConfig GetGpuLaunchConfig1D( + const platform::CUDADeviceContext& context, int element_count) { + PADDLE_ENFORCE_GT(element_count, 0, platform::errors::InvalidArgument( + "element count should greater than 0," + " but received value is %d.", + element_count)); + + const int theory_thread_count = element_count; + // Get Max threads in all SM + int max_pyhsical_threads = context.GetMaxPhysicalThreadCount(); + int sm = context.GetSMCount(); + + // Compute pyhsical threads we need, should small than max sm threads + const int physical_thread_count = + std::min(max_pyhsical_threads, theory_thread_count); + + // Need get from device + const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock()); + // Suppose block count small than factor * sm, factor is a experiments value. + int factor = 4; + const int block_count = + std::min(DivUp(physical_thread_count, thread_per_block), factor * sm); - GpuLaunchConfig config(threads, blocks); + GpuLaunchConfig config; + config.theory_thread_count.x = theory_thread_count; + config.thread_per_block.x = thread_per_block; + config.block_per_grid.x = block_count; + return config; +} + +inline GpuLaunchConfig GetGpuLaunchConfig2D( + const platform::CUDADeviceContext& context, int xdim, int ydim) { + PADDLE_ENFORCE_GT(xdim, 0, platform::errors::InvalidArgument( + "x dim number should greater than 0," + " but received value is:%d", + xdim)); + PADDLE_ENFORCE_GT(ydim, 0, platform::errors::InvalidArgument( + "y dim number should greater than 0," + " but received value is:%d", + ydim)); + + const int kThreadsPerBlock = 256; + int block_cols = std::min(xdim, kThreadsPerBlock); + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + + int max_physical_threads = context.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_physical_threads / kThreadsPerBlock, 1); + GpuLaunchConfig config; + // Noticed, block size is not align to 32, if needed do it yourself. + config.theory_thread_count = dim3(xdim, ydim, 1); + config.thread_per_block = dim3(block_cols, block_rows, 1); + + int grid_x = std::min(DivUp(xdim, block_cols), max_blocks); + int grid_y = std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)); + + config.block_per_grid = dim3(grid_x, grid_y, 1); return config; } +// 3D will add later + } // namespace platform } // namespace paddle + +#endif diff --git a/paddle/fluid/platform/gpu_launch_param_config.h b/paddle/fluid/platform/gpu_launch_param_config.h deleted file mode 100755 index 40f4ef975e76c..0000000000000 --- a/paddle/fluid/platform/gpu_launch_param_config.h +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Used for compute gpu launch parameter - -#pragma once - -#ifdef PADDLE_WITH_CUDA - -#include -#include -#include -#include -#include - -namespace paddle { -namespace platform { - -inline int DivUp(int a, int b) { return (a + b - 1) / b; } - -struct GpuLaunchParamConfig { - dim3 theory_thread_count = dim3(0, 0, 0); - dim3 thread_per_block = dim3(0, 0, 0); - dim3 block_per_grid = dim3(0, 0, 0); -}; - -inline GpuLaunchParamConfig GetGpuLaunchConfig1D( - const platform::CUDADeviceContext& context, int element_count) { - PADDLE_ENFORCE_GT(element_count, 0, platform::errors::InvalidArgument( - "element count should greater than 0," - " but received value is %d.", - element_count)); - - const int theory_thread_count = element_count; - // Get Max threads in all SM - int max_pyhsical_threads = context.GetMaxPhysicalThreadCount(); - int sm = context.GetSMCount(); - - // Compute pyhsical threads we need, should small than max sm threads - const int physical_thread_count = - std::min(max_pyhsical_threads, theory_thread_count); - - // Need get from device - const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock()); - // Suppose block count small than factor * sm, factor is a experiments value. - int factor = 4; - const int block_count = - std::min(DivUp(physical_thread_count, thread_per_block), factor * sm); - - GpuLaunchParamConfig config; - config.theory_thread_count.x = theory_thread_count; - config.thread_per_block.x = thread_per_block; - config.block_per_grid.x = block_count; - return config; -} - -inline GpuLaunchParamConfig GetGpuLaunchConfig2D( - const platform::CUDADeviceContext& context, int xdim, int ydim) { - PADDLE_ENFORCE_GT(xdim, 0, platform::errors::InvalidArgument( - "x dim number should greater than 0," - " but received value is:%d", - xdim)); - PADDLE_ENFORCE_GT(ydim, 0, platform::errors::InvalidArgument( - "y dim number should greater than 0," - " but received value is:%d", - ydim)); - - const int kThreadsPerBlock = 256; - int block_cols = std::min(xdim, kThreadsPerBlock); - int block_rows = std::max(kThreadsPerBlock / block_cols, 1); - - int max_physical_threads = context.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_physical_threads / kThreadsPerBlock, 1); - - GpuLaunchParamConfig config; - // Noticed, block size is not align to 32, if needed do it yourself. - config.theory_thread_count = dim3(xdim, ydim, 1); - config.thread_per_block = dim3(block_cols, block_rows, 1); - - int grid_x = std::min(DivUp(xdim, block_cols), max_blocks); - int grid_y = std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)); - - config.block_per_grid = dim3(grid_x, grid_y, 1); - return config; -} - -// 3D will add later - -} // namespace platform -} // namespace paddle - -#endif diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index 287e85cb271f8..d152836fd61ac 100755 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -531,4 +531,6 @@ def test_case(self): if __name__ == "__main__": + import paddle + paddle.enable_static() unittest.main() From 2c5144fcb0c891c0cdc32fb66a9f56fa9accd605 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 19 Oct 2020 09:48:52 +0000 Subject: [PATCH 2/5] refine the config --- paddle/fluid/platform/gpu_launch_config.h | 6 +++--- .../paddle/fluid/tests/unittests/test_bilinear_interp_op.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/platform/gpu_launch_config.h b/paddle/fluid/platform/gpu_launch_config.h index 4ed20743bb0e9..f4dbe9b718f7d 100755 --- a/paddle/fluid/platform/gpu_launch_config.h +++ b/paddle/fluid/platform/gpu_launch_config.h @@ -30,9 +30,9 @@ namespace platform { inline int DivUp(int a, int b) { return (a + b - 1) / b; } struct GpuLaunchConfig { - dim3 theory_thread_count = dim3(0, 0, 0); - dim3 thread_per_block = dim3(0, 0, 0); - dim3 block_per_grid = dim3(0, 0, 0); + dim3 theory_thread_count = dim3(1, 1, 1); + dim3 thread_per_block = dim3(1, 1, 1); + dim3 block_per_grid = dim3(1, 1, 1); }; inline GpuLaunchConfig GetGpuLaunchConfig1D( diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index d152836fd61ac..287e85cb271f8 100755 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -531,6 +531,4 @@ def test_case(self): if __name__ == "__main__": - import paddle - paddle.enable_static() unittest.main() From ee39c33df9285c7a1248a19bbdeb267f3e29b394 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 19 Oct 2020 12:47:25 +0000 Subject: [PATCH 3/5] refine the code --- paddle/fluid/operators/interpolate_op.cu | 4 ++-- paddle/fluid/operators/interpolate_v2_op.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 9be385683652a..6be7dbdc110d5 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -995,7 +995,7 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpFw<<<<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); @@ -1283,7 +1283,7 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpBw<<<<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); diff --git a/paddle/fluid/operators/interpolate_v2_op.cu b/paddle/fluid/operators/interpolate_v2_op.cu index 88de1dbb3c1dd..90abcaa8b472a 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cu +++ b/paddle/fluid/operators/interpolate_v2_op.cu @@ -1032,7 +1032,7 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpFw<<<<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); @@ -1392,7 +1392,7 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpBw<<<<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); From 710a560e00a8d8e829733390ee08bc957ea1862d Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 19 Oct 2020 12:57:24 +0000 Subject: [PATCH 4/5] refine the code --- paddle/fluid/platform/gpu_launch_config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/platform/gpu_launch_config.h b/paddle/fluid/platform/gpu_launch_config.h index f4dbe9b718f7d..83511f242958f 100755 --- a/paddle/fluid/platform/gpu_launch_config.h +++ b/paddle/fluid/platform/gpu_launch_config.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 80a9762c9f26f46193690bb07ede65930f1f8aef Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 19 Oct 2020 13:00:06 +0000 Subject: [PATCH 5/5] refine the code --- paddle/fluid/platform/gpu_launch_config.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/platform/gpu_launch_config.h b/paddle/fluid/platform/gpu_launch_config.h index 83511f242958f..d2d57995b728e 100755 --- a/paddle/fluid/platform/gpu_launch_config.h +++ b/paddle/fluid/platform/gpu_launch_config.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -95,7 +95,7 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D( return config; } -// 3D will add later +// TODO(wangchaochaohu): 3D will add later } // namespace platform } // namespace paddle