Skip to content

Commit

Permalink
Refactoring the ChiSquared distribution (#4695)
Browse files Browse the repository at this point in the history
* Refactoring ChiSquared distribution

* Refactoring ChiSquared (minor edit)

* Refactoring Chisquared (another one-line change)

* Trying to rebase/merge my branch with updated upstream v4

* Using aesara chisquare op (r.f. PR #414) and renamed ChiSquared to ChiSquare

* Added logpdf & logcdf to the ChiSquare class

* Corrected function name

* Updating branch

* Refactoring ChiSquared: bug fixed, tests work locally

* Minor fix: removed float32 specification

* ☀️ underflow to -inf seems normal in float32

* Minor fix in documentation
  • Loading branch information
larryshamalama authored Jun 7, 2021
1 parent 4dfc8e0 commit b4a67d3
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 14 deletions.
46 changes: 41 additions & 5 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BetaRV,
WeibullRV,
cauchy,
chisquare,
exponential,
gamma,
gumbel,
Expand Down Expand Up @@ -2562,7 +2563,7 @@ def logcdf(value, alpha, beta):
)


class ChiSquared(Gamma):
class ChiSquared(PositiveContinuous):
r"""
:math:`\chi^2` log-likelihood.
Expand Down Expand Up @@ -2597,13 +2598,48 @@ class ChiSquared(Gamma):
Parameters
----------
nu: int
nu: float
Degrees of freedom (nu > 0).
"""
rv_op = chisquare

def __init__(self, nu, *args, **kwargs):
self.nu = nu = at.as_tensor_variable(floatX(nu))
super().__init__(alpha=nu / 2.0, beta=0.5, *args, **kwargs)
@classmethod
def dist(cls, nu, *args, **kwargs):
nu = at.as_tensor_variable(floatX(nu))
return super().dist([nu], *args, **kwargs)

def logp(value, nu):
"""
Calculate log-probability of ChiSquared distribution at specified value.
Parameters
----------
value: numeric
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or Aesara tensor
Returns
-------
TensorVariable
"""
return Gamma.logp(value, nu / 2, 2)

def logcdf(value, nu):
"""
Compute the log of the cumulative distribution function for ChiSquared distribution
at the specified value.
Parameters
----------
value: numeric or np.ndarray or `TensorVariable`
Value(s) for which log CDF is calculated. If the log CDF for
multiple values are desired the values must be provided in a numpy
array or `TensorVariable`.
Returns
-------
TensorVariable
"""
return Gamma.logcdf(value, nu / 2, 2)


# TODO: Remove this once logpt for multiplication is working!
Expand Down
17 changes: 14 additions & 3 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,15 +1048,26 @@ def test_half_normal(self):
lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_chi_squared(self):
def test_chisquared_logp(self):
self.check_logp(
ChiSquared,
Rplus,
{"nu": Rplusdunif},
{"nu": Rplus},
lambda value, nu: sp.chi2.logpdf(value, df=nu),
)

@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Fails on float32 due to numerical issues",
)
def test_chisquared_logcdf(self):
self.check_logcdf(
ChiSquared,
Rplus,
{"nu": Rplus},
lambda value, nu: sp.chi2.logcdf(value, df=nu),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_wald_logp(self):
self.check_logp(
Expand Down
19 changes: 13 additions & 6 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,6 @@ class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
params = {"kappa": 1.0, "b": 1.0, "mu": 0.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestChiSquared(BaseTestCases.BaseTestCase):
distribution = pm.ChiSquared
params = {"nu": 2.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestExGaussian(BaseTestCases.BaseTestCase):
distribution = pm.ExGaussian
Expand Down Expand Up @@ -782,6 +776,19 @@ class TestInverseGammaMuSigma(BaseTestDistribution):
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestChiSquared(BaseTestDistribution):
pymc_dist = pm.ChiSquared
pymc_dist_params = {"nu": 2.0}
expected_rv_op_params = {"nu": 2.0}
reference_dist_params = {"df": 2.0}
reference_dist = seeded_numpy_distribution_builder("chisquare")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestBinomial(BaseTestDistribution):
pymc_dist = pm.Binomial
pymc_dist_params = {"n": 100, "p": 0.33}
Expand Down

0 comments on commit b4a67d3

Please sign in to comment.