Skip to content

Commit

Permalink
Porting kroneckernormal distribution to v4 (#4774)
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>

Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
  • Loading branch information
kc611 and ricardoV94 authored Jun 27, 2021
1 parent c9a2b40 commit 77ce200
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 200 deletions.
201 changes: 74 additions & 127 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import warnings

from functools import reduce

import aesara
import aesara.tensor as at
import numpy as np
Expand Down Expand Up @@ -45,7 +47,7 @@
from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
from pymc3.distributions.distribution import Continuous, Discrete
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
from pymc3.math import kron_diag, kron_dot

__all__ = [
"MvNormal",
Expand Down Expand Up @@ -1702,6 +1704,32 @@ def _distr_parameters_for_repr(self):
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]


class KroneckerNormalRV(RandomVariable):
name = "kroneckernormal"
ndim_supp = 2
ndims_params = [1, 0, 2]
dtype = "floatX"
_print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}")

def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None):
return default_shape_from_params(1, dist_params, rep_param_idx, param_shapes)

def rng_fn(self, rng, mu, sigma, *covs, size=None):
size = size if size else covs[-1]
covs = covs[:-1] if covs[-1] == size else covs

cov = reduce(linalg.kron, covs)

if sigma:
cov = cov + sigma ** 2 * np.eye(cov.shape[0])

x = multivariate_normal.rng_fn(rng=rng, mean=mu, cov=cov, size=size)
return x


kroneckernormal = KroneckerNormalRV()


class KroneckerNormal(Continuous):
r"""
Multivariate normal log-likelihood with Kronecker-structured covariance.
Expand Down Expand Up @@ -1790,160 +1818,79 @@ class KroneckerNormal(Continuous):
----------
.. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models"
"""
rv_op = kroneckernormal

def __init__(self, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):
self._setup(covs, chols, evds, sigma)
super().__init__(*args, **kwargs)
self.mu = at.as_tensor_variable(mu)
self.mean = self.median = self.mode = self.mu
@classmethod
def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):

def _setup(self, covs, chols, evds, sigma):
self.cholesky = Cholesky(lower=True, on_error="raise")
if len([i for i in [covs, chols, evds] if i is not None]) != 1:
raise ValueError(
"Incompatible parameterization. Specify exactly one of covs, chols, or evds."
)
self._isEVD = False
self.sigma = sigma
self.is_noisy = self.sigma is not None and self.sigma != 0
if covs is not None:
self._cov_type = "cov"
self.covs = covs
if self.is_noisy:
# Noise requires eigendecomposition
eigh_map = map(eigh, covs)
self._setup_evd(eigh_map)
else:
# Otherwise use cholesky as usual
self.chols = list(map(self.cholesky, self.covs))
self.chol_diags = list(map(at.diag, self.chols))
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols])
self.N = at.prod(self.sizes)
elif chols is not None:
self._cov_type = "chol"
if self.is_noisy: # A strange case...
# Noise requires eigendecomposition
covs = [at.dot(chol, chol.T) for chol in chols]
eigh_map = map(eigh, covs)
self._setup_evd(eigh_map)
else:
self.chols = chols
self.chol_diags = list(map(at.diag, self.chols))
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols])
self.N = at.prod(self.sizes)
else:
self._cov_type = "evd"
self._setup_evd(evds)

def _setup_evd(self, eigh_iterable):
self._isEVD = True
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
self.Qs = list(map(at.as_tensor_variable, Qs))
self.QTs = list(map(at.transpose, self.Qs))

self.eigs_sep = list(map(at.as_tensor_variable, eigs_sep))
self.eigs = kron_diag(*self.eigs_sep) # Combine separate eigs
if self.is_noisy:
self.eigs += self.sigma ** 2
self.N = self.eigs.shape[0]

def _setup_random(self):
if not hasattr(self, "mv_params"):
self.mv_params = {"mu": self.mu}
if self._cov_type == "cov":
cov = kronecker(*self.covs)
if self.is_noisy:
cov = cov + self.sigma ** 2 * at.identity_like(cov)
self.mv_params["cov"] = cov
elif self._cov_type == "chol":
if self.is_noisy:
covs = []
for eig, Q in zip(self.eigs_sep, self.Qs):
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
covs.append(cov_i)
cov = kronecker(*covs)
if self.is_noisy:
cov = cov + self.sigma ** 2 * at.identity_like(cov)
self.mv_params["chol"] = self.cholesky(cov)
else:
self.mv_params["chol"] = kronecker(*self.chols)
elif self._cov_type == "evd":
covs = []
for eig, Q in zip(self.eigs_sep, self.Qs):
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
covs.append(cov_i)
cov = kronecker(*covs)
if self.is_noisy:
cov = cov + self.sigma ** 2 * at.identity_like(cov)
self.mv_params["cov"] = cov
sigma = sigma if sigma else 0

def random(self, point=None, size=None):
if chols is not None:
covs = [chol.dot(chol.T) for chol in chols]
elif evds is not None:
eigh_iterable = evds
covs = []
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
for eig, Q in zip(eigs_sep, Qs):
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
covs.append(cov_i)

mu = at.as_tensor_variable(mu)

# mean = median = mode = mu
return super().dist([mu, sigma, *covs], **kwargs)

def logp(value, mu, sigma, *covs):
"""
Draw random values from Multivariate Normal distribution
with Kronecker-structured covariance.
Calculate log-probability of Multivariate Normal distribution
with Kronecker-structured covariance at specified value.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
value: numeric
Value for which log-probability is calculated.
Returns
-------
array
TensorVariable
"""
# Expand params into terms MvNormal can understand to force consistency
self._setup_random()
self.mv_params["shape"] = self.shape
dist = MvNormal.dist(**self.mv_params)
return dist.random(point, size)

def _quaddist(self, value):
"""Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))"""
# Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))
if value.ndim > 2 or value.ndim == 0:
raise ValueError("Invalid dimension for value: %s" % value.ndim)
raise ValueError(f"Invalid dimension for value: {value.ndim}")
if value.ndim == 1:
onedim = True
value = value[None, :]
else:
onedim = False

delta = value - self.mu
if self._isEVD:
sqrt_quad = kron_dot(self.QTs, delta.T)
sqrt_quad = sqrt_quad / at.sqrt(self.eigs[:, None])
logdet = at.sum(at.log(self.eigs))
else:
sqrt_quad = kron_solve_lower(self.chols, delta.T)
logdet = 0
for chol_size, chol_diag in zip(self.sizes, self.chol_diags):
logchol = at.log(chol_diag) * self.N / chol_size
logdet += at.sum(2 * logchol)
delta = value - mu

eigh_iterable = map(eigh, covs)
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
Qs = list(map(at.as_tensor_variable, Qs))
QTs = list(map(at.transpose, Qs))

eigs_sep = list(map(at.as_tensor_variable, eigs_sep))
eigs = kron_diag(*eigs_sep) # Combine separate eigs
eigs += sigma ** 2
N = eigs.shape[0]

sqrt_quad = kron_dot(QTs, delta.T)
sqrt_quad = sqrt_quad / at.sqrt(eigs[:, None])
logdet = at.sum(at.log(eigs))

# Square each sample
quad = at.batched_dot(sqrt_quad.T, sqrt_quad.T)
if onedim:
quad = quad[0]
return quad, logdet

def logp(self, value):
"""
Calculate log-probability of Multivariate Normal distribution
with Kronecker-structured covariance at specified value.
Parameters
----------
value: numeric
Value for which log-probability is calculated.
Returns
-------
TensorVariable
"""
quad, logdet = self._quaddist(value)
return -(quad + logdet + self.N * at.log(2 * np.pi)) / 2.0
a = -(quad + logdet + N * at.log(2 * np.pi)) / 2.0
return a

def _distr_parameters_for_repr(self):
return ["mu"]
Expand Down
12 changes: 7 additions & 5 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,19 +388,19 @@ def matrix_normal_logpdf_chol(value, mu, rowchol, colchol):
)


def kron_normal_logpdf_cov(value, mu, covs, sigma):
def kron_normal_logpdf_cov(value, mu, covs, sigma, size=None):
cov = kronecker(*covs).eval()
if sigma is not None:
cov += sigma ** 2 * np.eye(*cov.shape)
return scipy.stats.multivariate_normal.logpdf(value, mu, cov).sum()


def kron_normal_logpdf_chol(value, mu, chols, sigma):
def kron_normal_logpdf_chol(value, mu, chols, sigma, size=None):
covs = [np.dot(chol, chol.T) for chol in chols]
return kron_normal_logpdf_cov(value, mu, covs, sigma=sigma)


def kron_normal_logpdf_evd(value, mu, evds, sigma):
def kron_normal_logpdf_evd(value, mu, evds, sigma, size=None):
covs = []
for eigs, Q in evds:
try:
Expand Down Expand Up @@ -1943,8 +1943,7 @@ def test_matrixnormal(self, n):

@pytest.mark.parametrize("n", [2, 3])
@pytest.mark.parametrize("m", [3])
@pytest.mark.parametrize("sigma", [None, 1.0])
@pytest.mark.xfail(reason="Distribution not refactored yet")
@pytest.mark.parametrize("sigma", [None, 1])
def test_kroneckernormal(self, n, m, sigma):
np.random.seed(5)
N = n * m
Expand Down Expand Up @@ -1990,6 +1989,9 @@ def test_kroneckernormal(self, n, m, sigma):
)

dom = Domain([np.random.randn(2, N) * 0.1], edges=(None, None), shape=(2, N))
cov_args["size"] = 2
chol_args["size"] = 2
evd_args["size"] = 2

self.check_logp(
KroneckerNormal,
Expand Down
Loading

0 comments on commit 77ce200

Please sign in to comment.