From 8300f164de18c38d661a873f1671d8b8898b85d9 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 5 Nov 2022 12:47:26 +0100 Subject: [PATCH] Move sampling code into sampling submodule This is a follow-up to #6257 where we split the `sampling.py` into two files. --- .github/workflows/tests.yml | 14 ++++---- docs/source/api/samplers.rst | 4 +-- docs/source/contributing/build_docs.md | 2 +- pymc/__init__.py | 1 - pymc/sampling/__init__.py | 16 +++++++++ .../forward.py} | 0 pymc/{sampling_jax.py => sampling/jax.py} | 2 +- pymc/{sampling.py => sampling/mcmc.py} | 4 +-- .../parallel.py} | 0 pymc/smc/kernels.py | 2 +- pymc/smc/sampling.py | 2 +- pymc/tests/distributions/test_mixture.py | 4 +-- pymc/tests/distributions/test_multivariate.py | 2 +- pymc/tests/distributions/test_timeseries.py | 4 +-- pymc/tests/sampler_fixtures.py | 2 +- .../test_forward.py} | 2 +- .../test_jax.py} | 2 +- .../test_mcmc.py} | 34 +++++++++---------- .../test_parallel.py} | 2 +- pymc/variational/opvi.py | 5 +-- scripts/run_mypy.py | 12 +++---- 21 files changed, 66 insertions(+), 50 deletions(-) create mode 100644 pymc/sampling/__init__.py rename pymc/{sampling_forward.py => sampling/forward.py} (100%) rename pymc/{sampling_jax.py => sampling/jax.py} (99%) rename pymc/{sampling.py => sampling/mcmc.py} (99%) rename pymc/{parallel_sampling.py => sampling/parallel.py} (100%) rename pymc/tests/{test_sampling_forward.py => sampling/test_forward.py} (99%) rename pymc/tests/{test_sampling_jax.py => sampling/test_jax.py} (99%) rename pymc/tests/{test_sampling.py => sampling/test_mcmc.py} (97%) rename pymc/tests/{test_parallel_sampling.py => sampling/test_parallel.py} (99%) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 31ee9a085f6..33c6225451b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,16 +59,16 @@ jobs: pymc/tests/distributions/test_censored.py pymc/tests/distributions/test_simulator.py pymc/tests/distributions/test_truncated.py - pymc/tests/test_sampling_forward.py + pymc/tests/sampling/forward.py pymc/tests/stats/test_convergence.py - | pymc/tests/tuning/test_scaling.py pymc/tests/tuning/test_starting.py - pymc/tests/test_sampling.py pymc/tests/distributions/test_dist_math.py pymc/tests/distributions/test_transform.py - pymc/tests/test_parallel_sampling.py + pymc/tests/sampling/mcmc.py + pymc/tests/sampling/parallel.py pymc/tests/test_printing.py - | @@ -150,8 +150,8 @@ jobs: test-subset: - pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py - pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py - - pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/test_parallel_sampling.py - - pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py + - pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/sampling/test_parallel.py + - pymc/tests/sampling/test_mcmc.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py fail-fast: false runs-on: ${{ matrix.os }} @@ -221,7 +221,7 @@ jobs: python-version: ["3.9"] test-subset: - | - pymc/tests/test_parallel_sampling.py + pymc/tests/sampling/parallel.py pymc/tests/test_data.py pymc/tests/test_model.py @@ -294,7 +294,7 @@ jobs: floatx: [float64] python-version: ["3.9"] test-subset: - - pymc/tests/test_sampling_jax.py + - pymc/tests/sampling/test_sampling_jax.py fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index 0acbc9df286..265462e10f4 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -13,8 +13,8 @@ This submodule contains functions for MCMC and forward sampling. sample_prior_predictive sample_posterior_predictive sample_posterior_predictive_w - sampling_jax.sample_blackjax_nuts - sampling_jax.sample_numpyro_nuts + sampling.jax.sample_blackjax_nuts + sampling.jax.sample_numpyro_nuts iter_sample init_nuts draw diff --git a/docs/source/contributing/build_docs.md b/docs/source/contributing/build_docs.md index 09e37328f4d..d0d4044ec87 100644 --- a/docs/source/contributing/build_docs.md +++ b/docs/source/contributing/build_docs.md @@ -9,7 +9,7 @@ To build the docs, run these commands at PyMC repository root: ```bash pip install -r requirements-dev.txt # Make sure the dev requirements are installed -pip install numpyro # Make sure `sampling_jax` docs can be built +pip install numpyro # Make sure `sampling/jax` docs can be built pip install -e . # Install local pymc version as installable package make clean # clean built docs from previous runs and intermediate outputs make html # Build docs diff --git a/pymc/__init__.py b/pymc/__init__.py index 27cdf6e2bb8..09314aa5c30 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -68,7 +68,6 @@ def __set_compiler_flags(): from pymc.plots import * from pymc.printing import * from pymc.sampling import * -from pymc.sampling_forward import * from pymc.smc import * from pymc.stats import * from pymc.step_methods import * diff --git a/pymc/sampling/__init__.py b/pymc/sampling/__init__.py new file mode 100644 index 00000000000..4a5d2a57a8c --- /dev/null +++ b/pymc/sampling/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pymc.sampling.forward import * +from pymc.sampling.mcmc import * diff --git a/pymc/sampling_forward.py b/pymc/sampling/forward.py similarity index 100% rename from pymc/sampling_forward.py rename to pymc/sampling/forward.py diff --git a/pymc/sampling_jax.py b/pymc/sampling/jax.py similarity index 99% rename from pymc/sampling_jax.py rename to pymc/sampling/jax.py index 3c435d3d276..c284b2fba1c 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling/jax.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from pymc.initial_point import StartDict -from pymc.sampling import _init_jitter +from pymc.sampling.mcmc import _init_jitter xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() diff --git a/pymc/sampling.py b/pymc/sampling/mcmc.py similarity index 99% rename from pymc/sampling.py rename to pymc/sampling/mcmc.py index 9b9b59c7ca3..4a497ff4681 100644 --- a/pymc/sampling.py +++ b/pymc/sampling/mcmc.py @@ -45,7 +45,7 @@ make_initial_point_fns_per_chain, ) from pymc.model import Model, modelcontext -from pymc.parallel_sampling import Draw, _cpu_count +from pymc.sampling.parallel import Draw, _cpu_count from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks from pymc.step_methods import NUTS, CompoundStep, DEMetropolis from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared @@ -1404,7 +1404,7 @@ def _mp_sample( mtrace : pymc.backends.base.MultiTrace A ``MultiTrace`` object that contains the samples for all chains. """ - import pymc.parallel_sampling as ps + import pymc.sampling.parallel as ps # We did draws += tune in pm.sample draws -= tune diff --git a/pymc/parallel_sampling.py b/pymc/sampling/parallel.py similarity index 100% rename from pymc/parallel_sampling.py rename to pymc/sampling/parallel.py diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 8059ec54b46..43d060da507 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -35,7 +35,7 @@ from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.model import Point, modelcontext -from pymc.sampling_forward import sample_prior_predictive +from pymc.sampling.forward import sample_prior_predictive from pymc.step_methods.metropolis import MultivariateNormalProposal from pymc.vartypes import discrete_types diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 87d198dcd4c..ddeed25a165 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -32,7 +32,7 @@ from pymc.backends.arviz import dict_to_dataset, to_inference_data from pymc.backends.base import MultiTrace from pymc.model import modelcontext -from pymc.parallel_sampling import _cpu_count +from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH diff --git a/pymc/tests/distributions/test_mixture.py b/pymc/tests/distributions/test_mixture.py index 113f3e37b1d..940aa219112 100644 --- a/pymc/tests/distributions/test_mixture.py +++ b/pymc/tests/distributions/test_mixture.py @@ -55,12 +55,12 @@ from pymc.distributions.transforms import _default_transform from pymc.math import expand_packed_triangular from pymc.model import Model -from pymc.sampling import sample -from pymc.sampling_forward import ( +from pymc.sampling.forward import ( draw, sample_posterior_predictive, sample_prior_predictive, ) +from pymc.sampling.mcmc import sample from pymc.step_methods import Metropolis from pymc.tests.distributions.util import ( Domain, diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index d023912ce2c..e7cb8286963 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -41,7 +41,7 @@ ) from pymc.distributions.shape_utils import change_dist_size, to_tuple from pymc.math import kronecker -from pymc.sampling_forward import draw +from pymc.sampling.forward import draw from pymc.tests.distributions.util import ( BaseTestDistributionRandom, Domain, diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 735b84b7076..31e5f07acc3 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -42,8 +42,8 @@ RandomWalk, ) from pymc.model import Model -from pymc.sampling import sample -from pymc.sampling_forward import draw, sample_posterior_predictive +from pymc.sampling.forward import draw, sample_posterior_predictive +from pymc.sampling.mcmc import sample from pymc.tests.distributions.util import assert_moment_is_expected from pymc.tests.helpers import select_by_precision diff --git a/pymc/tests/sampler_fixtures.py b/pymc/tests/sampler_fixtures.py index dabb3466e64..db66784adc8 100644 --- a/pymc/tests/sampler_fixtures.py +++ b/pymc/tests/sampler_fixtures.py @@ -178,7 +178,7 @@ def make_step(cls): if hasattr(cls, "step_args"): args.update(cls.step_args) if "scaling" not in args: - _, step = pm.sampling.init_nuts(n_init=10000, **args) + _, step = pm.sampling.mcmc.init_nuts(n_init=10000, **args) else: step = pm.NUTS(**args) return step diff --git a/pymc/tests/test_sampling_forward.py b/pymc/tests/sampling/test_forward.py similarity index 99% rename from pymc/tests/test_sampling_forward.py rename to pymc/tests/sampling/test_forward.py index d02d4df1f34..28a33c418d9 100644 --- a/pymc/tests/test_sampling_forward.py +++ b/pymc/tests/sampling/test_forward.py @@ -35,7 +35,7 @@ from pymc.aesaraf import compile_pymc from pymc.backends.base import MultiTrace -from pymc.sampling_forward import ( +from pymc.sampling.forward import ( compile_forward_sampling_function, get_vars_in_point_list, ) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/sampling/test_jax.py similarity index 99% rename from pymc/tests/test_sampling_jax.py rename to pymc/tests/sampling/test_jax.py index 9b0c0ab909d..8bf23af2efa 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/sampling/test_jax.py @@ -17,7 +17,7 @@ import pymc as pm with pytest.warns(UserWarning, match="module is experimental"): - from pymc.sampling_jax import ( + from pymc.sampling.jax import ( _get_batched_jittered_initial_points, _get_log_likelihood, _numpyro_nuts_defaults, diff --git a/pymc/tests/test_sampling.py b/pymc/tests/sampling/test_mcmc.py similarity index 97% rename from pymc/tests/test_sampling.py rename to pymc/tests/sampling/test_mcmc.py index 64fdb2a40ff..f801c17f4c4 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/sampling/test_mcmc.py @@ -35,7 +35,7 @@ from pymc.backends.ndarray import NDArray from pymc.distributions import transforms from pymc.exceptions import SamplingError -from pymc.sampling import assign_step_methods +from pymc.sampling.mcmc import assign_step_methods from pymc.stats.convergence import SamplerWarning, WarningType from pymc.step_methods import ( NUTS, @@ -57,7 +57,7 @@ def setup_method(self): def test_checks_seeds_kwarg(self): with self.model: with pytest.raises(ValueError, match="Number of seeds"): - pm.sampling.init_nuts(chains=2, random_seed=[1]) + pm.sampling.mcmc.init_nuts(chains=2, random_seed=[1]) class TestSample(SeededTest): @@ -208,7 +208,7 @@ def test_sample_args(self): def test_iter_sample(self): with self.model: - samps = pm.sampling.iter_sample( + samps = pm.sampling.mcmc.iter_sample( draws=5, step=self.step, start=self.start, @@ -255,7 +255,7 @@ def test_reset_tuning(self): with self.model: tune = 50 chains = 2 - start, step = pm.sampling.init_nuts(chains=chains, random_seed=[1, 2]) + start, step = pm.sampling.mcmc.init_nuts(chains=chains, random_seed=[1, 2]) with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) pm.sample(draws=2, tune=tune, chains=chains, step=step, initvals=start, cores=1) @@ -346,11 +346,11 @@ def test_sampler_stat_tune(self, cores): ) def test_sample_start_bad_shape(self, start, error): with pytest.raises(error): - pm.sampling._check_start_shape(self.model, start) + pm.sampling.mcmc._check_start_shape(self.model, start) @pytest.mark.parametrize("start", [{"x": np.array([1, 1])}, {"x": [10, 10]}, {"x": [-10, -10]}]) def test_sample_start_good_shape(self, start): - pm.sampling._check_start_shape(self.model, start) + pm.sampling.mcmc._check_start_shape(self.model, start) def test_sample_callback(self): callback = mock.Mock() @@ -515,7 +515,7 @@ def test_choose_chains(n_points, tune, expected_length, expected_n_traces): trace_1.record({"a": 0}) for _ in range(n_points[2]): trace_2.record({"a": 0}) - traces, length = pm.sampling._choose_chains([trace_0, trace_1, trace_2], tune=tune) + traces, length = pm.sampling.mcmc._choose_chains([trace_0, trace_1, trace_2], tune=tune) assert length == expected_length assert expected_n_traces == len(traces) @@ -575,29 +575,29 @@ def test_constant_named(self): class TestChooseBackend: def test_choose_backend_none(self): - with mock.patch("pymc.sampling.NDArray") as nd: - pm.sampling._choose_backend(None) + with mock.patch("pymc.backends.ndarray.NDArray") as nd: + pm.sampling.mcmc._choose_backend(None) assert nd.called def test_choose_backend_list_of_variables(self): - with mock.patch("pymc.sampling.NDArray") as nd: - pm.sampling._choose_backend(["var1", "var2"]) + with mock.patch("pymc.backends.ndarray.NDArray") as nd: + pm.sampling.mcmc._choose_backend(["var1", "var2"]) nd.assert_called_with(vars=["var1", "var2"]) def test_errors_and_warnings(self): with pm.Model(): A = pm.Normal("A") B = pm.Uniform("B") - strace = pm.sampling.NDArray(vars=[A, B]) + strace = pm.backends.ndarray.NDArray(vars=[A, B]) strace.setup(10, 0) with pytest.raises(ValueError, match="from existing MultiTrace"): - pm.sampling._choose_backend(trace=MultiTrace([strace])) + pm.sampling.mcmc._choose_backend(trace=MultiTrace([strace])) strace.record({"A": 2, "B_interval__": 0.1}) assert len(strace) == 1 with pytest.raises(ValueError, match="Continuation of traces"): - pm.sampling._choose_backend(trace=strace) + pm.sampling.mcmc._choose_backend(trace=strace) def check_exec_nuts_init(method): @@ -657,7 +657,7 @@ def test_init_jitter(initval, jitter_max_retries, expectation): # Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1) # and positive (valid) when it returns 1 (jitter = 1) with mock.patch("numpy.random.Generator.uniform", side_effect=[-1, -1, -1, 1, -1]): - start = pm.sampling._init_jitter( + start = pm.sampling.mcmc._init_jitter( model=m, initvals=None, seeds=[1], @@ -704,7 +704,7 @@ def test_log_warning_stats(caplog): stats = [s1, s2] with caplog.at_level(logging.WARNING): - pm.sampling.log_warning_stats(stats) + pm.sampling.mcmc.log_warning_stats(stats) # We have a list of stats dicts, because there might be several samplers involved. assert "too low" in caplog.records[0].message @@ -716,7 +716,7 @@ def test_log_warning_stats_knows_SamplerWarning(caplog): stats = [dict(warning=SamplerWarning(WarningType.BAD_ENERGY, "Not that interesting", "debug"))] with caplog.at_level(logging.DEBUG, logger="pymc"): - pm.sampling.log_warning_stats(stats) + pm.sampling.mcmc.log_warning_stats(stats) assert "Not that interesting" in caplog.records[0].message diff --git a/pymc/tests/test_parallel_sampling.py b/pymc/tests/sampling/test_parallel.py similarity index 99% rename from pymc/tests/test_parallel_sampling.py rename to pymc/tests/sampling/test_parallel.py index 2883acd297f..77ba9f48b7a 100644 --- a/pymc/tests/test_parallel_sampling.py +++ b/pymc/tests/sampling/test_parallel.py @@ -27,7 +27,7 @@ from aesara.tensor.type import TensorType import pymc as pm -import pymc.parallel_sampling as ps +import pymc.sampling.parallel as ps from pymc.aesaraf import floatX diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index d3ec0b8b2b9..be9e4304623 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -66,7 +66,8 @@ reseed_rngs, rvs_to_value_vars, ) -from pymc.backends import NDArray +from pymc.backends.base import MultiTrace +from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext @@ -1477,7 +1478,7 @@ def sample( finally: trace.close() - trace = pm.sampling.MultiTrace([trace]) + trace = MultiTrace([trace]) if not return_inferencedata: return trace else: diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 2b55afdddf3..c16cfaa9c39 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -49,10 +49,10 @@ pymc/ode/__init__.py pymc/ode/ode.py pymc/ode/utils.py -pymc/parallel_sampling.py pymc/plots/__init__.py -pymc/sampling.py -pymc/sampling_forward.py +pymc/sampling/forward.py +pymc/sampling/mcmc.py +pymc/sampling/parallel.py pymc/smc/__init__.py pymc/smc/sampling.py pymc/smc/kernels.py @@ -167,10 +167,10 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]): print("You can run `python scripts/run_mypy.py --verbose` to reproduce this test locally.") sys.exit(1) - if unexpected_passing == {"pymc/sampling_jax.py"}: - print("Letting you know that 'pymc/sampling_jax.py' unexpectedly passed.") + if unexpected_passing == {"pymc/sampling/jax.py"}: + print("Letting you know that 'pymc/sampling/jax.py' unexpectedly passed.") print("But this file is known to sometimes pass and sometimes not.") - print("Unless you tried to resolve problems in sampling_jax.py just ignore this message.") + print("Unless you tried to resolve problems in sampling/jax.py just ignore this message.") elif unexpected_passing: print("!!!!!!!!!") print(f"{len(unexpected_passing)} files unexpectedly passed the type checks:")