Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First Normalizing flows Tutorial #2542

Merged
merged 25 commits into from
Jul 19, 2020

Conversation

stefanwebb
Copy link
Contributor

@stefanwebb stefanwebb commented Jun 28, 2020

I've added a tutorial that explains the basics of doing normalizing flows. In the process I discovered there is some change I've made to Spline that stops autodiff happening across steps (see the error message in the notebook)... I wanted to post the tutorial here before I've fixed this bug in case I can get some early feedback on the tutorial itself :)

View notebook

@stefanwebb stefanwebb marked this pull request as ready for review July 7, 2020 20:59
@stefanwebb stefanwebb added Examples and removed WIP labels Jul 7, 2020
@stefanwebb
Copy link
Contributor Author

There is one last bug I have to fix that is stopping sampling from the ConditionalSpline transform. The rest of the tutorial is ready for review!

@fritzo
Copy link
Member

fritzo commented Jul 8, 2020

First review pass:

  • Generally looks like a great intro tutorial!
  • I generally try to avoid referencing future work (e.g. "(Part 1)", "this series of tutorials", "two forthcoming tutorials"). My reasoning is that projects are often interrupted (due to emergencies, changed priorities, life events, etc.). When a project referencing future work is interrupted, it appears incomplete, ideally temporarily but often forever. To avoid this unfinished appearance, you can write the tutorial as if it were stand-alone, then in subsequent PRs add text connecting multiple tutorials. This is especially important if we release when only, say, 2 of your 4 tutorials are currently completed. The one place I have broken this rule is in file naming (e.g. forecasting_i.ipynb) because post-publication changes to filenames would break users' links. I also find it helpful to organize sequences of tutorials by creating master issues with checklists, but your preferences may differ 🙂
  • Could you add a sentence at the top recommending what other tutorials are recommended as prerequisites, as I have attempted to do in other tutorials? While your tutorial is independent of much of Pyro, I think at least the tensor shapes tutorial would help users understand event_dim. You might even mention that "This tutorial is independent of much of Pyro, but users may want to read about distribution shapes in the tensor shapes tutorial" or similar.
  • Could you link to docs each time you reference a documented class or function? This includes both Pyro class and PyTorch classes. Sometimes I get lazy and link only the first occurrence of a documented noun in a given section or paragraph. Note sphinx does not allow backticks inside of link text, so you'll need to convert text without backticks, e.g.
    - live in the `pyro.distributions.transforms` module
    + live in the [pyro.distributions.transforms](http://docs.pyro.ai/en/stable/distributions.html#transforms) module
    and
    - The class `ExpTransform`
    + The class [ExpTransform](https://pytorch.org/docs/stable/distributions.html#torch.distributions.transforms.ExpTransform)
  • I'd also recommend omitting full qualifiers like pyro.distributions. for names like TransformedDistribution whose hyperlink clarifies its location. I think this will especially improve readability later in the tutorial e.g. pyro.distributions.transforms.ConditionalSpline (users can always hover or click to disambiguate!).
  • Cell 8: Do you really want to make dataset requires_grad? This seems pretty weird, and you end up needing to detach it below.
  • Cell 8: To make it cheap to run on CI, could you add a line smoke_test = ('CI' in os.environ) in the imports section (cell 1) then change steps to
    steps = 1 if smoke_test else 1001
    Our tutorial runner then smoke tests the tutorials with only a single iteration.
    (Note I usually use round_number+1 so that the final loss is printed by the step % 100 conditional, but feel free to ignore)
  • Cell 9: nice example and plots!
  • Cell 9: nit: In the scatterplots, could you set alpha=0.8 or 0.5 to improve readability? (here and in subsequent cells)
  • Multivariate transforms: consider linking to the tensor shapes tutorial.
  • Multivariate transforms: again I'd recommend removing reference to future work "in the next tutorial", and instead adding that text in a subsequent PR (with link).
  • Cells 19, 23: maybe print only every 200 or 500 iterations?
  • Cell 23: Again, it looks like you have unnecessary uses of .detach().

@fritzo fritzo added this to the 1.4 release milestone Jul 10, 2020
@martinjankowiak
Copy link
Collaborator

martinjankowiak commented Jul 10, 2020

looks great! some comments:

  • nit: can we replace ln -> log?
  • did you mean for "Calculus reference" to point to a url? i guess there are other missing refs too
  • first mention of "hypernetwork": define/elaborate?

@stefanwebb
Copy link
Contributor Author

stefanwebb commented Jul 18, 2020

I think the <em> tags are due to a bug or artifact in how Jupyter notebooks are rendered in the Github viewer.

Those formulae seem to be fine on the nbviewer website, although I can't get it to update to my latest version... I think it will look fine when it's been converted to HTML by the doc system

@stefanwebb
Copy link
Contributor Author

@fritzo I couldn't get the learning loop to work without using .detach() on the input to .log_prob(). Any ideas on this?

@stefanwebb
Copy link
Contributor Author

stefanwebb commented Jul 18, 2020

Okay, I've solved the bug! It was caused by conditioning on a minibatch of size 1000 and then not drawing a sample with a minibatch size that broadcasts over that... (This may be indicative of an API flaw)

@stefanwebb
Copy link
Contributor Author

Should be good to go! 🥳

@martinjankowiak
Copy link
Collaborator

lgtm!

  • typo "bivariation"
  • typo " as the decomposition as the product"
  • typo "can not"
    what errors do you get when you remove the detach statements?

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs fixes to .detach() logic, I'll try to push to this branch...

@fritzo
Copy link
Member

fritzo commented Jul 18, 2020

OK, I've pushed three changes:

  • fixed f-strings

  • replaced %%timeit -> %%time

  • replaced .detach() with a new set of methods .clear_cache().

    What was happening was this: In contrast to most Pyro models, this notebook creates TransformedDistributions directly and uses PyTorch optimizers on those. In usual Pyro models, the transform caches live only for the duration of a single model invocation; indeed Pyro tries to be purely functional and avoid mutation wherever possible and this makes it easier to reason about caching. But in this tutorial the TransformModules are dangerously reused even after being updated via optimizer.step(). After being updated, their caches are invalid, and this was causing loud errors "cannot backward through graph a second time" (which is lucky, you could have been stuck with silent errors if you had specified create_graph=True). The hacky workaround of .detach()ing arguments to TransformedDistribution.log_prob() indeed bypassed the cache, but those invalid old cached values were still dangerously lying around. So instead I've created some new .clear_cache() methods and explicitly called them after optimizer.step(). Note these shouldn't be needed in Pyro programs that create flyweight Transforms. I'm not sure what to do with proper TransformModules; I couldn't seem to get it working to clear the cache in a backward hook. Filed as bug TransformModule cache is invalid after optimizer step #2564

fritzo
fritzo previously approved these changes Jul 18, 2020
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martinjankowiak does the .clear_cache() logic look ok to you?

fritzo
fritzo previously approved these changes Jul 19, 2020
fritzo
fritzo previously approved these changes Jul 19, 2020
@martinjankowiak martinjankowiak merged commit 741813d into pyro-ppl:dev Jul 19, 2020
@stefanwebb stefanwebb deleted the normalizing-flows-tutorial1 branch July 23, 2020 20:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants