Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring the ChiSquared distribution #4695

Merged
merged 13 commits into from
Jun 7, 2021
2 changes: 1 addition & 1 deletion docs/source/api/distributions/continuous.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Continuous
InverseGamma
Weibull
Lognormal
ChiSquared
ChiSquare
Wald
Pareto
ExGaussian
Expand Down
4 changes: 2 additions & 2 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
AsymmetricLaplace,
Beta,
Cauchy,
ChiSquared,
ChiSquare,
ExGaussian,
Exponential,
Flat,
Expand Down Expand Up @@ -125,7 +125,7 @@
"Bound",
"Lognormal",
"HalfStudentT",
"ChiSquared",
"ChiSquare",
"HalfNormal",
"Wald",
"Pareto",
Expand Down
11 changes: 7 additions & 4 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BetaRV,
WeibullRV,
cauchy,
chisquare,
exponential,
gamma,
gumbel,
Expand Down Expand Up @@ -87,7 +88,7 @@
"Weibull",
"HalfStudentT",
"Lognormal",
"ChiSquared",
"ChiSquare",
"HalfNormal",
"Wald",
"Pareto",
Expand Down Expand Up @@ -2548,7 +2549,7 @@ def logcdf(value, alpha, beta):
)


class ChiSquared(Gamma):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
class ChiSquare(Gamma):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
r"""
:math:`\chi^2` log-likelihood.

Expand Down Expand Up @@ -2586,9 +2587,11 @@ class ChiSquared(Gamma):
nu: int
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
Degrees of freedom (nu > 0).
"""
rv_op = chisquare

def __init__(self, nu, *args, **kwargs):
self.nu = nu = at.as_tensor_variable(floatX(nu))
@classmethod
def dist(cls, nu, *args, **kwargs):
nu = at.as_tensor_variable(floatX(nu))
super().__init__(alpha=nu / 2.0, beta=0.5, *args, **kwargs)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
4 changes: 2 additions & 2 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from pymc3.aesaraf import floatX, intX
from pymc3.distributions import transforms
from pymc3.distributions.continuous import ChiSquared, Normal
from pymc3.distributions.continuous import ChiSquare, Normal
from pymc3.distributions.dist_math import bound, factln, logpow
from pymc3.distributions.distribution import Continuous, Discrete
from pymc3.distributions.special import gammaln, multigammaln
Expand Down Expand Up @@ -905,7 +905,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testv
tril_testval = None

c = at.sqrt(
ChiSquared("%s_c" % name, nu - np.arange(2, 2 + n_diag), shape=n_diag, testval=diag_testval)
ChiSquare("%s_c" % name, nu - np.arange(2, 2 + n_diag), shape=n_diag, testval=diag_testval)
)
pm._log.info("Added new variable %s_c to model diagonal of Wishart." % name)
z = Normal("%s_z" % name, 0.0, 1.0, shape=n_tril, testval=tril_testval)
Expand Down
2 changes: 1 addition & 1 deletion pymc3/gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _build_prior(self, name, X, reparameterize=True, **kwargs):
cov = stabilize(self.cov_func(X))
shape = infer_shape(X, kwargs.pop("shape", None))
if reparameterize:
chi2 = pm.ChiSquared(name + "_chi2_", self.nu)
chi2 = pm.ChiSquare(name + "_chi2_", self.nu)
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=shape, **kwargs)
f = pm.Deterministic(name, (at.sqrt(self.nu) / chi2) * (mu + cholesky(cov).dot(v)))
else:
Expand Down
5 changes: 2 additions & 3 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
Bound,
Categorical,
Cauchy,
ChiSquared,
ChiSquare,
Constant,
DensityDist,
Dirichlet,
Expand Down Expand Up @@ -1030,10 +1030,9 @@ def test_half_normal(self):
lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_chi_squared(self):
self.check_logp(
ChiSquared,
ChiSquare,
Rplus,
{"nu": Rplusdunif},
lambda value, nu: sp.chi2.logpdf(value, df=nu),
Expand Down
13 changes: 10 additions & 3 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,17 @@ class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
params = {"kappa": 1.0, "b": 1.0, "mu": 0.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestChiSquared(BaseTestCases.BaseTestCase):
distribution = pm.ChiSquared
class TestChiSquare(BaseTestCases.BaseTestCase):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
distribution = pm.ChiSquare
params = {"nu": 2.0}
expected_rv_op_params = {"nu": 2.0}
reference_dist_params = {"df": 2.0}
reference_dist = seeded_numpy_distribution_builder("chisquare")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
Expand Down