Skip to content

Commit

Permalink
[ROCM] bugfix for unittest (#32392)
Browse files Browse the repository at this point in the history
* fix test_unpool_op

* fix test_inplace_addto_strategy

* fix test_conv2d_fusion_op

* fix test_imperative_lod_tensor_to_selected_rows, test_imperative_selected_rows_to_lod_tensor

* fix test_dot_op

* fix test_correlation_op

* fix tracer

* fix test_memcpy_op
  • Loading branch information
ronny1996 authored May 6, 2021
1 parent efdb0a7 commit 3139262
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 91 deletions.
1 change: 0 additions & 1 deletion cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ function(op_library TARGET)
list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "correlation_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
Expand Down
49 changes: 38 additions & 11 deletions paddle/fluid/operators/conv_cudnn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -699,24 +699,51 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {

// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f;
#ifdef PADDLE_WITH_HIP
// MIOPEN ONLY support beta to be 0.0f
ScalingParamType<T> beta = 0.0f;
#else
ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : 0.0f;
#endif
VLOG(4) << "Conv_grad: use_addto = " << ctx.Attr<bool>("use_addto");

if (input_grad) {
// When beta is 0, it is unnecessary to reset input_grad.
// When beta is 1, the output cannot be reset since addt strategy used.
#ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardData(
handle, &alpha, args1.odesc.desc(), output_grad_data,
args1.wdesc.desc(), filter_data, args1.cdesc.desc(),
data_algo, &beta, args1.idesc.desc(),
transformed_input_grad_data, cudnn_workspace_ptr,
workspace_size));
},
workspace_size);
if (ctx.Attr<bool>("use_addto")) {
Tensor temp_tensor(transformed_input_grad.type());
temp_tensor.Resize(transformed_input_grad.dims());
T* temp_tensor_data = temp_tensor.mutable_data<T>(ctx.GetPlace());
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardData(
handle, &alpha, args1.odesc.desc(), output_grad_data,
args1.wdesc.desc(), filter_data, args1.cdesc.desc(),
data_algo, &beta, args1.idesc.desc(), temp_tensor_data,
cudnn_workspace_ptr, workspace_size));
},
workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenOpTensor(
handle, miopenTensorOpAdd, &alpha, args1.idesc.desc(),
transformed_input_grad_data, &alpha, args1.idesc.desc(),
temp_tensor_data, &beta, args1.idesc.desc(),
transformed_input_grad_data));
} else {
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardData(
handle, &alpha, args1.odesc.desc(), output_grad_data,
args1.wdesc.desc(), filter_data, args1.cdesc.desc(),
data_algo, &beta, args1.idesc.desc(),
transformed_input_grad_data, cudnn_workspace_ptr,
workspace_size));
},
workspace_size);
}

#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
Expand Down
70 changes: 6 additions & 64 deletions paddle/fluid/operators/conv_miopen_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,28 +146,8 @@ struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false));
};

if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
} else {
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetForward());

auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());

VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;

algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.fwd_algo;
});
}
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
VLOG(3) << "choose algo " << algo;
return algo;
}
Expand Down Expand Up @@ -208,27 +188,8 @@ struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false));
};

if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_data_algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardData());

auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());

VLOG(10) << "miopenConvolutionFwdAlgoPerf_t"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;

algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.bwd_data_algo;
});
}
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_data_algo;
VLOG(3) << "choose algo " << algo;
return algo;
}
Expand Down Expand Up @@ -269,27 +230,8 @@ struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false));
};

if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_weights_algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardFilter());

auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());

VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;

algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.bwd_weights_algo;
});
}
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_weights_algo;
VLOG(3) << "choose algo " << algo;
return algo;
}
Expand Down
21 changes: 16 additions & 5 deletions paddle/fluid/operators/correlation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_WITH_HIP
// HIP not supported yet

#include <algorithm>
#include <string>
#include "paddle/fluid/framework/op_registry.h"

#ifdef __HIPCC__
#define __syncwarp() __all(1)
#endif

namespace paddle {
namespace operators {

#ifdef __HIPCC__
#define THREADS_PER_BLOCK 64
#else
#define THREADS_PER_BLOCK 32
#endif
#define FULL_MASK 0xffffffff

using framework::Tensor;

template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) {
for (int offset = 16; offset > 0; offset /= 2) {
#ifdef __HIPCC__
val += __shfl_down(val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset);
#endif
}
return val;
}

template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) {
#ifdef __HIPCC__
static __shared__ T shared[64];
#else
static __shared__ T shared[32];
#endif
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;

Expand Down Expand Up @@ -483,5 +496,3 @@ REGISTER_OP_CUDA_KERNEL(correlation, ops::CorrelationCUDAKernel<float>,
ops::CorrelationCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel<float>,
ops::CorrelationCUDAGradKernel<double>);
#endif // not PADDLE_WITH_HIP
3 changes: 1 addition & 2 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ if (WITH_GPU OR WITH_ROCM)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n")
endif()
# conv_fusion_op needs cudnn 7 above
# HIP not support cudnnConvolutionBiasActivationForward
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100))
if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif()
Expand Down
83 changes: 81 additions & 2 deletions paddle/fluid/operators/fused/conv_fusion_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif

DECLARE_int64(cudnn_exhaustive_search_times);

namespace paddle {
namespace operators {

#if CUDNN_VERSION >= 7100
#if PADDLE_WITH_HIP || CUDNN_VERSION >= 7100
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
Expand Down Expand Up @@ -162,7 +166,78 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
#ifdef PADDLE_WITH_HIP
miopenConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenSetConvolutionGroupCount(cudnn_conv_desc,
groups));
// Now only support NCHW
std::vector<int> bias_dim = {
1, static_cast<int>(transformed_output.dims()[1]), 1, 1};
miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_input.dims()));
miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_output.dims()));
miopenTensorDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize<int>(filter->dims()));
miopenTensorDescriptor_t cudnn_bias_desc =
bias_desc.descriptor<T>(layout, bias_dim);
miopenActivationDescriptor_t cudnn_act_desc =
act_desc.descriptor<T>(activation);

miopenConvFwdAlgorithm_t algo;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();

auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims());

size_t workspace_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionForwardGetWorkSpaceSize(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, &workspace_size));
int find_count;
miopenConvAlgoPerf_t find_result;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenFindConvolutionForwardAlgorithm(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &find_count, &find_result,
cudnn_workspace_ptr, workspace_size, false));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
VLOG(3) << "cuDNN forward algo " << algo;

{
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, &beta, cudnn_output_desc,
output_data, cudnn_workspace, workspace_size));
};
workspace_handle.RunFunc(cudnn_func, workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionForwardBias(
handle, &alpha, cudnn_bias_desc, bias_data, &beta,
cudnn_output_desc, output_data));
if (activation != "identity") {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationForward(
handle, cudnn_act_desc, &alpha, cudnn_output_desc, output_data,
&beta, cudnn_output_desc, output_data));
}
if (residual) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenOpTensor(
handle, miopenTensorOpAdd, &alpha, cudnn_output_desc, output_data,
&alpha, cudnn_output_desc, residual_data, &beta, cudnn_output_desc,
output_data));
}
}
#else // PADDLE_WITH_HIP
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS(
Expand Down Expand Up @@ -327,6 +402,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
#endif
std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels");
if (channels.size()) {
auto outs = ctx.MultiOutput<framework::Tensor>("Outputs");
Expand Down Expand Up @@ -358,8 +434,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle

#if CUDNN_VERSION >= 7100
namespace ops = paddle::operators;
#if CUDNN_VERSION >= 7100
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>);
#endif
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>);
#endif
8 changes: 8 additions & 0 deletions paddle/fluid/operators/math/unpooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ class Unpool2dMaxFunctor<platform::CUDADeviceContext, T> {
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMax<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_height, input_width,
Expand Down Expand Up @@ -117,7 +121,11 @@ class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, T> {
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMaxGrad<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_height, input_width,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/memcpy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, plat::float16,
ops::MemcpyKernel);

#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, int, ops::MemcpyKernel,
int64_t, ops::MemcpyKernel, bool,
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/platform/dynload/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(miopenActivationBackward); \
__macro(miopenConvolutionBackwardWeights); \
__macro(miopenConvolutionForward); \
__macro(miopenConvolutionForwardBias); \
__macro(miopenConvolutionBackwardBias); \
__macro(miopenConvolutionForwardGetWorkSpaceSize); \
__macro(miopenConvolutionBackwardDataGetWorkSpaceSize); \
Expand Down
Loading

0 comments on commit 3139262

Please sign in to comment.