-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sine Skewed Toridial distribution #2826
Merged
Merged
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
585beb9
Bump to version 1.5.2 (#2755)
fritzo 260d05b
Merge branch 'dev'
fritzo 09bcbc0
Added sine skewed distribution and tests.
OlaRonning b7ae4d1
Added repr.
OlaRonning 85e352c
Fixed shape tests and minor fixes to docstring.
OlaRonning 7ee6643
Fixed lint.
OlaRonning 789f550
Updated docstring with uniform prior.
OlaRonning be91d9a
Fixed skewness shape assertion.
OlaRonning 2a200d3
ensure `SineSkewed` is on the torus.
OlaRonning e7a1a74
Reverted `infer_shapes` in `sine_skewed` and `# isort: split` in `dis…
OlaRonning 1b2e1ca
Merge branch 'feature/ss_dist' of github.com:aleatory-science/pyro in…
OlaRonning e802aa4
Sketched `SineSkewed.expand`
OlaRonning d1801b9
Fixed `SineSkewed.log_prob`.
OlaRonning 3b44ebe
Added pep exception to `distributions.__init__`
OlaRonning 84ac72e
Fixed `SineSkewed` on cuda.
OlaRonning c92ef62
Restricted `event_dim=2`
OlaRonning 906211a
Fixed doc_string and updated tests.
OlaRonning e237531
fixed linting
OlaRonning 5e4020a
fixed arg_constraints
OlaRonning bd93a2b
cleaned __repr__
OlaRonning 4646ce6
Fixed comments.
OlaRonning ffe50e6
Fixed `n_dim=1` and updated `test_sine_skewed`; missing updated fixtu…
OlaRonning d935f74
Added fixture.
OlaRonning 9ca8a45
Fixed tests.
OlaRonning e98bb2b
Merge branch 'feature/ss_fix_dim' into feature/ss_dist
OlaRonning 51f6365
Merge branch 'dev' of github.com:pyro-ppl/pyro into feature/ss_dist
OlaRonning dd461fd
removed deprecated add_stylesheet
OlaRonning 5cfee34
reverted to `add_stylesheet`
OlaRonning 6d79eb3
Removed raise from sine_skewed.py
OlaRonning ff26ce9
Added equation references.
OlaRonning 4427e37
Fixed sampling bound in `SineSkewed`.
OlaRonning a7bf5fe
Fixed prior on `SineSkewed` to avoid `AffineTransform`.
OlaRonning 8d5d684
Merge remote-tracking branch 'origin/feature/ss_dist' into feature/ss…
OlaRonning c9ede43
Merged origin.
OlaRonning b1e1408
Merge branch 'master' of github.com:pyro-ppl/pyro into feature/ss_dist
OlaRonning fac5864
removed import all pyro distributions
OlaRonning 21af6e5
Merged upstream and fixed docstring for `SineSkewed`.
OlaRonning 9fc36ff
Fixed tests for SineSkewed with wrapper class.
OlaRonning 9b78a7a
Removed unused import from conftest.py
OlaRonning 0b5af73
Removed xfail int test_cuda for `SineSkewed`
OlaRonning 4d44aa0
Fixed DocString example.
OlaRonning 3143b45
Fixed psi_phi name in docstring.
OlaRonning File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import warnings | ||
from math import pi | ||
|
||
import torch | ||
from torch import broadcast_shapes | ||
from torch.distributions import Uniform | ||
|
||
from pyro.distributions import constraints | ||
|
||
from .torch_distribution import TorchDistribution | ||
|
||
|
||
class SineSkewed(TorchDistribution): | ||
"""Sine Skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus | ||
distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric) | ||
base distribution. | ||
|
||
Torus distributions are distributions with support on products of circles | ||
(i.e., ⨂^d S^1 where S^1=[-pi,pi) ). So, a 0-torus is a point, the 1-torus is a circle, | ||
and the 2-torus is commonly associated with the donut shape. | ||
|
||
The Sine Skewed X distribution is parameterized by a weight parameter for each dimension of the event of X. | ||
For example with a von Mises distribution over a circle (1-torus), the Sine Skewed von Mises Distribution has one | ||
skew parameter. The skewness parameters can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. | ||
For example, the following will produce a uniform prior over skewness for the 2-torus,:: | ||
|
||
def model(obs): | ||
# Sine priors | ||
phi_loc = pyro.sample('phi_loc', VonMises(pi, 2.)) | ||
psi_loc = pyro.sample('psi_loc', VonMises(-pi / 2, 2.)) | ||
phi_conc = pyro.sample('phi_conc', Beta(halpha_phi, beta_prec_phi - halpha_phi)) | ||
psi_conc = pyro.sample('psi_conc', Beta(halpha_psi, beta_prec_psi - halpha_psi)) | ||
corr_scale = pyro.sample('corr_scale', Beta(2., 5.)) | ||
|
||
# SS prior | ||
skew_phi = pyro.sample('skew_phi', Uniform(-1., 1.)) | ||
psi_bound = 1 - skew_phi.abs() | ||
skew_psi = pyro.sample('skew_psi', Uniform(-1., 1.)) | ||
skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=-1) | ||
assert skewness.shape == (num_mix_comp, 2) | ||
|
||
with pyro.plate('obs_plate'): | ||
sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc, | ||
phi_concentration=1000 * phi_conc, | ||
psi_concentration=1000 * psi_conc, | ||
weighted_correlation=corr_scale) | ||
return pyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs) | ||
|
||
To ensure the skewing does not alter the normalization constant of the (Sine Bivaraite von Mises) base | ||
distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of | ||
skewness to be less than or equal to one. | ||
So for the above snippet it must hold that:: | ||
|
||
skew_phi.abs()+skew_psi.abs() <= 1 | ||
|
||
We handle this in the prior by computing psi_bound and use it to scale skew_psi. | ||
We do **not** use psi_bound as:: | ||
|
||
skew_psi = pyro.sample('skew_psi', Uniform(-psi_bound, psi_bound)) | ||
|
||
as it would make the support for the Uniform distribution dynamic. | ||
|
||
In the context of :class:`~pyro.infer.SVI`, this distribution can freely be used as a likelihood, but use as | ||
latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist | ||
cannot be reparameterized. | ||
|
||
.. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,). | ||
|
||
.. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event | ||
must be less than or equal to one. See eq. 2.1 in [1]. | ||
|
||
** References: ** | ||
1. Sine-skewed toroidal distributions and their application in protein bioinformatics | ||
Ameijeiras-Alonso, J., Ley, C. (2019) | ||
|
||
:param torch.distributions.Distribution base_dist: base density on a d-dimensional torus. Supported base | ||
distributions include: 1D :class:`~pyro.distributions.VonMises`, | ||
:class:`~pyro.distributions.SineBivariateVonMises`, 1D :class:`~pyro.distributions.ProjectedNormal`, and | ||
:class:`~pyro.distributions.Uniform` (-pi, pi). | ||
:param torch.tensor skewness: skewness of the distribution. | ||
""" | ||
arg_constraints = {'skewness': constraints.independent(constraints.interval(-1., 1.), 1)} | ||
OlaRonning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
support = constraints.independent(constraints.real, 1) | ||
|
||
def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): | ||
assert base_dist.event_shape == skewness.shape[-1:], \ | ||
'Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`.' | ||
|
||
OlaRonning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (skewness.abs().sum(-1) > 1.).any(): | ||
warnings.warn("Total skewness weight shouldn't exceed one.", UserWarning) | ||
|
||
batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) | ||
OlaRonning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
event_shape = skewness.shape[-1:] | ||
self.skewness = skewness.broadcast_to(batch_shape + event_shape) | ||
self.base_dist = base_dist.expand(batch_shape) | ||
super().__init__(batch_shape, event_shape, validate_args=validate_args) | ||
|
||
if self._validate_args and base_dist.mean.device != skewness.device: | ||
raise ValueError(f"base_density: {base_dist.__class__.__name__} and SineSkewed " | ||
f"must be on same device.") | ||
|
||
def __repr__(self): | ||
args_string = ', '.join(['{}: {}'.format(p, getattr(self, p) | ||
if getattr(self, p).numel() == 1 | ||
else getattr(self, p).size()) for p in self.arg_constraints.keys()]) | ||
return self.__class__.__name__ + '(' + f'base_density: {str(self.base_dist)}, ' + args_string + ')' | ||
|
||
def sample(self, sample_shape=torch.Size()): | ||
bd = self.base_dist | ||
ys = bd.sample(sample_shape) | ||
u = Uniform(0., self.skewness.new_ones(())).sample(sample_shape + self.batch_shape) | ||
|
||
# Section 2.3 step 3 in [1] | ||
mask = u <= .5 + .5 * (self.skewness * torch.sin((ys - bd.mean) % (2 * pi))).sum(-1) | ||
mask = mask[..., None] | ||
samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi | ||
OlaRonning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return samples | ||
|
||
def log_prob(self, value): | ||
if self._validate_args: | ||
self._validate_sample(value) | ||
|
||
# Eq. 2.1 in [1] | ||
skew_prob = torch.log1p((self.skewness * torch.sin((value - self.base_dist.mean) % (2 * pi))).sum(-1)) | ||
return self.base_dist.log_prob(value) + skew_prob | ||
|
||
def expand(self, batch_shape, _instance=None): | ||
batch_shape = torch.Size(batch_shape) | ||
new = self._get_checked_instance(SineSkewed, _instance) | ||
base_dist = self.base_dist.expand(batch_shape) | ||
new.base_dist = base_dist | ||
new.skewness = self.skewness.expand(batch_shape + (-1,)) | ||
super(SineSkewed, new).__init__(batch_shape, self.event_shape, validate_args=False) | ||
new._validate_args = self._validate_args | ||
return new |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from math import pi | ||
|
||
import pytest | ||
import torch | ||
|
||
import pyro | ||
from pyro.distributions import Normal, SineSkewed, Uniform, VonMises, constraints | ||
from pyro.infer import SVI, Trace_ELBO | ||
from pyro.optim import Adam | ||
from tests.common import assert_equal | ||
|
||
BASE_DISTS = [(Uniform, [-pi, pi]), (VonMises, (0., 1.))] | ||
|
||
|
||
def _skewness(event_shape): | ||
skewness = torch.zeros(event_shape.numel()) | ||
done = False | ||
while not done: | ||
for i in range(event_shape.numel()): | ||
max_ = 1. - skewness.abs().sum(-1) | ||
if torch.any(max_ < 1e-15): | ||
break | ||
skewness[i] = Uniform(-max_, max_).sample() | ||
done = not torch.any(max_ < 1e-15) | ||
|
||
if event_shape == tuple(): | ||
skewness = skewness.reshape(event_shape) | ||
else: | ||
skewness = skewness.view(event_shape) | ||
return skewness | ||
|
||
|
||
@pytest.mark.parametrize('expand_shape', | ||
[(1,), (2,), (4,), (1, 1), (1, 2), (10, 10), (1, 3, 1), (10, 1, 5), (1, 1, 1), (3, 2, 3)]) | ||
@pytest.mark.parametrize('dist', BASE_DISTS) | ||
def test_ss_multidim_log_prob(expand_shape, dist): | ||
base_dist = dist[0](*(torch.tensor(param).expand(expand_shape) for param in dist[1])).to_event(1) | ||
|
||
loc = base_dist.sample((10,)) + Normal(0., 1e-3).sample() | ||
|
||
base_prob = base_dist.log_prob(loc) | ||
skewness = _skewness(base_dist.event_shape) | ||
|
||
ss = SineSkewed(base_dist, skewness) | ||
assert_equal(base_prob.shape, ss.log_prob(loc).shape) | ||
assert_equal(ss.sample().shape, torch.Size(expand_shape)) | ||
|
||
|
||
@pytest.mark.parametrize('dist', BASE_DISTS) | ||
@pytest.mark.parametrize('dim', [1, 2]) | ||
def test_ss_mle(dim, dist): | ||
base_dist = dist[0](*(torch.tensor(param).expand((dim,)) for param in dist[1])).to_event(1) | ||
|
||
skewness_tar = _skewness(base_dist.event_shape) | ||
data = SineSkewed(base_dist, skewness_tar).sample((1000,)) | ||
|
||
def model(data, batch_shape): | ||
skews = [] | ||
for i in range(dim): | ||
skews.append(pyro.param(f'skew{i}', .5 * torch.ones(batch_shape), constraint=constraints.interval(-1, 1))) | ||
|
||
skewness = torch.stack(skews, dim=-1) | ||
with pyro.plate("data", data.size(-len(data.size()))): | ||
pyro.sample('obs', SineSkewed(base_dist, skewness), obs=data) | ||
|
||
def guide(data, batch_shape): | ||
pass | ||
|
||
pyro.clear_param_store() | ||
adam = Adam({"lr": .1}) | ||
svi = SVI(model, guide, adam, loss=Trace_ELBO()) | ||
|
||
losses = [] | ||
steps = 80 | ||
for step in range(steps): | ||
losses.append(svi.step(data, base_dist.batch_shape)) | ||
|
||
act_skewness = torch.stack([v for k, v in pyro.get_param_store().items() if 'skew' in k], dim=-1) | ||
assert_equal(act_skewness, skewness_tar, 1e-1) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to have a constraint+transform for this (in a follow-up PR). I believe we can use signed stick-breaking transform here. This way users can define distributions over
skewness
(or just simply usepyro.param
with correct constraint). Without that, it is a bit cumbersome for users to define correctskewness
over general d-torus.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With that, we can have correct
constraints
in the distribution definition. :DThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be neat; I'll add it to the backlog.