Skip to content

Support automatic imputation for multivariate and symbolic distributions #6797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,11 @@ def Data(
# `convert_observed_data` takes care of parameter `value` and
# transforms it to something digestible for PyTensor.
arr = convert_observed_data(value)
if isinstance(arr, np.ma.MaskedArray):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have failed in the call to as_tensor_variable or shared anyway

raise NotImplementedError(
"Masked arrays or arrays with `nan` entries are not supported. "
"Pass them directly to `observed` if you want to trigger auto-imputation"
)

if mutable is None:
warnings.warn(
Expand Down
146 changes: 145 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@

from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import node_rewriter
from pytensor.graph import FunctionGraph, node_rewriter
from pytensor.graph.basic import Node, Variable
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import MetaType
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.var import TensorVariable
from typing_extensions import TypeAlias
Expand All @@ -49,6 +50,7 @@
)
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import BlockModelAccess
from pymc.printing import str_for_dist
Expand Down Expand Up @@ -1148,3 +1150,145 @@ def logcdf(value, c):
-np.inf,
0,
)


class PartialObservedRV(SymbolicRandomVariable):
"""RandomVariable with partially observed subspace, as indicated by a boolean mask.

See `create_partial_observed_rv` for more details.
"""


def create_partial_observed_rv(
rv: TensorVariable,
mask: Union[np.ndarray, TensorVariable],
) -> Tuple[
Tuple[TensorVariable, TensorVariable], Tuple[TensorVariable, TensorVariable], TensorVariable
]:
"""Separate observed and unobserved components of a RandomVariable.

This function may return two independent RandomVariables or, if not possible,
two variables from a common `PartialObservedRV` node

Parameters
----------
rv : TensorVariable
mask : tensor_like
Constant or variable boolean mask. True entries correspond to components of the variable that are not observed.

Returns
-------
observed_rv and mask : Tuple of TensorVariable
The observed component of the RV and respective indexing mask
unobserved_rv and mask : Tuple of TensorVariable
The unobserved component of the RV and respective indexing mask
joined_rv : TensorVariable
The symbolic join of the observed and unobserved components.
"""
if not mask.dtype == "bool":
raise ValueError(
f"mask must be an array or tensor of boolean dtype, got dtype: {mask.dtype}"
)

if mask.ndim > rv.ndim:
raise ValueError(f"mask can't have more dims than rv, got ndim: {mask.ndim}")

antimask = ~mask

can_rewrite = False
# Only pure RVs can be rewritten
if isinstance(rv.owner.op, RandomVariable):
ndim_supp = rv.owner.op.ndim_supp

# All univariate RVs can be rewritten
if ndim_supp == 0:
can_rewrite = True

# Multivariate RVs can be rewritten if masking does not split within support dimensions
else:
batch_dims = rv.type.ndim - ndim_supp
constant_mask = getattr(as_tensor_variable(mask), "data", None)

# Indexing does not overlap with core dimensions
if mask.ndim <= batch_dims:
can_rewrite = True

# Try to handle special case where mask is constant across support dimensions,
# TODO: This could be done by the rewrite itself
elif constant_mask is not None:
# We check if a constant_mask that only keeps the first entry of each support dim
# is equivalent to the original one after re-expanding.
trimmed_mask = constant_mask[(...,) + (0,) * ndim_supp]
expanded_mask = np.broadcast_to(
np.expand_dims(trimmed_mask, axis=tuple(range(-ndim_supp, 0))),
shape=constant_mask.shape,
)
if np.array_equal(constant_mask, expanded_mask):
mask = trimmed_mask
antimask = ~trimmed_mask
can_rewrite = True

if can_rewrite:
# Rewrite doesn't work with boolean masks. Should be fixed after https://github.com/pymc-devs/pytensor/pull/329
mask, antimask = mask.nonzero(), antimask.nonzero()

masked_rv = rv[mask]
fgraph = FunctionGraph(outputs=[masked_rv], clone=False)
[unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)

antimasked_rv = rv[antimask]
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False)
[observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)

# Make a clone of the observedRV, with a distinct rng so that observed and
# unobserved are never treated as equivalent (and mergeable) nodes by pytensor.
_, size, _, *inps = observed_rv.owner.inputs
observed_rv = observed_rv.owner.op(*inps, size=size)

# For all other cases use the more general PartialObservedRV
else:
# The symbolic graph simply splits the observed and unobserved components,
# so they can be given separate values.
dist_, mask_ = rv.type(), as_tensor_variable(mask).type()
observed_rv_, unobserved_rv_ = dist_[~mask_], dist_[mask_]

observed_rv, unobserved_rv = PartialObservedRV(
inputs=[dist_, mask_],
outputs=[observed_rv_, unobserved_rv_],
ndim_supp=rv.owner.op.ndim_supp,
)(rv, mask)

joined_rv = pt.empty(rv.shape, dtype=rv.type.dtype)
joined_rv = pt.set_subtensor(joined_rv[mask], unobserved_rv)
joined_rv = pt.set_subtensor(joined_rv[antimask], observed_rv)

return (observed_rv, antimask), (unobserved_rv, mask), joined_rv


@_logprob.register(PartialObservedRV)
def partial_observed_rv_logprob(op, values, dist, mask, **kwargs):
# For the logp, simply join the values
[obs_value, unobs_value] = values
antimask = ~mask
joined_value = pt.empty_like(dist)
joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
joined_logp = logp(dist, joined_value)

# If we have a univariate RV we can split apart the logp terms
if op.ndim_supp == 0:
return joined_logp[antimask], joined_logp[mask]
# Otherwise, we can't (always/ easily) split apart logp terms.
# We return the full logp for the observed value, and a 0-nd array for the unobserved value
else:
return joined_logp.ravel(), pt.zeros((0,), dtype=joined_logp.type.dtype)


@_moment.register(PartialObservedRV)
def partial_observed_rv_moment(op, partial_obs_rv, rv, mask):
# Unobserved output
if partial_obs_rv.owner.outputs.index(partial_obs_rv) == 1:
return moment(rv)[mask]
# Observed output
else:
return moment(rv)[~mask]
74 changes: 16 additions & 58 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@
from pytensor.compile import DeepCopyOp, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.sharedvar import ScalarSharedVariable
from pytensor.tensor.var import TensorConstant, TensorVariable
Expand Down Expand Up @@ -1409,67 +1407,27 @@ def make_obs_var(
if total_size is not None:
raise ValueError("total_size is not compatible with imputed variables")

if not isinstance(rv_var.owner.op, RandomVariable):
raise NotImplementedError(
"Automatic inputation is only supported for univariate RandomVariables."
f" {rv_var} of type {type(rv_var.owner.op)} is not supported."
)

if rv_var.owner.op.ndim_supp > 0:
raise NotImplementedError(
f"Automatic inputation is only supported for univariate "
f"RandomVariables, but {rv_var} is multivariate"
)
from pymc.distributions.distribution import create_partial_observed_rv

# We can get a random variable comprised of only the unobserved
# entries by lifting the indices through the `RandomVariable` `Op`.
(
(observed_rv, observed_mask),
(unobserved_rv, _),
joined_rv,
) = create_partial_observed_rv(rv_var, mask)
observed_data = pt.as_tensor(data.data[observed_mask])

masked_rv_var = rv_var[mask.nonzero()]

fgraph = FunctionGraph(
[i for i in graph_inputs((masked_rv_var,)) if not isinstance(i, Constant)],
[masked_rv_var],
clone=False,
)
# Register ObservedRV corresponding to observed component
observed_rv.name = f"{name}_observed"
self.create_value_var(observed_rv, transform=None, value_var=observed_data)
self.add_named_variable(observed_rv)
self.observed_RVs.append(observed_rv)

(missing_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
# Register FreeRV corresponding to unobserved components
self.register_rv(unobserved_rv, f"{name}_unobserved", transform=transform)

self.register_rv(missing_rv_var, f"{name}_missing", transform=transform)

# Now, we lift the non-missing observed values and produce a new
# `rv_var` that contains only those.
#
# The end result is two disjoint distributions: one for the missing
# values, and another for the non-missing values.

antimask_idx = (~mask).nonzero()
nonmissing_data = pt.as_tensor_variable(data[antimask_idx].data)
unmasked_rv_var = rv_var[antimask_idx]
unmasked_rv_var = unmasked_rv_var.owner.clone().default_output()

fgraph = FunctionGraph(
[i for i in graph_inputs((unmasked_rv_var,)) if not isinstance(i, Constant)],
[unmasked_rv_var],
clone=False,
)
(observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
# Make a clone of the RV, but let it create a new rng so that observed and
# missing are not treated as equivalent nodes by pytensor. This would happen
# if the size of the masked and unmasked array happened to coincide
_, size, _, *inps = observed_rv_var.owner.inputs
observed_rv_var = observed_rv_var.owner.op(*inps, size=size, name=f"{name}_observed")
observed_rv_var.tag.observations = nonmissing_data

self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
self.add_named_variable(observed_rv_var)
self.observed_RVs.append(observed_rv_var)

# Create deterministic that combines observed and missing
# Register Deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
rv_var = pt.empty(data.shape, dtype=observed_rv_var.type.dtype)
rv_var = pt.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
rv_var = pt.set_subtensor(rv_var[antimask_idx], observed_rv_var)
rv_var = Deterministic(name, rv_var, self, dims)
rv_var = Deterministic(name, joined_rv, self, dims)

else:
if sps.issparse(data):
Expand Down
21 changes: 13 additions & 8 deletions tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,10 @@ def test_missing_data_model(self):
)

# make sure that data is really missing
assert "y_missing" in model.named_vars
assert "y_unobserved" in model.named_vars

test_dict = {
"posterior": ["x", "y_missing"],
"posterior": ["x", "y_unobserved"],
"observed_data": ["y_observed"],
"log_likelihood": ["y_observed"],
}
Expand All @@ -352,7 +352,6 @@ def test_missing_data_model(self):
# See https://github.com/pymc-devs/pymc/issues/5255
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)

@pytest.mark.xfail(reason="Multivariate partial observed RVs not implemented for V4")
def test_mv_missing_data_model(self):
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)

Expand All @@ -361,19 +360,25 @@ def test_mv_missing_data_model(self):
mu = pm.Normal("mu", 0, 1, size=2)
sd_dist = pm.HalfNormal.dist(1.0, size=2)
# pylint: disable=unpacking-non-sequence
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist)
# pylint: enable=unpacking-non-sequence
with pytest.warns(ImputationWarning):
y = pm.MvNormal("y", mu=mu, chol=chol, observed=data)
inference_data = pm.sample(100, chains=2, return_inferencedata=True)
inference_data = pm.sample(
tune=10,
draws=10,
chains=2,
step=pm.Metropolis(),
idata_kwargs=dict(log_likelihood=True),
)

# make sure that data is really missing
assert isinstance(y.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1))
assert isinstance(y.owner.inputs[0].owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1))

test_dict = {
"posterior": ["mu", "chol_cov"],
"observed_data": ["y"],
"log_likelihood": ["y"],
"observed_data": ["y_observed"],
"log_likelihood": ["y_observed"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
Expand Down
Loading