Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sine Skewed Toridial distribution #2826

Merged
merged 42 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
585beb9
Bump to version 1.5.2 (#2755)
fritzo Feb 1, 2021
260d05b
Merge branch 'dev'
fritzo Mar 4, 2021
09bcbc0
Added sine skewed distribution and tests.
OlaRonning Apr 27, 2021
b7ae4d1
Added repr.
OlaRonning Apr 27, 2021
85e352c
Fixed shape tests and minor fixes to docstring.
OlaRonning Apr 27, 2021
7ee6643
Fixed lint.
OlaRonning Apr 27, 2021
789f550
Updated docstring with uniform prior.
OlaRonning Apr 28, 2021
be91d9a
Fixed skewness shape assertion.
OlaRonning Apr 28, 2021
2a200d3
ensure `SineSkewed` is on the torus.
OlaRonning Apr 30, 2021
e7a1a74
Reverted `infer_shapes` in `sine_skewed` and `# isort: split` in `dis…
OlaRonning May 1, 2021
1b2e1ca
Merge branch 'feature/ss_dist' of github.com:aleatory-science/pyro in…
OlaRonning May 1, 2021
e802aa4
Sketched `SineSkewed.expand`
OlaRonning May 1, 2021
d1801b9
Fixed `SineSkewed.log_prob`.
OlaRonning May 2, 2021
3b44ebe
Added pep exception to `distributions.__init__`
OlaRonning May 2, 2021
84ac72e
Fixed `SineSkewed` on cuda.
OlaRonning May 3, 2021
c92ef62
Restricted `event_dim=2`
OlaRonning May 5, 2021
906211a
Fixed doc_string and updated tests.
OlaRonning May 5, 2021
e237531
fixed linting
OlaRonning May 5, 2021
5e4020a
fixed arg_constraints
OlaRonning May 5, 2021
bd93a2b
cleaned __repr__
OlaRonning May 5, 2021
4646ce6
Fixed comments.
OlaRonning May 5, 2021
ffe50e6
Fixed `n_dim=1` and updated `test_sine_skewed`; missing updated fixtu…
OlaRonning May 7, 2021
d935f74
Added fixture.
OlaRonning May 7, 2021
9ca8a45
Fixed tests.
OlaRonning May 9, 2021
e98bb2b
Merge branch 'feature/ss_fix_dim' into feature/ss_dist
OlaRonning May 9, 2021
51f6365
Merge branch 'dev' of github.com:pyro-ppl/pyro into feature/ss_dist
OlaRonning May 9, 2021
dd461fd
removed deprecated add_stylesheet
OlaRonning May 9, 2021
5cfee34
reverted to `add_stylesheet`
OlaRonning May 9, 2021
6d79eb3
Removed raise from sine_skewed.py
OlaRonning May 10, 2021
ff26ce9
Added equation references.
OlaRonning May 10, 2021
4427e37
Fixed sampling bound in `SineSkewed`.
OlaRonning May 10, 2021
a7bf5fe
Fixed prior on `SineSkewed` to avoid `AffineTransform`.
OlaRonning May 10, 2021
8d5d684
Merge remote-tracking branch 'origin/feature/ss_dist' into feature/ss…
OlaRonning May 10, 2021
c9ede43
Merged origin.
OlaRonning May 10, 2021
b1e1408
Merge branch 'master' of github.com:pyro-ppl/pyro into feature/ss_dist
OlaRonning Jun 2, 2021
fac5864
removed import all pyro distributions
OlaRonning Jun 2, 2021
21af6e5
Merged upstream and fixed docstring for `SineSkewed`.
OlaRonning Jun 2, 2021
9fc36ff
Fixed tests for SineSkewed with wrapper class.
OlaRonning Jun 2, 2021
9b78a7a
Removed unused import from conftest.py
OlaRonning Jun 2, 2021
0b5af73
Removed xfail int test_cuda for `SineSkewed`
OlaRonning Jun 2, 2021
4d44aa0
Fixed DocString example.
OlaRonning Jun 4, 2021
3143b45
Fixed psi_phi name in docstring.
OlaRonning Jun 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,13 @@ SineBivariateVonMises
:undoc-members:
:show-inheritance:

SineSkewed
----------
.. autoclass:: pyro.distributions.SineSkewed
:members:
:undoc-members:
:show-inheritance:

SoftLaplace
-------------
.. autoclass:: pyro.distributions.SoftLaplace
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
RelaxedOneHotCategoricalStraightThrough,
)
from pyro.distributions.sine_bivariate_von_mises import SineBivariateVonMises
from pyro.distributions.sine_skewed import SineSkewed
from pyro.distributions.softlaplace import SoftLaplace
from pyro.distributions.spanning_tree import SpanningTree
from pyro.distributions.stable import Stable
Expand Down Expand Up @@ -132,6 +133,7 @@
"RelaxedBernoulliStraightThrough",
"RelaxedOneHotCategoricalStraightThrough",
"SineBivariateVonMises",
"SineSkewed",
"SoftLaplace",
"SpanningTree",
"Stable",
Expand Down
136 changes: 136 additions & 0 deletions pyro/distributions/sine_skewed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import warnings
from math import pi

import torch
from torch import broadcast_shapes
from torch.distributions import Uniform

from pyro.distributions import constraints

from .torch_distribution import TorchDistribution


class SineSkewed(TorchDistribution):
"""Sine Skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus
distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric)
base distribution.

Torus distributions are distributions with support on products of circles
(i.e., ⨂^d S^1 where S^1=[-pi,pi) ). So, a 0-torus is a point, the 1-torus is a circle,
and the 2-torus is commonly associated with the donut shape.

The Sine Skewed X distribution is parameterized by a weight parameter for each dimension of the event of X.
For example with a von Mises distribution over a circle (1-torus), the Sine Skewed von Mises Distribution has one
skew parameter. The skewness parameters can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`.
For example, the following will produce a uniform prior over skewness for the 2-torus,::

def model(obs):
# Sine priors
phi_loc = pyro.sample('phi_loc', VonMises(pi, 2.))
psi_loc = pyro.sample('psi_loc', VonMises(-pi / 2, 2.))
phi_conc = pyro.sample('phi_conc', Beta(halpha_phi, beta_prec_phi - halpha_phi))
psi_conc = pyro.sample('psi_conc', Beta(halpha_psi, beta_prec_psi - halpha_psi))
corr_scale = pyro.sample('corr_scale', Beta(2., 5.))

# SS prior
skew_phi = pyro.sample('skew_phi', Uniform(-1., 1.))
psi_bound = 1 - skew_phi.abs()
skew_psi = pyro.sample('skew_psi', Uniform(-1., 1.))
skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=-1)
assert skewness.shape == (num_mix_comp, 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a constraint+transform for this (in a follow-up PR). I believe we can use signed stick-breaking transform here. This way users can define distributions over skewness (or just simply use pyro.param with correct constraint). Without that, it is a bit cumbersome for users to define correct skewness over general d-torus.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With that, we can have correct constraints in the distribution definition. :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be neat; I'll add it to the backlog.


with pyro.plate('obs_plate'):
sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc,
phi_concentration=1000 * phi_conc,
psi_concentration=1000 * psi_conc,
weighted_correlation=corr_scale)
return pyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)

To ensure the skewing does not alter the normalization constant of the (Sine Bivaraite von Mises) base
distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of
skewness to be less than or equal to one.
So for the above snippet it must hold that::

skew_phi.abs()+skew_psi.abs() <= 1

We handle this in the prior by computing psi_bound and use it to scale skew_psi.
We do **not** use psi_bound as::

skew_psi = pyro.sample('skew_psi', Uniform(-psi_bound, psi_bound))

as it would make the support for the Uniform distribution dynamic.

In the context of :class:`~pyro.infer.SVI`, this distribution can freely be used as a likelihood, but use as
latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist
cannot be reparameterized.

.. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,).

.. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
must be less than or equal to one. See eq. 2.1 in [1].

** References: **
1. Sine-skewed toroidal distributions and their application in protein bioinformatics
Ameijeiras-Alonso, J., Ley, C. (2019)

:param torch.distributions.Distribution base_dist: base density on a d-dimensional torus. Supported base
distributions include: 1D :class:`~pyro.distributions.VonMises`,
:class:`~pyro.distributions.SineBivariateVonMises`, 1D :class:`~pyro.distributions.ProjectedNormal`, and
:class:`~pyro.distributions.Uniform` (-pi, pi).
:param torch.tensor skewness: skewness of the distribution.
"""
arg_constraints = {'skewness': constraints.independent(constraints.interval(-1., 1.), 1)}
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved

support = constraints.independent(constraints.real, 1)

def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None):
assert base_dist.event_shape == skewness.shape[-1:], \
'Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`.'

OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
if (skewness.abs().sum(-1) > 1.).any():
warnings.warn("Total skewness weight shouldn't exceed one.", UserWarning)

batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1])
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
event_shape = skewness.shape[-1:]
self.skewness = skewness.broadcast_to(batch_shape + event_shape)
self.base_dist = base_dist.expand(batch_shape)
super().__init__(batch_shape, event_shape, validate_args=validate_args)

if self._validate_args and base_dist.mean.device != skewness.device:
raise ValueError(f"base_density: {base_dist.__class__.__name__} and SineSkewed "
f"must be on same device.")

def __repr__(self):
args_string = ', '.join(['{}: {}'.format(p, getattr(self, p)
if getattr(self, p).numel() == 1
else getattr(self, p).size()) for p in self.arg_constraints.keys()])
return self.__class__.__name__ + '(' + f'base_density: {str(self.base_dist)}, ' + args_string + ')'

def sample(self, sample_shape=torch.Size()):
bd = self.base_dist
ys = bd.sample(sample_shape)
u = Uniform(0., self.skewness.new_ones(())).sample(sample_shape + self.batch_shape)

# Section 2.3 step 3 in [1]
mask = u <= .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1)
mask = mask[..., None]
samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
return samples

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Eq. 2.1 in [1]
skew_prob = torch.log1p((self.skewness * torch.sin((value - self.base_dist.mean) % (2 * pi))).sum(-1))
return self.base_dist.log_prob(value) + skew_prob

def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
new = self._get_checked_instance(SineSkewed, _instance)
base_dist = self.base_dist.expand(batch_shape)
new.base_dist = base_dist
new.skewness = self.skewness.expand(batch_shape + (-1,))
super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
31 changes: 31 additions & 0 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from math import pi

import numpy as np
import pytest
Expand Down Expand Up @@ -38,6 +39,18 @@ def __init__(self, rate, *, validate_args=None):
super().__init__(rate, is_sparse=True, validate_args=validate_args)


class SineSkewedUniform(dist.SineSkewed):
def __init__(self, lower, upper, skewness, *args, **kwargs):
base_dist = dist.Uniform(lower, upper).to_event(lower.ndim)
super().__init__(base_dist, skewness, *args, **kwargs)


class SineSkewedVonMises(dist.SineSkewed):
def __init__(self, von_loc, von_conc, skewness):
base_dist = dist.VonMises(von_loc, von_conc).to_event(von_loc.ndim)
super().__init__(base_dist, skewness)


continuous_dists = [
Fixture(pyro_dist=dist.Uniform,
scipy_dist=sp.uniform,
Expand Down Expand Up @@ -341,6 +354,24 @@ def __init__(self, rate, *, validate_args=None):
{'loc': [2.0, 50.0], 'scale': [4.0, 100.0],
'test_data': [[2.0, 50.0], [2.0, 50.0]]},
]),
Fixture(pyro_dist=SineSkewedUniform,
examples=[
{'lower': [-pi, -pi],
'upper':[pi, pi],
'skewness': [-pi / 4, .1],
'test_data': [pi / 2, -2 * pi / 3]}
]),
Fixture(pyro_dist=SineSkewedVonMises,
examples=[
{'von_loc': [0.],
'von_conc': [1.],
'skewness': [.342355],
'test_data': [.1]},
{'von_loc': [0., -1.234],
'von_conc': [1., 10.],
'skewness': [[.342355, -.0001], [.91, 0.09]],
'test_data': [[.1, -3.2], [-2., 0.]]}
]),
Fixture(pyro_dist=dist.AsymmetricLaplace,
examples=[
{'loc': [1.0], 'left_scale': [1.0], 'right_scale': [4.0],
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_support_shape(dist):


def test_infer_shapes(dist):
if "LKJ" in dist.pyro_dist.__name__:
if "LKJ" in dist.pyro_dist.__name__ or "SineSkewed" in dist.pyro_dist.__name__:
pytest.xfail(reason="cannot statically compute shape")
for idx in range(dist.get_num_test_data()):
dist_params = dist.get_dist_params(idx)
Expand Down
79 changes: 79 additions & 0 deletions tests/distributions/test_sine_skewed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from math import pi

import pytest
import torch

import pyro
from pyro.distributions import Normal, SineSkewed, Uniform, VonMises, constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from tests.common import assert_equal

BASE_DISTS = [(Uniform, [-pi, pi]), (VonMises, (0., 1.))]


def _skewness(event_shape):
skewness = torch.zeros(event_shape.numel())
done = False
while not done:
for i in range(event_shape.numel()):
max_ = 1. - skewness.abs().sum(-1)
if torch.any(max_ < 1e-15):
break
skewness[i] = Uniform(-max_, max_).sample()
done = not torch.any(max_ < 1e-15)

if event_shape == tuple():
skewness = skewness.reshape(event_shape)
else:
skewness = skewness.view(event_shape)
return skewness


@pytest.mark.parametrize('expand_shape',
[(1,), (2,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)])
@pytest.mark.parametrize('dist', BASE_DISTS)
def test_ss_multidim_log_prob(expand_shape, dist):
base_dist = dist[0](*(torch.tensor(param).expand(expand_shape) for param in dist[1])).to_event(1)

loc = base_dist.sample((10,)) + Normal(0., 1e-3).sample()

base_prob = base_dist.log_prob(loc)
skewness = _skewness(base_dist.event_shape)

ss = SineSkewed(base_dist, skewness)
assert_equal(base_prob.shape, ss.log_prob(loc).shape)
assert_equal(ss.sample().shape, torch.Size(expand_shape))


@pytest.mark.parametrize('dist', BASE_DISTS)
@pytest.mark.parametrize('dim', [1, 2])
def test_ss_mle(dim, dist):
base_dist = dist[0](*(torch.tensor(param).expand((dim,)) for param in dist[1])).to_event(1)

skewness_tar = _skewness(base_dist.event_shape)
data = SineSkewed(base_dist, skewness_tar).sample((1000,))

def model(data, batch_shape):
skews = []
for i in range(dim):
skews.append(pyro.param(f'skew{i}', .5 * torch.ones(batch_shape), constraint=constraints.interval(-1, 1)))

skewness = torch.stack(skews, dim=-1)
with pyro.plate("data", data.size(-len(data.size()))):
pyro.sample('obs', SineSkewed(base_dist, skewness), obs=data)

def guide(data, batch_shape):
pass

pyro.clear_param_store()
adam = Adam({"lr": .1})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

losses = []
steps = 80
for step in range(steps):
losses.append(svi.step(data, base_dist.batch_shape))

act_skewness = torch.stack([v for k, v in pyro.get_param_store().items() if 'skew' in k], dim=-1)
assert_equal(act_skewness, skewness_tar, 1e-1)