Skip to content

Commit

Permalink
Split sampling into three modules
Browse files Browse the repository at this point in the history
Closes #6141
  • Loading branch information
michaelosthege committed Nov 1, 2022
1 parent e57d1d7 commit 2c684d7
Show file tree
Hide file tree
Showing 15 changed files with 2,582 additions and 2,440 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ jobs:
pymc/tests/distributions/test_censored.py
pymc/tests/distributions/test_simulator.py
pymc/tests/distributions/test_truncated.py
pymc/tests/test_sampling_predictive.py
pymc/tests/stats/test_convergence.py
- |
pymc/tests/tuning/test_scaling.py
Expand Down Expand Up @@ -147,7 +149,7 @@ jobs:
python-version: ["3.8"]
test-subset:
- pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py
- pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
- pymc/tests/test_model.py pymc/tests/test_sampling_utils.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/test_parallel_sampling.py
- pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py

Expand Down
2 changes: 2 additions & 0 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __set_compiler_flags():
from pymc.plots import *
from pymc.printing import *
from pymc.sampling import *
from pymc.sampling_predictive import *
from pymc.sampling_utils import *
from pymc.smc import *
from pymc.stats import *
from pymc.step_methods import *
Expand Down
769 changes: 9 additions & 760 deletions pymc/sampling.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from pymc.initial_point import StartDict
from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter
from pymc.sampling import _init_jitter
from pymc.sampling_utils import RandomSeed, _get_seeds_per_chain

xla_flags = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
Expand Down
Loading

0 comments on commit 2c684d7

Please sign in to comment.