Use zygote2differential to wrap chainrules inputs #1057
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
@mzgubic implemented zygote2differential as a better version of wrap_chainrules_inputs and added it to use in the code for
rrule_via_ad
.But it was not added to the normal path for when Zygote uses ChainRules.
I guess because it requires keeping the primal values in memory.
Which is probably a lot?
Anyway this would give us more consistent chainrules types.
No more
Tangent{Any}
ornothings
that are hidden with-in arrays.We probably do not want to merge this as is because of the extra memory use.
or maybe it is not too bad. Do we have a benchmark for it?
But hopefully this will fix the problems in TuringLang/DistributionsAD.jl#197
cc @devmotion .
If it does we can look at reworking
zygote2differential
to not have to store so much.We learnt a lot about doing that for
ProjectTo
same techniques can be applied here.
NB: I am putting this PR up at 9:30 at night, and I have not even run it locally.
Might have typos etc and just not work.
It also has no tests, yet.