From 71b65692c43cb580f4812c27a077cbd716c51a68 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 29 Mar 2024 17:43:26 +0100 Subject: [PATCH] Make default STEP_METHODS a list that can be modified --- pymc/step_methods/__init__.py | 7 ++++--- tests/sampling/test_mcmc.py | 16 +++++++++++----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/pymc/step_methods/__init__.py b/pymc/step_methods/__init__.py index 3413609514a..5f44acc728c 100644 --- a/pymc/step_methods/__init__.py +++ b/pymc/step_methods/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pymc.step_methods.compound import CompoundStep +from pymc.step_methods.compound import BlockedStep, CompoundStep from pymc.step_methods.hmc import NUTS, HamiltonianMC from pymc.step_methods.metropolis import ( BinaryGibbsMetropolis, @@ -30,7 +30,8 @@ ) from pymc.step_methods.slicer import Slice -STEP_METHODS = ( +# Other step methods can be added by appending to this list +STEP_METHODS: list[type[BlockedStep]] = [ NUTS, HamiltonianMC, Metropolis, @@ -38,4 +39,4 @@ BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis, -) +] diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index a18430818d9..3f676d08466 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -762,12 +762,18 @@ def kill_grad(x): steps = assign_step_methods(model, []) assert isinstance(steps, Slice) - def test_modify_step_methods(self): + @pytest.fixture + def step_methods(self): + """Make sure we reset the STEP_METHODS after the test is done.""" + methods_copy = pm.STEP_METHODS.copy() + yield pm.STEP_METHODS + pm.STEP_METHODS.clear() + for method in methods_copy: + pm.STEP_METHODS.append(method) + + def test_modify_step_methods(self, step_methods): """Test step methods can be changed""" - # remove nuts from step_methods - step_methods = list(pm.STEP_METHODS) step_methods.remove(NUTS) - pm.STEP_METHODS = step_methods with pm.Model() as model: pm.Normal("x", 0, 1) @@ -776,7 +782,7 @@ def test_modify_step_methods(self): assert not isinstance(steps, NUTS) # add back nuts - pm.STEP_METHODS = [*step_methods, NUTS] + step_methods.append(NUTS) with pm.Model() as model: pm.Normal("x", 0, 1)