Skip to content

Commit

Permalink
Implement default transform for Mixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 28, 2022
1 parent 0b9f9cb commit fa015e3
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 0 deletions.
81 changes: 81 additions & 0 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

from aeppl.abstract import MeasurableVariable, _get_measurable_outputs
from aeppl.logprob import _logcdf, _logprob
from aeppl.transforms import IntervalTransform
from aesara.compile.builders import OpFromGraph
from aesara.graph.basic import equal_computations
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

from pymc.aesaraf import change_rv_size
from pymc.distributions import transforms
from pymc.distributions.continuous import Normal, get_tau_sigma
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
Expand All @@ -35,6 +38,7 @@
)
from pymc.distributions.logprob import logcdf, logp
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.util import check_dist_not_registered
from pymc.vartypes import discrete_types

Expand Down Expand Up @@ -461,6 +465,83 @@ def marginal_mixture_moment(op, rv, rng, weights, *components):
return mix_moment


# List of transforms that can be used by Mixture, either because they do not require
# special handling or because we have custom logic to enable them. If new default
# transforms are implemented, this list and function should be updated
allowed_default_mixture_transforms = (
transforms.CholeskyCovPacked,
transforms.CircularTransform,
transforms.IntervalTransform,
transforms.LogTransform,
transforms.LogExpM1,
transforms.LogOddsTransform,
transforms.Ordered,
transforms.Simplex,
transforms.SumTo1,
)


class MixtureTransformWarning(UserWarning):
pass


@_default_transform.register(MarginalMixtureRV)
def marginal_mixture_default_transform(op, rv):
def transform_warning():
warnings.warn(
f"No safe default transform found for Mixture distribution {rv}. This can "
"happen when compoments have different supports or default transforms.\n"
"If appropriate, you can specify a custom transform for more efficient sampling.",
MixtureTransformWarning,
stacklevel=2,
)

rng, weights, *components = rv.owner.inputs

default_transforms = [
_default_transform(component.owner.op, component) for component in components
]

# If there are more than one type of default transforms, we do not apply any
if len({type(transform) for transform in default_transforms}) != 1:
transform_warning()
return None

default_transform = default_transforms[0]

if not isinstance(default_transform, allowed_default_mixture_transforms):
transform_warning()
return None

if isinstance(default_transform, IntervalTransform):
# If there are more than one component, we need to check the IntervalTransform
# of the components are actually equivalent (e.g., we don't have an
# Interval(0, 1), and an Interval(0, 2)).
if len(default_transforms) > 1:
value = rv.type()
backward_expressions = [
transform.backward(value, *component.owner.inputs)
for transform, component in zip(default_transforms, components)
]
for expr1, expr2 in zip(backward_expressions[:-1], backward_expressions[1:]):
if not equal_computations([expr1], [expr2]):
transform_warning()
return None

# We need to create a new IntervalTransform that expects the Mixture inputs
args_fn = default_transform.args_fn

def mixture_args_fn(rng, weights, *components):
# We checked that the interval transforms of each component are equivalent,
# so we can just pass the inputs of the first component
return args_fn(*components[0].owner.inputs)

return IntervalTransform(args_fn=mixture_args_fn)

else:
return default_transform


class NormalMixture:
R"""
Normal mixture log-likelihood
Expand Down
96 changes: 96 additions & 0 deletions pymc/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import pytest
import scipy.stats as st

from aeppl.transforms import IntervalTransform, LogTransform
from aeppl.transforms import Simplex as SimplexTransform
from aesara import tensor as at
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable
from numpy.testing import assert_allclose
from scipy.special import logsumexp
Expand All @@ -32,6 +35,7 @@
Exponential,
Gamma,
HalfNormal,
HalfStudentT,
LKJCholeskyCov,
LogNormal,
Mixture,
Expand All @@ -40,9 +44,14 @@
Normal,
NormalMixture,
Poisson,
StickBreakingWeights,
Triangular,
Uniform,
)
from pymc.distributions.logprob import logp
from pymc.distributions.mixture import MixtureTransformWarning
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.math import expand_packed_triangular
from pymc.model import Model
from pymc.sampling import (
Expand Down Expand Up @@ -1216,3 +1225,90 @@ def test_list_multivariate_components(self, weights, comp_dists, size, expected)
with Model() as model:
Mixture("x", weights, comp_dists, size=size)
assert_moment_is_expected(model, expected, check_finite_logp=False)


class TestMixtureDefaultTransforms:
@pytest.mark.parametrize(
"comp_dists, expected_transform_type",
[
(Poisson.dist(1, size=2), type(None)),
(Normal.dist(size=2), type(None)),
(Uniform.dist(size=2), IntervalTransform),
(HalfNormal.dist(size=2), LogTransform),
([HalfNormal.dist(), Normal.dist()], type(None)),
([HalfNormal.dist(1), Exponential.dist(1), HalfStudentT.dist(4, 1)], LogTransform),
([Dirichlet.dist([1, 2, 3, 4]), StickBreakingWeights.dist(1, K=3)], SimplexTransform),
([Uniform.dist(0, 1), Uniform.dist(0, 1), Triangular.dist(0, 1)], IntervalTransform),
([Uniform.dist(0, 1), Uniform.dist(0, 2)], type(None)),
],
)
def test_expected(self, comp_dists, expected_transform_type):
if isinstance(comp_dists, TensorVariable):
weights = np.ones(2) / 2
else:
weights = np.ones(len(comp_dists)) / len(comp_dists)
mix = Mixture.dist(weights, comp_dists)
assert isinstance(_default_transform(mix.owner.op, mix), expected_transform_type)

def test_hierarchical_interval_transform(self):
with Model() as model:
lower = Normal("lower", 0.5)
upper = Uniform("upper", 0, 1)
uniform = Uniform("uniform", -at.abs(lower), at.abs(upper), transform=None)
triangular = Triangular(
"triangular", -at.abs(lower), at.abs(upper), c=0.25, transform=None
)
comp_dists = [
Uniform.dist(-at.abs(lower), at.abs(upper)),
Triangular.dist(-at.abs(lower), at.abs(upper), c=0.25),
]
mix1 = Mixture("mix1", [0.3, 0.7], comp_dists)
mix2 = Mixture("mix2", [0.3, 0.7][::-1], comp_dists[::-1])

ip = model.compute_initial_point()
# We want an informative moment, other than zero
assert ip["mix1_interval__"] != 0

expected_mix_ip = (
IntervalTransform(args_fn=lambda *args: (-0.5, 0.5))
.forward(0.3 * ip["uniform"] + 0.7 * ip["triangular"])
.eval()
)
assert np.isclose(ip["mix1_interval__"], ip["mix2_interval__"])
assert np.isclose(ip["mix1_interval__"], expected_mix_ip)

def test_logp(self):
with Model() as m:
halfnorm = HalfNormal("halfnorm")
comp_dists = [HalfNormal.dist(), HalfNormal.dist()]
mix_transf = Mixture("mix_transf", w=[0.5, 0.5], comp_dists=comp_dists)
mix = Mixture("mix", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)

logp_fn = m.compile_logp(vars=[halfnorm, mix_transf, mix], sum=False)
test_point = {"halfnorm_log__": 1, "mix_transf_log__": 1, "mix": np.exp(1)}
logp_halfnorm, logp_mix_transf, logp_mix = logp_fn(test_point)
assert np.isclose(logp_halfnorm, logp_mix_transf)
assert np.isclose(logp_halfnorm, logp_mix + 1)

def test_warning(self):
with Model() as m:
comp_dists = [HalfNormal.dist(), Exponential.dist(1)]
with pytest.warns(None) as rec:
Mixture("mix1", w=[0.5, 0.5], comp_dists=comp_dists)
assert not rec

comp_dists = [Uniform.dist(0, 1), Uniform.dist(0, 2)]
with pytest.warns(MixtureTransformWarning):
Mixture("mix2", w=[0.5, 0.5], comp_dists=comp_dists)

comp_dists = [Normal.dist(), HalfNormal.dist()]
with pytest.warns(MixtureTransformWarning):
Mixture("mix3", w=[0.5, 0.5], comp_dists=comp_dists)

with pytest.warns(None) as rec:
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
assert not rec

with pytest.warns(None) as rec:
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)
assert not rec

0 comments on commit fa015e3

Please sign in to comment.