Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify some contents for elementwise op impl #32414

Merged
merged 2 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_add_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
Expand All @@ -34,7 +33,9 @@ namespace operators {
*/
template <typename T>
struct CudaAddFunctor {
inline HOSTDEVICE T operator()(T args[]) const { return args[0] + args[1]; }
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] + args[1];
}
};

template <typename T>
Expand Down
32 changes: 20 additions & 12 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"

#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -90,8 +101,7 @@ struct ElementwiseDataWrapper {

template <ElementwiseType ET, int VecSize, typename T, typename Functor>
__device__ void VectorizedKernelImpl(
ElementwiseDataWrapper<ET, VecSize, T> data, int size, Functor func,
int tid) {
ElementwiseDataWrapper<ET, VecSize, T> data, Functor func, int tid) {
using VecType = CudaAlignedVector<T, VecSize>;
VecType ins_vec[ET];
VecType out_vec;
Expand Down Expand Up @@ -121,10 +131,9 @@ __device__ void VectorizedKernelImpl(
data.store_vector(out_vec, tid);
}

template <ElementwiseType ET, typename T, typename Functor>
__device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, 1, T> data,
int size, Functor func, int start,
int remain) {
template <ElementwiseType ET, int VecSize, typename T, typename Functor>
__device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data,
Functor func, int start, int remain) {
T ins[ET];
T out;

Expand All @@ -146,12 +155,11 @@ __global__ void VectorizedKernel(const T *__restrict__ in0,
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = size - VecSize * tid;
remain = remain > 0 ? remain : 0;
auto data = ElementwiseDataWrapper<ET, VecSize, T>(out, in0, in1);
if (remain >= VecSize) {
auto data = ElementwiseDataWrapper<ET, VecSize, T>(out, in0, in1);
VectorizedKernelImpl(data, size, func, tid);
VectorizedKernelImpl(data, func, tid);
} else {
auto data = ElementwiseDataWrapper<ET, 1, T>(out, in0, in1);
ScalarKernelImpl(data, size, func, tid * VecSize, remain);
ScalarKernelImpl(data, func, tid * VecSize, remain);
}
}

Expand All @@ -162,7 +170,7 @@ __global__ void ScalarKernel(const T *__restrict__ in0,
auto data = ElementwiseDataWrapper<ET, 1, T>(out, in0, in1);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = tid < size ? 1 : 0;
ScalarKernelImpl(data, size, func, tid, remain);
ScalarKernelImpl(data, func, tid, remain);
}

template <ElementwiseType ET, typename T, typename Functor>
Expand All @@ -173,7 +181,7 @@ void LaunchElementwiseCudaKernel(
// calculate the max vec_size for all ins and outs
auto size = ins[0]->numel();
int vec_size = GetVectorizedSize<T>(ins, *outs);
int block_size = PADDLE_CUDA_THREAD_SIZE;
int block_size = ELEMENTWISE_BLOCK_SIZE;
int grid_size =
((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
const T *in0 = ins[0]->data<T>();
Expand Down