Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move content of distributions.special into distributions.dist_math #4760

Merged
merged 12 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ jobs:
--ignore=pymc3/tests/test_minibatches.py
--ignore=pymc3/tests/test_pickling.py
--ignore=pymc3/tests/test_plots.py
--ignore=pymc3/tests/test_special_functions.py
--ignore=pymc3/tests/test_updates.py
--ignore=pymc3/tests/test_examples.py
--ignore=pymc3/tests/test_gp.py
Expand All @@ -67,7 +66,6 @@ jobs:
pymc3/tests/test_minibatches.py
pymc3/tests/test_pickling.py
pymc3/tests/test_plots.py
pymc3/tests/test_special_functions.py
pymc3/tests/test_updates.py

- |
Expand Down
2 changes: 1 addition & 1 deletion pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@
gammaln,
i0e,
incomplete_beta,
log_i0,
log_normal,
logpow,
normal_lccdf,
normal_lcdf,
zvalue,
)
from pymc3.distributions.distribution import Continuous
from pymc3.distributions.special import log_i0
from pymc3.math import log1mexp, log1pexp, logdiffexp, logit
from pymc3.util import UNSET

Expand Down
40 changes: 39 additions & 1 deletion pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from aesara.scalar import UnaryScalarOp, upgrade_to_float_no_complex
from aesara.scan import until
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import gammaln
themrzmaster marked this conversation as resolved.
Show resolved Hide resolved
from aesara.tensor.slinalg import Cholesky, Solve

from pymc3.aesaraf import floatX
from pymc3.distributions.shape_utils import to_tuple
from pymc3.distributions.special import gammaln

f = floatX
c = -0.5 * np.log(2.0 * np.pi)
Expand Down Expand Up @@ -634,3 +634,41 @@ def clipped_beta_rvs(a, b, size=None, random_state=None, dtype="float64"):
out = scipy.stats.beta.rvs(a, b, size=size, random_state=random_state).astype(dtype)
lower, upper = _beta_clip_values[dtype]
return np.maximum(np.minimum(out, upper), lower)


def multigammaln(a, p):
"""Multivariate Log Gamma

Parameters
----------
a: tensor like
p: int
degrees of freedom. p > 0
"""
i = at.arange(1, p + 1)
return p * (p - 1) * at.log(np.pi) / 4.0 + at.sum(gammaln(a + (1.0 - i) / 2.0), axis=0)


def log_i0(x):
"""
Calculates the logarithm of the 0 order modified Bessel function of the first kind""
"""
return at.switch(
at.lt(x, 5),
at.log1p(
x ** 2.0 / 4.0
+ x ** 4.0 / 64.0
+ x ** 6.0 / 2304.0
+ x ** 8.0 / 147456.0
+ x ** 10.0 / 14745600.0
+ x ** 12.0 / 2123366400.0
),
x
- 0.5 * at.log(2.0 * np.pi * x)
+ at.log1p(
1.0 / (8.0 * x)
+ 9.0 / (128.0 * x ** 2.0)
+ 225.0 / (3072.0 * x ** 3.0)
+ 11025.0 / (98304.0 * x ** 4.0)
),
)
3 changes: 1 addition & 2 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@
from pymc3.aesaraf import floatX, intX
from pymc3.distributions import transforms
from pymc3.distributions.continuous import ChiSquared, Normal
from pymc3.distributions.dist_math import bound, factln, logpow
from pymc3.distributions.dist_math import bound, factln, gammaln, logpow, multigammaln
from pymc3.distributions.distribution import Continuous, Discrete
from pymc3.distributions.special import gammaln, multigammaln
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker

__all__ = [
Expand Down
58 changes: 0 additions & 58 deletions pymc3/distributions/special.py

This file was deleted.

26 changes: 26 additions & 0 deletions pymc3/tests/test_dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import numpy as np
import numpy.testing as npt
import pytest
import scipy.special as ss
themrzmaster marked this conversation as resolved.
Show resolved Hide resolved

from aesara import config, function
from aesara.tensor.random.basic import multinomial
from scipy import interpolate, stats

Expand All @@ -32,7 +34,9 @@
clipped_beta_rvs,
factln,
i0e,
multigammaln,
)
from pymc3.tests.checks import close_to
from pymc3.tests.helpers import verify_grad


Expand Down Expand Up @@ -236,3 +240,25 @@ def test_clipped_beta_rvs(dtype):
# equal to zero or one (issue #3898)
values = clipped_beta_rvs(0.01, 0.01, size=1000000, dtype=dtype)
assert not (np.any(values == 0) or np.any(values == 1))


def check_vals(fn1, fn2, *args):
v = fn1(*args)
close_to(v, fn2(*args), 1e-6 if v.dtype == np.float64 else 1e-4)


def test_multigamma():
x = at.vector("x")
p = at.scalar("p")

xvals = [np.array([v], dtype=config.floatX) for v in [0.1, 2, 5, 10, 50, 100]]

multigammaln_ = function([x, p], multigammaln(x, p), mode="FAST_COMPILE")

def ssmultigammaln(a, b):
return np.array(ss.multigammaln(a[0], b), config.floatX)

for p in [0, 1, 2, 3, 4, 100]:
for x in xvals:
if np.all(x > 0.5 * (p - 1)):
check_vals(multigammaln_, ssmultigammaln, x, p)
3 changes: 3 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,13 +1356,16 @@ def test_t(self):
lambda value, nu, mu, lam: sp.t.logcdf(value, nu, mu, lam ** -0.5),
n_samples=10, # relies on slow incomplete beta
)
# TODO: reenable when PR #4736 is merged
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A reminder issue would be good. Or @ricardoV94 you rebase the other PR after we merge it and then you can enable it right there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do that. Linking it here for now #4736

self.check_logcdf(
StudentT,
R,
{"nu": Rplus, "mu": R, "sigma": Rplus},
lambda value, nu, mu, sigma: sp.t.logcdf(value, nu, mu, sigma),
n_samples=5, # Just testing alternative parametrization
)
"""

def test_cauchy(self):
self.check_logp(
Expand Down
6 changes: 3 additions & 3 deletions pymc3/tests/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ def system(y, t, p):
ode_model = DifferentialEquation(func=system, t0=0, times=times, n_states=2, n_theta=2)

with pm.Model() as model:
beta = pm.HalfCauchy("beta", 1)
gamma = pm.HalfCauchy("gamma", 1)
sigma = pm.HalfCauchy("sigma", 1, shape=2)
beta = pm.HalfCauchy("beta", 1, initval=1)
gamma = pm.HalfCauchy("gamma", 1, initval=1)
sigma = pm.HalfCauchy("sigma", 1, shape=2, initval=[1, 1])
forward = ode_model(theta=[beta, gamma], y0=[0.99, 0.01])
y = pm.Lognormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)

Expand Down
45 changes: 0 additions & 45 deletions pymc3/tests/test_special_functions.py

This file was deleted.