Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about triplane prediction & optimization #29

Open
rfeinman opened this issue Nov 7, 2023 · 3 comments
Open

Question about triplane prediction & optimization #29

rfeinman opened this issue Nov 7, 2023 · 3 comments

Comments

@rfeinman
Copy link

rfeinman commented Nov 7, 2023

Thank you so much for an awesome code library!

I am trying to train a neural network to predict triplane codes from a reference image view of an object. I am using your triplane-nerf library for the rendering and it works pretty well but I am seeing some odd pixelation & artifacts even after training to convergence. Below is a very brief code description of the optimization procedure that I follow during training. The parameters of decoder and predictor_net are optimized. Am I doing anything wrong here? I've included a visualization of the predicted (rendered) image vs. target image at the bottom of this message.

I noticed that the output density_bitfield from nerf.get_density does not have grad. Don't we need gradients to flow through the density MLP in order to facilitate proper training? Is there a way to do this with grad?

from lib.models.autodecoders.base_nerf import BaseNeRF
from lib.models.decoders.triplane_decoder import TriPlaneDecoder
from lib.core.utils.nerf_utils import get_cam_rays


decoder = TriPlaneDecoder(
    base_layers=[3 * 6, 64], 
    density_layers=[64, 1],
    color_layers=[64, 3],
    dir_layers=[16, 64],
)

nerf = BaseNeRF(code_size=(3, 6, 64, 64), grid_size=64)


def render(code, density_bitfield, h, w, intrinsics, poses):
    rays_o, rays_d = get_cam_rays(poses, intrinsics, h=h, w=w)

    batch_size, height, width, channels = rays_o.shape
    rays_o = rays_o.view(batch_size, height * width, channels)
    rays_d = rays_d.view(batch_size, height * width, channels)        

    outputs = decoder(rays_o, rays_d, code, density_bitfield, nerf.grid_size)

    image = outputs['image'] + nerf.bg_color * (1 - outputs['weights_sum'].unsqueeze(-1))

    return image.reshape(batch_size, h, w, 3)


for _ in range(iterations):
    # reference_img is size 128 x 128
    triplane_code = predictor_net(reference_img, reference_intrinsics, reference_poses)

    _, density_bitfield = nerf.get_density(
        decoder, triplane_code, cfg=dict(density_thresh=0.1, density_step=16))

    pred_img = render(
        triplane_code, density_bitfield, h=128, w=128, 
        intrinsics=target_intrinsics, poses=target_poses)
    
    loss = (pred_img - target_img).pow(2).mean()
    loss.backward()
    #optimizer.step() ... etc

prediction vs. target:

prediction
target

@Lakonik
Copy link
Owner

Lakonik commented Nov 7, 2023

Hi! The density grid is for occupancy-based pruning, which is not part of the gradient graph by design. Since you are using triplanes with a resolution of 64, it could be possible that there's not enough capacity to capture the full details of the target.
However, this problem could be mitigated by using LPIPS loss, which will be a new feature of this codebase in an uncoming updated release.

@rfeinman
Copy link
Author

rfeinman commented Dec 8, 2023

Hi @Lakonik - thanks for the helpful reply to my question! It seems like occupancy-based pruning is designed primarily for the discrete NeRF problem where you have some finite set of scenes and you can maintain a grid state for each (the density_bitfield) that is updated at some interval.

I'm wondering: what is your suggested approach for raymarching in a setting where the number of NeRFs is infinite? For example, a setting where we predict a triplane nerf from an image like LRM. Should the density_bitfield be computed from scratch for each prediction? Or maybe just use some dummy value of the bitfield?

@Lakonik
Copy link
Owner

Lakonik commented Dec 9, 2023

density_bitfield can be computed from scratch (by updating for multiple steps at once, which costs some time) or simply be filled with 255 (no pruning). If you render multiple views at once then compute density_bitfield could be faster, otherwise just use the filled dummy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants