-
Notifications
You must be signed in to change notification settings - Fork 457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add partial support for second-order derivatives (for grid.h's input) #69
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Jianfei, thank you very much for this PR -- I think this is great functionality to have, even if implemented just for the encoding!
This is also a great implementation, I have very little to criticize. Here are just a few small requests before I feel comfortable merging.
Also just wanted to say that I very much appreciate the detailed derivations and explanations in the PR. This is really something to point other people to as an example of how it should be done! |
Thanks for the nice tips and kind words :) I resolved some of the suggestions now. There are a few remains, mostly about computing that I'm not sure. |
Hi @Tom94,
|
Wow, this is incredible, thank you so much for adding By the way: do you want to add your test script to the |
forward.dy_dx.data(), | ||
dL_dy_rm, | ||
// outputs | ||
dL_ddLdoutput->pitched_ptr() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you replace this pitched pointer with ->view()
in order to also support row-major dL_ddLdoutput
?
Turns out that was the only nitpick I could find in the code -- so I quickly made the change myself and merged. Thanks again for adding all this. I can't stress enough how much I appreciate this sort of high-quality code contribution! |
@ventusff Thank you so much for your work on this. Are you still planning to do the double backwards for the fully_fused_mlp.cu? I would really like to test it on my thesis project. |
Hi,
As discussed in issue #58, a backward_backward functionality is needed.
For now, I managed to add a partial support for second-order derivatives, only for
grid.h
, and only ford(dL_dinput)_d(...)
.This satifies the need of backwarding nablas's gradients toward grid params and downstream modules' params.
edit: now support d(dL_dinput)_d(input)
I tried my best to follow the code style of your original project, and have done various tests to ensure that the code is running correctly.
The changes are made as little as I can, and the compiling has passed with no errors.
Code tests
I provide three testing tools for the newly added
backward_backward_input
functionality:https://gist.github.com/ventusff/57f47588eaff5f8b77a382260e7da8a3
torch.autograd.gradcheck
.Toy model compute graph:
toy_model:
You can see that the gradients of
nablas
are backwared towardsgrid_params
(needsd(dL_dx)_dgrid)
) anddecoder.xxx
(needsd(dL_dx)_d(dL_doutput)
) throughbackwardBackward
.Theoretical derivation
From gradients of
dL_dx
, to gradients ofdL_dgrid
anddL_d(dL_doutput)
.edit: add theoretical derivations of d(dy_dx)_dx
1. Since I add a virtual functionbackward_backward_input
inobject.h:DifferentiableObject
, I must put a dummy error function innetwork.h:Network
andencoding.h:Encoding
to pass compiling, which inherit fromDifferentiableObject
.Currently this funciton just throw a NotImplementedError. I hope this won't hurt any current mechanisms, and you can finish the remain support for backward_backward later to cancel this dummy function.