diff --git a/pymc/model.py b/pymc/model.py index 5b580201be..55303752aa 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -474,7 +474,7 @@ def __init__(self, mean=0, sigma=1, name=''): # 3) you can create variables with Var method self.Var('v1', Normal.dist(mu=mean, sigma=sd)) - # this will create variable named like '{prefix/}v1' + # this will create variable named like '{prefix::}v1' # and assign attribute 'v1' to instance created # variable can be accessed with self.v1 or self['v1'] @@ -516,7 +516,7 @@ def __init__(self, mean=0, sigma=1, name=''): CustomModel(mean=1, name='first') CustomModel(mean=2, name='second') - # variables inside both scopes will be named like `first/*`, `second/*` + # variables inside both scopes will be named like `first::*`, `second::*` """ @@ -538,6 +538,12 @@ def __new__(cls, *args, **kwargs): instance._aesara_config = kwargs.get("aesara_config", {}) return instance + @staticmethod + def _validate_name(name): + if name.endswith(":"): + raise KeyError("name should not end with `:`") + return name + def __init__( self, name="", @@ -545,7 +551,7 @@ def __init__( check_bounds=True, rng_seeder: Optional[Union[int, np.random.RandomState]] = None, ): - self.name = name + self.name = self._validate_name(name) self.check_bounds = check_bounds if rng_seeder is None: @@ -1462,14 +1468,15 @@ def prefix(self) -> str: if self.isroot or not self.parent.prefix: name = self.name else: - name = f"{self.parent.prefix}/{self.name}" - return name.strip("/") + name = f"{self.parent.prefix}::{self.name}" + return name def name_for(self, name): """Checks if name has prefix and adds if needed""" + name = self._validate_name(name) if self.prefix: if not name.startswith(self.prefix): - return f"{self.prefix}/{name}" + return f"{self.prefix}::{name}" else: return name else: @@ -1477,10 +1484,11 @@ def name_for(self, name): def name_of(self, name): """Checks if name has prefix and deletes if needed""" + name = self._validate_name(name) if not self.prefix or not name: return name - elif name.startswith(self.prefix + "/"): - return name[len(self.prefix) + 1 :] + elif name.startswith(self.prefix + "::"): + return name[len(self.prefix) + 2 :] else: return name diff --git a/pymc/tests/test_data_container.py b/pymc/tests/test_data_container.py index c2bf69e22b..98460c6ac2 100644 --- a/pymc/tests/test_data_container.py +++ b/pymc/tests/test_data_container.py @@ -399,8 +399,8 @@ def test_data_naming(): with pm.Model("named_model") as model: x = pm.ConstantData("x", [1.0, 2.0, 3.0]) y = pm.Normal("y") - assert y.name == "named_model/y" - assert x.name == "named_model/x" + assert y.name == "named_model::y" + assert x.name == "named_model::x" def test_get_data(): diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index d22c2ffaa4..70b3d1ea6e 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -18,6 +18,7 @@ import aesara import aesara.sparse as sparse import aesara.tensor as at +import arviz as az import cloudpickle import numpy as np import numpy.ma as ma @@ -94,20 +95,20 @@ def test_context_passes_vars_to_parent_model(self): usermodel2.register_rv(pm.Normal.dist(), "v3") pm.Normal("v4") # this variable is created in parent model too - assert "another/v2" in model.named_vars - assert "another/v3" in model.named_vars - assert "another/v3" in usermodel2.named_vars - assert "another/v4" in model.named_vars - assert "another/v4" in usermodel2.named_vars + assert "another::v2" in model.named_vars + assert "another::v3" in model.named_vars + assert "another::v3" in usermodel2.named_vars + assert "another::v4" in model.named_vars + assert "another::v4" in usermodel2.named_vars assert hasattr(usermodel2, "v3") assert hasattr(usermodel2, "v2") assert hasattr(usermodel2, "v4") # When you create a class based model you should follow some rules with model: m = NewModel("one_more") - assert m.d is model["one_more/d"] - assert m["d"] is model["one_more/d"] - assert m["one_more/d"] is model["one_more/d"] + assert m.d is model["one_more::d"] + assert m["d"] is model["one_more::d"] + assert m["one_more::d"] is model["one_more::d"] class TestNested: @@ -123,8 +124,8 @@ def test_nest_context_works(self): def test_named_context(self): with pm.Model() as m: NewModel(name="new") - assert "new/v1" in m.named_vars - assert "new/v2" in m.named_vars + assert "new::v1" in m.named_vars + assert "new::v2" in m.named_vars def test_docstring_example1(self): usage1 = DocstringModel() @@ -137,10 +138,10 @@ def test_docstring_example1(self): def test_docstring_example2(self): with pm.Model() as model: DocstringModel(name="prefix") - assert "prefix/v1" in model.named_vars - assert "prefix/v2" in model.named_vars - assert "prefix/v3" in model.named_vars - assert "prefix/v3_sq" in model.named_vars + assert "prefix::v1" in model.named_vars + assert "prefix::v2" in model.named_vars + assert "prefix::v3" in model.named_vars + assert "prefix::v3_sq" in model.named_vars assert len(model.potentials), 1 def test_duplicates_detection(self): @@ -160,14 +161,30 @@ def test_nested_named_model_repeated(self): b = pm.Normal("var") with pm.Model("sub"): b = pm.Normal("var") - assert {"sub/var", "sub/sub/var"} == set(model.named_vars.keys()) + assert {"sub::var", "sub::sub::var"} == set(model.named_vars.keys()) def test_nested_named_model(self): with pm.Model("sub1") as model: b = pm.Normal("var") with pm.Model("sub2"): b = pm.Normal("var") - assert {"sub1/var", "sub1/sub2/var"} == set(model.named_vars.keys()) + assert {"sub1::var", "sub1::sub2::var"} == set(model.named_vars.keys()) + + def test_nested_model_to_netcdf(self, tmp_path): + with pm.Model("scope") as model: + b = pm.Normal("var") + trace = pm.sample(100, tune=0) + az.to_netcdf(trace, tmp_path / "trace.nc") + trace1 = az.from_netcdf(tmp_path / "trace.nc") + assert "scope::var" in trace1.posterior + + def test_bad_name(self): + with pm.Model() as model: + with pytest.raises(KeyError): + b = pm.Normal("var::") + with pytest.raises(KeyError): + with pm.Model("scope::") as model: + b = pm.Normal("v") class TestObserved: diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index 6711dfd873..d9e729f85b 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -534,9 +534,9 @@ def test_named_model(self): s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data) trace = pm.sample_smc(draws=10, chains=2, return_inferencedata=False) - assert f"{name}/a" in trace.varnames - assert f"{name}/b" in trace.varnames - assert f"{name}/b_log__" in trace.varnames + assert f"{name}::a" in trace.varnames + assert f"{name}::b" in trace.varnames + assert f"{name}::b_log__" in trace.varnames class TestMHKernel(SeededTest):