Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scope separator for netcdf #5663

Merged
merged 7 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def __init__(self, mean=0, sigma=1, name=''):

# 3) you can create variables with Var method
self.Var('v1', Normal.dist(mu=mean, sigma=sd))
# this will create variable named like '{prefix/}v1'
# this will create variable named like '{prefix::}v1'
# and assign attribute 'v1' to instance created
# variable can be accessed with self.v1 or self['v1']

Expand Down Expand Up @@ -516,7 +516,7 @@ def __init__(self, mean=0, sigma=1, name=''):
CustomModel(mean=1, name='first')
CustomModel(mean=2, name='second')

# variables inside both scopes will be named like `first/*`, `second/*`
# variables inside both scopes will be named like `first::*`, `second::*`

"""

Expand All @@ -538,14 +538,20 @@ def __new__(cls, *args, **kwargs):
instance._aesara_config = kwargs.get("aesara_config", {})
return instance

@staticmethod
def _validate_name(name):
if name.endswith(":"):
raise KeyError("name should not end with `:`")
return name

def __init__(
self,
name="",
coords=None,
check_bounds=True,
rng_seeder: Optional[Union[int, np.random.RandomState]] = None,
):
self.name = name
self.name = self._validate_name(name)
self.check_bounds = check_bounds

if rng_seeder is None:
Expand Down Expand Up @@ -1462,25 +1468,27 @@ def prefix(self) -> str:
if self.isroot or not self.parent.prefix:
name = self.name
else:
name = f"{self.parent.prefix}/{self.name}"
return name.strip("/")
name = f"{self.parent.prefix}::{self.name}"
return name

def name_for(self, name):
"""Checks if name has prefix and adds if needed"""
name = self._validate_name(name)
if self.prefix:
if not name.startswith(self.prefix):
return f"{self.prefix}/{name}"
return f"{self.prefix}::{name}"
else:
return name
else:
return name

def name_of(self, name):
"""Checks if name has prefix and deletes if needed"""
name = self._validate_name(name)
if not self.prefix or not name:
return name
elif name.startswith(self.prefix + "/"):
return name[len(self.prefix) + 1 :]
elif name.startswith(self.prefix + "::"):
return name[len(self.prefix) + 2 :]
else:
return name

Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ def test_data_naming():
with pm.Model("named_model") as model:
x = pm.ConstantData("x", [1.0, 2.0, 3.0])
y = pm.Normal("y")
assert y.name == "named_model/y"
assert x.name == "named_model/x"
assert y.name == "named_model::y"
assert x.name == "named_model::x"


def test_get_data():
Expand Down
49 changes: 33 additions & 16 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import aesara
import aesara.sparse as sparse
import aesara.tensor as at
import arviz as az
import cloudpickle
import numpy as np
import numpy.ma as ma
Expand Down Expand Up @@ -94,20 +95,20 @@ def test_context_passes_vars_to_parent_model(self):
usermodel2.register_rv(pm.Normal.dist(), "v3")
pm.Normal("v4")
# this variable is created in parent model too
assert "another/v2" in model.named_vars
assert "another/v3" in model.named_vars
assert "another/v3" in usermodel2.named_vars
assert "another/v4" in model.named_vars
assert "another/v4" in usermodel2.named_vars
assert "another::v2" in model.named_vars
assert "another::v3" in model.named_vars
assert "another::v3" in usermodel2.named_vars
assert "another::v4" in model.named_vars
assert "another::v4" in usermodel2.named_vars
assert hasattr(usermodel2, "v3")
assert hasattr(usermodel2, "v2")
assert hasattr(usermodel2, "v4")
# When you create a class based model you should follow some rules
with model:
m = NewModel("one_more")
assert m.d is model["one_more/d"]
assert m["d"] is model["one_more/d"]
assert m["one_more/d"] is model["one_more/d"]
assert m.d is model["one_more::d"]
assert m["d"] is model["one_more::d"]
assert m["one_more::d"] is model["one_more::d"]


class TestNested:
Expand All @@ -123,8 +124,8 @@ def test_nest_context_works(self):
def test_named_context(self):
with pm.Model() as m:
NewModel(name="new")
assert "new/v1" in m.named_vars
assert "new/v2" in m.named_vars
assert "new::v1" in m.named_vars
assert "new::v2" in m.named_vars

def test_docstring_example1(self):
usage1 = DocstringModel()
Expand All @@ -137,10 +138,10 @@ def test_docstring_example1(self):
def test_docstring_example2(self):
with pm.Model() as model:
DocstringModel(name="prefix")
assert "prefix/v1" in model.named_vars
assert "prefix/v2" in model.named_vars
assert "prefix/v3" in model.named_vars
assert "prefix/v3_sq" in model.named_vars
assert "prefix::v1" in model.named_vars
assert "prefix::v2" in model.named_vars
assert "prefix::v3" in model.named_vars
assert "prefix::v3_sq" in model.named_vars
assert len(model.potentials), 1

def test_duplicates_detection(self):
Expand All @@ -160,14 +161,30 @@ def test_nested_named_model_repeated(self):
b = pm.Normal("var")
with pm.Model("sub"):
b = pm.Normal("var")
assert {"sub/var", "sub/sub/var"} == set(model.named_vars.keys())
assert {"sub::var", "sub::sub::var"} == set(model.named_vars.keys())

def test_nested_named_model(self):
with pm.Model("sub1") as model:
b = pm.Normal("var")
with pm.Model("sub2"):
b = pm.Normal("var")
assert {"sub1/var", "sub1/sub2/var"} == set(model.named_vars.keys())
assert {"sub1::var", "sub1::sub2::var"} == set(model.named_vars.keys())

def test_nested_model_to_netcdf(self, tmp_path):
with pm.Model("scope") as model:
b = pm.Normal("var")
trace = pm.sample(100, tune=0)
az.to_netcdf(trace, tmp_path / "trace.nc")
trace1 = az.from_netcdf(tmp_path / "trace.nc")
assert "scope::var" in trace1.posterior

def test_bad_name(self):
with pm.Model() as model:
with pytest.raises(KeyError):
b = pm.Normal("var::")
with pytest.raises(KeyError):
with pm.Model("scope::") as model:
b = pm.Normal("v")


class TestObserved:
Expand Down
6 changes: 3 additions & 3 deletions pymc/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,9 @@ def test_named_model(self):
s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data)

trace = pm.sample_smc(draws=10, chains=2, return_inferencedata=False)
assert f"{name}/a" in trace.varnames
assert f"{name}/b" in trace.varnames
assert f"{name}/b_log__" in trace.varnames
assert f"{name}::a" in trace.varnames
assert f"{name}::b" in trace.varnames
assert f"{name}::b_log__" in trace.varnames


class TestMHKernel(SeededTest):
Expand Down