Skip to content

Commit

Permalink
Default moment for CustomDist provided with a dist function (#6873
Browse files Browse the repository at this point in the history
)

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
aerubanov and ricardoV94 authored Nov 13, 2023
1 parent d7415de commit ad450a6
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 15 deletions.
101 changes: 86 additions & 15 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@

from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import FunctionGraph, node_rewriter
from pytensor.graph.basic import Node, Variable
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
from pytensor.graph.basic import Node, Variable, io_toposort
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.rewriting.basic import GraphRewriter, in2out
from pytensor.graph.utils import MetaType
from pytensor.scan.op import Scan
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -83,6 +85,59 @@
PLATFORM = sys.platform


class MomentRewrite(GraphRewriter):
def rewrite_moment_scan_node(self, node):
if not isinstance(node.op, Scan):
return

node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
op = node.op

local_fgraph_topo = io_toposort(node_inputs, node_outputs)

replace_with_moment = []
to_replace_set = set()

for nd in local_fgraph_topo:
if nd not in to_replace_set and isinstance(
nd.op, (RandomVariable, SymbolicRandomVariable)
):
replace_with_moment.append(nd.out)
to_replace_set.add(nd)
givens = {}
if len(replace_with_moment) > 0:
for item in replace_with_moment:
givens[item] = moment(item)
else:
return
op_outs = clone_replace(node_outputs, replace=givens)

nwScan = Scan(
node_inputs,
op_outs,
op.info,
mode=op.mode,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
name=op.name,
allow_gc=op.allow_gc,
)
nw_node = nwScan(*(node.inputs), return_list=True)[0].owner
return nw_node

def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())

def apply(self, fgraph):
for node in fgraph.toposort():
if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)):
fgraph.replace(node.out, moment(node.out))
elif isinstance(node.op, Scan):
new_node = self.rewrite_moment_scan_node(node)
if new_node is not None:
fgraph.replace_all(tuple(zip(node.outputs, new_node.outputs)))


class _Unpickling:
pass

Expand Down Expand Up @@ -601,6 +656,20 @@ def update(self, node: Node):
return updates


@_moment.register(CustomSymbolicDistRV)
def dist_moment(op, rv, *args):
node = rv.owner
rv_out_idx = node.outputs.index(rv)

fgraph = op.fgraph.clone()
replace_moments = MomentRewrite()
replace_moments.rewrite(fgraph)
# Replace dummy inner inputs by outer inputs
fgraph.replace_all(tuple(zip(op.inner_inputs, args)), import_missing=True)
moment = fgraph.outputs[rv_out_idx]
return moment


class _CustomSymbolicDist(Distribution):
rv_type = CustomSymbolicDistRV

Expand All @@ -622,14 +691,6 @@ def dist(
if logcdf is None:
logcdf = default_not_implemented(class_name, "logcdf")

if moment is None:
moment = functools.partial(
default_moment,
rv_name=class_name,
has_fallback=True,
ndim_supp=ndim_supp,
)

return super().dist(
dist_params,
class_name=class_name,
Expand Down Expand Up @@ -685,9 +746,19 @@ def custom_dist_logp(op, values, size, *params, **kwargs):
def custom_dist_logcdf(op, value, size, *params, **kwargs):
return logcdf(value, *params[: len(dist_params)])

@_moment.register(rv_type)
def custom_dist_get_moment(op, rv, size, *params):
return moment(rv, size, *params[: len(params)])
if moment is not None:

@_moment.register(rv_type)
def custom_dist_get_moment(op, rv, size, *params):
return moment(
rv,
size,
*[
p
for p in params
if not isinstance(p.type, (RandomType, RandomGeneratorType))
],
)

@_change_dist_size.register(rv_type)
def change_custom_symbolic_dist_size(op, rv, new_size, expand):
Expand Down
98 changes: 98 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,104 @@ def custom_dist(mu, sigma, size):
ip = m.initial_point()
np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip))

@pytest.mark.parametrize(
"dist_params, size, expected, dist_fn",
[
(
(5, 1),
None,
np.exp(5),
lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size)),
),
(
(2, np.ones(5)),
None,
np.exp([2, 2, 2, 2, 2] + np.ones(5)),
lambda mu, sigma, size: pt.exp(
pm.Normal.dist(mu, sigma, size=size) + pt.ones(size)
),
),
(
(1, 2),
None,
np.sqrt(np.exp(1 + 0.5 * 2**2)),
lambda mu, sigma, size: pt.sqrt(pm.LogNormal.dist(mu, sigma, size=size)),
),
(
(4,),
(3,),
np.log([4, 4, 4]),
lambda nu, size: pt.log(pm.ChiSquared.dist(nu, size=size)),
),
(
(12, 1),
None,
12,
lambda mu1, sigma, size: pm.Normal.dist(mu1, sigma, size=size),
),
],
)
def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn):
with Model() as model:
CustomDist("x", *dist_params, dist=dist_fn, size=size)
assert_moment_is_expected(model, expected)

def test_custom_dist_default_moment_scan(self):
def scan_step(left, right):
x = pm.Uniform.dist(left, right)
x_update = collect_default_updates([x])
return x, x_update

def dist(size):
xs, updates = scan(
fn=scan_step,
sequences=[
pt.as_tensor_variable(np.array([-4, -3])),
pt.as_tensor_variable(np.array([-2, -1])),
],
name="xs",
)
return xs

with Model() as model:
CustomDist("x", dist=dist)
assert_moment_is_expected(model, np.array([-3, -2]))

def test_custom_dist_default_moment_scan_recurring(self):
def scan_step(xtm1):
x = pm.Normal.dist(xtm1 + 1)
x_update = collect_default_updates([x])
return x, x_update

def dist(size):
xs, _ = scan(
fn=scan_step,
outputs_info=pt.as_tensor_variable(np.array([0])).astype(float),
n_steps=3,
name="xs",
)
return xs

with Model() as model:
CustomDist("x", dist=dist)
assert_moment_is_expected(model, np.array([[1], [2], [3]]))

@pytest.mark.parametrize(
"left, right, size, expected",
[
(-1, 1, None, 0 + 5),
(-3, -1, None, -2 + 5),
(-3, 1, (3,), np.array([-1 + 5, -1 + 5, -1 + 5])),
],
)
def test_custom_dist_default_moment_nested(self, left, right, size, expected):
def dist_fn(left, right, size):
return pm.Truncated.dist(pm.Normal.dist(0, 1), left, right, size=size) + 5

with Model() as model:
CustomDist("x", left, right, size=size, dist=dist_fn)
assert_moment_is_expected(model, expected)

def test_logcdf_inference(self):
def custom_dist(mu, sigma, size):
return pt.exp(pm.Normal.dist(mu, sigma, size=size))
Expand Down

0 comments on commit ad450a6

Please sign in to comment.