diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 9a032609606..d824045c7ef 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -33,6 +33,9 @@ def __init__(self, methods): for method in self.methods: if method.generates_stats: self.stats_dtypes.extend(method.stats_dtypes) + self.name = ( + f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]" + ) def step(self, point): if self.generates_stats: diff --git a/pymc/tests/test_step.py b/pymc/tests/test_step.py index 06b14c70cb2..9df55f57b6c 100644 --- a/pymc/tests/test_step.py +++ b/pymc/tests/test_step.py @@ -221,6 +221,16 @@ def test_blocked(self): assert not isinstance(sampler_instance, CompoundStep) assert isinstance(sampler_instance, sampler) + def test_name(self): + with Model() as m: + c1 = HalfNormal("c1") + c2 = HalfNormal("c2") + + step1 = NUTS([c1]) + step2 = Slice([c2]) + step = CompoundStep([step1, step2]) + assert step.name == "Compound[nuts, slice]" + class TestAssignStepMethods: def test_bernoulli(self):