Skip to content

Commit

Permalink
Improve use of Aligned_vector the same as PaddlePaddle#35373.
Browse files Browse the repository at this point in the history
  • Loading branch information
limin2021 committed Sep 13, 2021
1 parent 189ac39 commit 8c73637
Showing 1 changed file with 164 additions and 51 deletions.
215 changes: 164 additions & 51 deletions paddle/fluid/operators/dropout_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,105 @@ limitations under the License. */
namespace paddle {
namespace operators {

// template <typename T, typename MaskType>
// __global__ void RandomGenerator(const size_t n, uint64_t seed,
// const float dropout_prob, const T* src,
// MaskType* mask_data, T* dst,
// bool is_upscale_in_train, uint64_t increment)
// {
// int idx = blockDim.x * blockIdx.x + threadIdx.x;
// #ifdef PADDLE_WITH_HIP
// hiprandStatePhilox4_32_10_t state;
// hiprand_init(seed, idx, increment, &state);
// #else
// curandStatePhilox4_32_10_t state;
// curand_init(seed, idx, increment, &state);
// #endif

// MaskType mask;
// T dest;
// for (; idx < n; idx += blockDim.x * gridDim.x) {
// T s = src[idx];
// #ifdef PADDLE_WITH_HIP
// if (hiprand_uniform(&state) < dropout_prob) {
// #else
// if (curand_uniform(&state) < dropout_prob) {
// #endif
// mask = 0;
// dest = 0;
// } else {
// mask = 1;
// if (is_upscale_in_train) {
// dest = s / static_cast<T>(1.0f - dropout_prob);
// } else {
// dest = s;
// }
// }
// mask_data[idx] = mask;
// dst[idx] = dest;
// }
// }

// template <typename T, typename MaskType, int VecSize>
// __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
// const float dropout_prob,
// const T* src, MaskType* mask_data,
// T* dst, bool is_upscale_in_train,
// uint64_t increment) {
// #ifdef PADDLE_WITH_HIP
// int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
// hiprandStatePhilox4_32_10_t state;
// hiprand_init(seed, idx, increment, &state);
// #else
// int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
// curandStatePhilox4_32_10_t state;
// curand_init(seed, idx, increment, &state);
// #endif

// MaskType mask;
// T dest;
// using LoadT = platform::AlignedVector<T, VecSize>;
// using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
// T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
// for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
// T src_vec[VecSize];
// LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
// *value = *reinterpret_cast<const LoadT*>(&src[i]);
// #ifdef PADDLE_WITH_HIP
// float4 rand = hiprand_uniform4(&state);
// #else
// float4 rand = curand_uniform4(&state);
// #endif

// T dest_vec[VecSize];
// MaskType mask_vec[VecSize];

// #pragma unroll
// for (int ii = 0; ii < VecSize; ii++) {
// if ((&rand.x)[ii] < dropout_prob) {
// dest_vec[ii] = 0;
// mask_vec[ii] = 0;
// } else {
// if (is_upscale_in_train) {
// dest_vec[ii] = src_vec[ii] * factor;
// } else {
// dest_vec[ii] = src_vec[ii];
// }
// mask_vec[ii] = 1;
// }
// }

// *(reinterpret_cast<LoadT*>(&dst[i])) =
// *reinterpret_cast<LoadT*>(&dest_vec[0]);
// *(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
// *reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
// }
// }

template <typename T, typename MaskType>
__global__ void RandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
MaskType* mask, T* dst,
bool is_upscale_in_train, uint64_t increment) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
#ifdef PADDLE_WITH_HIP
Expand All @@ -51,36 +146,36 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed,
curand_init(seed, idx, increment, &state);
#endif

MaskType mask;
T dest;
MaskType mask_val;
T dst_val;
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
T src_val = src[idx];
#ifdef PADDLE_WITH_HIP
if (hiprand_uniform(&state) < dropout_prob) {
#else
if (curand_uniform(&state) < dropout_prob) {
#endif
mask = 0;
dest = 0;
mask_val = 0;
dst_val = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
dest = s;
}
mask_val = 1;
dst_val = is_upscale_in_train ? src_val * factor : src_val;
}
mask_data[idx] = mask;
dst[idx] = dest;
mask[idx] = mask_val;
dst[idx] = dst_val;
}
}

template <typename T, typename MaskType, int VecSize>
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
const T* src, MaskType* mask, T* dst,
bool is_upscale_in_train,
uint64_t increment) {
using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;

#ifdef PADDLE_WITH_HIP
int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
hiprandStatePhilox4_32_10_t state;
Expand All @@ -91,72 +186,90 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
curand_init(seed, idx, increment, &state);
#endif

MaskType mask;
T dest;
using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
T src_vec[VecSize];
LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
*value = *reinterpret_cast<const LoadT*>(&src[i]);
LoadT src_val;
platform::Load<T, VecSize>(&src[i], &src_val);

#ifdef PADDLE_WITH_HIP
float4 rand = hiprand_uniform4(&state);
#else
float4 rand = curand_uniform4(&state);
#endif

T dest_vec[VecSize];
MaskType mask_vec[VecSize];
LoadT dst_val;
MaskLoadT mask_val;

#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
if ((&rand.x)[ii] < dropout_prob) {
dest_vec[ii] = 0;
mask_vec[ii] = 0;
for (int j = 0; j < VecSize; j++) {
if ((&rand.x)[j] < dropout_prob) {
dst_val[j] = 0;
mask_val[j] = 0;
} else {
if (is_upscale_in_train) {
dest_vec[ii] = src_vec[ii] * factor;
} else {
dest_vec[ii] = src_vec[ii];
}
mask_vec[ii] = 1;
dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j];
mask_val[j] = 1;
}
}

*(reinterpret_cast<LoadT*>(&dst[i])) =
*reinterpret_cast<LoadT*>(&dest_vec[0]);
*(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
*reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
platform::Store<T, VecSize>(dst_val, &dst[i]);
platform::Store<MaskType, VecSize>(mask_val, &mask[i]);
}
}

// template <typename T, typename MaskType, int VecSize>
// __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
// const T factor, const int64_t size,
// T* dx) {
// int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

// using LoadT = platform::AlignedVector<T, VecSize>;
// using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;

// for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x *
// VecSize) {
// T dout_vec[VecSize];
// LoadT* dout_value = reinterpret_cast<LoadT*>(&dout_vec);
// *dout_value = *reinterpret_cast<const LoadT*>(&dout[i]);

// MaskType mask_vec[VecSize];
// MaskLoadT* mask_value = reinterpret_cast<MaskLoadT*>(&mask_vec);
// *mask_value = *reinterpret_cast<const MaskLoadT*>(&mask[i]);

// T dx_vec[VecSize];

// #pragma unroll
// for (int ii = 0; ii < VecSize; ii++) {
// dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
// }

// *(reinterpret_cast<LoadT*>(&dx[i])) =
// *reinterpret_cast<LoadT*>(&dx_vec[0]);
// }
// }

template <typename T, typename MaskType, int VecSize>
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
const T factor, const int64_t size,
T* dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;

int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
T dout_vec[VecSize];
LoadT* dout_value = reinterpret_cast<LoadT*>(&dout_vec);
*dout_value = *reinterpret_cast<const LoadT*>(&dout[i]);
LoadT dout_val;
platform::Load<T, VecSize>(&dout[i], &dout_val);

MaskType mask_vec[VecSize];
MaskLoadT* mask_value = reinterpret_cast<MaskLoadT*>(&mask_vec);
*mask_value = *reinterpret_cast<const MaskLoadT*>(&mask[i]);
MaskLoadT mask_val;
platform::Load<MaskType, VecSize>(&mask[i], &mask_val);

T dx_vec[VecSize];
LoadT dx_val;

#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
for (int j = 0; j < VecSize; j++) {
dx_val[j] = dout_val[j] * static_cast<T>(mask_val[j]) * factor;
}

*(reinterpret_cast<LoadT*>(&dx[i])) = *reinterpret_cast<LoadT*>(&dx_vec[0]);
platform::Store<T, VecSize>(dx_val, &dx[i]);
}
}

Expand Down

0 comments on commit 8c73637

Please sign in to comment.