Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
bwengals committed Oct 20, 2023
1 parent f8dc8d0 commit 310f7e1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
34 changes: 15 additions & 19 deletions pymc_experimental/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,22 @@

import numpy as np
import pytensor.tensor as pt
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.continuous import (
check_parameters, DIST_PARAMETER_TYPES, PositiveContinuous
DIST_PARAMETER_TYPES,
PositiveContinuous,
check_parameters,
)
from pymc.distributions.distribution import Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.pytensorf import floatX
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from scipy import stats

from pymc_experimental.distributions.dist_math import (
studentt_kld_distance,
pc_prior_studentt_kld_dist_inv_op,
pc_prior_studentt_logp,
pc_prior_studentt_kld_dist_inv_op,
studentt_kld_distance,
)


Expand Down Expand Up @@ -226,7 +227,7 @@ def moment(rv, size, mu, sigma, xi):
mode = pt.full(size, mode)
return mode


class PCPriorStudentT_dof_RV(RandomVariable):
name = "pc_prior_studentt_dof"
ndim_supp = 0
Expand All @@ -236,9 +237,9 @@ class PCPriorStudentT_dof_RV(RandomVariable):

@classmethod
def rng_fn(cls, rng, lam, size=None) -> np.ndarray:
return pc_prior_studentt_kld_dist_inv_op.spline(
rng.exponential(scale=1.0 / lam, size=size)
)
return pc_prior_studentt_kld_dist_inv_op.spline(rng.exponential(scale=1.0 / lam, size=size))


pc_prior_studentt_dof = PCPriorStudentT_dof_RV()


Expand All @@ -264,12 +265,11 @@ def moment(rv, size, lam):
mean = pt.full(size, mean)
return mean


@classmethod
def get_lam(cls, alpha=None, U=None, lam=None):
if (alpha is not None) and (U is not None):
return -np.log(alpha) / studentt_kld_distance(U)
elif (lam is not None):
elif lam is not None:
return lam
else:
raise ValueError(
Expand All @@ -280,12 +280,8 @@ def get_lam(cls, alpha=None, U=None, lam=None):
def logp(value, lam):
res = pc_prior_studentt_logp(value, lam)
res = pt.switch(
pt.lt(value, 2 + 1e-6), # 2 + 1e-6 smallest value for nu
pt.lt(value, 2 + 1e-6), # 2 + 1e-6 smallest value for nu
-np.inf,
res,
)
return check_parameters(
res,
lam > 0,
msg="lam > 0"
)
return check_parameters(res, lam > 0, msg="lam > 0")
27 changes: 26 additions & 1 deletion pymc_experimental/tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,32 @@
)

# the distributions to be tested
from pymc_experimental.distributions import GenExtreme
from pymc_experimental.distributions import GenExtreme, PCPriorStudentT_dof


class TestPCPriorStudentT_dof:
"""The test compares the result to what's implemented in INLA. Since it's a specialized
distribution the user shouldn't ever draw random samples from it, calculate the logcdf, or
any of that. The log-probability won't match up exactly to INLA. INLA uses a numeric
approximation and this implementation uses an exact solution in the relevant domain and a
numerical approximation out to the tail.
"""

@pytest.mark.parameterize(
"test_case",
[
{"U": 30, "alpha": 0.5, "dof": 5, "inla_result": -4.792407},
{"U": 30, "alpha": 0.5, "dof": 5000, "inla_result": -14.03713},
{"U": 30, "alpha": 0.5, "dof": 1, "inla_result": -np.inf}, # actually INLA throws error
{"U": 30, "alpha": 0.1, "dof": 5, "inla_result": -15.25691},
{"U": 30, "alpha": 0.9, "dof": 5, "inla_result": -2.416043},
{"U": 5, "alpha": 0.99, "dof": 5, "inla_result": -5.992945},
{"U": 5, "alpha": 0.01, "dof": 5, "inla_result": -4.460736},
],
)
def test_logp(self, test_case):
d = PCPriorStudentT_dof.dist(U=test_case["U"], alpha=test_case["alpha"])
npt.assert_allclose(pm.logp(d, test_case["dof"]), test_case["inla_result"], rtol=0.1)


class TestGenExtremeClass:
Expand Down

0 comments on commit 310f7e1

Please sign in to comment.