Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate Instant-NGP inference #197

Merged
merged 22 commits into from
May 3, 2023

Conversation

Linyou
Copy link
Contributor

@Linyou Linyou commented Apr 7, 2023

This PR enhances Nerfacc's Instant-NGP inference performance by implementing the following API changes:

  1. The traverse_grids function has been modified to support both train and test modes.
  2. A new function called mark_invisible_cells has been added to the occ_grid module in order to prevent rendering artifacts in unseen spaces.

Comment on lines 274 to 279
num_steps += tid;
continuous_resume += tid;
t_starts += tid * N_samples;
t_ends += tid * N_samples;
valid_mask += tid * N_samples;
ray_indices += tid * N_samples;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't use += anymore because of the forloop:

for (int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_rays; tid += blockDim.x * gridDim.x)

It would be wrong if += is executed twice within the forloop.

use something like valid_mask[tid * Nsamples] to read/write the data

@liruilong940607
Copy link
Collaborator

Thanks for implementing this! It is pretty nice to have that support

@liruilong940607
Copy link
Collaborator

On the high level, the two ways of ray marching are pretty similar to each other: The "train" way of marching is to take N rays and march all the steps for each ray. The "test" ways of marching is to take all rays and march N steps for each ray (and iterative).

I feel it should be not that hard to unify the API (as well as the implementation) for the two.

To be more concrete, implementation differences between the two are:

  • "test" way needs to take in an argument "max_per_ray_samples", for which the "train" way could simply set it to "inf".
  • "test" way would want to pre-compute the "{t_sorted, t_indices, hits}" so that they are not computed multiple times. As you have already done, we should do this in python instead of C. So that we can make these three as optional arguments for the API that allows for passing in precomputed values.
  • "test" way needs to pass in a mask with shape (n_rays,) to skip rays. (e.g., the alive_indices you were using). "train" way can have an all ones mask.

So maybe we can unify them into the same "traverse_grid" function, with extra arguments (max_per_ray_samples=inf, masks=None, t_sorted=None, t_indices=None, hits=None). And for "traverse_grid_test", your can just call that function with an updated "near_planes" at every iteration of marching.

In this case, I think it makes sense to let the CUDA kernel return an extra tensor (n_rays,) that indicates the termination distance during grid traversal, which is essentially the "near_planes" for the next iteration of "traverse_grid_test". (the near_planes you are returning has a confusing name, which I think it should be termination_planes or something like that.

@Linyou
Copy link
Contributor Author

Linyou commented Apr 11, 2023

Sound good! I think we could unify the API using extra arguments "(max_per_ray_samples=inf, masks=None, t_sorted=None, t_indices=None, hits=None)". Nice idea, BTW!

We also need to unify the return values. I suggest using the data structure defined in "data_spect.h" to store t_start and t_end, instead of creating torch::Tensor directly as I am currently doing. It may be helpful to add new methods for allocating memory in "data_spect.h" specifically for t_start and `t_end", since they are pre-allocated in the "test" way. What do you think?

As for near_planes, I already tweaked the code so we don't need to return it, we can just update it inside the "traverse_grid" kernel.

@liruilong940607
Copy link
Collaborator

We also need to unify the return values. I suggest using the data structure defined in "data_spect.h" to store t_start and t_end, instead of creating torch::Tensor directly as I am currently doing. It may be helpful to add new methods for allocating memory in "data_spect.h" specifically for t_start and `t_end", since they are pre-allocated in the "test" way. What do you think?

I think you can use the RaySegmentsSpec just like what is being used in the traverse_grid function. And you can get t_starts and t_ends by:

https://github.com/KAIR-BAIR/nerfacc/blob/8340e19daad4bafe24125150a8c56161838086fa/tests/test_grid.py#L60-L61

As for near_planes, I already tweaked the code so we don't need to return it, we can just update it inside the "traverse_grid" kernel.

Do you mean that you inplace change the value of it? I would suggest against doing inplace modification as it is not quite user-friendly.

@Linyou
Copy link
Contributor Author

Linyou commented Apr 13, 2023

I have unified the "traverse_grid" API, and now both "train" and "test" can use the same Python function. On the low level, we still need to call separate C functions to launch the CUDA kernel.

Note that the "traverse_grid" function now returns three objects (intervals, samples, termination_planes), and "termination_planes" will be just None when "ray_mask_id" is not provided.

@Linyou Linyou changed the title Adding Instant-NGP inference rendering Accelerate Instant-NGP Rendering & Add GUI in Nerfacc Apr 13, 2023
@Linyou Linyou changed the title Accelerate Instant-NGP Rendering & Add GUI in Nerfacc Accelerate Instant-NGP Rendering & Add GUI Apr 13, 2023
@Linyou Linyou changed the title Accelerate Instant-NGP Rendering & Add GUI Accelerate Instant-NGP inference Apr 19, 2023
add test mode for traverse_grids
examples/utils.py Outdated Show resolved Hide resolved
examples/gui.py Outdated Show resolved Hide resolved
examples/taichi_kernel.py Outdated Show resolved Hide resolved
examples/train_ngp_nerf_occ.py Outdated Show resolved Hide resolved
examples/train_ngp_nerf_occ.py Outdated Show resolved Hide resolved
nerfacc/cuda/csrc/grid.cu Outdated Show resolved Hide resolved
examples/utils.py Outdated Show resolved Hide resolved
nerfacc/estimators/occ_grid.py Outdated Show resolved Hide resolved
nerfacc/estimators/occ_grid.py Outdated Show resolved Hide resolved
nerfacc/cuda/csrc/nerfacc.cpp Outdated Show resolved Hide resolved
nerfacc/cuda/csrc/grid.cu Outdated Show resolved Hide resolved
nerfacc/cuda/csrc/grid.cu Outdated Show resolved Hide resolved
@liruilong940607
Copy link
Collaborator

@Linyou The latest commit should resolve the memory concerns we had before. The test is also updated to match with the actual use case. Lmk what do you think.

@Linyou
Copy link
Contributor Author

Linyou commented Apr 30, 2023

Thanks! I believe that the current API design is now highly usable for test mode rendering, thanks to the latest commit.

BTW, after this PR is merged, I will create a new one for ngp test mode rendering in the examples.

@liruilong940607
Copy link
Collaborator

@Linyou I also did some cleanups for mark_invisible_cells() and changed the API a tiny bit (the K). Now I'm happy to merge it if you think it's all good.

nerfacc/estimators/occ_grid.py Show resolved Hide resolved
nerfacc/estimators/occ_grid.py Show resolved Hide resolved
nerfacc/estimators/occ_grid.py Show resolved Hide resolved
@liruilong940607
Copy link
Collaborator

Comments addressed. Ready to Go? @Linyou

@Linyou
Copy link
Contributor Author

Linyou commented May 3, 2023

@liruilong940607 Yeah! All good!

@liruilong940607 liruilong940607 merged commit 1031504 into nerfstudio-project:master May 3, 2023
@liruilong940607
Copy link
Collaborator

Thanks for the patience!! Shipped!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants