Skip to content

Commit

Permalink
Make default STEP_METHODS a list that can be modified
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 29, 2024
1 parent e442348 commit 71b6569
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
7 changes: 4 additions & 3 deletions pymc/step_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pymc.step_methods.compound import CompoundStep
from pymc.step_methods.compound import BlockedStep, CompoundStep
from pymc.step_methods.hmc import NUTS, HamiltonianMC
from pymc.step_methods.metropolis import (
BinaryGibbsMetropolis,
Expand All @@ -30,12 +30,13 @@
)
from pymc.step_methods.slicer import Slice

STEP_METHODS = (
# Other step methods can be added by appending to this list
STEP_METHODS: list[type[BlockedStep]] = [
NUTS,
HamiltonianMC,
Metropolis,
BinaryMetropolis,
BinaryGibbsMetropolis,
Slice,
CategoricalGibbsMetropolis,
)
]
16 changes: 11 additions & 5 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,12 +762,18 @@ def kill_grad(x):
steps = assign_step_methods(model, [])
assert isinstance(steps, Slice)

def test_modify_step_methods(self):
@pytest.fixture
def step_methods(self):
"""Make sure we reset the STEP_METHODS after the test is done."""
methods_copy = pm.STEP_METHODS.copy()
yield pm.STEP_METHODS
pm.STEP_METHODS.clear()
for method in methods_copy:
pm.STEP_METHODS.append(method)

def test_modify_step_methods(self, step_methods):
"""Test step methods can be changed"""
# remove nuts from step_methods
step_methods = list(pm.STEP_METHODS)
step_methods.remove(NUTS)
pm.STEP_METHODS = step_methods

with pm.Model() as model:
pm.Normal("x", 0, 1)
Expand All @@ -776,7 +782,7 @@ def test_modify_step_methods(self):
assert not isinstance(steps, NUTS)

# add back nuts
pm.STEP_METHODS = [*step_methods, NUTS]
step_methods.append(NUTS)

with pm.Model() as model:
pm.Normal("x", 0, 1)
Expand Down

0 comments on commit 71b6569

Please sign in to comment.