Skip to content

Commit 252a4b5

Browse files
committed
Add the standard gamma RandomVariable
1 parent bf2de51 commit 252a4b5

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

aesara/tensor/random/basic.py

+40
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,45 @@ def rng_fn_scipy(cls, rng, shape, scale, size):
431431
gamma = GammaRV()
432432

433433

434+
class StandardGammaRV(GammaRV):
435+
r"""A standard gamma continuous random variable.
436+
437+
The probability density function for `standard_gamma` in terms of its shape
438+
parameters :math:`\alpha` is:
439+
440+
.. math::
441+
442+
f(x) = \frac{1}{\Gamma(\alpha)}x^{\alpha-1}e^{-x}
443+
444+
for :math:`x \geq 0`, :math:`\alpha > 0`. :math:`\Gamma` is the gamma
445+
function:
446+
447+
.. math::
448+
449+
\Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t
450+
451+
"""
452+
453+
def __call__(self, shape, size=None, **kwargs):
454+
"""Draw samples from a standard gamma distribution.
455+
456+
Parameters
457+
----------
458+
shape
459+
The shape :math:`\alpha` of the gamma distribution. Must be positive.
460+
size
461+
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
462+
independent, identically distributed random variables are
463+
returned. Default is `None` in which case a single random variable
464+
is returned.
465+
466+
"""
467+
return super().__call__(shape, rate=1.0, size=size, **kwargs)
468+
469+
470+
standard_gamma = StandardGammaRV()
471+
472+
434473
class ChiSquareRV(RandomVariable):
435474
r"""A chi square continuous random variable.
436475
@@ -2012,6 +2051,7 @@ def __call__(self, x, **kwargs):
20122051
"uniform",
20132052
"standard_cauchy",
20142053
"standard_exponential",
2054+
"standard_gamma",
20152055
"standard_normal",
20162056
"standard_t",
20172057
"negative_binomial",

tests/tensor/random/test_basic.py

+26
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
randint,
5050
standard_cauchy,
5151
standard_exponential,
52+
standard_gamma,
5253
standard_normal,
5354
standard_t,
5455
t,
@@ -368,6 +369,31 @@ def test_fn(shape, rate, **kwargs):
368369
)
369370

370371

372+
@pytest.mark.parametrize(
373+
"a, size",
374+
[
375+
(np.array(0.5, dtype=config.floatX), None),
376+
(np.array(0.5, dtype=config.floatX), []),
377+
(
378+
np.full((1, 2), 0.5, dtype=config.floatX),
379+
None,
380+
),
381+
],
382+
)
383+
def test_standard_gamma_samples(a, size):
384+
gamma_test_fn = fixed_scipy_rvs("gamma")
385+
386+
def test_fn(shape, **kwargs):
387+
return gamma_test_fn(shape, scale=1.0, **kwargs)
388+
389+
compare_sample_values(
390+
standard_gamma,
391+
a,
392+
size=size,
393+
test_fn=test_fn,
394+
)
395+
396+
371397
@pytest.mark.parametrize(
372398
"df, size",
373399
[

0 commit comments

Comments
 (0)