Skip to content

Commit

Permalink
Assert ndim and number of dims match
Browse files Browse the repository at this point in the history
Fixes bug where invalid number of dims was given to Deterministic introduced by OrderedLogistic/Probit
  • Loading branch information
ricardoV94 committed Jun 25, 2024
1 parent 29eef08 commit fde8233
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 19 deletions.
11 changes: 2 additions & 9 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,13 +1185,6 @@ def logp(value, p):
)


class _OrderedLogistic(Categorical):
r"""
Underlying class for ordered logistic distributions.
See docs for the OrderedLogistic wrapper class for more details on how to use it in models.
"""


class OrderedLogistic:
R"""Ordered Logistic distribution.
Expand Down Expand Up @@ -1263,7 +1256,7 @@ class OrderedLogistic:
def __new__(cls, name, eta, cutpoints, compute_p=True, **kwargs):
p = cls.compute_p(eta, cutpoints)
if compute_p:
p = pm.Deterministic(f"{name}_probs", p, dims=kwargs.get("dims"))
p = pm.Deterministic(f"{name}_probs", p)
out_rv = Categorical(name, p=p, **kwargs)
return out_rv

Expand Down Expand Up @@ -1367,7 +1360,7 @@ class OrderedProbit:
def __new__(cls, name, eta, cutpoints, sigma=1, compute_p=True, **kwargs):
p = cls.compute_p(eta, cutpoints, sigma)
if compute_p:
p = pm.Deterministic(f"{name}_probs", p, dims=kwargs.get("dims"))
p = pm.Deterministic(f"{name}_probs", p)
out_rv = Categorical(name, p=p, **kwargs)
return out_rv

Expand Down
5 changes: 5 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,11 @@ def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None):
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
if any(var.name == dim for dim in dims if dim is not None):
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
if (var_ndim := getattr(var, "ndim", None)) is not None:
if var_ndim != len(dims):
raise ValueError(
f"{var} has {var_ndim} dims but {len(dims)} dim labels were provided."
)
self.named_vars_to_dims[var.name] = dims

self.named_vars[var.name] = var
Expand Down
16 changes: 11 additions & 5 deletions tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,19 +897,25 @@ def test_shape_inputs(self, eta, cutpoints, expected):
assert p_shape == expected

def test_compute_p(self):
with pm.Model() as m:
pm.OrderedLogistic("ol_p", cutpoints=np.array([-2, 0, 2]), eta=0)
pm.OrderedLogistic("ol_no_p", cutpoints=np.array([-2, 0, 2]), eta=0, compute_p=False)
with pm.Model(coords={"test_dim": [0]}) as m:
pm.OrderedLogistic("ol_p", cutpoints=np.array([-2, 0, 2]), eta=0, dims="test_dim")
pm.OrderedLogistic(
"ol_no_p", cutpoints=np.array([-2, 0, 2]), eta=0, compute_p=False, dims="test_dim"
)
assert len(m.deterministics) == 1

x = pm.OrderedLogistic.dist(cutpoints=np.array([-2, 0, 2]), eta=0)
assert isinstance(x, TensorVariable)

# Test it works with auto-imputation
with pm.Model() as m:
with pm.Model(coords={"test_dim": [0, 1, 2]}) as m:
with pytest.warns(ImputationWarning):
pm.OrderedLogistic(
"ol", cutpoints=np.array([-2, 0, 2]), eta=0, observed=[0, np.nan, 1]
"ol",
cutpoints=np.array([[-2, 0, 2]]),
eta=0,
observed=[0, np.nan, 1],
dims=["test_dim"],
)
assert len(m.deterministics) == 2 # One from the auto-imputation, the other from compute_p

Expand Down
25 changes: 20 additions & 5 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,20 @@ def test_add_named_variable_checks_dim_name(self):
rv2.name = "yumyum"
pmodel.add_named_variable(rv2, dims=("nomnom", None))

def test_dims_type_check(self):
def test_add_named_variable_checks_number_of_dims(self):
match = "dim labels were provided"
with pm.Model(coords={"bad": range(6)}) as m:
with pytest.raises(ValueError, match=match):
m.add_named_variable(pt.random.normal(size=(6, 6, 6), name="a"), dims=("bad",))

with pytest.raises(ValueError, match=match):
m.add_named_variable(pt.random.normal(size=(6, 6, 6), name="b"), dims="bad")

# For variables without ndim we can't check
m.add_named_variable(pytensor.as_symbolic(None, name="c"), dims=("bad",))
assert m.named_vars_to_dims == {"c": ("bad",)}

def test_rv_dims_type_check(self):
with pm.Model(coords={"a": range(5)}) as m:
with pytest.raises(TypeError, match="Dims must be string"):
x = pm.Normal("x", shape=(10, 5), dims=(None, "a"))
Expand All @@ -899,12 +912,14 @@ def test_none_coords_autonumbering(self):
# TODO: Either allow dims without coords everywhere or nowhere
with pm.Model() as m:
m.add_coord(name="a", values=None, length=3)
m.add_coord(name="b", values=range(5))
x = pm.Normal("x", dims=("a", "b"))
m.add_coord(name="b", values=range(-5, 0))
m.add_coord(name="c", values=None, length=7)
x = pm.Normal("x", dims=("a", "b", "c"))
prior = pm.sample_prior_predictive(draws=2).prior
assert prior["x"].shape == (1, 2, 3, 5)
assert prior["x"].shape == (1, 2, 3, 5, 7)
assert list(prior.coords["a"].values) == list(range(3))
assert list(prior.coords["b"].values) == list(range(5))
assert list(prior.coords["b"].values) == list(range(-5, 0))
assert list(prior.coords["c"].values) == list(range(7))

def test_set_data_indirect_resize_without_coords(self):
with pm.Model() as pmodel:
Expand Down

0 comments on commit fde8233

Please sign in to comment.