Skip to content

Commit

Permalink
【Hackathon No.34】优化 poisson op (#45160)
Browse files Browse the repository at this point in the history
* 【Hackathon No.34】优化 poisson op

* [poisson] code style fix

* modify code style

* prevent from big number

* modify code style

* modify code style

* modify import

* modify import

* modify code style
  • Loading branch information
Rayman96 authored Aug 24, 2022
1 parent a012d42 commit 3c14b09
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
8 changes: 8 additions & 0 deletions paddle/phi/backends/gpu/gpu_launch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig3D(const phi::GPUContext& context,
return config;
}

template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim =
reinterpret_cast<const phi::GPUContext&>(ctx).GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
grid_dim->z = grid_dim->z < max_grid_dim[2] ? grid_dim->z : max_grid_dim[2];
}
} // namespace gpu
} // namespace backends
} // namespace phi
Expand Down
43 changes: 17 additions & 26 deletions paddle/phi/kernels/gpu/poisson_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,55 +20,46 @@ limitations under the License. */
#endif

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/poisson_kernel.h"

namespace phi {

template <typename T>
struct PoissonCudaFunctor {
public:
PoissonCudaFunctor(const T* in,
T* out,
unsigned int seed,
unsigned int offset)
: in_(in), out_(out), seed_(seed), offset_(offset) {}

__device__ void operator()(int64_t idx) {
__global__ void GetPoisson(
const T* in, T* out, const int N, unsigned int seed, unsigned int offset) {
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
#ifdef __NVCC__
curandStatePhilox4_32_10_t state;
curand_init(seed_, idx, offset_, &state);
out_[idx] = static_cast<T>(curand_poisson(&state, in_[idx]));
curand_init(seed, idx, offset, &state);
out[idx] = static_cast<T>(curand_poisson(&state, in[idx]));
#elif __HIPCC__
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed_, idx, offset_, &state);
out_[idx] = static_cast<T>(hiprand_poisson(&state, in_[idx]));
hiprand_init(seed, idx, offset, &state);
out[idx] = static_cast<T>(hiprand_poisson(&state, in[idx]));
#endif
}

private:
const T* in_;
T* out_;
const unsigned int seed_;
const unsigned int offset_;
};
}

template <typename T, typename Context>
void PoissonKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
auto size = x.numel();
const int size = x.numel();
const int kMaxBlockDim = 256;

int block_size = std::min(kMaxBlockDim, ctx.GetMaxThreadsPerBlock());
dim3 dim_block(block_size);
dim3 dim_grid((size + block_size - 1) / block_size);
phi::backends::gpu::LimitGridDim(ctx, &dim_grid);

auto gen_cuda = ctx.GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(20);
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;

phi::funcs::ForRange<Context> for_range(ctx, size);

PoissonCudaFunctor<T> functor(x_data, out_data, seed, offset);
for_range(functor);
GetPoisson<T><<<dim_grid, dim_block>>>(x_data, out_data, size, seed, offset);
}

} // namespace phi
Expand Down

0 comments on commit 3c14b09

Please sign in to comment.