Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance for Large Models #200

Open
5 tasks
cscherrer opened this issue Sep 4, 2020 · 2 comments
Open
5 tasks

Performance for Large Models #200

cscherrer opened this issue Sep 4, 2020 · 2 comments

Comments

@cscherrer
Copy link
Owner

@DilumAluthge suggested in #161 that we add a Bayesian neural net example. There's a lot we'll need to do to get good performance here. Let's have some discussion here, then created a meta-issue for these tasks. We can close the current issue once we feel discussion is done and we've added the meta-issue.

So far, we've mostly focused on DynamicHMC with the default ForwardDiff. This is fast for small models, but for higher dimensions we really need reverse-mode AD. The obvious choice here is Zygote, but I'm also really impressed with the performance benchamarks of [Yota.jl][(https://github.com/dfdx/Yota.jl). Currently, Zygote uses ChainRules.jl, while Yota doesn't (yet), but uses its own rule-writing system. Yota's system looks very nice, but would require us writing rules for all of the distributions, which is probably too much. Zygote's big win here is from crowdsourcing.

Still, we'll need to

  • Check that Zygote is set up to work with Soss models
  • Make sure it's easy to use
  • Ideally, have some sensible default that switches between ForwardDiff and Zygote based on parameter dimensionality

Soss models are typically small at the top-level, though a given node could be large (e.g., a neural net). In the long term, we should be able to leverage this, something like

  • Allow Zygote gradient information on a node to propagate to top-level model

Since Zygote uses ChainRules, most gradient work is done. But we've added lots of distribution combinators, which will need ChainRules rrules.

  • Add ChainRules.rrules for For, iid, Mix, MarkovChain

There's probably more, please add other concerns here to roll into the meta-issue

cc: @millerjoey

@cscherrer
Copy link
Owner Author

Oh, and we'll need

  • LoopVectorization.jl for logpdf and its gradient, wherever possible

@cscherrer
Copy link
Owner Author

This could be an easier way to get us there:
https://github.com/mcabbott/Tullio.jl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant