Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#32 from sabreshao/0.15.0_float16
Browse files Browse the repository at this point in the history
Fix functions in cuda_device_function.h
  • Loading branch information
carlushuang authored Oct 9, 2018
2 parents f6df82e + 4fc6ced commit 119022a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
1 change: 1 addition & 0 deletions cmake/external/rocprim.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ExternalProject_Add(
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hcc
CMAKE_ARGS -DONLY_INSTALL=ON
CMAKE_ARGS -DBUILD_TEST=OFF
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${ROCPRIM_INSTALL_DIR}

INSTALL_DIR ${ROCPRIM_INSTALL_DIR}
Expand Down
56 changes: 41 additions & 15 deletions paddle/fluid/platform/cuda_device_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "hip/hip_fp16.h"
#include "hip/hip_runtime.h"
#include <type_traits>

namespace paddle {
namespace platform {
Expand All @@ -34,19 +35,6 @@ static __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int sr
return __shfl(val, src_line, width);
}

template <>
static __forceinline__ __device__ double CudaShuffleDownSync(unsigned mask, double val,
int delta, int width) {
return (float)__shfl_down((float)val, delta, width);
}

template <>
static __forceinline__ __device__ double CudaShuffleSync(unsigned mask, double val, int src_line,
int width) {
return (float)__shfl((float)val, src_line, width);
}


#if 0
template <typename T>
HOSTDEVICE T Infinity() {
Expand All @@ -55,15 +43,53 @@ HOSTDEVICE T Infinity() {
#endif

template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
typename std::enable_if<!std::is_integral<T>::value, T>::type
__device__ reduceSum(T val_in, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const int warpSize = 32;
__shared__ float shm[warpSize];
float val = val_in;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);

for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::CudaShuffleDownSync(mask, val, offset);

if (tid < warpSize) shm[tid] = 0;
__syncthreads();

if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
__syncthreads();

CREATE_SHFL_MASK(mask, tid < warpSize);

if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::CudaShuffleDownSync(mask, val, offset);
}
return val;
}

template <typename T>
typename std::enable_if<std::is_integral<T>::value, T>::type
__device__ reduceSum(T val_in, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const int warpSize = 32;
__shared__ T shm[warpSize];
__shared__ int shm[warpSize];
int val = val_in;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);

Expand Down

0 comments on commit 119022a

Please sign in to comment.