Skip to content

Commit

Permalink
Support VDB based Estimator (#285)
Browse files Browse the repository at this point in the history
* fvdb integrate

* vdb in examples

* format
  • Loading branch information
liruilong940607 committed Feb 7, 2024
1 parent bfbf027 commit 32273f8
Show file tree
Hide file tree
Showing 5 changed files with 542 additions and 27 deletions.
28 changes: 25 additions & 3 deletions examples/train_ngp_nerf_occ.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,26 @@ def run(args):
**test_dataset_kwargs,
)

estimator = OccGridEstimator(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
if args.vdb:
from fvdb import sparse_grid_from_dense

from nerfacc.estimators.vdb import VDBEstimator

assert grid_nlvl == 1, "VDBEstimator only supports grid_nlvl=1"
voxel_sizes = (aabb[3:] - aabb[:3]) / grid_resolution
origins = aabb[:3] + voxel_sizes / 2
grid = sparse_grid_from_dense(
1,
(grid_resolution, grid_resolution, grid_resolution),
voxel_sizes=voxel_sizes,
origins=origins,
)
estimator = VDBEstimator(grid).to(device)
estimator.aabbs = [aabb]
else:
estimator = OccGridEstimator(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)

# setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10)
Expand Down Expand Up @@ -278,6 +295,11 @@ def occ_eval_fn(x):
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
"--vdb",
action="store_true",
help="use VDBEstimator instead of OccGridEstimator",
)
args = parser.parse_args()

run(args)
59 changes: 35 additions & 24 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,35 +92,46 @@ def render_image_with_occgrid(
rays_d = chunk_rays.viewdirs

def sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
sigmas = radiance_field.query_density(positions, t)
if t_starts.shape[0] == 0:
sigmas = torch.empty((0, 1), device=t_starts.device)
else:
sigmas = radiance_field.query_density(positions)
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = (
t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
)
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
sigmas = radiance_field.query_density(positions, t)
else:
sigmas = radiance_field.query_density(positions)
return sigmas.squeeze(-1)

def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
rgbs, sigmas = radiance_field(positions, t, t_dirs)
if t_starts.shape[0] == 0:
rgbs = torch.empty((0, 3), device=t_starts.device)
sigmas = torch.empty((0, 1), device=t_starts.device)
else:
rgbs, sigmas = radiance_field(positions, t_dirs)
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = (
t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
)
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
rgbs, sigmas = radiance_field(positions, t, t_dirs)
else:
rgbs, sigmas = radiance_field(positions, t_dirs)
return rgbs, sigmas.squeeze(-1)

ray_indices, t_starts, t_ends = estimator.sampling(
Expand Down
4 changes: 4 additions & 0 deletions nerfacc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

from .data_specs import RayIntervals, RaySamples
from .estimators.occ_grid import OccGridEstimator
from .estimators.prop_net import PropNetEstimator
from .estimators.vdb import VDBEstimator, traverse_vdbs
from .grid import ray_aabb_intersect, traverse_grids
from .losses import distortion
from .pack import pack_info
Expand Down Expand Up @@ -46,7 +48,9 @@
"RaySamples",
"ray_aabb_intersect",
"traverse_grids",
"traverse_vdbs",
"OccGridEstimator",
"PropNetEstimator",
"VDBEstimator",
"distortion",
]
Loading

0 comments on commit 32273f8

Please sign in to comment.