Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ValuedVariable #78

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from aeppl.logprob import logprob # isort: split

from aeppl.joint_logprob import conditional_logprob, joint_logprob
from aeppl.joint_logprob import DensityNotFound, conditional_logprob, joint_logprob
from aeppl.printing import latex_pprint, pprint

# isort: off
Expand Down
55 changes: 55 additions & 0 deletions aeppl/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from functools import singledispatch
from typing import Callable, List

import aesara.tensor as at
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
from aesara.graph.utils import MetaType
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.type import TensorType


class MeasurableVariable(abc.ABC):
Expand Down Expand Up @@ -124,3 +127,55 @@ class MeasurableElemwise(Elemwise):


MeasurableVariable.register(MeasurableElemwise)


class ValuedVariable(Op):
r"""Represents the association of a measurable variable and its value.

A `ValuedVariable` node represents the pair :math:`(Y, y)`, where
:math:`Y` is a random variable and :math:`y \sim Y`.

Log-probability (densities) are functions over these pairs, which makes
these nodes in a graph an intermediate form that serves to construct a
log-probability from a model graph.

This intermediate form can be used as the target for rewrites that
otherwise wouldn't make sense to apply to--say--a random variable node
directly. An example is `BroadcastTo` lifting through `RandomVariable`\s.
"""

default_output = 0
view_map = {0: [0]}

def make_node(self, rv, value):

assert isinstance(rv.type, TensorType)
out_rv = rv.type()

vv = at.as_tensor_variable(value)
assert isinstance(vv.type, TensorType)

# TODO: We should probably check the `Type`s of `out_rv` and `vv`
if vv.type.dtype != rv.type.dtype:
raise TypeError(
f"Value type {vv.type} does not match random variable type {out_rv.type}"
)

return Apply(self, [rv, vv], [out_rv])

def perform(self, node, inputs, out):
out[0][0] = inputs[0]

def grad(self, inputs, outputs):
return [
grad_undefined(self, k, inp, "No gradient defined for `ValuedVariable`")
for k, inp in enumerate(inputs)
]

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


MeasurableVariable.register(ValuedVariable)

valued_variable = ValuedVariable()
14 changes: 7 additions & 7 deletions aeppl/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from aeppl.abstract import (
MeasurableElemwise,
MeasurableVariable,
ValuedVariable,
assign_custom_measurable_outputs,
)
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp
from aeppl.rewriting import measurable_ir_rewrites_db

if TYPE_CHECKING:
from aesara.graph.basic import Op, Variable
from aesara.graph.basic import Variable
from aesara.graph.op import Op


class MeasurableClip(MeasurableElemwise):
Expand All @@ -37,8 +39,7 @@ def find_measurable_clips(
) -> Optional[List["Variable"]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
if isinstance(node.op, MeasurableClip):
return None # pragma: no cover

clipped_var = node.outputs[0]
Expand All @@ -47,7 +48,7 @@ def find_measurable_clips(
if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
and not isinstance(base_var, ValuedVariable)
):
return None

Expand Down Expand Up @@ -190,8 +191,7 @@ def construct_measurable_rounding(
fgraph: FunctionGraph, node: Node, rounded_op: "Op"
) -> Optional[List["Variable"]]:

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
if isinstance(node.op, MeasurableRound):
return None # pragma: no cover

(rounded_var,) = node.outputs
Expand All @@ -200,7 +200,7 @@ def construct_measurable_rounding(
if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
and not isinstance(base_var, ValuedVariable)
# Rounding only makes sense for continuous variables
and base_var.dtype.startswith("float")
):
Expand Down
17 changes: 7 additions & 10 deletions aeppl/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor.extra_ops import CumOp

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.abstract import (
MeasurableVariable,
ValuedVariable,
assign_custom_measurable_outputs,
)
from aeppl.logprob import _logprob, logprob
from aeppl.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from aeppl.rewriting import measurable_ir_rewrites_db


class MeasurableCumsum(CumOp):
Expand Down Expand Up @@ -50,20 +54,13 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
if isinstance(node.op, MeasurableCumsum):
return None # pragma: no cover

rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)

if rv_map_feature is None:
return None # pragma: no cover

rv = node.outputs[0]

base_rv = node.inputs[0]
if not (
base_rv.owner
and isinstance(base_rv.owner.op, MeasurableVariable)
and base_rv not in rv_map_feature.rv_values
and not isinstance(base_rv, ValuedVariable)
):
return None # pragma: no cover

Expand Down
Loading