Skip to content

Commit

Permalink
Unify the implementation of AlignedVector and simplify the codes of d…
Browse files Browse the repository at this point in the history
…ropout and cast. (#35373)
  • Loading branch information
Xreki authored Sep 3, 2021
1 parent a9dfebb commit c171eca
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 172 deletions.
38 changes: 11 additions & 27 deletions paddle/fluid/operators/cast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_config.h"

namespace paddle {
namespace operators {

// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};

template <typename T>
inline int VectorizedSize(const T* pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
}
return 1;
}

template <typename InT, typename OutT, int VecSize>
__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
using LoadT = platform::AlignedVector<InT, VecSize>;
using StoreT = platform::AlignedVector<OutT, VecSize>;

int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = AlignedVector<InT, VecSize>;
using StoreT = AlignedVector<OutT, VecSize>;
for (int64_t i = idx * VecSize; i < N;
i += blockDim.x * gridDim.x * VecSize) {
InT in_vec[VecSize];
LoadT* in_value = reinterpret_cast<LoadT*>(&in_vec);
*in_value = *reinterpret_cast<const LoadT*>(&in[i]);
LoadT in_val;
platform::Load<InT, VecSize>(&in[i], &in_val);

OutT out_vec[VecSize];
StoreT out_val;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
out_vec[ii] = static_cast<OutT>(in_vec[ii]);
for (int j = 0; j < VecSize; j++) {
out_val[j] = static_cast<OutT>(in_val[j]);
}

*(reinterpret_cast<StoreT*>(&out[i])) =
*reinterpret_cast<StoreT*>(&out_vec[0]);
platform::Store<OutT, VecSize>(out_val, &out[i]);
}
}

Expand All @@ -78,7 +62,7 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
auto* out = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx_, size);
int vec_size = VectorizedSize<OutT>(out);
int vec_size = platform::GetVectorizedSize<OutT>(out);
if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
VecCastCUDAKernel<InT, OutT, 4><<<
config.block_per_grid, config.thread_per_block, 0, ctx_.stream()>>>(
Expand Down
70 changes: 30 additions & 40 deletions paddle/fluid/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace operators {
template <typename T, typename MaskType>
__global__ void RandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
MaskType* mask, T* dst,
bool is_upscale_in_train, uint64_t increment) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
#ifdef PADDLE_WITH_HIP
Expand All @@ -49,36 +49,36 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed,
curand_init(seed, idx, increment, &state);
#endif

MaskType mask;
T dest;
MaskType mask_val;
T dst_val;
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
T src_val = src[idx];
#ifdef PADDLE_WITH_HIP
if (hiprand_uniform(&state) < dropout_prob) {
#else
if (curand_uniform(&state) < dropout_prob) {
#endif
mask = 0;
dest = 0;
mask_val = 0;
dst_val = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
dest = s;
}
mask_val = 1;
dst_val = is_upscale_in_train ? src_val * factor : src_val;
}
mask_data[idx] = mask;
dst[idx] = dest;
mask[idx] = mask_val;
dst[idx] = dst_val;
}
}

template <typename T, typename MaskType, int VecSize>
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
const T* src, MaskType* mask, T* dst,
bool is_upscale_in_train,
uint64_t increment) {
using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;

#ifdef PADDLE_WITH_HIP
int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
hiprandStatePhilox4_32_10_t state;
Expand All @@ -89,43 +89,33 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
curand_init(seed, idx, increment, &state);
#endif

MaskType mask;
T dest;
using LoadT = AlignedVector<T, VecSize>;
using MaskLoadT = AlignedVector<MaskType, VecSize>;
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
T src_vec[VecSize];
LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
*value = *reinterpret_cast<const LoadT*>(&src[i]);
LoadT src_val;
platform::Load<T, VecSize>(&src[i], &src_val);

#ifdef PADDLE_WITH_HIP
float4 rand = hiprand_uniform4(&state);
#else
float4 rand = curand_uniform4(&state);
#endif

T dest_vec[VecSize];
MaskType mask_vec[VecSize];
LoadT dst_val;
MaskLoadT mask_val;

#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
if ((&rand.x)[ii] < dropout_prob) {
dest_vec[ii] = 0;
mask_vec[ii] = 0;
for (int j = 0; j < VecSize; j++) {
if ((&rand.x)[j] < dropout_prob) {
dst_val[j] = 0;
mask_val[j] = 0;
} else {
if (is_upscale_in_train) {
dest_vec[ii] = src_vec[ii] * factor;
} else {
dest_vec[ii] = src_vec[ii];
}
mask_vec[ii] = 1;
dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j];
mask_val[j] = 1;
}
}

*(reinterpret_cast<LoadT*>(&dst[i])) =
*reinterpret_cast<LoadT*>(&dest_vec[0]);
*(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
*reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
platform::Store<T, VecSize>(dst_val, &dst[i]);
platform::Store<MaskType, VecSize>(mask_val, &mask[i]);
}
}

Expand Down Expand Up @@ -185,7 +175,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
// same as the previous calls.
uint64_t seed_data;
uint64_t increment;
int vec_size = VectorizedSize<T>(x_data);
int vec_size = platform::GetVectorizedSize<T>(x_data);
auto offset = ((x_numel - 1) / (config.block_per_grid.x *
config.thread_per_block.x * vec_size) +
1) *
Expand Down
44 changes: 13 additions & 31 deletions paddle/fluid/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,36 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/gpu_launch_config.h"

namespace paddle {
namespace operators {

// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};

template <typename T>
inline int VectorizedSize(const T* pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
}
return 1;
}

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename MaskType, int VecSize>
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
const T factor, const int64_t size,
T* dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

using LoadT = AlignedVector<T, VecSize>;
using MaskLoadT = AlignedVector<MaskType, VecSize>;
using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;

int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
T dout_vec[VecSize];
LoadT* dout_value = reinterpret_cast<LoadT*>(&dout_vec);
*dout_value = *reinterpret_cast<const LoadT*>(&dout[i]);
LoadT dout_val;
platform::Load<T, VecSize>(&dout[i], &dout_val);

MaskType mask_vec[VecSize];
MaskLoadT* mask_value = reinterpret_cast<MaskLoadT*>(&mask_vec);
*mask_value = *reinterpret_cast<const MaskLoadT*>(&mask[i]);
MaskLoadT mask_val;
platform::Load<MaskType, VecSize>(&mask[i], &mask_val);

T dx_vec[VecSize];
LoadT dx_val;

#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
for (int j = 0; j < VecSize; j++) {
dx_val[j] = dout_val[j] * static_cast<T>(mask_val[j]) * factor;
}

*(reinterpret_cast<LoadT*>(&dx[i])) = *reinterpret_cast<LoadT*>(&dx_vec[0]);
platform::Store<T, VecSize>(dx_val, &dx[i]);
}
}
#endif
Expand Down Expand Up @@ -187,7 +169,7 @@ class DropoutGradKernel : public framework::OpKernel<T> {
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
int vec_size = VectorizedSize<T>(grad_y->data<T>());
int vec_size = platform::GetVectorizedSize<T>(grad_y->data<T>());
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
size % 4 == 0) {
#if defined(__NVCC__) || defined(__HIPCC__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ struct StridesCalculation {
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int VecSize, int kDims>
struct BroadcastArgsWrapper {
using InVecType = platform::CudaAlignedVector<InT, VecSize>;
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
using InVecType = platform::AlignedVector<InT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;

OutT *out_data;
OutVecType *vec_out_data;
Expand Down Expand Up @@ -320,7 +320,7 @@ template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET, int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWrapper broadcast_wrapper, int tid) {
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;
OutVecType args_out;
InT ins[ET];
InT args[ET][VecSize];
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,

template <ElementwiseType ET, int VecSize, typename InT, typename OutT>
struct ElementwiseDataWrapper {
using InVecType = platform::CudaAlignedVector<InT, VecSize>;
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
using InVecType = platform::AlignedVector<InT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;

const InT *__restrict__ in_data[ET];
OutT *out_data;
Expand Down Expand Up @@ -117,8 +117,8 @@ template <ElementwiseType ET, int VecSize, typename ElementwiseWrapper,
typename InT, typename OutT, typename Functor>
__device__ inline void VectorizedKernelImpl(ElementwiseWrapper data,
Functor func, int tid) {
using InVecType = platform::CudaAlignedVector<InT, VecSize>;
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
using InVecType = platform::AlignedVector<InT, VecSize>;
using OutVecType = platform::AlignedVector<OutT, VecSize>;
InVecType ins_vec[ET];
OutVecType out_vec;
InT *ins_ptr[ET];
Expand Down
29 changes: 3 additions & 26 deletions paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,36 +96,13 @@ __global__ void BroadcastKernelBinary(
kernel_primitives::WriteData<OutT, VecSize, 1, 1>(out + fix, result, num);
}

template <typename T>
int GetVectorizedSizeImpl(const T* pointer) {
constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec8 =
std::alignment_of<platform::CudaAlignedVector<T, 8>>::value; // NOLINT
constexpr int vec4 =
std::alignment_of<platform::CudaAlignedVector<T, 4>>::value; // NOLINT
constexpr int vec2 =
std::alignment_of<platform::CudaAlignedVector<T, 2>>::value; // NOLINT
if (address % vec8 == 0) {
// Note: this line can change from 4 to 8 if it can improve the performance.
return std::min(4, valid_vec_size);
} else if (address % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0) {
return std::min(2, valid_vec_size);
} else {
return 1;
}
}

// bias add forward impl for "[m, n] + [n] = [m, n]"
template <typename T>
void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
const T* in0, const T* in1, T* out) {
int in_vec_size =
std::min(GetVectorizedSizeImpl<T>(in0), GetVectorizedSizeImpl<T>(in1));
int out_vec_size = std::min(4, GetVectorizedSizeImpl<T>(out));
int in_vec_size = std::min(platform::GetVectorizedSize<T>(in0),
platform::GetVectorizedSize<T>(in1));
int out_vec_size = std::min(4, platform::GetVectorizedSize<T>(out));
int vec_size = std::min(out_vec_size, in_vec_size);

int numel = m * n;
Expand Down
Loading

0 comments on commit c171eca

Please sign in to comment.