diff --git a/paddle/fluid/operators/dirichlet_op.cc b/paddle/fluid/operators/dirichlet_op.cc new file mode 100644 index 0000000000000..f981660165717 --- /dev/null +++ b/paddle/fluid/operators/dirichlet_op.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dirichlet_op.h" + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" + +namespace paddle { +namespace operators { +template +struct GammaCPUFunctor { + GammaCPUFunctor(const T* alpha, T* gamma, + BaseSampler uniform, + BaseSampler normal) + : alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {} + + HOST void operator()(int64_t index) { + auto sample = sample_gamma( + alpha_[index], uniform_, normal_); + gamma_[index] = std::max(std::numeric_limits::min(), sample); + } + + const T* alpha_; + T* gamma_; + BaseSampler uniform_; + BaseSampler normal_; +}; + +template +struct DirichletSampler { + void operator()(const framework::ExecutionContext& ctx, const Tensor* alpha, + Tensor* out) { + auto& dev_ctx = ctx.device_context(); + + auto p_gen = framework::DefaultCPUGenerator(); + auto generator = p_gen->GetCPUEngine(); + + auto uniform = [&generator]() -> T { + std::uniform_real_distribution u(0.0, 1.0); + return u(*generator); + }; + BaseSampler standard_uniform(uniform); + + auto normal = [&generator]() { + std::normal_distribution n(0.0, 1.0); + return n(*generator); + }; + BaseSampler standard_normal(normal); + + // sample from K gamma distributions, where K=alpha.numel() + framework::Tensor gamma_samples; + gamma_samples.mutable_data(alpha->dims(), dev_ctx.GetPlace()); + GammaCPUFunctor gamma_functor( + alpha->data(), gamma_samples.data(), standard_uniform, + standard_normal); + platform::ForRange for_range(dev_ctx, + alpha->numel()); + for_range(gamma_functor); + + // normalize them into a simplex, along the last axis + framework::Tensor gamma_sum; + auto new_shape = gamma_samples.dims(); + new_shape[new_shape.size() - 1] = 1; + gamma_sum.mutable_data(new_shape, dev_ctx.GetPlace()); + + ReduceKernelFunctor( + &gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx) + .template apply(); + ElementwiseComputeEx, platform::CPUDeviceContext, T, T>( + ctx, &gamma_samples, &gamma_sum, -1, DivFunctor(), out); + } +}; + +class DirichletOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Alpha", "(Tensor), The dirichlet Alpha parameter"); + AddOutput("Out", "(Tensor), The output tensor of sample"); + AddComment(R"DOC(Sample random data from dirichlet distribution.)DOC"); + } +}; + +class DirichletOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "dirichlet"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "dirichlet"); + const auto alpha_dim = ctx->GetInputDim("Alpha"); + PADDLE_ENFORCE_GE(alpha_dim.size(), 1, + platform::errors::InvalidArgument( + "ShapeError: The number of dimensions of 'Alpha' " + "must be greater than or euqal to 1. " + "But received Alpha's dimensions = %d,", + alpha_dim.size())); + ctx->ShareDim("Alpha", /*->*/ "Out"); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(dirichlet, paddle::operators::DirichletOp, + paddle::operators::DirichletOpMaker); +REGISTER_OP_CPU_KERNEL( + dirichlet, + paddle::operators::DirichletKernel, + paddle::operators::DirichletKernel); diff --git a/paddle/fluid/operators/dirichlet_op.cu b/paddle/fluid/operators/dirichlet_op.cu new file mode 100644 index 0000000000000..3e1d523ae0e15 --- /dev/null +++ b/paddle/fluid/operators/dirichlet_op.cu @@ -0,0 +1,115 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/operators/dirichlet_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" +#include "paddle/fluid/platform/for_range.h" + +#ifdef PADDLE_WITH_CUDA +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#endif + +#if defined(PADDLE_WITH_CUDA) +using COMPAT_RANDSTATEPHILOX4_32_10_T = curandStatePhilox4_32_10_t; +#define COMPAT_RAND_INIT curand_init +#define COMPAT_RAND_UNIFORM curand_uniform +#define COMPAT_RAND_NORMAL curand_normal +#elif defined(PADDLE_WITH_HIP) +using COMPAT_RANDSTATEPHILOX4_32_10_T = hiprandStatePhilox4_32_10_t; +#define COMPAT_RAND_INIT hiprand_init +#define COMPAT_RAND_UNIFORM hiprand_uniform +#define COMPAT_RAND_NORMAL hiprand_normal +#endif + +namespace paddle { +namespace operators { +template +struct GammaCUDAFunctor { + GammaCUDAFunctor(const T* alpha, T* gamma, uint64_t seed, uint64_t offset) + : alpha_(alpha), gamma_(gamma), seed_(seed), offset_(offset) {} + + DEVICE void operator()(int64_t index) { + // curand initialization + COMPAT_RANDSTATEPHILOX4_32_10_T state; + COMPAT_RAND_INIT(/*seed=*/seed_, /*subsequence=*/index, /*offset=*/offset_, + &state); + + // sample + auto uniform_lambda = [&state]() { return COMPAT_RAND_UNIFORM(&state); }; + BaseSampler standard_uniform(uniform_lambda); + auto normal_lambda = [&state]() { return COMPAT_RAND_NORMAL(&state); }; + BaseSampler standard_normal(normal_lambda); + + auto sample = + sample_gamma( + alpha_[index], standard_uniform, standard_normal); + gamma_[index] = std::max(std::numeric_limits::min(), sample); + } + + const T* alpha_; + T* gamma_; + const uint64_t seed_; + const uint64_t offset_; +}; + +template +struct DirichletSampler { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* alpha, framework::Tensor* out) { + auto& dev_ctx = ctx.device_context(); + + // init state, seed & offset for all threads + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); + auto p_gen = framework::GetDefaultCUDAGenerator(device_id); + auto seed_and_offset = p_gen->IncrementOffset(10); // hard-coded offset + auto seed = seed_and_offset.first; + auto offset = seed_and_offset.second; + + // sample from K gamma distributions, where K=alpha.numel() + framework::Tensor gamma_samples; + gamma_samples.mutable_data(alpha->dims(), dev_ctx.GetPlace()); + GammaCUDAFunctor gamma_functor(alpha->data(), gamma_samples.data(), + seed, offset); + platform::ForRange for_range(dev_ctx, + out->numel()); + for_range(gamma_functor); + + // normalize them into a simplex, along the last axis + framework::Tensor gamma_sum; + auto new_shape = gamma_samples.dims(); + new_shape[new_shape.size() - 1] = 1; + gamma_sum.mutable_data(new_shape, dev_ctx.GetPlace()); + + ReduceKernelFunctor( + &gamma_samples, &gamma_sum, {new_shape.size() - 1}, true, false, ctx) + .template apply(); + ElementwiseComputeEx, platform::CUDADeviceContext, T, T>( + ctx, &gamma_samples, &gamma_sum, -1, DivFunctor(), out); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + dirichlet, ops::DirichletKernel, + ops::DirichletKernel); diff --git a/paddle/fluid/operators/dirichlet_op.h b/paddle/fluid/operators/dirichlet_op.h new file mode 100644 index 0000000000000..540acad423aa3 --- /dev/null +++ b/paddle/fluid/operators/dirichlet_op.h @@ -0,0 +1,129 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +// ROCM hcc doesn't work well with using std:: in kernel functions +#if defined(PADDLE_WITH_CUDA) +#define COMPAT_EXP exp +#define COMPAT_CEIL ceil +#define COMPAT_FLOOR floor +#define COMPAT_LOG log +#define COMPAT_POW pow +#define COMPAT_SQRT sqrt +#define COMPAT_TAN tan +#define COMPAT_ABS abs +#define COMPAT_LOG1P log1p +#else +#define COMPAT_EXP std::exp +#define COMPAT_CEIL std::ceil +#define COMPAT_FLOOR std::floor +#define COMPAT_LOG std::log +#define COMPAT_POW std::pow +#define COMPAT_SQRT std::sqrt +#define COMPAT_TAN std::tan +#define COMPAT_ABS std::abs +#define COMPAT_LOG1P std::log1p +#endif + +namespace paddle { +namespace operators { +template +struct DirichletSampler; + +template +struct BaseSampler { + SamplerT sampler_; + HOSTDEVICE BaseSampler(const SamplerT& sampler) : sampler_(sampler) {} + HOSTDEVICE ScalarT sample() { return sampler_(); } +}; + +// `sample_gamma` is d from Numpy's distributions.c, and add support for +// paddle data type and code style. +// Source MIT licensed: +/* Copyright 2005 Robert Kern (robert.kern@gmail.com) + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +template +HOSTDEVICE ScalarT sample_gamma( + ScalarT alpha, BaseSampler standard_uniform, + BaseSampler standard_normal) { + AccscalarT scale = 1.0f; + + // Boost alpha for higher acceptance probability. + if (alpha < 1.0f) { + if (alpha == 0.f) return 0.f; + scale *= COMPAT_POW(1 - standard_uniform.sample(), 1.0f / alpha); + alpha += 1.0f; + } + + // This implements the acceptance-rejection method of Marsaglia and Tsang + // (2000) + // doi:10.1145/358407.358414 + const AccscalarT d = alpha - 1.0f / 3.0f; + const AccscalarT c = 1.0f / COMPAT_SQRT(9.0f * d); + for (;;) { + AccscalarT x, y; + do { + x = standard_normal.sample(); + y = 1.0f + c * x; + } while (y <= 0); + const AccscalarT v = y * y * y; + const AccscalarT u = 1 - standard_uniform.sample(); + const AccscalarT xx = x * x; + if (u < 1.0f - 0.0331f * xx * xx) + return static_cast(scale * d * v); + if (COMPAT_LOG(u) < 0.5f * xx + d * (1.0f - v + COMPAT_LOG(v))) + return static_cast(scale * d * v); + } +} + +template +class DirichletKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* alpha = ctx.Input("Alpha"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + DirichletSampler sampler; + sampler(ctx, alpha, out); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/distribution/test_dirichlet_op.py b/python/paddle/fluid/tests/unittests/distribution/test_dirichlet_op.py new file mode 100644 index 0000000000000..3e7662b573e0d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_dirichlet_op.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import re +import sys +import unittest + +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid.dygraph as dg +import paddle.static as static +import scipy.stats +from numpy.random import random as rand +sys.path.append("../") +from op_test import OpTest +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +class TestDirichletOp(OpTest): + # Because dirichlet random sample have not gradient, we skip gradient check. + no_need_check_grad = True + + def setUp(self): + self.op_type = "dirichlet" + self.alpha = np.array((1., 2.)) + self.sample_shape = (100000, 2) + + self.inputs = {'Alpha': np.broadcast_to(self.alpha, self.sample_shape)} + self.attrs = {} + self.outputs = {'Out': np.zeros(self.sample_shape)} + + def test_check_output(self): + self.check_output_customized(self._hypothesis_testing) + + def _hypothesis_testing(self, outs): + self.assertEqual(outs[0].shape, self.sample_shape) + self.assertTrue(np.all(outs[0] > 0.0)) + self.assertLess( + scipy.stats.kstest( + outs[0][:, 0], + # scipy dirichlet have not cdf, use beta to replace it. + scipy.stats.beta( + a=self.alpha[0], b=self.alpha[1]).cdf)[0], + 0.01)