Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Ternary ops in elmentwise and broadcast #33976

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"

Expand Down
125 changes: 62 additions & 63 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct DimensionsTransform {

struct StridesCalculation {
std::vector<std::vector<uint32_t>> strides;
std::vector<FastDivMod> divmoders;
std::vector<platform::FastDivMod> divmoders;

private:
// To calculate the strides of each input_tensor.
Expand All @@ -190,29 +190,29 @@ struct StridesCalculation {
strides.resize(N, std::vector<uint32_t>(dim_size, 1));

for (int i = 0; i < dim_size; ++i) {
divmoders[i] = FastDivMod(out_dims[i]);
divmoders[i] = platform::FastDivMod(out_dims[i]);
}
CalculateStrides(N, dim_size, in_dims);
}
};

template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int VecSize, int kDims>
struct BroadcastArgsWarpper {
using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
struct BroadcastArgsWrapper {
using InVecType = platform::CudaAlignedVector<InT, VecSize>;
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;

OutT *out_data;
OutVecType *vec_out_data;
const InT *__restrict__ in_data[ET];
const InVecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET];
FastDivMod divmoders[kDims];
platform::FastDivMod divmoders[kDims];
uint32_t strides[ET][framework::DDim::kMaxRank];
uint32_t scalar_cal_offset;
Functor func;

HOSTDEVICE BroadcastArgsWarpper(
HOSTDEVICE BroadcastArgsWrapper(
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int scalar_cal_offset, Functor func,
const StridesCalculation &offset_calculator)
Expand All @@ -227,7 +227,7 @@ struct BroadcastArgsWarpper {
out_data = out->data<OutT>();
vec_out_data = reinterpret_cast<OutVecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(FastDivMod));
kDims * sizeof(platform::FastDivMod));
}

__device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
Expand Down Expand Up @@ -302,61 +302,60 @@ struct BroadcastArgsWarpper {
}
};

template <typename InT, typename OutT, typename BroadcastArgsWarpper,
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
BroadcastArgsWrapper broadcast_wrapper, int tid) {
InT args[ET];
OutT args_out;
broadcast_warpper.LoadScalarizedData(args, tid);
broadcast_wrapper.LoadScalarizedData(args, tid);

#pragma unroll(ET)
for (int j = 1; j < ET; ++j) {
args_out = broadcast_warpper.func(args);
}
broadcast_warpper.StoreScalarizedData(args_out, tid);
// Calcualtion of the in_tensor data.
args_out = broadcast_wrapper.func(args);

broadcast_wrapper.StoreScalarizedData(args_out, tid);
}

template <typename InT, typename OutT, typename BroadcastArgsWarpper,
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET, int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
using OutVecType = CudaAlignedVector<OutT, VecSize>;
BroadcastArgsWrapper broadcast_wrapper, int tid) {
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
OutVecType args_out;
InT ins[ET];
InT args[ET][VecSize];
broadcast_warpper.LoadVectorizedData(args, tid);
broadcast_wrapper.LoadVectorizedData(args, tid);

#pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
ins[j] = args[j][i];
}
args_out.val[i] = broadcast_warpper.func(ins);
args_out.val[i] = broadcast_wrapper.func(ins);
}
broadcast_warpper.StoreVectorizedData(args_out, tid);
broadcast_wrapper.StoreVectorizedData(args_out, tid);
}

template <typename InT, typename OutT, typename BroadcastArgsWarpper,
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET, int VecSize>
__global__ void ElementwiseBroadcastKernel(
BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
BroadcastArgsWrapper broadcast_wrapper, int main_tid, int tail_tid) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Vectorized calculation of major data whose length is the max multipler of
// VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if (tid < main_tid) {
VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET, VecSize>(
broadcast_warpper, tid);
VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET, VecSize>(
broadcast_wrapper, tid);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if (tid < tail_tid) {
ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET>(
broadcast_warpper, tid);
ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET>(
broadcast_wrapper, tid);
}
}

Expand All @@ -367,7 +366,7 @@ void LaunchBroadcastKernelForDifferentDimSize(
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) {
int numel = out->numel();
const int threads = 256;
int threads = GetThreadsConfig(ctx, numel, VecSize);
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / VecSize;
int tail_tid = numel % VecSize;
Expand All @@ -380,75 +379,75 @@ void LaunchBroadcastKernelForDifferentDimSize(

switch (merge_dims.dim_size) {
case 1: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 1>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 1>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 2: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 2>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 2>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 3: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 3>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 3>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 4: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 4>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 4>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 5: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 5>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 5>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 6: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 6>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 6>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 7: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 7>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 7>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 8: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 8>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 8>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
default: {
Expand All @@ -473,11 +472,11 @@ void LaunchBroadcastElementwiseCudaKernel(
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
auto temp_size = platform::GetVectorizedSize<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
int out_vec_size = platform::GetVectorizedSize<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);

switch (vec_size) {
Expand Down
Loading