From 57654dc5ef3bfb81e60cc6db2724f124d0ba3788 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 20 Jan 2022 19:25:14 +0100 Subject: [PATCH] Add name to CompountStep --- pymc/step_methods/compound.py | 3 +++ pymc/tests/test_step.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 9a03260960..d824045c7e 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -33,6 +33,9 @@ def __init__(self, methods): for method in self.methods: if method.generates_stats: self.stats_dtypes.extend(method.stats_dtypes) + self.name = ( + f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]" + ) def step(self, point): if self.generates_stats: diff --git a/pymc/tests/test_step.py b/pymc/tests/test_step.py index 06b14c70cb..9df55f57b6 100644 --- a/pymc/tests/test_step.py +++ b/pymc/tests/test_step.py @@ -221,6 +221,16 @@ def test_blocked(self): assert not isinstance(sampler_instance, CompoundStep) assert isinstance(sampler_instance, sampler) + def test_name(self): + with Model() as m: + c1 = HalfNormal("c1") + c2 = HalfNormal("c2") + + step1 = NUTS([c1]) + step2 = Slice([c2]) + step = CompoundStep([step1, step2]) + assert step.name == "Compound[nuts, slice]" + class TestAssignStepMethods: def test_bernoulli(self):