Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial support for second-order derivatives (for grid.h's input) #69

Merged
merged 9 commits into from
Mar 25, 2022
Merged

Add partial support for second-order derivatives (for grid.h's input) #69

merged 9 commits into from
Mar 25, 2022

Conversation

ventusff
Copy link
Contributor

@ventusff ventusff commented Mar 18, 2022

Hi,

As discussed in issue #58, a backward_backward functionality is needed.

For now, I managed to add a partial support for second-order derivatives, only for grid.h, and only for d(dL_dinput)_d(...) .
This satifies the need of backwarding nablas's gradients toward grid params and downstream modules' params.

edit: now support d(dL_dinput)_d(input)

# NOTE: currently support:
#  ✓ d(dL_dinput)_d(dL_doutput)  ->  kernel_grid_backward_input_backward_dLdoutput
#  ✓ d(dL_dinput)_d(params)      ->  kernel_grid_backward_input_backward_grid
#  ✓ d(dL_dinput)_d(input)       ->  kernel_grid_backward_input_backward_input
#  x d(dL_dparam)_d(...)

I tried my best to follow the code style of your original project, and have done various tests to ensure that the code is running correctly.
The changes are made as little as I can, and the compiling has passed with no errors.

Code tests

I provide three testing tools for the newly added backward_backward_input functionality:
https://gist.github.com/ventusff/57f47588eaff5f8b77a382260e7da8a3

  1. ✔️ test_train(): train a toy SDF model with eikonal term.
  2. ✔️ grad_check(): check backward_backward numerical correctness via torch.autograd.gradcheck.
  3. ✔️ vis_graph(): visualize torch compute graph

Toy model compute graph:

toy_model:

class SDF(nn.Module):
    def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16) -> None:
        super().__init__()
        self.encoder = tcnn.Encoding(3, {
            "otype": "HashGrid" if hash else "DenseGrid",
            "n_levels": n_levels,
            "n_features_per_level": 2,
            "log2_hashmap_size": log2_hashmap_size,
            "base_resolution": base_resolution,
            "per_level_scale": 1.5
        })
        self.decoder = nn.Sequential(
            nn.Linear(self.encoder.n_output_dims, 64),
            nn.ReLU(True),
            nn.Linear(64, 1)
        )
    
    def forward(self, x):
        encoded = self.encoder(x).to(dtype=torch.float)
        sdf = self.decoder(encoded)
        return sdf
    
    def forward_with_nablas(self, x):
        with torch.enable_grad():
            x = x.requires_grad_(True)
            sdf = self.forward(x)
            nablas = autograd.grad(
                sdf,
                x,
                torch.ones_like(sdf, device=x.device),
                create_graph=True,
                retain_graph=True,
                only_inputs=True)[0]
        return sdf, nablas

You can see that the gradients of nablas are backwared towards grid_params (needs d(dL_dx)_dgrid)) and decoder.xxx (needs d(dL_dx)_d(dL_doutput)) through backwardBackward.
attached

Theoretical derivation

From gradients of dL_dx, to gradients of dL_dgrid and dL_d(dL_doutput).
second-order-grad

edit: add theoretical derivations of d(dy_dx)_dx

image

1. Since I add a virtual function backward_backward_input in object.h:DifferentiableObject, I must put a dummy error function in network.h:Network and encoding.h:Encoding to pass compiling, which inherit from DifferentiableObject.
Currently this funciton just throw a NotImplementedError. I hope this won't hurt any current mechanisms, and you can finish the remain support for backward_backward later to cancel this dummy function.

Copy link
Collaborator

@Tom94 Tom94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Jianfei, thank you very much for this PR -- I think this is great functionality to have, even if implemented just for the encoding!

This is also a great implementation, I have very little to criticize. Here are just a few small requests before I feel comfortable merging.

bindings/torch/tinycudann/modules.py Outdated Show resolved Hide resolved
bindings/torch/tinycudann/modules.py Outdated Show resolved Hide resolved
bindings/torch/tinycudann/modules.py Show resolved Hide resolved
bindings/torch/tinycudann/modules.py Outdated Show resolved Hide resolved
include/tiny-cuda-nn/encoding.h Outdated Show resolved Hide resolved
include/tiny-cuda-nn/encodings/grid.h Show resolved Hide resolved
@Tom94
Copy link
Collaborator

Tom94 commented Mar 18, 2022

Also just wanted to say that I very much appreciate the detailed derivations and explanations in the PR. This is really something to point other people to as an example of how it should be done!

@ventusff
Copy link
Contributor Author

Thanks for the nice tips and kind words :) I resolved some of the suggestions now. There are a few remains, mostly about computing that I'm not sure.

@ventusff
Copy link
Contributor Author

ventusff commented Mar 25, 2022

Hi @Tom94,
I've finished the following implementation and tests, and I think it's probably ready to merge now 😃

  • update with your latest convention of code and data convention, including MatrixView and xxx_impl
  • as you suggested, I add an implementation of d(dL_dx)_dx, which requires:
    • an implementation of d(dy_dx)_dx
    • an overload of pos_fract for second-order derivatives, and second-order derivatives for Smoothstep
    • an additional kernel kernel_grid_backward_input_backward_input
    • bindings update and modules udpate
  • passed additional gradcheck and gradgradcheck for the newly added d(dL_dx)_dx's calculation; check numerical correctness under both Linear interpolation and Smoothstep interpolation (the previous gist script for testing is also updated)

I will clean and upload theoretical derivations for d(dy_dx)_dx soon. In the meantime you can review my changes :)
Edit: theoretical derivations are updated in the top PR comment.

@Tom94
Copy link
Collaborator

Tom94 commented Mar 25, 2022

Wow, this is incredible, thank you so much for adding d(dL_dx)_dx! I haven't had the time to get back to this PR during the week and did not expect the follow-up. Will go through it now. :)

By the way: do you want to add your test script to the scripts folder? It would surely come in handy in the future to ensure nothing regresses. (E.g. when support for non-contiguous inputs from torch is added.)

forward.dy_dx.data(),
dL_dy_rm,
// outputs
dL_ddLdoutput->pitched_ptr()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you replace this pitched pointer with ->view() in order to also support row-major dL_ddLdoutput?

@Tom94 Tom94 merged commit b35e196 into NVlabs:master Mar 25, 2022
@Tom94
Copy link
Collaborator

Tom94 commented Mar 25, 2022

Turns out that was the only nitpick I could find in the code -- so I quickly made the change myself and merged.

Thanks again for adding all this. I can't stress enough how much I appreciate this sort of high-quality code contribution!

@juuso-oskari
Copy link

@ventusff Thank you so much for your work on this. Are you still planning to do the double backwards for the fully_fused_mlp.cu? I would really like to test it on my thesis project.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants