Skip to content

Commit

Permalink
Replace rvs_to_total_sizes mapping by RescaledRandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 14, 2023
1 parent 2eb911e commit 05e7547
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 200 deletions.
87 changes: 4 additions & 83 deletions pymc/logprob/joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from collections import deque
from typing import Dict, List, Optional, Sequence, Union

import numpy as np
import pytensor
import pytensor.tensor as pt

Expand All @@ -55,7 +54,6 @@
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
from pymc.logprob.utils import rvs_to_value_vars
from pymc.pytensorf import floatX


def logp(rv: TensorVariable, value) -> TensorVariable:
Expand Down Expand Up @@ -248,77 +246,6 @@ def factorized_joint_logprob(
return logprob_vars


TOTAL_SIZE = Union[int, Sequence[int], None]


def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable:
"""
Gets scaling constant for logp.
Parameters
----------
total_size: Optional[int|List[int]]
size of a fully observed data without minibatching,
`None` means data is fully observed
shape: shape
shape of an observed data
ndim: int
ndim hint
Returns
-------
scalar
"""
if total_size is None:
coef = 1.0
elif isinstance(total_size, int):
if ndim >= 1:
denom = shape[0]
else:
denom = 1
coef = floatX(total_size) / floatX(denom)
elif isinstance(total_size, (list, tuple)):
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
raise TypeError(
"Unrecognized `total_size` type, expected "
"int or list of ints, got %r" % total_size
)
if Ellipsis in total_size:
sep = total_size.index(Ellipsis)
begin = total_size[:sep]
end = total_size[sep + 1 :]
if Ellipsis in end:
raise ValueError(
"Double Ellipsis in `total_size` is restricted, got %r" % total_size
)
else:
begin = total_size
end = []
if (len(begin) + len(end)) > ndim:
raise ValueError(
"Length of `total_size` is too big, "
"number of scalings is bigger that ndim, got %r" % total_size
)
elif (len(begin) + len(end)) == 0:
coef = 1.0
if len(end) > 0:
shp_end = shape[-len(end) :]
else:
shp_end = np.asarray([])
shp_begin = shape[: len(begin)]
begin_coef = [
floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None
]
end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = pt.prod(coefs)
else:
raise TypeError(
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
)
return pt.as_tensor(coef, dtype=pytensor.config.floatX)


def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
# Raise if there are unexpected RandomVariables in the logp graph
# Only SimulatorRVs MinibatchIndexRVs are allowed
Expand Down Expand Up @@ -348,7 +275,6 @@ def joint_logp(
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
jacobian: bool = True,
rvs_to_total_sizes: Dict[TensorVariable, TOTAL_SIZE],
**kwargs,
) -> List[TensorVariable]:
"""Thin wrapper around pymc.logprob.factorized_joint_logprob, extended with Model
Expand All @@ -371,18 +297,13 @@ def joint_logp(
**kwargs,
)

# The function returns the logp for every single value term we provided to it. This
# includes the extra values we plugged in above, so we filter those we actually
# wanted in the same order they were given in.
# The function returns the logp for every single value term we provided to it.
# This includes the extra values we plugged in above, so we filter those we
# actually wanted in the same order they were given in.
logp_terms = {}
for rv in rvs:
value_var = rvs_to_values[rv]
logp_term = temp_logp_terms[value_var]
total_size = rvs_to_total_sizes.get(rv, None)
if total_size is not None:
scaling = _get_scaling(total_size, value_var.shape, value_var.ndim)
logp_term *= scaling
logp_terms[value_var] = logp_term
logp_terms[value_var] = temp_logp_terms[value_var]

_check_no_rvs(list(logp_terms.values()))
return list(logp_terms.values())
31 changes: 19 additions & 12 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ def __init__(
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
self.free_RVs = treelist(parent=self.parent.free_RVs)
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
Expand All @@ -577,7 +576,6 @@ def __init__(
self.values_to_rvs = treedict()
self.rvs_to_values = treedict()
self.rvs_to_transforms = treedict()
self.rvs_to_total_sizes = treedict()
self.rvs_to_initial_values = treedict()
self.free_RVs = treelist()
self.observed_RVs = treelist()
Expand Down Expand Up @@ -758,7 +756,6 @@ def logp(
rvs=rvs,
rvs_to_values=self.rvs_to_values,
rvs_to_transforms=self.rvs_to_transforms,
rvs_to_total_sizes=self.rvs_to_total_sizes,
jacobian=jacobian,
)
assert isinstance(rv_logps, list)
Expand Down Expand Up @@ -1310,8 +1307,6 @@ def register_rv(
name = self.name_for(name)
rv_var.name = name
_add_future_warning_tag(rv_var)
rv_var.tag.total_size = total_size
self.rvs_to_total_sizes[rv_var] = total_size

# Associate previously unknown dimension names with
# the length of the corresponding RV dimension.
Expand All @@ -1323,6 +1318,8 @@ def register_rv(
self.add_coord(dname, values=None, length=rv_var.shape[d])

if observed is None:
if total_size is not None:
raise ValueError("total_size can only be passed to observed RVs")
self.free_RVs.append(rv_var)
self.create_value_var(rv_var, transform)
self.add_named_variable(rv_var, dims)
Expand All @@ -1347,12 +1344,17 @@ def register_rv(

# `rv_var` is potentially changed by `make_obs_var`,
# for example into a new graph for imputation of missing data.
rv_var = self.make_obs_var(rv_var, observed, dims, transform)
rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)

return rv_var

def make_obs_var(
self, rv_var: TensorVariable, data: np.ndarray, dims, transform: Optional[Any]
self,
rv_var: TensorVariable,
data: np.ndarray,
dims,
transform: Union[Any, None],
total_size: Union[int, None],
) -> TensorVariable:
"""Create a `TensorVariable` for an observed random variable.
Expand Down Expand Up @@ -1388,18 +1390,16 @@ def make_obs_var(

mask = getattr(data, "mask", None)
if mask is not None:
if mask.all():
# If there are no observed values, this variable isn't really
# observed.
return rv_var

impute_message = (
f"Data in {rv_var} contains missing values and"
" will be automatically imputed from the"
" sampling distribution."
)
warnings.warn(impute_message, ImputationWarning)

if total_size is not None:
raise ValueError("total_size is not compatible with imputed variables")

if not isinstance(rv_var.owner.op, RandomVariable):
raise NotImplementedError(
"Automatic inputation is only supported for univariate RandomVariables."
Expand Down Expand Up @@ -1467,6 +1467,13 @@ def make_obs_var(
data = sparse.basic.as_sparse(data, name=name)
else:
data = at.as_tensor_variable(data, name=name)

if total_size:
from pymc.variational.rescaled_rv import create_rescaled_rv

rv_var = create_rescaled_rv(rv_var, total_size)
rv_var.name = name

rv_var.tag.observations = data
self.create_value_var(rv_var, transform=None, value_var=data)
self.add_named_variable(rv_var, dims)
Expand Down
4 changes: 0 additions & 4 deletions pymc/tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def check_transform_elementwise_logp(self, model):
(x,),
rvs_to_values={x: x_val_transf},
rvs_to_transforms={x: transform},
rvs_to_total_sizes={},
jacobian=False,
)[0]
.sum()
Expand All @@ -323,7 +322,6 @@ def check_transform_elementwise_logp(self, model):
(x,),
rvs_to_values={x: x_val_untransf},
rvs_to_transforms={},
rvs_to_total_sizes={},
)[0]
.sum()
.eval({x_val_untransf: test_array_untransf})
Expand Down Expand Up @@ -362,7 +360,6 @@ def check_vectortransform_elementwise_logp(self, model):
(x,),
rvs_to_values={x: x_val_transf},
rvs_to_transforms={x: transform},
rvs_to_total_sizes={},
jacobian=False,
)[0]
.sum()
Expand All @@ -373,7 +370,6 @@ def check_vectortransform_elementwise_logp(self, model):
(x,),
rvs_to_values={x: x_val_untransf},
rvs_to_transforms={},
rvs_to_total_sizes={},
)[0]
.sum()
.eval({x_val_untransf: test_array_untransf})
Expand Down
1 change: 0 additions & 1 deletion pymc/tests/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,6 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
(model["x"],),
rvs_to_values={model["x"]: at.constant(moment)},
rvs_to_transforms={},
rvs_to_total_sizes={},
)[0]
.sum()
.eval()
Expand Down
54 changes: 1 addition & 53 deletions pymc/tests/logprob/test_joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@
import pymc as pm

from pymc.logprob.abstract import logprob
from pymc.logprob.joint_logprob import (
_get_scaling,
factorized_joint_logprob,
joint_logp,
)
from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logp
from pymc.logprob.utils import rvs_to_value_vars, walk_model
from pymc.tests.helpers import assert_no_rvs
from pymc.tests.logprob.utils import joint_logprob
Expand Down Expand Up @@ -281,52 +277,6 @@ def test_multiple_rvs_to_same_value_raises():
joint_logprob({x_rv1: x, x_rv2: x})


def test_get_scaling():
assert _get_scaling(None, (2, 3), 2).eval() == 1
# ndim >=1 & ndim<1
assert _get_scaling(45, (2, 3), 1).eval() == 22.5
assert _get_scaling(45, (2, 3), 0).eval() == 45

# list or tuple tests
# total_size contains other than Ellipsis, None and Int
with pytest.raises(TypeError, match="Unrecognized `total_size` type"):
_get_scaling([2, 4, 5, 9, 11.5], (2, 3), 2)
# check with Ellipsis
with pytest.raises(ValueError, match="Double Ellipsis in `total_size` is restricted"):
_get_scaling([1, 2, 5, Ellipsis, Ellipsis], (2, 3), 2)
with pytest.raises(
ValueError,
match="Length of `total_size` is too big, number of scalings is bigger that ndim",
):
_get_scaling([1, 2, 5, Ellipsis], (2, 3), 2)

assert _get_scaling([Ellipsis], (2, 3), 2).eval() == 1

assert _get_scaling([4, 5, 9, Ellipsis, 32, 12], (2, 3, 2), 5).eval() == 960
assert _get_scaling([4, 5, 9, Ellipsis], (2, 3, 2), 5).eval() == 15
# total_size with no Ellipsis (end = [ ])
with pytest.raises(
ValueError,
match="Length of `total_size` is too big, number of scalings is bigger that ndim",
):
_get_scaling([1, 2, 5], (2, 3), 2)

assert _get_scaling([], (2, 3), 2).eval() == 1
assert _get_scaling((), (2, 3), 2).eval() == 1
# total_size invalid type
with pytest.raises(
TypeError,
match="Unrecognized `total_size` type, expected int or list of ints, got {1, 2, 5}",
):
_get_scaling({1, 2, 5}, (2, 3), 2)

# test with rvar from model graph
with pm.Model() as m2:
rv_var = pm.Uniform("a", 0.0, 1.0)
total_size = []
assert _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim).eval() == 1.0


def test_joint_logp_basic():
"""Make sure we can compute a log-likelihood for a hierarchical model with transforms."""

Expand All @@ -348,7 +298,6 @@ def test_joint_logp_basic():
(b,),
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
rvs_to_total_sizes={},
)

# There shouldn't be any `RandomVariable`s in the resulting graph
Expand Down Expand Up @@ -394,7 +343,6 @@ def test_joint_logp_incsubtensor(indices, size):
(a_idx,),
rvs_to_values={a_idx: a_value_var},
rvs_to_transforms={},
rvs_to_total_sizes={},
)

logp_vals = a_idx_logp[0].eval({a_value_var: a_val})
Expand Down
3 changes: 0 additions & 3 deletions pymc/tests/logprob/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def custom_logp(value, x):
[x],
rvs_to_values={x: value},
rvs_to_transforms={},
rvs_to_total_sizes={},
)

with pm.Model():
Expand All @@ -248,7 +247,6 @@ def custom_logp(value, x):
[y],
rvs_to_values={y: y.type()},
rvs_to_transforms={},
rvs_to_total_sizes={},
)

# The above warning should go away with ignore_logprob.
Expand All @@ -261,5 +259,4 @@ def custom_logp(value, x):
[y],
rvs_to_values={y: y.type()},
rvs_to_transforms={},
rvs_to_total_sizes={},
)
2 changes: 0 additions & 2 deletions pymc/tests/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def logp(value, x):
[y],
rvs_to_values={y: y.type()},
rvs_to_transforms={},
rvs_to_total_sizes={},
)

# The above warning should go away with ignore_logprob.
Expand All @@ -259,5 +258,4 @@ def logp(value, x):
[y],
rvs_to_values={y: y.type()},
rvs_to_transforms={},
rvs_to_total_sizes={},
)
Loading

0 comments on commit 05e7547

Please sign in to comment.