Skip to content

Commit

Permalink
merge ET
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Sep 3, 2021
1 parent 2af5eab commit a68841c
Showing 1 changed file with 42 additions and 109 deletions.
151 changes: 42 additions & 109 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,102 +69,51 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
return vec_size;
}

template <int VecSize, typename InT, typename OutT, typename Functor>
__global__ void ElementVectorizedUnary(const InT *__restrict__ in0, OutT *out,
int size, Functor func) {
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor, bool IsBoundary>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, ET> &in, OutT *out, int num,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
// data offset of this block
int num = size - data_offset;
num = (VecSize * blockDim.x) > num ? num : VecSize * blockDim.x;
// the num this time have to deal with
InT args[VecSize];
InT args[ET][VecSize];
OutT result[VecSize];
const bool is_reminder = true;
if (VecSize * blockDim.x > num) { // reminder segment
kps::Init<InT, VecSize>(&args[0], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, is_reminder>(args, in0 + data_offset,
num);
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args,
func);
kps::WriteData<OutT, VecSize, 1, 1, is_reminder>(out + data_offset, result,
num);
} else { // complete segment
kps::ReadData<InT, VecSize, 1, 1>(args, in0 + data_offset, num);
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args,
func);
kps::WriteData<OutT, VecSize, 1, 1>(out + data_offset, result, num);
// load data
#pragma unroll
for (int i = 0; i < ET; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
num);
}
}

template <int VecSize, typename InT, typename OutT, typename Functor>
__global__ void ElementVectorizedBinary(const InT *__restrict__ in0,
const InT *__restrict__ in1, OutT *out,
int size, Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
// data offset of this block
int num = size - data_offset;
num = (VecSize * blockDim.x) > num ? num : VecSize * blockDim.x;
// the num this time have to deal with
InT arg0[VecSize];
InT arg1[VecSize];
OutT result[VecSize];

const bool is_reminder = true;
if (VecSize * blockDim.x > num) { // reminder segment
kps::Init<InT, VecSize>(&arg0[0], static_cast<InT>(1.0f));
kps::Init<InT, VecSize>(&arg1[0], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, is_reminder>(&arg0[0], in0 + data_offset,
num);
kps::ReadData<InT, VecSize, 1, 1, is_reminder>(&arg1[0], in1 + data_offset,
num);
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, &arg0[0],
&arg1[0], func);
kps::WriteData<OutT, VecSize, 1, 1, is_reminder>(out + data_offset, result,
num);
} else { // complete segment
kps::ReadData<InT, VecSize, 1, 1>(&arg0[0], in0 + data_offset, num);
kps::ReadData<InT, VecSize, 1, 1>(&arg1[0], in1 + data_offset, num);
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, &arg0[0],
&arg1[0], func);
kps::WriteData<OutT, VecSize, 1, 1>(out + data_offset, result, num);
// compute
if (ET == kUnary) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
func);
} else if (ET == kBinary) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
} else {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}

// store
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result,
num);
}

template <int VecSize, typename InT, typename OutT, typename Functor>
__global__ void ElementVectorizedTernary(const InT *__restrict__ in0,
const InT *__restrict__ in1,
const InT *__restrict__ in2, OutT *out,
int size, Functor func) {
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__global__ void ElementVectorizeKernel(
framework::Array<const InT *__restrict__, ET> in, OutT *out, int size,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
// data offset of this block
int num = size - data_offset;
num = (VecSize * blockDim.x) > num ? num : VecSize * blockDim.x;
// the num this time have to deal with
InT args[3][VecSize];
OutT result[VecSize];

const bool is_reminder = true;
if (VecSize * blockDim.x > num) { // reminder segment
kps::Init<InT, VecSize>(args[0], static_cast<InT>(1.0f));
kps::Init<InT, VecSize>(args[1], static_cast<InT>(1.0f));
kps::Init<InT, VecSize>(args[2], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, is_reminder>(args[0], in0 + data_offset,
num);
kps::ReadData<InT, VecSize, 1, 1, is_reminder>(args[1], in1 + data_offset,
num);
kps::ReadData<InT, VecSize, 1, 1, is_reminder>(args[2], in2 + data_offset,
num);
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
kps::WriteData<OutT, VecSize, 1, 1, is_reminder>(out + data_offset, result,
num);
} else {
kps::ReadData<InT, VecSize, 1, 1>(args[0], in0 + data_offset, num);
kps::ReadData<InT, VecSize, 1, 1>(args[1], in1 + data_offset, num);
kps::ReadData<InT, VecSize, 1, 1>(args[2], in2 + data_offset, num);
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
kps::WriteData<OutT, VecSize, 1, 1>(out + data_offset, result, num);
DealSegment<ET, VecSize, InT, OutT, Functor, true>(in, out, num, func);
} else { // complete segment
DealSegment<ET, VecSize, InT, OutT, Functor, false>(in, out, num, func);
}
}

Expand All @@ -178,32 +127,16 @@ void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
int block_size = GetThreadsConfig(ctx, numel, VecSize);
int grid_size =
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
const InT *in0 = ins[0]->data<InT>();
OutT *out = (*outs)[0]->data<OutT>();
// cuda kernel

auto stream = ctx.stream();
switch (ET) {
case ElementwiseType::kTernary:
ElementVectorizedTernary<VecSize, InT, OutT,
Functor><<<grid_size, block_size, 0, stream>>>(
in0, ins[1]->data<InT>(), ins[2]->data<InT>(), out, numel, func);
break;
case ElementwiseType::kBinary:
ElementVectorizedBinary<VecSize, InT, OutT,
Functor><<<grid_size, block_size, 0, stream>>>(
in0, ins[1]->data<InT>(), out, numel, func);
break;
case ElementwiseType::kUnary:
ElementVectorizedUnary<VecSize, InT, OutT,
Functor><<<grid_size, block_size, 0, stream>>>(
in0, out, numel, func);
break;
default: {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported this ElementwiseType : %d !", ET));
break;
}
OutT *out = (*outs)[0]->data<OutT>();
framework::Array<const InT *__restrict__, ET> in;
for (int i = 0; i < ET; i++) {
in[i] = ins[i]->data<InT>();
}
ElementVectorizeKernel<ET, VecSize, InT, OutT,
Functor><<<grid_size, block_size, 0, stream>>>(
in, out, numel, func);
}

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
Expand Down

0 comments on commit a68841c

Please sign in to comment.