Skip to content

Commit 4be8bc0

Browse files
committed
Add the Student's t RandomVariable
1 parent 1a3ec8d commit 4be8bc0

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed

aesara/tensor/random/basic.py

+55
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,60 @@ def rng_fn_scipy(cls, rng, b, loc, scale, size):
13701370
truncexpon = TruncExponentialRV()
13711371

13721372

1373+
class StudentTRV(ScipyRandomVariable):
1374+
r"""A Student's t continuous random variable.
1375+
1376+
The probability density function for `t` in terms of its degrees of freedom
1377+
parameter :math:`\nu`, location parameter :math:`\mu` and scale
1378+
parameter :math:`\sigma` is:
1379+
1380+
.. math::
1381+
1382+
f(x; \nu, \alpha, \beta) = \frac{\Gamma(\frac{\nu + 1}{2})}{\Gamma(\frac{\nu}{2})} \left(\frac{1}{\pi\nu\sigma}\right)^{\frac{1}{2}} \left[1+\frac{(x-\mu)^2}{\nu\sigma}\right]^{-\frac{\nu+1}{2}}
1383+
1384+
for :math:`\nu > 0`, :math:`\sigma > 0`.
1385+
1386+
"""
1387+
name = "t"
1388+
ndim_supp = 0
1389+
ndims_params = [0, 0, 0]
1390+
dtype = "floatX"
1391+
_print_name = ("StudentT", "\\operatorname{StudentT}")
1392+
1393+
def __call__(self, df, loc=0.0, scale=1.0, size=None, **kwargs):
1394+
r"""Draw samples from a Student's t distribution.
1395+
1396+
Signature
1397+
---------
1398+
1399+
`(), (), () -> ()`
1400+
1401+
Parameters
1402+
----------
1403+
df
1404+
Degrees of freedom parameter :math:`\nu` of the distribution. Must be
1405+
positive.
1406+
loc
1407+
Location parameter :math:`\mu` of the distribution.
1408+
scale
1409+
Scale parameter :math:`\sigma` of the distribution. Must be
1410+
positive.
1411+
size
1412+
Sample shape. If the given size is `(m, n, k)`, then `m * n * k`
1413+
independent, identically distributed samples are returned. Default is
1414+
`None` in which case a single sample is returned.
1415+
1416+
"""
1417+
return super().__call__(df, loc, scale, size=size, **kwargs)
1418+
1419+
@classmethod
1420+
def rng_fn_scipy(cls, rng, df, loc, scale, size):
1421+
return stats.t.rvs(df, loc=loc, scale=scale, size=size, random_state=rng)
1422+
1423+
1424+
t = StudentTRV()
1425+
1426+
13731427
class BernoulliRV(ScipyRandomVariable):
13741428
r"""A Bernoulli discrete random variable.
13751429
@@ -2071,4 +2125,5 @@ def __call__(self, x, **kwargs):
20712125
"standard_normal",
20722126
"negative_binomial",
20732127
"gengamma",
2128+
"t",
20742129
]

doc/library/tensor/random/basic.rst

+3
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ Aesara can produce :class:`RandomVariable`\s that draw samples from many differe
145145
.. autoclass:: aesara.tensor.random.basic.StandardNormalRV
146146
:members: __call__
147147

148+
.. autoclass:: aesara.tensor.random.basic.StudentTRV
149+
:members: __call__
150+
148151
.. autoclass:: aesara.tensor.random.basic.TriangularRV
149152
:members: __call__
150153

tests/tensor/random/test_basic.py

+43
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
poisson,
4949
randint,
5050
standard_normal,
51+
t,
5152
triangular,
5253
truncexpon,
5354
uniform,
@@ -926,6 +927,48 @@ def test_truncexpon_samples(b, loc, scale, size):
926927
)
927928

928929

930+
@pytest.mark.parametrize(
931+
"df, loc, scale, size",
932+
[
933+
(
934+
np.array(2, dtype=config.floatX),
935+
np.array(0, dtype=config.floatX),
936+
np.array(1, dtype=config.floatX),
937+
None,
938+
),
939+
(
940+
np.array(2, dtype=config.floatX),
941+
np.array(0, dtype=config.floatX),
942+
np.array(1, dtype=config.floatX),
943+
[],
944+
),
945+
(
946+
np.array(2, dtype=config.floatX),
947+
np.array(0, dtype=config.floatX),
948+
np.array(1, dtype=config.floatX),
949+
[2, 3],
950+
),
951+
(
952+
np.full((1, 2), 5, dtype=config.floatX),
953+
np.array(0, dtype=config.floatX),
954+
np.array(1, dtype=config.floatX),
955+
None,
956+
),
957+
],
958+
)
959+
def test_t_samples(df, loc, scale, size):
960+
compare_sample_values(
961+
t,
962+
df,
963+
loc,
964+
scale,
965+
size=size,
966+
test_fn=lambda *args, size=None, random_state=None, **kwargs: t.rng_fn(
967+
random_state, *(args + (size,))
968+
),
969+
)
970+
971+
929972
@pytest.mark.parametrize(
930973
"p, size",
931974
[

0 commit comments

Comments
 (0)