diff --git a/dpnp/backend.pxd b/dpnp/backend.pxd index f0bca00deb8b..b813d0468b81 100644 --- a/dpnp/backend.pxd +++ b/dpnp/backend.pxd @@ -66,6 +66,7 @@ cdef extern from "backend/backend_iface_fptr.hpp" namespace "DPNPFuncName": # n DPNP_FN_FLOOR DPNP_FN_FLOOR_DIVIDE DPNP_FN_FMOD + DPNP_FN_GAMMA DPNP_FN_GAUSSIAN DPNP_FN_HYPOT DPNP_FN_INVERT diff --git a/dpnp/backend/backend_iface_fptr.hpp b/dpnp/backend/backend_iface_fptr.hpp index 0ad87d8af3eb..9b1b0cd970e3 100644 --- a/dpnp/backend/backend_iface_fptr.hpp +++ b/dpnp/backend/backend_iface_fptr.hpp @@ -95,6 +95,7 @@ enum class DPNPFuncName : size_t DPNP_FN_FLOOR, /**< Used in numpy.floor() implementation */ DPNP_FN_FLOOR_DIVIDE, /**< Used in numpy.floor_divide() implementation */ DPNP_FN_FMOD, /**< Used in numpy.fmod() implementation */ + DPNP_FN_GAMMA, /**< Used in numpy.random.gamma() implementation */ DPNP_FN_GAUSSIAN, /**< Used in numpy.random.randn() implementation */ DPNP_FN_HYPOT, /**< Used in numpy.hypot() implementation */ DPNP_FN_INVERT, /**< Used in numpy.invert() implementation */ diff --git a/dpnp/backend/custom_kernels_random.cpp b/dpnp/backend/custom_kernels_random.cpp index b5ddf84be546..ffb319bff7a9 100644 --- a/dpnp/backend/custom_kernels_random.cpp +++ b/dpnp/backend/custom_kernels_random.cpp @@ -64,6 +64,25 @@ void custom_rng_exponential_c(void* result, _DataType beta, size_t size) event_out.wait(); } +template +void custom_rng_gamma_c(void* result, _DataType shape, _DataType scale, size_t size) +{ + if (!size) + { + return; + } + + // set displacement a + const _DataType a = (_DataType(0.0)); + + _DataType* result1 = reinterpret_cast<_DataType*>(result); + + mkl_rng::gamma<_DataType> distribution(shape, a, scale); + // perform generation + auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1); + event_out.wait(); +} + template void custom_rng_gaussian_c(void* result, _DataType mean, _DataType stddev, size_t size) { @@ -107,6 +126,9 @@ void func_map_init_random(func_map_t& fmap) fmap[DPNPFuncName::DPNP_FN_EXPONENTIAL][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_rng_exponential_c}; fmap[DPNPFuncName::DPNP_FN_EXPONENTIAL][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_rng_exponential_c}; + fmap[DPNPFuncName::DPNP_FN_GAMMA][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_rng_gamma_c}; + fmap[DPNPFuncName::DPNP_FN_GAMMA][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_rng_gamma_c}; + fmap[DPNPFuncName::DPNP_FN_GAUSSIAN][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_rng_gaussian_c}; fmap[DPNPFuncName::DPNP_FN_GAUSSIAN][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_rng_gaussian_c}; diff --git a/dpnp/random/_random.pyx b/dpnp/random/_random.pyx index 611bd676f61f..488b5a722569 100644 --- a/dpnp/random/_random.pyx +++ b/dpnp/random/_random.pyx @@ -45,6 +45,7 @@ cimport numpy __all__ = [ "dpnp_chisquare", "dpnp_exponential", + "dpnp_gamma", "dpnp_randn", "dpnp_random", "dpnp_srand", @@ -54,6 +55,7 @@ __all__ = [ ctypedef void(*fptr_custom_rng_chi_square_c_1out_t)(void *, int, size_t) ctypedef void(*fptr_custom_rng_exponential_c_1out_t)(void *, double, size_t) +ctypedef void(*fptr_custom_rng_gamma_c_1out_t)(void *, double, double, size_t) ctypedef void(*fptr_custom_rng_gaussian_c_1out_t)(void *, double, double, size_t) ctypedef void(*fptr_custom_rng_uniform_c_1out_t)(void *, long, long, size_t) @@ -111,6 +113,42 @@ cpdef dparray dpnp_exponential(double beta, size): return result +cpdef dparray dpnp_gamma(double shape, double scale, size): + """ + Returns an array populated with samples from gamma distribution. + + `dpnp_gamma` generates a matrix filled with random floats sampled from a + univariate gamma distribution of `shape` and `scale`. + + """ + + dtype = numpy.float64 + cdef dparray result + cdef DPNPFuncType param1_type + cdef DPNPFuncData kernel_data + cdef fptr_custom_rng_gamma_c_1out_t func + + if shape == 0.0 or scale==0.0: + result = dparray(size, dtype=dtype) + result.fill(0.0) + else: + # convert string type names (dparray.dtype) to C enum DPNPFuncType + param1_type = dpnp_dtype_to_DPNPFuncType(dtype) + + # get the FPTR data structure + kernel_data = get_dpnp_function_ptr(DPNP_FN_GAMMA, param1_type, param1_type) + + result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type) + # ceate result array with type given by FPTR data + result = dparray(size, dtype=result_type) + + func = kernel_data.ptr + # call FPTR function + func(result.get_data(), shape, scale, result.size) + + return result + + cpdef dparray dpnp_randn(dims): """ Returns an array populated with samples from standard normal distribution. diff --git a/dpnp/random/dpnp_iface_random.py b/dpnp/random/dpnp_iface_random.py index ce19a018751f..1a8229a479c6 100644 --- a/dpnp/random/dpnp_iface_random.py +++ b/dpnp/random/dpnp_iface_random.py @@ -1,545 +1,621 @@ -# cython: language_level=3 -# -*- coding: utf-8 -*- -# ***************************************************************************** -# Copyright (c) 2016-2020, Intel Corporation -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# - Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -# THE POSSIBILITY OF SUCH DAMAGE. -# ***************************************************************************** - -""" -Module Intel NumPy Random - -Set of functions to implement NumPy random module API - - .. seealso:: :meth:`numpy.random` - -""" - - -import dpnp -import numpy - -from dpnp.dparray import dparray -from dpnp.dpnp_utils import * -from dpnp.random._random import * - - -__all__ = [ - 'chisquare', - 'exponential', - 'rand', - 'ranf', - 'randint', - 'randn', - 'random', - 'random_integers', - 'random_sample', - 'seed', - 'sample', - 'uniform' -] - - -def chisquare(df, size=None): - """ - chisquare(df, size=None) - - Draw samples from a chi-square distribution. - - When `df` independent random variables, each with standard normal - distributions (mean 0, variance 1), are squared and summed, the - resulting distribution is chi-square (see Notes). This distribution - is often used in hypothesis testing. - - Parameters - ---------- - df : float - Number of degrees of freedom, must be > 0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``df`` is a scalar. Otherwise, - ``np.array(df).size`` samples are drawn. - - Returns - ------- - out : ndarray or scalar - Drawn samples from the parameterized chi-square distribution. - - Raises - ------ - ValueError - When `df` <= 0 or when an inappropriate `size` (e.g. ``size=-1``) - is given. - - Examples - -------- - >>> dpnp.random.chisquare(2,4) - array([ 1.89920014, 9.00867716, 3.13710533, 5.62318272]) # random - - """ - - if not use_origin_backend(df): - if size is None: - size = 1 - elif isinstance(size, tuple): - for dim in size: - if not isinstance(dim, int): - checker_throw_value_error("chisquare", "type(dim)", type(dim), int) - elif not isinstance(size, int): - checker_throw_value_error("chisquare", "type(size)", type(size), int) - - # TODO: - # array_like of floats for `df` - # add check for df array like, after adding array-like interface for df param - if df <= 0: - checker_throw_value_error("chisquare", "df", df, "positive") - # TODO: - # float to int, safe - return dpnp_chisquare(int(df), size) - - return call_origin(numpy.random.chisquare, df, size) - - -def exponential(scale=1.0, size=None): - """Exponential distribution. - - Draw samples from an exponential distribution. - - Its probability density function is - - .. math:: f(x; \\frac{1}{\\beta}) = \\frac{1}{\\beta} \\exp(-\\frac{x}{\\beta}), - - for ``x > 0`` and 0 elsewhere. :math:`\\beta` is the scale parameter, - which is the inverse of the rate parameter :math:`\\lambda = 1/\\beta`. - The rate parameter is an alternative, widely used parameterization - of the exponential distribution [3]_. - - The exponential distribution is a continuous analogue of the - geometric distribution. It describes many common situations, such as - the size of raindrops measured over many rainstorms [1]_, or the time - between page requests to Wikipedia [2]_. - - .. note:: - New code should use the ``exponential`` method of a ``default_rng()`` - instance instead; please see the :ref:`random-quick-start`. - - Parameters - ---------- - scale : float - The scale parameter, :math:`\\beta = 1/\\lambda`. Must be - non-negative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``scale`` is a scalar. Otherwise, - ``np.array(scale).size`` samples are drawn. - - Returns - ------- - out : dparray - Drawn samples from the parameterized exponential distribution. - - References - ---------- - .. [1] Peyton Z. Peebles Jr., "Probability, Random Variables and - Random Signal Principles", 4th ed, 2001, p. 57. - .. [2] Wikipedia, "Poisson process", - https://en.wikipedia.org/wiki/Poisson_process - .. [3] Wikipedia, "Exponential distribution", - https://en.wikipedia.org/wiki/Exponential_distribution - - """ - - if not use_origin_backend(scale): - if size is None: - size = 1 - elif isinstance(size, tuple): - for dim in size: - if not isinstance(dim, int): - checker_throw_value_error("exponential", "type(dim)", type(dim), int) - elif not isinstance(size, int): - checker_throw_value_error("exponential", "type(size)", type(size), int) - - if scale < 0: - checker_throw_value_error("exponential", "scale", scale, "non-negative") - - return dpnp_exponential(scale, size) - - return call_origin(numpy.random.exponential, scale, size) - - -def rand(d0, *dn): - """ - Create an array of the given shape and populate it - with random samples from a uniform distribution over [0, 1). - - Parameters - ---------- - d0, d1, …, dn : The dimensions of the returned array, must be non-negative. - - Returns - ------- - out : Random values. - - See Also - -------- - random - - """ - - if not use_origin_backend(d0): - dims = tuple([d0, *dn]) - - for dim in dims: - if not isinstance(dim, int): - checker_throw_value_error("rand", "type(dim)", type(dim), int) - return dpnp_random(dims) - - return call_origin(numpy.random.rand, d0, *dn) - - -def ranf(size): - """ - Return random floats in the half-open interval [0.0, 1.0). - This is an alias of random_sample. - - Parameters - ---------- - size : Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. - - Returns - ------- - out : Array of random floats of shape size. - - See Also - -------- - random - - """ - - if not use_origin_backend(size): - for dim in size: - if not isinstance(dim, int): - checker_throw_value_error("ranf", "type(dim)", type(dim), int) - return dpnp_random(size) - - return call_origin(numpy.random.ranf, size) - - -def randint(low, high=None, size=None, dtype=int): - """ - randint(low, high=None, size=None, dtype=int) - - Return random integers from `low` (inclusive) to `high` (exclusive). - Return random integers from the "discrete uniform" distribution of - the specified dtype in the "half-open" interval [`low`, `high`). If - `high` is None (the default), then results are from [0, `low`). - - Parameters - ---------- - low : int - Lowest (signed) integer to be drawn from the distribution (unless - ``high=None``, in which case this parameter is one above the - *highest* such integer). - high : int, optional - If provided, one above the largest (signed) integer to be drawn - from the distribution. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. - dtype : dtype, optional - Desired dtype of the result. Byteorder must be native. - The default value is int. - Returns - ------- - out : array of random ints - `size`-shaped array of random integers from the appropriate - distribution, or a single such random int if `size` not provided. - See Also - -------- - random_integers : similar to `randint`, only for the closed - interval [`low`, `high`], and 1 is the lowest value if `high` is - omitted. - - """ - - if not use_origin_backend(low): - if size is None: - size = 1 - elif isinstance(size, tuple): - for dim in size: - if not isinstance(dim, int): - checker_throw_value_error("randint", "type(dim)", type(dim), int) - elif not isinstance(size, int): - checker_throw_value_error("randint", "type(size)", type(size), int) - - if high is None: - high = low - low = 0 - - low = int(low) - high = int(high) - - if (low >= high): - checker_throw_value_error("randint", "low", low, high) - - _dtype = numpy.dtype(dtype) - - # TODO: - # supported only int32 - # or just raise error when dtype != numpy.int32 - if _dtype == numpy.int32 or _dtype == numpy.int64: - _dtype = numpy.int32 - else: - raise TypeError('Unsupported dtype %r for randint' % dtype) - return dpnp_uniform(low, high, size, _dtype) - - return call_origin(numpy.random.randint, low, high, size, dtype) - - -def randn(d0, *dn): - """ - If positive int_like arguments are provided, randn generates an array of shape (d0, d1, ..., dn), - filled with random floats sampled from a univariate “normal” (Gaussian) distribution of mean 0 and variance 1. - - Parameters - ---------- - d0, d1, …, dn : The dimensions of the returned array, must be non-negative. - - Returns - ------- - out : (d0, d1, ..., dn)-shaped array of floating-point samples from the standard normal distribution. - - See Also - -------- - standard_normal - normal - - """ - - if not use_origin_backend(d0): - dims = tuple([d0, *dn]) - - for dim in dims: - if not isinstance(dim, int): - checker_throw_value_error("randn", "type(dim)", type(dim), int) - return dpnp_randn(dims) - - return call_origin(numpy.random.randn, d0, *dn) - - -def random(size): - """ - Return random floats in the half-open interval [0.0, 1.0). - Alias for random_sample. - - Parameters - ---------- - size : Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. - - Returns - ------- - out : Array of random floats of shape size. - - See Also - -------- - random - - """ - - if not use_origin_backend(size): - for dim in size: - if not isinstance(dim, int): - checker_throw_value_error("random", "type(dim)", type(dim), int) - return dpnp_random(size) - - return call_origin(numpy.random.random, size) - - -def random_integers(low, high=None, size=None): - """ - random_integers(low, high=None, size=None) - - Random integers between `low` and `high`, inclusive. - Return random integers from the "discrete uniform" distribution in - the closed interval [`low`, `high`]. If `high` is - None (the default), then results are from [1, `low`]. - - Parameters - ---------- - low : int - Lowest (signed) integer to be drawn from the distribution (unless - ``high=None``, in which case this parameter is the *highest* such - integer). - high : int, optional - If provided, the largest (signed) integer to be drawn from the - distribution (see above for behavior if ``high=None``). - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. - Returns - ------- - out : array of random ints - `size`-shaped array of random integers from the appropriate - distribution, or a single such random int if `size` not provided. - See Also - -------- - randint - - """ - - if not use_origin_backend(low): - if high is None: - high = low - low = 1 - return randint(low, int(high) + 1, size=size) - - return call_origin(numpy.random.random_integers, low, high, size) - - -def random_sample(size): - """ - Return random floats in the half-open interval [0.0, 1.0). - - Parameters - ---------- - size : Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. - - Returns - ------- - out : Array of random floats of shape size. - - See Also - -------- - random - - """ - - if not use_origin_backend(size): - for dim in size: - if not isinstance(dim, int): - checker_throw_value_error("random_sample", "type(dim)", type(dim), int) - return dpnp_random(size) - - return call_origin(numpy.random.random_sample, size) - - -def seed(seed=None): - """ - Reseed a legacy philox4x32x10 random number generator engine - - Parameters - ---------- - seed : {None, int}, optional - - """ - if not use_origin_backend(seed): - # TODO: - # implement seed default value as is in numpy - if seed is None: - seed = 1 - elif not isinstance(seed, int): - checker_throw_value_error("seed", "type(seed)", type(seed), int) - elif seed < 0: - checker_throw_value_error("seed", "seed", seed, "non-negative") - return dpnp_srand(seed) - - return call_origin(numpy.random.seed, seed) - - -def sample(size): - """ - Return random floats in the half-open interval [0.0, 1.0). - This is an alias of random_sample. - - Parameters - ---------- - size : Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. - - Returns - ------- - out : Array of random floats of shape size. - - See Also - -------- - random - - """ - - if not use_origin_backend(size): - for dim in size: - if not isinstance(dim, int): - checker_throw_value_error("sample", "type(dim)", type(dim), int) - return dpnp_random(size) - - return call_origin(numpy.random.sample, size) - - -def uniform(low=0.0, high=1.0, size=None): - """ - uniform(low=0.0, high=1.0, size=None) - - Draw samples from a uniform distribution. - Samples are uniformly distributed over the half-open interval - ``[low, high)`` (includes low, but excludes high). In other words, - any value within the given interval is equally likely to be drawn - by `uniform`. - Parameters - ---------- - low : float, optional - Lower boundary of the output interval. All values generated will be - greater than or equal to low. The default value is 0. - high : float - Upper boundary of the output interval. All values generated will be - less than high. The default value is 1.0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``low`` and ``high`` are both scalars. - Returns - ------- - out : array or scalar - Drawn samples from the parameterized uniform distribution. - See Also - -------- - random : Floats uniformly distributed over ``[0, 1)``. - - """ - - if not use_origin_backend(low): - if size is None: - size = 1 - if low == high: - # TODO: - # currently dparray.full is not implemented - # return dpnp.dparray.dparray.full(size, low, dtype=numpy.float64) - message = "`low` equal to `high`, should return an array, filled with `low` value." - message += " Currently not supported. See: numpy.full TODO" - checker_throw_runtime_error("uniform", message) - elif low > high: - low, high = high, low - return dpnp_uniform(low, high, size, dtype=numpy.float64) - - return call_origin(numpy.random.uniform, low, high, size) +# cython: language_level=3 +# -*- coding: utf-8 -*- +# ***************************************************************************** +# Copyright (c) 2016-2020, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +""" +Module Intel NumPy Random + +Set of functions to implement NumPy random module API + + .. seealso:: :meth:`numpy.random` + +""" + + +import dpnp +import numpy + +from dpnp.dparray import dparray +from dpnp.dpnp_utils import * +from dpnp.random._random import * + + +__all__ = [ + 'chisquare', + 'exponential', + 'gamma', + 'rand', + 'ranf', + 'randint', + 'randn', + 'random', + 'random_integers', + 'random_sample', + 'seed', + 'sample', + 'uniform' +] + + +def chisquare(df, size=None): + """ + chisquare(df, size=None) + + Draw samples from a chi-square distribution. + + When `df` independent random variables, each with standard normal + distributions (mean 0, variance 1), are squared and summed, the + resulting distribution is chi-square (see Notes). This distribution + is often used in hypothesis testing. + + Parameters + ---------- + df : float + Number of degrees of freedom, must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``df`` is a scalar. Otherwise, + ``np.array(df).size`` samples are drawn. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the parameterized chi-square distribution. + + Raises + ------ + ValueError + When `df` <= 0 or when an inappropriate `size` (e.g. ``size=-1``) + is given. + + Examples + -------- + >>> dpnp.random.chisquare(2,4) + array([ 1.89920014, 9.00867716, 3.13710533, 5.62318272]) # random + + """ + + if not use_origin_backend(df): + if size is None: + size = 1 + elif isinstance(size, tuple): + for dim in size: + if not isinstance(dim, int): + checker_throw_value_error("chisquare", "type(dim)", type(dim), int) + elif not isinstance(size, int): + checker_throw_value_error("chisquare", "type(size)", type(size), int) + + # TODO: + # array_like of floats for `df` + # add check for df array like, after adding array-like interface for df param + if df <= 0: + checker_throw_value_error("chisquare", "df", df, "positive") + # TODO: + # float to int, safe + return dpnp_chisquare(int(df), size) + + return call_origin(numpy.random.chisquare, df, size) + + +def exponential(scale=1.0, size=None): + """Exponential distribution. + + Draw samples from an exponential distribution. + + Its probability density function is + + .. math:: f(x; \\frac{1}{\\beta}) = \\frac{1}{\\beta} \\exp(-\\frac{x}{\\beta}), + + for ``x > 0`` and 0 elsewhere. :math:`\\beta` is the scale parameter, + which is the inverse of the rate parameter :math:`\\lambda = 1/\\beta`. + The rate parameter is an alternative, widely used parameterization + of the exponential distribution [3]_. + + The exponential distribution is a continuous analogue of the + geometric distribution. It describes many common situations, such as + the size of raindrops measured over many rainstorms [1]_, or the time + between page requests to Wikipedia [2]_. + + .. note:: + New code should use the ``exponential`` method of a ``default_rng()`` + instance instead; please see the :ref:`random-quick-start`. + + Parameters + ---------- + scale : float + The scale parameter, :math:`\\beta = 1/\\lambda`. Must be + non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``scale`` is a scalar. Otherwise, + ``np.array(scale).size`` samples are drawn. + + Returns + ------- + out : dparray + Drawn samples from the parameterized exponential distribution. + + References + ---------- + .. [1] Peyton Z. Peebles Jr., "Probability, Random Variables and + Random Signal Principles", 4th ed, 2001, p. 57. + .. [2] Wikipedia, "Poisson process", + https://en.wikipedia.org/wiki/Poisson_process + .. [3] Wikipedia, "Exponential distribution", + https://en.wikipedia.org/wiki/Exponential_distribution + + """ + + if not use_origin_backend(scale): + if size is None: + size = 1 + elif isinstance(size, tuple): + for dim in size: + if not isinstance(dim, int): + checker_throw_value_error("exponential", "type(dim)", type(dim), int) + elif not isinstance(size, int): + checker_throw_value_error("exponential", "type(size)", type(size), int) + + if scale < 0: + checker_throw_value_error("exponential", "scale", scale, "non-negative") + + return dpnp_exponential(scale, size) + + return call_origin(numpy.random.exponential, scale, size) + + +def gamma(shape, scale=1.0, size=None): + """Gamma distribution. + + Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + `shape` (sometimes designated "k") and `scale` (sometimes designated + "theta"), where both parameters are > 0. + + .. note:: + New code should use the ``gamma`` method of a ``default_rng()`` + instance instead; please see the :ref:`random-quick-start`. + + Parameters + ---------- + shape : float or array_like of floats + The shape of the gamma distribution. Must be non-negative. + scale : float or array_like of floats, optional + The scale of the gamma distribution. Must be non-negative. + Default is equal to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` and ``scale`` are both scalars. + + Returns + ------- + out : dparray + Drawn samples from the parameterized gamma distribution. + + Notes + ----- + The probability density for the Gamma distribution is + + .. math:: p(x) = x^{k-1}\\frac{e^{-x/\\theta}}{\\theta^k\\Gamma(k)}, + + where :math:`k` is the shape and :math:`\\theta` the scale, + and :math:`\\Gamma` is the Gamma function. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + + References + ---------- + .. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A + Wolfram Web Resource. + http://mathworld.wolfram.com/GammaDistribution.html + .. [2] Wikipedia, "Gamma distribution", + https://en.wikipedia.org/wiki/Gamma_distribution + + """ + + # TODO: + # array_like of floats for `scale` and `shape` + if not use_origin_backend(scale): + if size is None: + size = 1 + elif isinstance(size, tuple): + for dim in size: + if not isinstance(dim, int): + checker_throw_value_error("gamma", "type(dim)", type(dim), int) + elif not isinstance(size, int): + checker_throw_value_error("gamma", "type(size)", type(size), int) + + if scale < 0: + checker_throw_value_error("gamma", "scale", scale, "non-negative") + if shape < 0: + checker_throw_value_error("gamma", "shape", shape, "non-negative") + + return dpnp_gamma(shape, scale, size) + + return call_origin(numpy.random.gamma, shape, scale, size) + + +def rand(d0, *dn): + """ + Create an array of the given shape and populate it + with random samples from a uniform distribution over [0, 1). + + Parameters + ---------- + d0, d1, …, dn : The dimensions of the returned array, must be non-negative. + + Returns + ------- + out : Random values. + + See Also + -------- + random + + """ + + if not use_origin_backend(d0): + dims = tuple([d0, *dn]) + + for dim in dims: + if not isinstance(dim, int): + checker_throw_value_error("rand", "type(dim)", type(dim), int) + return dpnp_random(dims) + + return call_origin(numpy.random.rand, d0, *dn) + + +def ranf(size): + """ + Return random floats in the half-open interval [0.0, 1.0). + This is an alias of random_sample. + + Parameters + ---------- + size : Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. + + Returns + ------- + out : Array of random floats of shape size. + + See Also + -------- + random + + """ + + if not use_origin_backend(size): + for dim in size: + if not isinstance(dim, int): + checker_throw_value_error("ranf", "type(dim)", type(dim), int) + return dpnp_random(size) + + return call_origin(numpy.random.ranf, size) + + +def randint(low, high=None, size=None, dtype=int): + """ + randint(low, high=None, size=None, dtype=int) + + Return random integers from `low` (inclusive) to `high` (exclusive). + Return random integers from the "discrete uniform" distribution of + the specified dtype in the "half-open" interval [`low`, `high`). If + `high` is None (the default), then results are from [0, `low`). + + Parameters + ---------- + low : int + Lowest (signed) integer to be drawn from the distribution (unless + ``high=None``, in which case this parameter is one above the + *highest* such integer). + high : int, optional + If provided, one above the largest (signed) integer to be drawn + from the distribution. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. Byteorder must be native. + The default value is int. + Returns + ------- + out : array of random ints + `size`-shaped array of random integers from the appropriate + distribution, or a single such random int if `size` not provided. + See Also + -------- + random_integers : similar to `randint`, only for the closed + interval [`low`, `high`], and 1 is the lowest value if `high` is + omitted. + + """ + + if not use_origin_backend(low): + if size is None: + size = 1 + elif isinstance(size, tuple): + for dim in size: + if not isinstance(dim, int): + checker_throw_value_error("randint", "type(dim)", type(dim), int) + elif not isinstance(size, int): + checker_throw_value_error("randint", "type(size)", type(size), int) + + if high is None: + high = low + low = 0 + + low = int(low) + high = int(high) + + if (low >= high): + checker_throw_value_error("randint", "low", low, high) + + _dtype = numpy.dtype(dtype) + + # TODO: + # supported only int32 + # or just raise error when dtype != numpy.int32 + if _dtype == numpy.int32 or _dtype == numpy.int64: + _dtype = numpy.int32 + else: + raise TypeError('Unsupported dtype %r for randint' % dtype) + return dpnp_uniform(low, high, size, _dtype) + + return call_origin(numpy.random.randint, low, high, size, dtype) + + +def randn(d0, *dn): + """ + If positive int_like arguments are provided, randn generates an array of shape (d0, d1, ..., dn), + filled with random floats sampled from a univariate “normal” (Gaussian) distribution of mean 0 and variance 1. + + Parameters + ---------- + d0, d1, …, dn : The dimensions of the returned array, must be non-negative. + + Returns + ------- + out : (d0, d1, ..., dn)-shaped array of floating-point samples from the standard normal distribution. + + See Also + -------- + standard_normal + normal + + """ + + if not use_origin_backend(d0): + dims = tuple([d0, *dn]) + + for dim in dims: + if not isinstance(dim, int): + checker_throw_value_error("randn", "type(dim)", type(dim), int) + return dpnp_randn(dims) + + return call_origin(numpy.random.randn, d0, *dn) + + +def random(size): + """ + Return random floats in the half-open interval [0.0, 1.0). + Alias for random_sample. + + Parameters + ---------- + size : Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. + + Returns + ------- + out : Array of random floats of shape size. + + See Also + -------- + random + + """ + + if not use_origin_backend(size): + for dim in size: + if not isinstance(dim, int): + checker_throw_value_error("random", "type(dim)", type(dim), int) + return dpnp_random(size) + + return call_origin(numpy.random.random, size) + + +def random_integers(low, high=None, size=None): + """ + random_integers(low, high=None, size=None) + + Random integers between `low` and `high`, inclusive. + Return random integers from the "discrete uniform" distribution in + the closed interval [`low`, `high`]. If `high` is + None (the default), then results are from [1, `low`]. + + Parameters + ---------- + low : int + Lowest (signed) integer to be drawn from the distribution (unless + ``high=None``, in which case this parameter is the *highest* such + integer). + high : int, optional + If provided, the largest (signed) integer to be drawn from the + distribution (see above for behavior if ``high=None``). + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + Returns + ------- + out : array of random ints + `size`-shaped array of random integers from the appropriate + distribution, or a single such random int if `size` not provided. + See Also + -------- + randint + + """ + + if not use_origin_backend(low): + if high is None: + high = low + low = 1 + return randint(low, int(high) + 1, size=size) + + return call_origin(numpy.random.random_integers, low, high, size) + + +def random_sample(size): + """ + Return random floats in the half-open interval [0.0, 1.0). + + Parameters + ---------- + size : Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. + + Returns + ------- + out : Array of random floats of shape size. + + See Also + -------- + random + + """ + + if not use_origin_backend(size): + for dim in size: + if not isinstance(dim, int): + checker_throw_value_error("random_sample", "type(dim)", type(dim), int) + return dpnp_random(size) + + return call_origin(numpy.random.random_sample, size) + + +def seed(seed=None): + """ + Reseed a legacy philox4x32x10 random number generator engine + + Parameters + ---------- + seed : {None, int}, optional + + """ + if not use_origin_backend(seed): + # TODO: + # implement seed default value as is in numpy + if seed is None: + seed = 1 + elif not isinstance(seed, int): + checker_throw_value_error("seed", "type(seed)", type(seed), int) + elif seed < 0: + checker_throw_value_error("seed", "seed", seed, "non-negative") + return dpnp_srand(seed) + + return call_origin(numpy.random.seed, seed) + + +def sample(size): + """ + Return random floats in the half-open interval [0.0, 1.0). + This is an alias of random_sample. + + Parameters + ---------- + size : Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. + + Returns + ------- + out : Array of random floats of shape size. + + See Also + -------- + random + + """ + + if not use_origin_backend(size): + for dim in size: + if not isinstance(dim, int): + checker_throw_value_error("sample", "type(dim)", type(dim), int) + return dpnp_random(size) + + return call_origin(numpy.random.sample, size) + + +def uniform(low=0.0, high=1.0, size=None): + """ + uniform(low=0.0, high=1.0, size=None) + + Draw samples from a uniform distribution. + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + Parameters + ---------- + low : float, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``low`` and ``high`` are both scalars. + Returns + ------- + out : array or scalar + Drawn samples from the parameterized uniform distribution. + See Also + -------- + random : Floats uniformly distributed over ``[0, 1)``. + + """ + + if not use_origin_backend(low): + if size is None: + size = 1 + if low == high: + # TODO: + # currently dparray.full is not implemented + # return dpnp.dparray.dparray.full(size, low, dtype=numpy.float64) + message = "`low` equal to `high`, should return an array, filled with `low` value." + message += " Currently not supported. See: numpy.full TODO" + checker_throw_runtime_error("uniform", message) + elif low > high: + low, high = high, low + return dpnp_uniform(low, high, size, dtype=numpy.float64) + + return call_origin(numpy.random.uniform, low, high, size) diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 45b830408061..bb4b8a56ab05 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -7,6 +7,10 @@ tests/test_random.py::test_random_input_shape[chisquare] tests/test_random.py::test_random_input_size[chisquare] tests/test_random.py::test_randn_normal_distribution tests/test_random.py::test_random_seed_chisquare +tests/test_random.py::test_random_seed_gamma +tests/test_random.py::test_invalid_args_chisquare +tests/test_random.py::test_invalid_args_gamma +tests/test_random.py::test_check_moments_gamma tests/test_statistics.py::test_median[2-float32] tests/test_statistics.py::test_median[2-float64] tests/third_party/cupy/binary_tests/test_elementwise.py::TestElementwise::test_bitwise_and diff --git a/tests/test_random.py b/tests/test_random.py index d9e1003dea0f..aa90f418f910 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -4,6 +4,7 @@ import numpy # from scipy import stats from numpy.testing import assert_allclose +import math @pytest.mark.parametrize("func", @@ -131,6 +132,18 @@ def test_random_seed_exponential(): assert_allclose(a1, a2, rtol=1e-07, atol=0) +def test_random_seed_gamma(): + seed = 28041990 + size = 100 + shape = 3.0 # shape param for gamma distr + + dpnp.random.seed(seed) + a1 = dpnp.random.gamma(shape=shape, size=size) + dpnp.random.seed(seed) + a2 = dpnp.random.gamma(shape=shape, size=size) + assert_allclose(a1, a2, rtol=1e-07, atol=0) + + def test_invalid_args_chisquare(): size = 10 df = -1 # positive `df` is expected @@ -143,3 +156,27 @@ def test_invalid_args_exponential(): scale = -1 # non-negative `scale` is expected with pytest.raises(ValueError): dpnp.random.exponential(scale, size) + + +def test_invalid_args_gamma(): + size = 10 + shape = -1 # non-negative `shape` is expected + with pytest.raises(ValueError): + dpnp.random.gamma(shape=shape, size=size) + shape = 1.0 # OK + scale = -1.0 # non-negative `shape` is expected + with pytest.raises(ValueError): + dpnp.random.gamma(shape, scale, size) + + +def test_check_moments_gamma(): + seed = 28041990 + dpnp.random.seed(seed) + shape = 2.56 + scale = 0.8 + expected_mean = shape * scale + expected_var = shape * scale * scale + var = numpy.var(dpnp.random.gamma(shape=shape, scale=scale, size=10**6)) + mean = numpy.mean(dpnp.random.gamma(shape=shape, scale=scale, size=10**6)) + assert math.isclose(var, expected_var, abs_tol=0.003) + assert math.isclose(mean, expected_mean, abs_tol=0.003)