Skip to content

Commit

Permalink
Move sampling code into sampling submodule
Browse files Browse the repository at this point in the history
This is a follow-up to pymc-devs#6257 where we split the `sampling.py` into two files.
  • Loading branch information
michaelosthege committed Nov 5, 2022
1 parent 9c313cb commit 8300f16
Show file tree
Hide file tree
Showing 21 changed files with 66 additions and 50 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ jobs:
pymc/tests/distributions/test_censored.py
pymc/tests/distributions/test_simulator.py
pymc/tests/distributions/test_truncated.py
pymc/tests/test_sampling_forward.py
pymc/tests/sampling/forward.py
pymc/tests/stats/test_convergence.py
- |
pymc/tests/tuning/test_scaling.py
pymc/tests/tuning/test_starting.py
pymc/tests/test_sampling.py
pymc/tests/distributions/test_dist_math.py
pymc/tests/distributions/test_transform.py
pymc/tests/test_parallel_sampling.py
pymc/tests/sampling/mcmc.py
pymc/tests/sampling/parallel.py
pymc/tests/test_printing.py
- |
Expand Down Expand Up @@ -150,8 +150,8 @@ jobs:
test-subset:
- pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py
- pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/test_parallel_sampling.py
- pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/sampling/test_parallel.py
- pymc/tests/sampling/test_mcmc.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py

fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down Expand Up @@ -221,7 +221,7 @@ jobs:
python-version: ["3.9"]
test-subset:
- |
pymc/tests/test_parallel_sampling.py
pymc/tests/sampling/parallel.py
pymc/tests/test_data.py
pymc/tests/test_model.py
Expand Down Expand Up @@ -294,7 +294,7 @@ jobs:
floatx: [float64]
python-version: ["3.9"]
test-subset:
- pymc/tests/test_sampling_jax.py
- pymc/tests/sampling/test_sampling_jax.py
fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ This submodule contains functions for MCMC and forward sampling.
sample_prior_predictive
sample_posterior_predictive
sample_posterior_predictive_w
sampling_jax.sample_blackjax_nuts
sampling_jax.sample_numpyro_nuts
sampling.jax.sample_blackjax_nuts
sampling.jax.sample_numpyro_nuts
iter_sample
init_nuts
draw
Expand Down
2 changes: 1 addition & 1 deletion docs/source/contributing/build_docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ To build the docs, run these commands at PyMC repository root:

```bash
pip install -r requirements-dev.txt # Make sure the dev requirements are installed
pip install numpyro # Make sure `sampling_jax` docs can be built
pip install numpyro # Make sure `sampling/jax` docs can be built
pip install -e . # Install local pymc version as installable package
make clean # clean built docs from previous runs and intermediate outputs
make html # Build docs
Expand Down
1 change: 0 additions & 1 deletion pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __set_compiler_flags():
from pymc.plots import *
from pymc.printing import *
from pymc.sampling import *
from pymc.sampling_forward import *
from pymc.smc import *
from pymc.stats import *
from pymc.step_methods import *
Expand Down
16 changes: 16 additions & 0 deletions pymc/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pymc.sampling.forward import *
from pymc.sampling.mcmc import *
File renamed without changes.
2 changes: 1 addition & 1 deletion pymc/sampling_jax.py → pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from pymc.initial_point import StartDict
from pymc.sampling import _init_jitter
from pymc.sampling.mcmc import _init_jitter

xla_flags = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
Expand Down
4 changes: 2 additions & 2 deletions pymc/sampling.py → pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
make_initial_point_fns_per_chain,
)
from pymc.model import Model, modelcontext
from pymc.parallel_sampling import Draw, _cpu_count
from pymc.sampling.parallel import Draw, _cpu_count
from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
Expand Down Expand Up @@ -1404,7 +1404,7 @@ def _mp_sample(
mtrace : pymc.backends.base.MultiTrace
A ``MultiTrace`` object that contains the samples for all chains.
"""
import pymc.parallel_sampling as ps
import pymc.sampling.parallel as ps

# We did draws += tune in pm.sample
draws -= tune
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pymc.backends.ndarray import NDArray
from pymc.blocking import DictToArrayBijection
from pymc.model import Point, modelcontext
from pymc.sampling_forward import sample_prior_predictive
from pymc.sampling.forward import sample_prior_predictive
from pymc.step_methods.metropolis import MultivariateNormalProposal
from pymc.vartypes import discrete_types

Expand Down
2 changes: 1 addition & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pymc.backends.arviz import dict_to_dataset, to_inference_data
from pymc.backends.base import MultiTrace
from pymc.model import modelcontext
from pymc.parallel_sampling import _cpu_count
from pymc.sampling.parallel import _cpu_count
from pymc.smc.kernels import IMH


Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@
from pymc.distributions.transforms import _default_transform
from pymc.math import expand_packed_triangular
from pymc.model import Model
from pymc.sampling import sample
from pymc.sampling_forward import (
from pymc.sampling.forward import (
draw,
sample_posterior_predictive,
sample_prior_predictive,
)
from pymc.sampling.mcmc import sample
from pymc.step_methods import Metropolis
from pymc.tests.distributions.util import (
Domain,
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from pymc.distributions.shape_utils import change_dist_size, to_tuple
from pymc.math import kronecker
from pymc.sampling_forward import draw
from pymc.sampling.forward import draw
from pymc.tests.distributions.util import (
BaseTestDistributionRandom,
Domain,
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/distributions/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
RandomWalk,
)
from pymc.model import Model
from pymc.sampling import sample
from pymc.sampling_forward import draw, sample_posterior_predictive
from pymc.sampling.forward import draw, sample_posterior_predictive
from pymc.sampling.mcmc import sample
from pymc.tests.distributions.util import assert_moment_is_expected
from pymc.tests.helpers import select_by_precision

Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/sampler_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def make_step(cls):
if hasattr(cls, "step_args"):
args.update(cls.step_args)
if "scaling" not in args:
_, step = pm.sampling.init_nuts(n_init=10000, **args)
_, step = pm.sampling.mcmc.init_nuts(n_init=10000, **args)
else:
step = pm.NUTS(**args)
return step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from pymc.aesaraf import compile_pymc
from pymc.backends.base import MultiTrace
from pymc.sampling_forward import (
from pymc.sampling.forward import (
compile_forward_sampling_function,
get_vars_in_point_list,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pymc as pm

with pytest.warns(UserWarning, match="module is experimental"):
from pymc.sampling_jax import (
from pymc.sampling.jax import (
_get_batched_jittered_initial_points,
_get_log_likelihood,
_numpyro_nuts_defaults,
Expand Down
34 changes: 17 additions & 17 deletions pymc/tests/test_sampling.py → pymc/tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pymc.backends.ndarray import NDArray
from pymc.distributions import transforms
from pymc.exceptions import SamplingError
from pymc.sampling import assign_step_methods
from pymc.sampling.mcmc import assign_step_methods
from pymc.stats.convergence import SamplerWarning, WarningType
from pymc.step_methods import (
NUTS,
Expand All @@ -57,7 +57,7 @@ def setup_method(self):
def test_checks_seeds_kwarg(self):
with self.model:
with pytest.raises(ValueError, match="Number of seeds"):
pm.sampling.init_nuts(chains=2, random_seed=[1])
pm.sampling.mcmc.init_nuts(chains=2, random_seed=[1])


class TestSample(SeededTest):
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_sample_args(self):

def test_iter_sample(self):
with self.model:
samps = pm.sampling.iter_sample(
samps = pm.sampling.mcmc.iter_sample(
draws=5,
step=self.step,
start=self.start,
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_reset_tuning(self):
with self.model:
tune = 50
chains = 2
start, step = pm.sampling.init_nuts(chains=chains, random_seed=[1, 2])
start, step = pm.sampling.mcmc.init_nuts(chains=chains, random_seed=[1, 2])
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
pm.sample(draws=2, tune=tune, chains=chains, step=step, initvals=start, cores=1)
Expand Down Expand Up @@ -346,11 +346,11 @@ def test_sampler_stat_tune(self, cores):
)
def test_sample_start_bad_shape(self, start, error):
with pytest.raises(error):
pm.sampling._check_start_shape(self.model, start)
pm.sampling.mcmc._check_start_shape(self.model, start)

@pytest.mark.parametrize("start", [{"x": np.array([1, 1])}, {"x": [10, 10]}, {"x": [-10, -10]}])
def test_sample_start_good_shape(self, start):
pm.sampling._check_start_shape(self.model, start)
pm.sampling.mcmc._check_start_shape(self.model, start)

def test_sample_callback(self):
callback = mock.Mock()
Expand Down Expand Up @@ -515,7 +515,7 @@ def test_choose_chains(n_points, tune, expected_length, expected_n_traces):
trace_1.record({"a": 0})
for _ in range(n_points[2]):
trace_2.record({"a": 0})
traces, length = pm.sampling._choose_chains([trace_0, trace_1, trace_2], tune=tune)
traces, length = pm.sampling.mcmc._choose_chains([trace_0, trace_1, trace_2], tune=tune)
assert length == expected_length
assert expected_n_traces == len(traces)

Expand Down Expand Up @@ -575,29 +575,29 @@ def test_constant_named(self):

class TestChooseBackend:
def test_choose_backend_none(self):
with mock.patch("pymc.sampling.NDArray") as nd:
pm.sampling._choose_backend(None)
with mock.patch("pymc.backends.ndarray.NDArray") as nd:
pm.sampling.mcmc._choose_backend(None)
assert nd.called

def test_choose_backend_list_of_variables(self):
with mock.patch("pymc.sampling.NDArray") as nd:
pm.sampling._choose_backend(["var1", "var2"])
with mock.patch("pymc.backends.ndarray.NDArray") as nd:
pm.sampling.mcmc._choose_backend(["var1", "var2"])
nd.assert_called_with(vars=["var1", "var2"])

def test_errors_and_warnings(self):
with pm.Model():
A = pm.Normal("A")
B = pm.Uniform("B")
strace = pm.sampling.NDArray(vars=[A, B])
strace = pm.backends.ndarray.NDArray(vars=[A, B])
strace.setup(10, 0)

with pytest.raises(ValueError, match="from existing MultiTrace"):
pm.sampling._choose_backend(trace=MultiTrace([strace]))
pm.sampling.mcmc._choose_backend(trace=MultiTrace([strace]))

strace.record({"A": 2, "B_interval__": 0.1})
assert len(strace) == 1
with pytest.raises(ValueError, match="Continuation of traces"):
pm.sampling._choose_backend(trace=strace)
pm.sampling.mcmc._choose_backend(trace=strace)


def check_exec_nuts_init(method):
Expand Down Expand Up @@ -657,7 +657,7 @@ def test_init_jitter(initval, jitter_max_retries, expectation):
# Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1)
# and positive (valid) when it returns 1 (jitter = 1)
with mock.patch("numpy.random.Generator.uniform", side_effect=[-1, -1, -1, 1, -1]):
start = pm.sampling._init_jitter(
start = pm.sampling.mcmc._init_jitter(
model=m,
initvals=None,
seeds=[1],
Expand Down Expand Up @@ -704,7 +704,7 @@ def test_log_warning_stats(caplog):
stats = [s1, s2]

with caplog.at_level(logging.WARNING):
pm.sampling.log_warning_stats(stats)
pm.sampling.mcmc.log_warning_stats(stats)

# We have a list of stats dicts, because there might be several samplers involved.
assert "too low" in caplog.records[0].message
Expand All @@ -716,7 +716,7 @@ def test_log_warning_stats_knows_SamplerWarning(caplog):
stats = [dict(warning=SamplerWarning(WarningType.BAD_ENERGY, "Not that interesting", "debug"))]

with caplog.at_level(logging.DEBUG, logger="pymc"):
pm.sampling.log_warning_stats(stats)
pm.sampling.mcmc.log_warning_stats(stats)

assert "Not that interesting" in caplog.records[0].message

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from aesara.tensor.type import TensorType

import pymc as pm
import pymc.parallel_sampling as ps
import pymc.sampling.parallel as ps

from pymc.aesaraf import floatX

Expand Down
5 changes: 3 additions & 2 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
reseed_rngs,
rvs_to_value_vars,
)
from pymc.backends import NDArray
from pymc.backends.base import MultiTrace
from pymc.backends.ndarray import NDArray
from pymc.blocking import DictToArrayBijection
from pymc.initial_point import make_initial_point_fn
from pymc.model import modelcontext
Expand Down Expand Up @@ -1477,7 +1478,7 @@ def sample(
finally:
trace.close()

trace = pm.sampling.MultiTrace([trace])
trace = MultiTrace([trace])
if not return_inferencedata:
return trace
else:
Expand Down
12 changes: 6 additions & 6 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
pymc/ode/__init__.py
pymc/ode/ode.py
pymc/ode/utils.py
pymc/parallel_sampling.py
pymc/plots/__init__.py
pymc/sampling.py
pymc/sampling_forward.py
pymc/sampling/forward.py
pymc/sampling/mcmc.py
pymc/sampling/parallel.py
pymc/smc/__init__.py
pymc/smc/sampling.py
pymc/smc/kernels.py
Expand Down Expand Up @@ -167,10 +167,10 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]):
print("You can run `python scripts/run_mypy.py --verbose` to reproduce this test locally.")
sys.exit(1)

if unexpected_passing == {"pymc/sampling_jax.py"}:
print("Letting you know that 'pymc/sampling_jax.py' unexpectedly passed.")
if unexpected_passing == {"pymc/sampling/jax.py"}:
print("Letting you know that 'pymc/sampling/jax.py' unexpectedly passed.")
print("But this file is known to sometimes pass and sometimes not.")
print("Unless you tried to resolve problems in sampling_jax.py just ignore this message.")
print("Unless you tried to resolve problems in sampling/jax.py just ignore this message.")
elif unexpected_passing:
print("!!!!!!!!!")
print(f"{len(unexpected_passing)} files unexpectedly passed the type checks:")
Expand Down

0 comments on commit 8300f16

Please sign in to comment.