Skip to content

Commit

Permalink
add dirichlet random sample op in cpu and gpu kernel (PaddlePaddle#38244
Browse files Browse the repository at this point in the history
)

* add dirichlet sample op and cpu backend kernel

* add Dirichlet op cuda kernel  (#6)

* add dirichlet op hip kernel

Co-authored-by: Feiyu Chan <chenfeiyu@baidu.com>
  • Loading branch information
cxxly and Feiyu Chan authored Dec 30, 2021
1 parent cc83c95 commit c5bf09b
Show file tree
Hide file tree
Showing 4 changed files with 429 additions and 0 deletions.
125 changes: 125 additions & 0 deletions paddle/fluid/operators/dirichlet_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/dirichlet_op.h"

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"

namespace paddle {
namespace operators {
template <typename T, typename UniformSamplerT, typename NormalSamplerT>
struct GammaCPUFunctor {
GammaCPUFunctor(const T* alpha, T* gamma,
BaseSampler<T, UniformSamplerT> uniform,
BaseSampler<T, NormalSamplerT> normal)
: alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {}

HOST void operator()(int64_t index) {
auto sample = sample_gamma<T, T, UniformSamplerT, NormalSamplerT>(
alpha_[index], uniform_, normal_);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
}

const T* alpha_;
T* gamma_;
BaseSampler<T, UniformSamplerT> uniform_;
BaseSampler<T, NormalSamplerT> normal_;
};

template <typename T>
struct DirichletSampler<platform::CPUDeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx, const Tensor* alpha,
Tensor* out) {
auto& dev_ctx = ctx.device_context<platform::CPUDeviceContext>();

auto p_gen = framework::DefaultCPUGenerator();
auto generator = p_gen->GetCPUEngine();

auto uniform = [&generator]() -> T {
std::uniform_real_distribution<T> u(0.0, 1.0);
return u(*generator);
};
BaseSampler<T, decltype(uniform)> standard_uniform(uniform);

auto normal = [&generator]() {
std::normal_distribution<T> n(0.0, 1.0);
return n(*generator);
};
BaseSampler<T, decltype(normal)> standard_normal(normal);

// sample from K gamma distributions, where K=alpha.numel()
framework::Tensor gamma_samples;
gamma_samples.mutable_data<T>(alpha->dims(), dev_ctx.GetPlace());
GammaCPUFunctor<T, decltype(uniform), decltype(normal)> gamma_functor(
alpha->data<T>(), gamma_samples.data<T>(), standard_uniform,
standard_normal);
platform::ForRange<platform::CPUDeviceContext> for_range(dev_ctx,
alpha->numel());
for_range(gamma_functor);

// normalize them into a simplex, along the last axis
framework::Tensor gamma_sum;
auto new_shape = gamma_samples.dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum.mutable_data<T>(new_shape, dev_ctx.GetPlace());

ReduceKernelFunctor<platform::CPUDeviceContext, T, SumFunctor>(
&gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx)
.template apply<T>();
ElementwiseComputeEx<DivFunctor<T>, platform::CPUDeviceContext, T, T>(
ctx, &gamma_samples, &gamma_sum, -1, DivFunctor<T>(), out);
}
};

class DirichletOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Alpha", "(Tensor), The dirichlet Alpha parameter");
AddOutput("Out", "(Tensor), The output tensor of sample");
AddComment(R"DOC(Sample random data from dirichlet distribution.)DOC");
}
};

class DirichletOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "dirichlet");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "dirichlet");
const auto alpha_dim = ctx->GetInputDim("Alpha");
PADDLE_ENFORCE_GE(alpha_dim.size(), 1,
platform::errors::InvalidArgument(
"ShapeError: The number of dimensions of 'Alpha' "
"must be greater than or euqal to 1. "
"But received Alpha's dimensions = %d,",
alpha_dim.size()));
ctx->ShareDim("Alpha", /*->*/ "Out");
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(dirichlet, paddle::operators::DirichletOp,
paddle::operators::DirichletOpMaker);
REGISTER_OP_CPU_KERNEL(
dirichlet,
paddle::operators::DirichletKernel<paddle::platform::CPUDeviceContext,
float>,
paddle::operators::DirichletKernel<paddle::platform::CPUDeviceContext,
double>);
115 changes: 115 additions & 0 deletions paddle/fluid/operators/dirichlet_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/dirichlet_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/platform/for_range.h"

#ifdef PADDLE_WITH_CUDA
#include <curand_kernel.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hiprand_kernel.h>
#endif

#if defined(PADDLE_WITH_CUDA)
using COMPAT_RANDSTATEPHILOX4_32_10_T = curandStatePhilox4_32_10_t;
#define COMPAT_RAND_INIT curand_init
#define COMPAT_RAND_UNIFORM curand_uniform
#define COMPAT_RAND_NORMAL curand_normal
#elif defined(PADDLE_WITH_HIP)
using COMPAT_RANDSTATEPHILOX4_32_10_T = hiprandStatePhilox4_32_10_t;
#define COMPAT_RAND_INIT hiprand_init
#define COMPAT_RAND_UNIFORM hiprand_uniform
#define COMPAT_RAND_NORMAL hiprand_normal
#endif

namespace paddle {
namespace operators {
template <typename T>
struct GammaCUDAFunctor {
GammaCUDAFunctor(const T* alpha, T* gamma, uint64_t seed, uint64_t offset)
: alpha_(alpha), gamma_(gamma), seed_(seed), offset_(offset) {}

DEVICE void operator()(int64_t index) {
// curand initialization
COMPAT_RANDSTATEPHILOX4_32_10_T state;
COMPAT_RAND_INIT(/*seed=*/seed_, /*subsequence=*/index, /*offset=*/offset_,
&state);

// sample
auto uniform_lambda = [&state]() { return COMPAT_RAND_UNIFORM(&state); };
BaseSampler<T, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
auto normal_lambda = [&state]() { return COMPAT_RAND_NORMAL(&state); };
BaseSampler<T, decltype(normal_lambda)> standard_normal(normal_lambda);

auto sample =
sample_gamma<T, T, decltype(uniform_lambda), decltype(normal_lambda)>(
alpha_[index], standard_uniform, standard_normal);
gamma_[index] = std::max(std::numeric_limits<T>::min(), sample);
}

const T* alpha_;
T* gamma_;
const uint64_t seed_;
const uint64_t offset_;
};

template <typename T>
struct DirichletSampler<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* alpha, framework::Tensor* out) {
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();

// init state, seed & offset for all threads
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto p_gen = framework::GetDefaultCUDAGenerator(device_id);
auto seed_and_offset = p_gen->IncrementOffset(10); // hard-coded offset
auto seed = seed_and_offset.first;
auto offset = seed_and_offset.second;

// sample from K gamma distributions, where K=alpha.numel()
framework::Tensor gamma_samples;
gamma_samples.mutable_data<T>(alpha->dims(), dev_ctx.GetPlace());
GammaCUDAFunctor<T> gamma_functor(alpha->data<T>(), gamma_samples.data<T>(),
seed, offset);
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
out->numel());
for_range(gamma_functor);

// normalize them into a simplex, along the last axis
framework::Tensor gamma_sum;
auto new_shape = gamma_samples.dims();
new_shape[new_shape.size() - 1] = 1;
gamma_sum.mutable_data<T>(new_shape, dev_ctx.GetPlace());

ReduceKernelFunctor<platform::CUDADeviceContext, T, SumFunctor>(
&gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx)
.template apply<T>();
ElementwiseComputeEx<DivFunctor<T>, platform::CUDADeviceContext, T, T>(
ctx, &gamma_samples, &gamma_sum, -1, DivFunctor<T>(), out);
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(
dirichlet, ops::DirichletKernel<paddle::platform::CUDADeviceContext, float>,
ops::DirichletKernel<paddle::platform::CUDADeviceContext, double>);
129 changes: 129 additions & 0 deletions paddle/fluid/operators/dirichlet_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <cmath>
#include <random>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"

// ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(PADDLE_WITH_CUDA)
#define COMPAT_EXP exp
#define COMPAT_CEIL ceil
#define COMPAT_FLOOR floor
#define COMPAT_LOG log
#define COMPAT_POW pow
#define COMPAT_SQRT sqrt
#define COMPAT_TAN tan
#define COMPAT_ABS abs
#define COMPAT_LOG1P log1p
#else
#define COMPAT_EXP std::exp
#define COMPAT_CEIL std::ceil
#define COMPAT_FLOOR std::floor
#define COMPAT_LOG std::log
#define COMPAT_POW std::pow
#define COMPAT_SQRT std::sqrt
#define COMPAT_TAN std::tan
#define COMPAT_ABS std::abs
#define COMPAT_LOG1P std::log1p
#endif

namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
struct DirichletSampler;

template <typename ScalarT, typename SamplerT>
struct BaseSampler {
SamplerT sampler_;
HOSTDEVICE BaseSampler(const SamplerT& sampler) : sampler_(sampler) {}
HOSTDEVICE ScalarT sample() { return sampler_(); }
};

// `sample_gamma` is d from Numpy's distributions.c, and add support for
// paddle data type and code style.
// Source MIT licensed:
/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

template <typename ScalarT, typename AccscalarT, typename UniformSamplerT,
typename NormalSamplerT>
HOSTDEVICE ScalarT sample_gamma(
ScalarT alpha, BaseSampler<AccscalarT, UniformSamplerT> standard_uniform,
BaseSampler<AccscalarT, NormalSamplerT> standard_normal) {
AccscalarT scale = 1.0f;

// Boost alpha for higher acceptance probability.
if (alpha < 1.0f) {
if (alpha == 0.f) return 0.f;
scale *= COMPAT_POW(1 - standard_uniform.sample(), 1.0f / alpha);
alpha += 1.0f;
}

// This implements the acceptance-rejection method of Marsaglia and Tsang
// (2000)
// doi:10.1145/358407.358414
const AccscalarT d = alpha - 1.0f / 3.0f;
const AccscalarT c = 1.0f / COMPAT_SQRT(9.0f * d);
for (;;) {
AccscalarT x, y;
do {
x = standard_normal.sample();
y = 1.0f + c * x;
} while (y <= 0);
const AccscalarT v = y * y * y;
const AccscalarT u = 1 - standard_uniform.sample();
const AccscalarT xx = x * x;
if (u < 1.0f - 0.0331f * xx * xx)
return static_cast<ScalarT>(scale * d * v);
if (COMPAT_LOG(u) < 0.5f * xx + d * (1.0f - v + COMPAT_LOG(v)))
return static_cast<ScalarT>(scale * d * v);
}
}

template <typename DeviceContext, typename T>
class DirichletKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* alpha = ctx.Input<framework::Tensor>("Alpha");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());

DirichletSampler<DeviceContext, T> sampler;
sampler(ctx, alpha, out);
}
};
} // namespace operators
} // namespace paddle
Loading

0 comments on commit c5bf09b

Please sign in to comment.