Skip to content

Commit

Permalink
Bivariate von Mises Distribution (#2821)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning authored May 19, 2021
1 parent f13ea00 commit c340831
Show file tree
Hide file tree
Showing 8 changed files with 438 additions and 14 deletions.
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,13 @@ Rejector
:undoc-members:
:show-inheritance:

SineBivariateVonMises
---------------------
.. autoclass:: pyro.distributions.SineBivariateVonMises
:members:
:undoc-members:
:show-inheritance:

SoftLaplace
-------------
.. autoclass:: pyro.distributions.SoftLaplace
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
RelaxedBernoulliStraightThrough,
RelaxedOneHotCategoricalStraightThrough,
)
from pyro.distributions.sine_bivariate_von_mises import SineBivariateVonMises
from pyro.distributions.softlaplace import SoftLaplace
from pyro.distributions.spanning_tree import SpanningTree
from pyro.distributions.stable import Stable
Expand Down Expand Up @@ -128,6 +129,7 @@
"Rejector",
"RelaxedBernoulliStraightThrough",
"RelaxedOneHotCategoricalStraightThrough",
"SineBivariateVonMises",
"SoftLaplace",
"SpanningTree",
"Stable",
Expand Down
228 changes: 228 additions & 0 deletions pyro/distributions/sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
import warnings
from math import pi

import torch
from torch.distributions import VonMises
from torch.distributions.utils import broadcast_all, lazy_property

from pyro.distributions import constraints
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape
from pyro.ops.special import log_I1


class SineBivariateVonMises(TorchDistribution):
r""" Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by
.. math::
C^{-1}\exp(\kappa_1\cos(x-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2))
and
.. math::
C = (2\pi)^2 \sum_{i=0} {2i \choose i}
\left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2),
where I_i(\cdot) is the modified bessel function of first kind, mu's are the locations of the distribution,
kappa's are the concentration and rho gives the correlation between angles x_1 and x_2.
This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.
To infer parameters, use :class:`~pyro.infer.NUTS` or :class:`~pyro.infer.HMC` with priors that
avoid parameterizations where the distribution becomes bimodal; see note below.
.. note:: Sample efficiency drops as
.. math::
\frac{\rho}{\kappa_1\kappa_2} \rightarrow 1
because the distribution becomes increasingly bimodal.
.. note:: The correlation and weighted_correlation params are mutually exclusive.
.. note:: In the context of :class:`~pyro.infer.SVI`, this distribution can be used as a likelihood but not for
latent variables.
** References: **
1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)
:param torch.Tensor phi_loc: location of first angle
:param torch.Tensor psi_loc: location of second angle
:param torch.Tensor phi_concentration: concentration of first angle
:param torch.Tensor psi_concentration: concentration of second angle
:param torch.Tensor correlation: correlation between the two angles
:param torch.Tensor weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc)
to avoid bimodality (see note).
"""

arg_constraints = {'phi_loc': constraints.real, 'psi_loc': constraints.real,
'phi_concentration': constraints.positive, 'psi_concentration': constraints.positive,
'correlation': constraints.real}
support = constraints.independent(constraints.real, 1)
max_sample_iter = 1000

def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None,
weighted_correlation=None, validate_args=None):

assert (correlation is None) != (weighted_correlation is None)

if weighted_correlation is not None:
sqrt_ = torch.sqrt if isinstance(phi_concentration, torch.Tensor) else math.sqrt
correlation = weighted_correlation * sqrt_(phi_concentration * psi_concentration) + 1e-8

phi_loc, psi_loc, phi_concentration, psi_concentration, correlation = broadcast_all(phi_loc, psi_loc,
phi_concentration,
psi_concentration,
correlation)
self.phi_loc = phi_loc
self.psi_loc = psi_loc
self.phi_concentration = phi_concentration
self.psi_concentration = psi_concentration
self.correlation = correlation
event_shape = torch.Size([2])
batch_shape = phi_loc.shape

super().__init__(batch_shape, event_shape, validate_args)

if self._validate_args and torch.any(phi_concentration * psi_concentration <= correlation ** 2):
warnings.warn(
f'{self.__class__.__name__} bimodal due to concentration-correlation relation, '
f'sampling will likely fail.', UserWarning)

@lazy_property
def norm_const(self):
corr = self.correlation.view(1, -1) + 1e-8
conc = torch.stack((self.phi_concentration, self.psi_concentration), dim=-1).view(-1, 2)
m = torch.arange(50, device=self.phi_loc.device).view(-1, 1)
fs = SineBivariateVonMises._lbinoms(m.max() + 1).view(-1, 1) + 2 * m * torch.log(corr) - m * torch.log(
4 * torch.prod(conc, dim=-1))
fs += log_I1(m.max(), conc, 51).sum(-1)
mfs = fs.max()
norm_const = 2 * torch.log(torch.tensor(2 * pi)) + mfs + (fs - mfs).logsumexp(0)
return norm_const.reshape(self.phi_loc.shape)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
indv = self.phi_concentration * torch.cos(value[..., 0] - self.phi_loc) + self.psi_concentration * torch.cos(
value[..., 1] - self.psi_loc)
corr = self.correlation * torch.sin(value[..., 0] - self.phi_loc) * torch.sin(value[..., 1] - self.psi_loc)
return indv + corr - self.norm_const

def sample(self, sample_shape=torch.Size()):
"""
** References: **
1. A New Unified Approach for the Simulation of aWide Class of Directional Distributions
John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)
"""
assert not torch._C._get_tracing_state(), "jit not supported"
sample_shape = torch.Size(sample_shape)

corr = self.correlation
conc = torch.stack((self.phi_concentration, self.psi_concentration))

eig = 0.5 * (conc[0] - corr ** 2 / conc[1])
eig = torch.stack((torch.zeros_like(eig), eig))
eigmin = torch.where(eig[1] < 0, eig[1], torch.zeros_like(eig[1], dtype=eig.dtype))
eig = eig - eigmin
b0 = self._bfind(eig)

total = sample_shape.numel()
missing = total * torch.ones((self.batch_shape.numel(),), dtype=torch.int, device=conc.device)
start = torch.zeros_like(missing, device=conc.device)
phi = torch.empty((2, *missing.shape, total), dtype=corr.dtype, device=conc.device)

max_iter = SineBivariateVonMises.max_sample_iter

# flatten batch_shape
conc = conc.view(2, -1, 1)
eigmin = eigmin.view(-1, 1)
corr = corr.reshape(-1, 1)
eig = eig.view(2, -1)
b0 = b0.view(-1)
phi_den = log_I1(0, conc[1]).view(-1, 1)
lengths = torch.arange(total, device=conc.device).view(1, -1)

while torch.any(missing > 0) and max_iter:
curr_conc = conc[:, missing > 0, :]
curr_corr = corr[missing > 0]
curr_eig = eig[:, missing > 0]
curr_b0 = b0[missing > 0]

x = torch.distributions.Normal(0., torch.sqrt(1 + 2 * curr_eig / curr_b0)).sample(
(missing[missing > 0].min(),)).view(2, -1, missing[missing > 0].min())
x /= x.norm(dim=0)[None, ...] # Angular Central Gaussian distribution

lf = curr_conc[0] * (x[0] - 1) + eigmin[missing > 0] + log_I1(0, torch.sqrt(
curr_conc[1] ** 2 + (curr_corr * x[1]) ** 2)).squeeze(0) - phi_den[missing > 0]
assert lf.shape == ((missing > 0).sum(), missing[missing > 0].min())

lg_inv = 1. - curr_b0.view(-1, 1) / 2 + torch.log(
curr_b0.view(-1, 1) / 2 + (curr_eig.view(2, -1, 1) * x ** 2).sum(0))
assert lg_inv.shape == lf.shape

accepted = torch.distributions.Uniform(0., torch.ones((), device=conc.device)).sample(lf.shape) < (
lf + lg_inv).exp()

phi_mask = torch.zeros((*missing.shape, total), dtype=torch.bool, device=conc.device)
phi_mask[missing > 0] = torch.logical_and(lengths < (start[missing > 0] + accepted.sum(-1)).view(-1, 1),
lengths >= start[missing > 0].view(-1, 1))

phi[:, phi_mask] = x[:, accepted]

start[missing > 0] += accepted.sum(-1)
missing[missing > 0] -= accepted.sum(-1)
max_iter -= 1

if max_iter == 0 or torch.any(missing > 0):
raise ValueError("maximum number of iterations exceeded; "
"try increasing `SineBivariateVonMises.max_sample_iter`")

phi = torch.atan2(phi[1], phi[0])

alpha = torch.sqrt(conc[1] ** 2 + (corr * torch.sin(phi)) ** 2)
beta = torch.atan(corr / conc[1] * torch.sin(phi))

psi = VonMises(beta, alpha).sample()

phi_psi = torch.stack(((phi + self.phi_loc.reshape((-1, 1)) + pi) % (2 * pi) - pi,
(psi + self.psi_loc.reshape((-1, 1)) + pi) % (2 * pi) - pi), dim=-1).permute(1, 0, 2)
return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape)

@property
def mean(self):
return torch.stack((self.phi_loc, self.psi_loc), dim=-1)

@classmethod
def infer_shapes(cls, **arg_shapes):
batch_shape = torch.Size(broadcast_shape(*arg_shapes.values()))
return batch_shape, torch.Size([2])

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(SineBivariateVonMises, _instance)
batch_shape = torch.Size(batch_shape)
for k in SineBivariateVonMises.arg_constraints.keys():
setattr(new, k, getattr(self, k).expand(batch_shape))
new.norm_const = self.norm_const.expand(batch_shape)
super(SineBivariateVonMises, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _bfind(self, eig):
b = eig.shape[0] / 2 * torch.ones(self.batch_shape, dtype=eig.dtype, device=eig.device)
g1 = torch.sum(1 / (b + 2 * eig) ** 2, dim=0)
g2 = torch.sum(-2 / (b + 2 * eig) ** 3, dim=0)
return torch.where(eig.norm(0) != 0, b - g1 / g2, b)

@staticmethod
def _lbinoms(n):
ns = torch.arange(n, device=n.device)
num = torch.lgamma(2 * ns + 1)
den = torch.lgamma(ns + 1)
return num - 2 * den
43 changes: 43 additions & 0 deletions pyro/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,46 @@ def log_binomial(n, k, tol=0.):
return n_plus_1.lgamma() - (k + 1).lgamma() - (n_plus_1 - k).lgamma()

return -n_plus_1.log() - log_beta(k + 1, n_plus_1 - k, tol=tol)


def log_I1(orders: int, value: torch.Tensor, terms=250):
r""" Compute first n log modified bessel function of first kind
.. math ::
\log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk)
- \lgamma(v + k + 1)\right])
:param orders: orders of the log modified bessel function.
:param value: values to compute modified bessel function for
:param terms: truncation of summation
:return: 0 to orders modified bessel function
"""
orders = orders + 1
if len(value.size()) == 0:
vshape = torch.Size([1])
else:
vshape = value.shape
value = value.reshape(-1, 1)

k = torch.arange(terms, device=value.device)
lgammas_all = torch.lgamma(torch.arange(1, terms + orders + 1, device=value.device))
assert lgammas_all.shape == (orders + terms,) # lgamma(0) = inf => start from 1

lvalues = torch.log(value / 2) * k.view(1, -1)
assert lvalues.shape == (vshape.numel(), terms)

lfactorials = lgammas_all[:terms]
assert lfactorials.shape == (terms,)

lgammas = lgammas_all.repeat(orders).view(orders, -1)
assert lgammas.shape == (orders, terms + orders) # lgamma(0) = inf => start from 1

indices = k[:orders].view(-1, 1) + k.view(1, -1)
assert indices.shape == (orders, terms)

seqs = (2 * lvalues[None, :, :] - lfactorials[None, None, :] - lgammas.gather(1, indices)[:, None, :]).logsumexp(-1)
assert seqs.shape == (orders, vshape.numel())

i1s = lvalues[..., :orders].T + seqs
assert i1s.shape == (orders, vshape.numel())
return i1s.view(-1, *vshape)
36 changes: 24 additions & 12 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(self, rate, *, validate_args=None):
],
# This hack seems to be the best option right now, as 'scale' is not handled well by get_scipy_batch_logpdf
scipy_arg_fn=lambda loc, covariance_matrix=None:
((), {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}),
((), {"mean": np.array(loc), "cov": np.array([[1.0, 0.5], [0.5, 1.0]])}),
prec=0.01,
min_samples=500000),
Fixture(pyro_dist=dist.LowRankMultivariateNormal,
Expand All @@ -197,7 +197,7 @@ def __init__(self, rate, *, validate_args=None):
'test_data': [[2.0, 1.0], [9.0, 3.4]]},
],
scipy_arg_fn=lambda loc, cov_diag=None, cov_factor=None:
((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}),
((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}),
prec=0.01,
min_samples=500000),
Fixture(pyro_dist=FoldedNormal,
Expand Down Expand Up @@ -280,12 +280,12 @@ def __init__(self, rate, *, validate_args=None):
Fixture(pyro_dist=dist.LKJ,
examples=[
{'dim': 3, 'concentration': 1., 'test_data':
[[[1.0000, -0.8221, 0.7655], [-0.8221, 1.0000, -0.5293], [0.7655, -0.5293, 1.0000]],
[[1.0000, -0.5345, -0.5459], [-0.5345, 1.0000, -0.0333], [-0.5459, -0.0333, 1.0000]],
[[1.0000, -0.3758, -0.2409], [-0.3758, 1.0000, 0.4653], [-0.2409, 0.4653, 1.0000]],
[[1.0000, -0.8800, -0.9493], [-0.8800, 1.0000, 0.9088], [-0.9493, 0.9088, 1.0000]],
[[1.0000, 0.2284, -0.1283], [0.2284, 1.0000, 0.0146], [-0.1283, 0.0146, 1.0000]]]},
]),
[[[1.0000, -0.8221, 0.7655], [-0.8221, 1.0000, -0.5293], [0.7655, -0.5293, 1.0000]],
[[1.0000, -0.5345, -0.5459], [-0.5345, 1.0000, -0.0333], [-0.5459, -0.0333, 1.0000]],
[[1.0000, -0.3758, -0.2409], [-0.3758, 1.0000, 0.4653], [-0.2409, 0.4653, 1.0000]],
[[1.0000, -0.8800, -0.9493], [-0.8800, 1.0000, 0.9088], [-0.9493, 0.9088, 1.0000]],
[[1.0000, 0.2284, -0.1283], [0.2284, 1.0000, 0.0146], [-0.1283, 0.0146, 1.0000]]]},
]),
Fixture(pyro_dist=dist.LKJCholesky,
examples=[
{
Expand All @@ -305,19 +305,31 @@ def __init__(self, rate, *, validate_args=None):
examples=[
{'stability': [1.5], 'skew': 0.1, 'test_data': [-10.]},
{'stability': [1.5], 'skew': 0.1, 'scale': 2.0, 'loc': -2.0, 'test_data': [10.]},
]),
]),
Fixture(pyro_dist=dist.MultivariateStudentT,
examples=[
{'df': 1.5, 'loc': [0.2, 0.3], 'scale_tril': [[0.8, 0.0], [1.3, 0.4]],
'test_data': [-3., 2]},
]),
]),
Fixture(pyro_dist=dist.ProjectedNormal,
examples=[
{'concentration': [0., 0.], 'test_data': [1., 0.]},
{'concentration': [2., 3.], 'test_data': [0., 1.]},
{'concentration': [0., 0., 0.], 'test_data': [1., 0., 0.]},
{'concentration': [-1., 2., 3.], 'test_data': [0., 0., 1.]},
]),
]),
Fixture(pyro_dist=dist.SineBivariateVonMises,
examples=[
{'phi_loc': [0.], 'psi_loc': [0.], 'phi_concentration': [5.], 'psi_concentration': [6.],
'correlation': [2.], 'test_data': [[0., 0.]]},
{'phi_loc': [3.003], 'psi_loc': [-1.343], 'phi_concentration': [5.], 'psi_concentration': [6.],
'correlation': [2.], 'test_data': [[0., 1.]]},
{'phi_loc': [-math.pi / 3], 'psi_loc': -1., 'phi_concentration': .5, 'psi_concentration': 10.,
'correlation': .9, 'test_data': [[1., 0.555]]},
{'phi_loc': [math.pi - .2, 1.], 'psi_loc': [0., 1.],
'phi_concentration': [5., 5.], 'psi_concentration': [7., .5],
'weighted_correlation': [.5, .1], 'test_data': [[[1., -3.], [1., 59.]]]},
]),
Fixture(pyro_dist=dist.SoftLaplace,
examples=[
{'loc': [2.0], 'scale': [4.0],
Expand All @@ -328,7 +340,7 @@ def __init__(self, rate, *, validate_args=None):
'test_data': [[[2.0]]]},
{'loc': [2.0, 50.0], 'scale': [4.0, 100.0],
'test_data': [[2.0, 50.0], [2.0, 50.0]]},
]),
]),
]

discrete_dists = [
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_gof(continuous_dist):

def test_mean(continuous_dist):
Dist = continuous_dist.pyro_dist
if Dist.__name__ in ["Cauchy", "HalfCauchy", "VonMises", "ProjectedNormal"]:
if Dist.__name__ in ["Cauchy", "HalfCauchy", "SineBivariateVonMises", "VonMises", "ProjectedNormal"]:
pytest.xfail(reason="Euclidean mean is not defined")
for i in range(continuous_dist.get_num_test_data()):
d = Dist(**continuous_dist.get_dist_params(i))
Expand Down
Loading

0 comments on commit c340831

Please sign in to comment.