From 8123bf39616ecbedbd87b928170b33285ec10c2f Mon Sep 17 00:00:00 2001 From: itchencheng Date: Sat, 23 Dec 2023 22:18:03 +0800 Subject: [PATCH] [OpenCL] Fix bugs in buffer-version softmax and add unit test --- .../opencl/cl_kernel/buffer/softmax_kernel.cl | 210 ++++++++++++++++- lite/kernels/opencl/softmax_buffer_compute.cc | 137 +++++++---- .../opencl/softmax_buffer_compute_test.cc | 216 ++++++++++++++++++ 3 files changed, 508 insertions(+), 55 deletions(-) create mode 100644 lite/kernels/opencl/softmax_buffer_compute_test.cc diff --git a/lite/backends/opencl/cl_kernel/buffer/softmax_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/softmax_kernel.cl index 717caf69ba9..abfaac7cd57 100644 --- a/lite/backends/opencl/cl_kernel/buffer/softmax_kernel.cl +++ b/lite/backends/opencl/cl_kernel/buffer/softmax_kernel.cl @@ -18,7 +18,8 @@ __kernel void softmax_width_buffer(__global const CL_DTYPE* input, __private const int N, __private const int C, __private const int H, - __private const int W) { + __private const int W, + __private const float4 mask) { int c = get_global_id(0); int bh = get_global_id(1); @@ -30,26 +31,31 @@ __kernel void softmax_width_buffer(__global const CL_DTYPE* input, /*Compute Max */ CL_DTYPE4 max_value_v4 = vload4(0, input + offset); for (int i = 1; i < W; i += 4) { - max_value_v4 = fmax(max_value_v4, vload4(0, input + offset + i)); + int tmpi = (i + 4 > W) ? W - 4 : i; + max_value_v4 = fmax(max_value_v4, vload4(0, input + offset + tmpi)); } CL_DTYPE max_value = max(max(max_value_v4.s0, max_value_v4.s1), max(max_value_v4.s2, max_value_v4.s3)); /*Compute Exp Sum*/ float4 sum_value_v4 = (float4)0; for (int i = 0; i < W; i += 4) { - sum_value_v4 += exp(convert_float4(vload4(0, input + offset + i)) - - (float4)max_value); + int tmpi = (i + 4 > W) ? W - 4 : i; + float4 maski = (i + 4 > W) ? mask : (float4)(1.0f); + sum_value_v4 += exp(convert_float4(vload4(0, input + offset + tmpi)) - + (float4)max_value) * + maski; } float sum_value = sum_value_v4.s0 + sum_value_v4.s1 + sum_value_v4.s2 + sum_value_v4.s3; /*Compute Result */ for (int i = 0; i < W; i += 4) { + int tmpi = (i + 4 > W) ? W - 4 : i; CL_DTYPE4 value = - CONVERT_TYPE_TO(convert_float4(exp(vload4(0, input + offset + i) - + CONVERT_TYPE_TO(convert_float4(exp(vload4(0, input + offset + tmpi) - (CL_DTYPE4)max_value)) / (float4)sum_value, CL_DTYPE4); - vstore4(value, 0, output + offset + i); + vstore4(value, 0, output + offset + tmpi); } } } @@ -64,8 +70,8 @@ __kernel void softmax_height_buffer(__global const CL_DTYPE* input, int b = get_global_id(1); const int w_4 = (W + 3) / 4; const int c = wc / w_4; // w4 - const int w = (wc % w_4) << 2; - // const int offset = ((b * C + c) * H + 0) * W + w; + int w = (wc % w_4) << 2; + w = (w + 4 > W) ? W - 4 : w; const int offset = (b * C + c) * H * W + w + 0 * W; if (wc < C * W && b < N) { /*Compute Max */ @@ -86,3 +92,191 @@ __kernel void softmax_height_buffer(__global const CL_DTYPE* input, } } } + +__kernel void softmax_channel_buffer(__global const CL_DTYPE* input, + __global CL_DTYPE* output, + __private const int N, + __private const int C, + __private const int H, + __private const int W) { + int hw = get_global_id(0); + int b = get_global_id(1); + const int w_4 = (W + 3) / 4; + const int h = hw / w_4; // w4 + int w = (hw % w_4) << 2; + w = (w + 4 > W) ? W - 4 : w; + const int offset = b * C * H * W + h * W + w; + const int ch_dim = H * W; + if (hw < H * W && b < N) { + /*Compute Max */ + CL_DTYPE4 max_value = vload4(0, input + offset); + for (int i = 1; i < C; ++i) { + max_value = max(max_value, vload4(0, input + offset + i * ch_dim)); + } + /*Compute Exp Sum*/ + CL_DTYPE4 sum_value = (CL_DTYPE4)(0.0f); + for (int i = 0; i < C; ++i) { + sum_value += exp(vload4(0, input + offset + i * ch_dim) - max_value); + } + /*Compute Result */ + for (int i = 0; i < C; ++i) { + CL_DTYPE4 value = + exp(vload4(0, input + offset + i * ch_dim) - max_value) / sum_value; + vstore4(value, 0, output + offset + i * ch_dim); + } + } +} + +__kernel void softmax_batch_buffer(__global const CL_DTYPE* input, + __global CL_DTYPE* output, + __private const int N, + __private const int C, + __private const int H, + __private const int W) { + int hw = get_global_id(0); + int c = get_global_id(1); + const int w_4 = (W + 3) / 4; + const int h = hw / w_4; // w4 + int w = (hw % w_4) << 2; + w = (w + 4 > W) ? W - 4 : w; + const int offset = c * H * W + h * W + w; + const int batch_dim = C * H * W; + + if (hw < H * W && c < C) { + /*Compute Max */ + CL_DTYPE4 max_value = vload4(0, input + offset); + for (int i = 1; i < N; ++i) { + max_value = max(max_value, vload4(0, input + offset + i * batch_dim)); + } + /*Compute Exp Sum*/ + CL_DTYPE4 sum_value = (CL_DTYPE4)(0.0f); + for (int i = 0; i < N; ++i) { + sum_value += exp(vload4(0, input + offset + i * batch_dim) - max_value); + } + /*Compute Result */ + for (int i = 0; i < N; ++i) { + CL_DTYPE4 value = + exp(vload4(0, input + offset + i * batch_dim) - max_value) / + sum_value; + vstore4(value, 0, output + offset + i * batch_dim); + } + } +} + +__kernel void softmax_1x1_buffer(__global const CL_DTYPE* input, + __global CL_DTYPE* output, + __private const int c_count, + __private const int c_blks) { + const int c_blk_idx = get_global_id(0); + const int b_idx = get_global_id(1); + const int tid = get_local_id(0); + + int offset = b_idx * c_count; + + __local float4 tmp[8]; + __local float* tmpx1 = (__local float*)tmp; + + // Compute Max + CL_DTYPE4 maxs = vload4(0, input + offset); + for (int s = tid; s < c_blks; s += 32) { + int tmpi = (s << 2); + tmpi = (tmpi + 4 > c_count) ? c_count - 4 : tmpi; + maxs = max(maxs, vload4(0, input + offset + tmpi)); + } + maxs.x = max(maxs.x, maxs.y); + maxs.z = max(maxs.z, maxs.w); + maxs.x = max(maxs.x, maxs.z); + tmpx1[tid] = (float)maxs.x; + + barrier(CLK_LOCAL_MEM_FENCE); + + float maximum; + float4 maxx4; + if (tid == 0) { + maxx4 = max(tmp[0], tmp[1]); + maxx4 = max(maxx4, tmp[2]); + maxx4 = max(maxx4, tmp[3]); + maxx4 = max(maxx4, tmp[4]); + maxx4 = max(maxx4, tmp[5]); + maxx4 = max(maxx4, tmp[6]); + maxx4 = max(maxx4, tmp[7]); + maximum = max(maxx4.x, maxx4.y); + maximum = max(maximum, maxx4.z); + maximum = max(maximum, maxx4.w); + tmpx1[0] = maximum; + } + barrier(CLK_LOCAL_MEM_FENCE); + maximum = tmpx1[0]; + + // Compute Exp Sum + float sum = 0.f; + for (int s = tid; s < c_blks; s += 32) { + for (int i = 0; i < 4; i++) { + int tmpi = (s << 2) + i; + sum += + (tmpi < c_count) ? exp((float)input[offset + tmpi] - maximum) : 0.f; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + tmpx1[tid] = sum; + + barrier(CLK_LOCAL_MEM_FENCE); + if (tid == 0) { + sum = dot((float4)(1.0f), tmp[0]); + sum += dot((float4)(1.0f), tmp[1]); + sum += dot((float4)(1.0f), tmp[2]); + sum += dot((float4)(1.0f), tmp[3]); + sum += dot((float4)(1.0f), tmp[4]); + sum += dot((float4)(1.0f), tmp[5]); + sum += dot((float4)(1.0f), tmp[6]); + sum += dot((float4)(1.0f), tmp[7]); + tmpx1[0] = 1.0f / sum; + } + barrier(CLK_LOCAL_MEM_FENCE); + sum = tmpx1[0]; + + // Compute Result + if (c_blk_idx < c_blks) { + int c_offset = (c_blk_idx << 2); + c_offset = c_offset + 4 > c_count ? c_count - 4 : c_offset; + CL_DTYPE4 src = vload4(0, input + offset + c_offset) - (CL_DTYPE4)maximum; +#ifdef CL_DTYPE_half + CL_DTYPE4 res = convert_half4(exp(convert_float4(src)) * sum); +#else + CL_DTYPE4 res = exp(src) * sum; +#endif + vstore4(res, 0, output + offset + c_offset); + } +} + +__kernel void softmax_common_buffer(__global const CL_DTYPE* input, + __global CL_DTYPE* output, + __private const int pre_dim, + __private const int select_range, + __private const int select_dim) { + int prefix = get_global_id(0); + int suffix = get_global_id(1); + + int offset = prefix * pre_dim + suffix; + + /*Compute Exp Sum*/ + CL_DTYPE max_value = input[offset]; + for (int i = 1; i < select_range; i++) { + max_value = max(max_value, input[offset + i * select_dim]); + } + + /*Compute Exp Sum*/ + float sum_value = 0.0f; + for (int i = 0; i < select_range; i++) { + sum_value += exp((float)(input[offset + i * select_dim] - max_value)); + } + + /*Compute Result */ + for (int i = 0; i < select_range; i++) { + CL_DTYPE value = CONVERT_TYPE_TO( + exp((float)(input[offset + i * select_dim] - max_value)) / + (float)sum_value, + CL_DTYPE); + output[offset + i * select_dim] = value; + } +} \ No newline at end of file diff --git a/lite/kernels/opencl/softmax_buffer_compute.cc b/lite/kernels/opencl/softmax_buffer_compute.cc index 28fe45c1648..3da55715655 100644 --- a/lite/kernels/opencl/softmax_buffer_compute.cc +++ b/lite/kernels/opencl/softmax_buffer_compute.cc @@ -25,9 +25,7 @@ class SoftmaxComputeBuffer public: using param_t = operators::SoftmaxParam; - std::string doc() const override { - return "Softmax using cl::Image2D, kFP16"; - } + std::string doc() const override { return "Softmax using cl::Buffer, kFP16"; } void PrepareForRun() override { softmax_param_ = param_.get_mutable(); @@ -35,24 +33,24 @@ class SoftmaxComputeBuffer int axis = softmax_param_->axis; VLOG(4) << "x_dims: " << x_dims; VLOG(4) << "axis: " << axis; - if (x_dims.size() > 1) { - axis = axis < 0 ? x_dims.size() + axis : axis; - axis_ = 4 - x_dims.size() + axis; - } else { // for dim 1 - axis_ = 1; // process width as channel for folder format - } - VLOG(4) << "axis_: " << axis_; - if (x_dims.size() == 2 && axis_ == 3) { + auto extend_in_dims = ExtendInputDims(x_dims); + axis = axis < 0 ? x_dims.size() + axis : axis; + axis_ = 4 - x_dims.size() + axis; + + if (extend_in_dims[3] < 4) { + small_w_flag_ = true; + kernel_func_name_ = "softmax_common_buffer"; + } else if ((x_dims.size() == 2 || x_dims.size() == 1) && axis_ == 3) { onexone_flag_ = true; - kernel_func_name_ = "softmax_1x1"; + kernel_func_name_ = "softmax_1x1_buffer"; } else if (axis_ == 3) { kernel_func_name_ = "softmax_width_buffer"; } else if (axis_ == 2) { kernel_func_name_ = "softmax_height_buffer"; } else if (axis_ == 1) { - kernel_func_name_ = "softmax_channel"; + kernel_func_name_ = "softmax_channel_buffer"; } else if (axis_ == 0) { - kernel_func_name_ = "softmax_batch"; + kernel_func_name_ = "softmax_batch_buffer"; } else { LOG(FATAL) << "do not support this axis value!" << "axis value is: " << axis_; @@ -95,20 +93,48 @@ class SoftmaxComputeBuffer CL_CHECK_FATAL(status); status = kernel.setArg(1, *out_buf); CL_CHECK_FATAL(status); - status = kernel.setArg(2, static_cast(extend_in_dims[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(3, static_cast(extend_in_dims[1])); - CL_CHECK_FATAL(status); - status = kernel.setArg(4, static_cast(extend_in_dims[2])); - CL_CHECK_FATAL(status); - status = kernel.setArg(5, static_cast(extend_in_dims[3])); - CL_CHECK_FATAL(status); + + if (small_w_flag_) { + int select_dim = 1; + for (int i = extend_in_dims.size() - 1; i >= 0; i--) { + if (i > axis_) { + select_dim *= extend_in_dims[i]; + } + } + int pre_dim = extend_in_dims[axis_] * select_dim; + status = kernel.setArg(2, static_cast(pre_dim)); + CL_CHECK_FATAL(status); + status = kernel.setArg(3, static_cast(extend_in_dims[axis_])); + CL_CHECK_FATAL(status); + status = kernel.setArg(4, static_cast(select_dim)); + CL_CHECK_FATAL(status); + } else if (onexone_flag_) { + status = kernel.setArg(2, static_cast(extend_in_dims[3])); + CL_CHECK_FATAL(status); + status = kernel.setArg(3, UP_DIV(static_cast(extend_in_dims[3]), 4)); + CL_CHECK_FATAL(status); + } else { + status = kernel.setArg(2, static_cast(extend_in_dims[0])); + CL_CHECK_FATAL(status); + status = kernel.setArg(3, static_cast(extend_in_dims[1])); + CL_CHECK_FATAL(status); + status = kernel.setArg(4, static_cast(extend_in_dims[2])); + CL_CHECK_FATAL(status); + status = kernel.setArg(5, static_cast(extend_in_dims[3])); + CL_CHECK_FATAL(status); + if (axis_ == 3) { + auto mask_v = GetMask4(extend_in_dims[3]); + cl_float4 mask = {mask_v[0], mask_v[1], mask_v[2], mask_v[3]}; + status = kernel.setArg(6, mask); + CL_CHECK_FATAL(status); + } + } status = EnqueueNDRangeKernel(context, kernel, cl::NullRange, global_work_size_, - cl::NullRange, + local_work_size_, nullptr, event_); CL_CHECK_FATAL(status); @@ -126,47 +152,63 @@ class SoftmaxComputeBuffer void SetGlobalLocal() { auto x_dims = softmax_param_->x->dims(); - int c = x_dims[1]; - int w_blk = (x_dims[3] + 3) / 4; - // int w = x_dims[3]; - int bh = x_dims[0] * x_dims[2]; - if (axis_ == 3) { // for width + auto extend_in_dims = ExtendInputDims(x_dims); + int n = extend_in_dims[0]; + int c = extend_in_dims[1]; + int h = extend_in_dims[2]; + int w = extend_in_dims[3]; + int w_blk = (w + 3) / 4; + int bh = n * h; + + if (small_w_flag_) { + int suffix_num = 1; + int prefix_num = 1; + for (int i = 0; i < extend_in_dims.size(); i++) { + if (i < axis_) { + prefix_num *= extend_in_dims[i]; + } else if (i > axis_) { + suffix_num *= extend_in_dims[i]; + } + } + global_work_size_ = cl::NDRange(prefix_num, suffix_num, 1); + } else if (onexone_flag_) { + local_work_size_ = cl::NDRange(32, 1, 1); + global_work_size_ = + cl::NDRange(ROUND_UP(UP_DIV(w, 4), local_work_size_[0]), h, 1); + } else if (axis_ == 3) { // for width global_work_size_ = cl::NDRange{static_cast(c), static_cast(bh), static_cast(1)}; } else if (axis_ == 2) { // for height - global_work_size_ = - cl::NDRange{static_cast(c * w_blk), - static_cast(last_x_dims_[0]), - static_cast(1)}; + global_work_size_ = cl::NDRange{static_cast(c * w_blk), + static_cast(n), + static_cast(1)}; + } else if (axis_ == 1) { // for channel + global_work_size_ = cl::NDRange{static_cast(h * w_blk), + static_cast(n), + static_cast(1)}; + } else { // for batch + global_work_size_ = cl::NDRange{static_cast(h * w_blk), + static_cast(c), + static_cast(1)}; } VLOG(4) << "gws: " << global_work_size_[0] << ", " << global_work_size_[1] << ", " << global_work_size_[2]; } - const std::vector GetChannelMask(int channels) { + const std::vector GetMask4(int total_count) { std::vector mask{0.0f, 0.0f, 0.0f, 0.0f}; - const int reminder = channels % 4 == 0 ? 4 : channels % 4; + const int reminder = total_count % 4 == 0 ? 4 : total_count % 4; for (int i = 0; i < reminder; ++i) { - mask[i] = 1.0f; + mask[3 - i] = 1.0f; } return mask; } const DDim ExtendInputDims(const DDim& in_dims) { auto extend_dims = std::vector{1, 1, 1, 1}; - if (onexone_flag_) { - extend_dims[0] = in_dims[0]; - extend_dims[1] = in_dims[1]; - } else { - for (int i = 0; i < in_dims.size(); i++) { - extend_dims[4 - in_dims.size() + i] = in_dims[i]; - } - if (in_dims.size() == - 1) { // transform dim_w to dim_c for dim1 folder case - extend_dims[1] = in_dims[0]; - extend_dims[3] = 1; - } + for (int i = 0; i < in_dims.size(); i++) { + extend_dims[4 - in_dims.size() + i] = in_dims[i]; } return DDim(extend_dims); } @@ -182,6 +224,7 @@ class SoftmaxComputeBuffer DDim last_x_dims_; int axis_; bool onexone_flag_{false}; + bool small_w_flag_{false}; DDim out_img_shape_ = DDim(std::vector( {static_cast(1), static_cast(1)})); cl::NDRange global_work_size_; diff --git a/lite/kernels/opencl/softmax_buffer_compute_test.cc b/lite/kernels/opencl/softmax_buffer_compute_test.cc new file mode 100644 index 00000000000..26caede5ede --- /dev/null +++ b/lite/kernels/opencl/softmax_buffer_compute_test.cc @@ -0,0 +1,216 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_image_converter.h" +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" +#include "lite/tests/utils/fill_data.h" + +#define FP32_ABS_DIFF (1e-6) +#define FP32_RELATIVE_DIFF (1e-6) +#define FP16_ABS_DIFF (1e-3) +#define FP16_RELATIVE_DIFF (1e-3) + +namespace paddle { +namespace lite { + +void softmax_baseline(const float *x_data, + float *out_data, + const DDim x_dims, + int axis) { + auto x_rank = x_dims.size(); + if (axis < 0) { + axis += x_rank; + } + int axis_size = x_dims[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int compute_size = outer_num * inner_num; + for (int i = 0; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int start = idx_outer * inner_num + idx_inner; + int offset; + + offset = start; + float max_data = std::numeric_limits::lowest(); + for (int j = 0; j < axis_size; j++) { + max_data = x_data[offset] > max_data ? x_data[offset] : max_data; + offset += inner_num; + } + + offset = start; + float sum_data = 0.f; + for (int j = 0; j < axis_size; j++) { + out_data[offset] = exp(x_data[offset] - max_data); + sum_data += out_data[offset]; + offset += inner_num; + } + + offset = start; + for (int j = 0; j < axis_size; j++) { + out_data[offset] /= sum_data; + offset += inner_num; + } + } +} + +void test(const lite_api::CLPrecisionType p, + const DDim &x_dim, + const int axis) { + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + CLRuntime::Global()->set_precision(p); + const bool fp16_flag = (p == lite_api::CLPrecisionType::CL_PRECISION_FP16); + LOG(INFO) << "\n\t[ START ] Test Precision=" + << lite_api::CLPrecisionTypeToStr(p) << " x_dim=" << x_dim + << " axis=" << axis; + + auto kernels = KernelRegistry::Global().Create( + "softmax", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + + lite::Tensor x, out; + operators::SoftmaxParam param; + param.x = &x; + param.output = &out; + param.axis = axis; + + kernel->SetParam(param); + kernel->SetContext(std::move(context)); + + DDim out_dim = x_dim; + x.Resize(x_dim); + out.Resize(out_dim); + + std::vector x_cpu(x_dim.production()); + std::vector x_cpu_half(x_dim.production()); + std::vector out_from_cpu(out_dim.production()); + std::vector out_from_gpu(out_dim.production()); + + // fill random input + float *ptr_float = x_cpu.data(); + half_t *ptr_half = x_cpu_half.data(); + fill_data_rand(ptr_float, -1.f, 1.f, x_dim.production()); + if (fp16_flag) { + for (int i = 0; i < x_dim.production(); i++) { + ptr_half[i] = Float2Half(ptr_float[i]); + } + } + + // x data + auto *x_data = x.mutable_data(TARGET(kOpenCL)); + size_t elemSize = fp16_flag ? sizeof(half_t) : sizeof(float); + const void *src_ptr = fp16_flag ? reinterpret_cast(ptr_half) + : reinterpret_cast(ptr_float); + TargetWrapperCL::MemcpySync( + x_data, src_ptr, x_dim.production() * elemSize, IoDirection::HtoD); + + // run cpu ref + softmax_baseline(ptr_float, out_from_cpu.data(), x_dim, axis); + + // run opencl kernel + kernel->Launch(); + +#ifdef LITE_WITH_PROFILE + profile::OpCharacter opchar; + kernel->SetProfileRuntimeKernelInfo(&opchar); + double timeInMS = CLRuntime::Global()->GetCommandTime(opchar.cl_event); + LOG(INFO) << "x_dim=" << x_dim << ", kernel=" << opchar.kernel_func_name + << ": time cost=" << timeInMS; +#endif + + CLRuntime::Global()->command_queue().finish(); + + // output + auto *out_data = fp16_flag ? out.mutable_data() + : out.mutable_data(); + void *out_gpu = out_from_gpu.data(); + TargetWrapperCL::MemcpySync( + out_gpu, out_data, out_dim.production() * elemSize, IoDirection::DtoH); + half_t *out_from_gpu_half = reinterpret_cast(out_gpu); + + VLOG(4) << "output_data vs output_ref_data"; + auto relative_diff_thres = + fp16_flag ? FP16_RELATIVE_DIFF : FP32_RELATIVE_DIFF; + auto abs_diff_thres = fp16_flag ? FP16_ABS_DIFF : FP32_ABS_DIFF; + uint32_t diff_cnt = 0; + for (int i = 0; i < out_dim.production(); i++) { + float gpu_value = + fp16_flag ? Half2Float(out_from_gpu_half[i]) : out_from_gpu[i]; + auto relative_diff = COMPUTE_RELATIVE_DIFF(gpu_value, out_from_cpu[i]); + auto abs_diff = COMPUTE_ABS_DIFF(gpu_value, out_from_cpu[i]); + EXPECT_FALSE(relative_diff > relative_diff_thres && + abs_diff > abs_diff_thres); + if (relative_diff > relative_diff_thres && abs_diff > abs_diff_thres) { + LOG(WARNING) << lite_api::CLPrecisionTypeToStr(p) << " err idx: " << i + << " abs_diff: " << abs_diff + << "\t relative_diff: " << relative_diff + << "\t out_ins: " << gpu_value + << "\t out_ref: " << out_from_cpu[i]; + diff_cnt++; + } + } + if (diff_cnt != 0) { + LOG(FATAL) << "\n\t[ FAILED ] " + << " Test Precision=" << lite_api::CLPrecisionTypeToStr(p) + << " x_dim=" << x_dim << " axis=" << axis + << "; diff_cnt= " << diff_cnt << "/" << out_dim.production(); + } else { + LOG(INFO) << "\n\t[ PASSED ] " + << " Test Precision=" << lite_api::CLPrecisionTypeToStr(p) + << " x_dim=" << x_dim << " axis=" << axis; + } +} + +TEST(softmax, compute_basic) { + for (const auto precision_type : + {lite_api::CLPrecisionType::CL_PRECISION_FP32, + lite_api::CLPrecisionType::CL_PRECISION_FP16}) { + for (const auto x_dim : std::vector>{ + {31, 11, 53, 831}, + {11, 53, 831}, + {53, 1000}, + {3, 2, 53, 1}, + {2, 53, 1}, + {53, 1}, + {3, 2, 53, 3}, + {2, 53, 3}, + {53, 3}, + {3, 2, 53, 7}, + {2, 53, 7}, + {53, 7}, + {1, 2, 3, 4}, + {2, 3, 4}, + {3, 4}, + }) { + int ndims = x_dim.size(); + for (int axis = -1; axis < ndims; axis++) { + test(precision_type, DDim(x_dim), axis); + } + } + // Special case, such as large num + const auto x_dims = std::vector{64, 1001}; + test(precision_type, DDim(x_dims), 1); + } +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(softmax, kOpenCL, kFP16, kNCHW, def);