Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish restructuring the tests to follow the structure of the code #6125

Merged
merged 12 commits into from
Oct 1, 2022
15 changes: 7 additions & 8 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
pymc/tests/test_aesaraf.py
pymc/tests/test_math.py
pymc/tests/backends/test_ndarray.py
pymc/tests/test_hmc.py
pymc/tests/step_methods/hmc/test_hmc.py
pymc/tests/test_func_utils.py
pymc/tests/distributions/test_shape_utils.py
pymc/tests/distributions/test_mixture.py
Expand All @@ -63,8 +63,7 @@ jobs:
- |
pymc/tests/tuning/test_scaling.py
pymc/tests/tuning/test_starting.py
pymc/tests/test_shared.py
pymc/tests/test_types.py
pymc/tests/test_sampling.py
pymc/tests/distributions/test_dist_math.py
pymc/tests/distributions/test_transform.py
pymc/tests/test_parallel_sampling.py
Expand Down Expand Up @@ -147,10 +146,10 @@ jobs:
floatx: [float64]
python-version: ["3.8"]
test-subset:
- pymc/tests/test_variational_inference.py pymc/tests/test_initial_point.py
- pymc/tests/test_model.py pymc/tests/test_step.py
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/test_smc.py pymc/tests/test_parallel_sampling.py
- pymc/tests/test_sampling.py pymc/tests/test_posteriors.py
- pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py
- pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/test_parallel_sampling.py
- pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it enough to do -pymc/tests/step_methods/ ?

Copy link
Member Author

@Armavica Armavica Sep 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, but here I selected only the few files where tests from test_posteriors ended up, to not increase too much the runtime of the tests. The rest is already tested in other platforms. I am not sure how much of a difference that makes, though. If we test the whole of step_methods I will also check that the check_all_tests_are_covered.py script understands what happens.


fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down Expand Up @@ -222,7 +221,7 @@ jobs:
- |
pymc/tests/test_parallel_sampling.py
pymc/tests/test_data.py
pymc/tests/test_missing.py
pymc/tests/test_model.py

- |
pymc/tests/test_sampling.py
Expand Down
70 changes: 70 additions & 0 deletions pymc/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import contextlib
import shutil
import tempfile
import warnings

from logging.handlers import BufferingHandler

Expand All @@ -24,7 +27,11 @@
from aesara.graph.rewriting.basic import in2out
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream

import pymc as pm

from pymc.aesaraf import at_rng, local_check_parameter_to_ninf_switch, set_at_rng
from pymc.tests.checks import close_to
from pymc.tests.models import mv_simple, mv_simple_coarse


class SeededTest:
Expand Down Expand Up @@ -148,3 +155,66 @@ def assert_random_state_equal(state1, state2):
(in2out(local_check_parameter_to_ninf_switch), -1)
)
)


class StepMethodTester:
def setup_class(self):
self.temp_dir = tempfile.mkdtemp()

def teardown_class(self):
shutil.rmtree(self.temp_dir)

def check_stat(self, check, idata, name):
group = idata.posterior
for (var, stat, value, bound) in check:
s = stat(group[var].sel(chain=0), axis=0)
close_to(s, value, bound, name)

def check_stat_dtype(self, step, idata):
# TODO: This check does not confirm the announced dtypes are correct as the
# sampling machinery will convert them automatically.
for stats_dtypes in getattr(step, "stats_dtypes", []):
for stat, dtype in stats_dtypes.items():
if stat == "tune":
continue
assert idata.sample_stats[stat].dtype == np.dtype(dtype)

def step_continuous(self, step_fn, draws):
start, model, (mu, C) = mv_simple()
unc = np.diag(C) ** 0.5
check = (("x", np.mean, mu, unc / 10), ("x", np.std, unc, unc / 10))
_, model_coarse, _ = mv_simple_coarse()
with model:
step = step_fn(C, model_coarse)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "More chains .* than draws .*", UserWarning)
idata = pm.sample(
tune=1000,
draws=draws,
chains=1,
step=step,
initvals=start,
model=model,
random_seed=1,
)
self.check_stat(check, idata, step.__class__.__name__)
self.check_stat_dtype(idata, step)


class RVsAssignmentStepsTester:
"""
Test that step methods convert input RVs to respective value vars
Step methods are tested with one and two variables to cover compound
the special branches in `BlockedStep.__new__`
"""

def continuous_steps(self, step, step_kwargs):
with pm.Model() as m:
c1 = pm.HalfNormal("c1")
c2 = pm.HalfNormal("c2")

with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(
step([c1, c2], **step_kwargs).vars
)
7 changes: 7 additions & 0 deletions pymc/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def simple_model():
return model.initial_point(), model, (mu, tau**-0.5)


def another_simple_model():
_, _model, _ = simple_model()
with _model:
pm.Potential("pot", at.ones((10, 10)))
return _model


def simple_categorical():
p = floatX_array([0.1, 0.2, 0.3, 0.4])
v = floatX_array([0.0, 1.0, 2.0, 3.0])
Expand Down
Empty file added pymc/tests/smc/__init__.py
Empty file.
File renamed without changes.
33 changes: 26 additions & 7 deletions pymc/tests/test_hmc.py → pymc/tests/step_methods/hmc/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import warnings

import numpy as np
import numpy.testing as npt
import pytest

import pymc
import pymc as pm

from pymc.aesaraf import floatX
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.step_methods.hmc import HamiltonianMC
from pymc.step_methods.hmc.base_hmc import BaseHMC
from pymc.tests import models
from pymc.tests.helpers import RVsAssignmentStepsTester, StepMethodTester


class TestStepHamiltonianMC(StepMethodTester):
@pytest.mark.parametrize(
"step_fn, draws",
[
(lambda C, _: HamiltonianMC(scaling=C, is_cov=True, blocked=False), 1000),
(lambda C, _: HamiltonianMC(scaling=C, is_cov=True), 1000),
],
)
def test_step_continuous(self, step_fn, draws):
self.step_continuous(step_fn, draws)


logger = logging.getLogger("pymc")
class TestRVsAssignmentHamiltonianMC(RVsAssignmentStepsTester):
@pytest.mark.parametrize("step, step_kwargs", [(HamiltonianMC, {})])
def test_continuous_steps(self, step, step_kwargs):
self.continuous_steps(step, step_kwargs)


def test_leapfrog_reversible():
Expand Down Expand Up @@ -57,12 +76,12 @@ def _hamiltonian_step(self, *args, **kwargs):


def test_nuts_tuning():
with pymc.Model():
pymc.Normal("mu", mu=0, sigma=1)
step = pymc.NUTS()
with pm.Model():
pm.Normal("mu", mu=0, sigma=1)
step = pm.NUTS()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
idata = pymc.sample(
idata = pm.sample(
10, step=step, tune=5, discard_tuned_samples=False, progressbar=False, chains=1
)

Expand Down
Loading