Skip to content

Commit

Permalink
Refactored Wishart and MatrixNormal distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Jul 2, 2021
1 parent c405c8d commit 65b49fc
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 252 deletions.
259 changes: 104 additions & 155 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
from pymc3.distributions.distribution import Continuous, Discrete
from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
from pymc3.math import kron_diag, kron_dot

__all__ = [
Expand Down Expand Up @@ -739,6 +740,26 @@ def __str__(self):
matrix_pos_def = PosDefMatrix()


class WishartRV(RandomVariable):
name = "wishart"
ndim_supp = 2
ndims_params = [0, 2]
dtype = "floatX"
_print_name = ("Wishart", "\\operatorname{Wishart}")

def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
# The shape of second parameter `V` defines the shape of the output.
return dist_params[1].shape

@classmethod
def rng_fn(cls, rng, nu, V, size=None):
size = size if size else 1 # Default size for Scipy's wishart.rvs is 1
return stats.wishart.rvs(np.int(nu), V, size=size, random_state=rng)


wishart = WishartRV()


class Wishart(Continuous):
r"""
Wishart log-likelihood.
Expand Down Expand Up @@ -775,46 +796,19 @@ class Wishart(Continuous):
This distribution is unusable in a PyMC3 model. You should instead
use LKJCholeskyCov or LKJCorr.
"""
rv_op = wishart

def __init__(self, nu, V, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"The Wishart distribution can currently not be used "
"for MCMC sampling. The probability of sampling a "
"symmetric matrix is basically zero. Instead, please "
"use LKJCholeskyCov or LKJCorr. For more information "
"on the issues surrounding the Wishart see here: "
"https://github.com/pymc-devs/pymc3/issues/538.",
UserWarning,
)
self.nu = nu = at.as_tensor_variable(nu)
self.p = p = at.as_tensor_variable(V.shape[0])
self.V = V = at.as_tensor_variable(V)
self.mean = nu * V
self.mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan)

def random(self, point=None, size=None):
"""
Draw random values from Wishart distribution.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
@classmethod
def dist(cls, nu, V, *args, **kwargs):
nu = at.as_tensor_variable(intX(nu))
V = at.as_tensor_variable(floatX(V))

Returns
-------
array
"""
# nu, V = draw_values([self.nu, self.V], point=point, size=size)
# size = 1 if size is None else size
# return generate_samples(stats.wishart.rvs, nu.item(), V, broadcast_shape=(size,))
# mean = nu * V
# p = V.shape[0]
# mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan)
return super().dist([nu, V], *args, **kwargs)

def logp(self, X):
def logp(X, nu, V):
"""
Calculate log-probability of Wishart distribution
at specified value.
Expand All @@ -828,9 +822,8 @@ def logp(self, X):
-------
TensorVariable
"""
nu = self.nu
p = self.p
V = self.V

p = V.shape[0]

IVI = det(V)
IXI = det(X)
Expand Down Expand Up @@ -1445,6 +1438,36 @@ def _distr_parameters_for_repr(self):
return ["eta", "n"]


class MatrixNormalRV(RandomVariable):
name = "matrixnormal"
ndim_supp = 2
ndims_params = [2, 2, 2]
dtype = "floatX"
_print_name = ("MatrixNormal", "\\operatorname{MatrixNormal}")

@classmethod
def rng_fn(cls, rng, mu, rowchol, colchol, size=None):

size = to_tuple(size)
dist_shape = to_tuple([rowchol.shape[0], colchol.shape[0]])
output_shape = size + dist_shape

# Broadcasting all parameters
(mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])

colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
colchol = np.swapaxes(colchol, -1, -2) # Take transpose

standard_normal = rng.standard_normal(output_shape)
samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))

return samples


matrixnormal = MatrixNormalRV()


class MatrixNormal(Continuous):
r"""
Matrix-valued normal log-likelihood.
Expand Down Expand Up @@ -1533,175 +1556,101 @@ class MatrixNormal(Continuous):
vals = pm.MatrixNormal('vals', mu=mu, colchol=colchol, rowcov=rowcov,
observed=data, shape=(m, n))
"""
rv_op = matrixnormal

def __init__(
self,
mu=0,
@classmethod
def dist(
cls,
mu,
rowcov=None,
rowchol=None,
rowtau=None,
colcov=None,
colchol=None,
coltau=None,
shape=None,
*args,
**kwargs,
):
self._setup_matrices(colcov, colchol, coltau, rowcov, rowchol, rowtau)
if shape is None:
raise TypeError("shape is a required argument")
assert len(shape) == 2, "shape must have length 2: mxn"
self.shape = shape
super().__init__(shape=shape, *args, **kwargs)
self.mu = at.as_tensor_variable(mu)
self.mean = self.median = self.mode = self.mu
self.solve_lower = solve_lower_triangular
self.solve_upper = solve_upper_triangular

def _setup_matrices(self, colcov, colchol, coltau, rowcov, rowchol, rowtau):

cholesky = Cholesky(lower=True, on_error="raise")

if mu.ndim == 1:
raise ValueError(
"1x1 Matrix was provided. Please use Normal distribution for such cases."
)

# Among-row matrices
if len([i for i in [rowtau, rowcov, rowchol] if i is not None]) != 1:
if len([i for i in [rowcov, rowchol] if i is not None]) != 1:
raise ValueError(
"Incompatible parameterization. "
"Specify exactly one of rowtau, rowcov, "
"or rowchol."
"Incompatible parameterization. Specify exactly one of rowcov, or rowchol."
)
if rowcov is not None:
self.m = rowcov.shape[0]
self._rowcov_type = "cov"
rowcov = at.as_tensor_variable(rowcov)
if rowcov.ndim != 2:
raise ValueError("rowcov must be two dimensional.")
self.rowchol_cov = cholesky(rowcov)
self.rowcov = rowcov
elif rowtau is not None:
raise ValueError("rowtau not supported at this time")
self.m = rowtau.shape[0]
self._rowcov_type = "tau"
rowtau = at.as_tensor_variable(rowtau)
if rowtau.ndim != 2:
raise ValueError("rowtau must be two dimensional.")
self.rowchol_tau = cholesky(rowtau)
self.rowtau = rowtau
rowchol_cov = cholesky(rowcov)
else:
self.m = rowchol.shape[0]
self._rowcov_type = "chol"
if rowchol.ndim != 2:
raise ValueError("rowchol must be two dimensional.")
self.rowchol_cov = at.as_tensor_variable(rowchol)
rowchol_cov = at.as_tensor_variable(rowchol)

# Among-column matrices
if len([i for i in [coltau, colcov, colchol] if i is not None]) != 1:
if len([i for i in [colcov, colchol] if i is not None]) != 1:
raise ValueError(
"Incompatible parameterization. "
"Specify exactly one of coltau, colcov, "
"or colchol."
"Incompatible parameterization. Specify exactly one of colcov, or colchol."
)
if colcov is not None:
self.n = colcov.shape[0]
self._colcov_type = "cov"
colcov = at.as_tensor_variable(colcov)
if colcov.ndim != 2:
raise ValueError("colcov must be two dimensional.")
self.colchol_cov = cholesky(colcov)
self.colcov = colcov
elif coltau is not None:
raise ValueError("coltau not supported at this time")
self.n = coltau.shape[0]
self._colcov_type = "tau"
coltau = at.as_tensor_variable(coltau)
if coltau.ndim != 2:
raise ValueError("coltau must be two dimensional.")
self.colchol_tau = cholesky(coltau)
self.coltau = coltau
colchol_cov = cholesky(colcov)
else:
self.n = colchol.shape[0]
self._colcov_type = "chol"
if colchol.ndim != 2:
raise ValueError("colchol must be two dimensional.")
self.colchol_cov = at.as_tensor_variable(colchol)
colchol_cov = at.as_tensor_variable(colchol)

def random(self, point=None, size=None):
mu = at.as_tensor_variable(floatX(mu))
# mean = median = mode = mu

return super().dist([mu, rowchol_cov, colchol_cov], **kwargs)

def logp(value, mu, rowchol, colchol):
"""
Draw random values from Matrix-valued Normal distribution.
Calculate log-probability of Matrix-valued Normal distribution
at specified value.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
value: numeric
Value for which log-probability is calculated.
Returns
-------
array
TensorVariable
"""
# mu, colchol, rowchol = draw_values(
# [self.mu, self.colchol_cov, self.rowchol_cov], point=point, size=size
# )
# size = to_tuple(size)
# dist_shape = to_tuple(self.shape)
# output_shape = size + dist_shape
#
# # Broadcasting all parameters
# (mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
# rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
#
# colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
# colchol = np.swapaxes(colchol, -1, -2) # Take transpose
#
# standard_normal = np.random.standard_normal(output_shape)
# samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))
# return samples

def _trquaddist(self, value):
"""Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and
the logdet of colcov and rowcov."""

delta = value - self.mu
rowchol_cov = self.rowchol_cov
colchol_cov = self.colchol_cov
# Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and
# the logdet of colcov and rowcov.
delta = value - mu

# Find exponent piece by piece
right_quaddist = self.solve_lower(rowchol_cov, delta)
right_quaddist = solve_lower_triangular(rowchol, delta)
quaddist = at.nlinalg.matrix_dot(right_quaddist.T, right_quaddist)
quaddist = self.solve_lower(colchol_cov, quaddist)
quaddist = self.solve_upper(colchol_cov.T, quaddist)
quaddist = solve_lower_triangular(colchol, quaddist)
quaddist = solve_upper_triangular(colchol.T, quaddist)
trquaddist = at.nlinalg.trace(quaddist)

coldiag = at.diag(colchol_cov)
rowdiag = at.diag(rowchol_cov)
coldiag = at.diag(colchol)
rowdiag = at.diag(rowchol)
half_collogdet = at.sum(at.log(coldiag)) # logdet(M) = 2*Tr(log(L))
half_rowlogdet = at.sum(at.log(rowdiag)) # Using Cholesky: M = L L^T
return trquaddist, half_collogdet, half_rowlogdet

def logp(self, value):
"""
Calculate log-probability of Matrix-valued Normal distribution
at specified value.

Parameters
----------
value: numeric
Value for which log-probability is calculated.
m = rowchol.shape[0]
n = colchol.shape[0]

Returns
-------
TensorVariable
"""
trquaddist, half_collogdet, half_rowlogdet = self._trquaddist(value)
m = self.m
n = self.n
norm = -0.5 * m * n * pm.floatX(np.log(2 * np.pi))
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet

def _distr_parameters_for_repr(self):
mapping = {"tau": "tau", "cov": "cov", "chol": "chol_cov"}
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]
return ["mu"]


class KroneckerNormalRV(RandomVariable):
Expand Down
Loading

0 comments on commit 65b49fc

Please sign in to comment.