Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First Normalizing flows Tutorial #2542

Merged
merged 25 commits into from
Jul 19, 2020
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d37c6f7
Up to first example
stefanwebb May 23, 2020
c2161e5
Finished first code section
stefanwebb May 23, 2020
5421112
Section on learning univariate distributions
stefanwebb May 25, 2020
9bfe831
Some progress on first tutorial
stefanwebb May 28, 2020
9b62754
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into normalizi…
stefanwebb Jun 27, 2020
e11e052
Section on multivariate distributions
stefanwebb Jun 27, 2020
a84526a
Additional sections
stefanwebb Jun 27, 2020
c2fccdf
Conditional transforms
stefanwebb Jun 27, 2020
a626e56
Removed debug changes to Spline
stefanwebb Jun 28, 2020
ed38e86
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into normalizi…
stefanwebb Jun 28, 2020
01dfe36
Ran code in tutorial
stefanwebb Jun 28, 2020
7c0e925
PEP8
stefanwebb Jun 28, 2020
bbbf69e
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into normalizi…
stefanwebb Jul 7, 2020
aee8759
Tested learning, added references to tutorial
stefanwebb Jul 7, 2020
2731140
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into normalizi…
stefanwebb Jul 16, 2020
13ee988
All of the editing changes, minus fixing the bug in the last Pyro exa…
stefanwebb Jul 17, 2020
f78124b
More edits
stefanwebb Jul 17, 2020
8da36dd
Fixed bug in example and conclusion section
stefanwebb Jul 18, 2020
a9806fd
Martin's changes
stefanwebb Jul 18, 2020
25b580a
Fix .detach() via .clear_cache()
fritzo Jul 18, 2020
0eaa816
Merge branch 'normalizing-flows-tutorial1' of github.com:stefanwebb/p…
fritzo Jul 18, 2020
ecd267e
Fix bugs in notebook
fritzo Jul 18, 2020
16c6320
Remove commented code
fritzo Jul 19, 2020
a4db4bb
Add comment about .clear_cache() usage
fritzo Jul 19, 2020
d975e10
Fix grammar
fritzo Jul 19, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyro/distributions/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __init__(self, transform):
def condition(self, context):
return self.transform

def clear_cache(self):
self.transform.clear_cache()


class ConditionalTransformedDistribution(ConditionalDistribution):
def __init__(self, base_dist, transforms):
Expand All @@ -66,3 +69,6 @@ def condition(self, context):
base_dist = self.base_dist.condition(context)
transforms = [t.condition(context) for t in self.transforms]
return TransformedDistribution(base_dist, transforms)

def clear_cache(self):
pass
14 changes: 14 additions & 0 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ def _Transform__getstate__(self):
return attrs


# TODO move upstream
@patch_dependency('torch.distributions.transforms.Transform.clear_cache')
def _Transform_clear_cache(self):
if self._cache_size == 1:
self._cached_x_y = None, None


# TODO move upstream
@patch_dependency('torch.distributions.TransformedDistribution.clear_cache')
def _TransformedDistribution_clear_cache(self):
for t in self.transforms:
t.clear_cache()


# Fixes a shape error in Multinomial.support with inhomogeneous .total_count
@patch_dependency('torch.distributions.Multinomial.support')
@torch.distributions.constraints.dependent_property
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/transforms/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _monotonic_rational_spline(inputs,

# Get the index of the bin that each input is in
# bin_idx ~ (batch_dim, input_dim, 1)
bin_idx = _searchsorted(cumheights + eps if inverse else cumwidths + eps, inputs)[..., None]
bin_idx = _searchsorted(cumheights + eps if inverse else cumwidths + eps, inputs).unsqueeze(-1)

# Select the value for the relevant bin for the variables used in the main calculation
input_widths = _select_bins(widths, bin_idx)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
'visdom>=0.1.4',
# 'biopython>=1.54', # requires Python 3.6
'pandas',
'scikit-learn',
'seaborn',
'wget',
]
Expand Down
1 change: 1 addition & 0 deletions tutorial/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Welcome to Pyro Examples and Tutorials!
minipyro
effect_handlers
modules
normalizing_flows_i

.. toctree::
:maxdepth: 2
Expand Down
868 changes: 868 additions & 0 deletions tutorial/source/normalizing_flows_i.ipynb

Large diffs are not rendered by default.