-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Ruilong Li
committed
Jan 3, 2024
1 parent
9f90842
commit 88a6aec
Showing
1 changed file
with
299 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,299 @@ | ||
import math | ||
from typing import Callable, List, Optional, Tuple, Union | ||
|
||
import torch | ||
from ..grid import _enlarge_aabb | ||
from ..volrend import ( | ||
render_visibility_from_alpha, | ||
render_visibility_from_density, | ||
) | ||
from .base import AbstractEstimator | ||
from torch import Tensor | ||
|
||
try: | ||
import svox | ||
except ImportError: | ||
raise ImportError( | ||
"Please install this forked version of svox: " | ||
"pip install git+https://github.com/liruilong940607/svox.git" | ||
) | ||
|
||
|
||
class N3TreeEstimator(AbstractEstimator): | ||
"""Use N3Tree to implement Occupancy Grid. | ||
This allows more flexible topologies than the cascaded grid. However, it is | ||
slower to create samples from the tree than the cascaded grid. By default, | ||
it has the same topology as the cascaded grid but `self.tree` can be | ||
modified to have different topologies. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
roi_aabb: Union[List[int], Tensor], | ||
resolution: Union[int, List[int], Tensor] = 128, | ||
levels: int = 1, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
|
||
if "contraction_type" in kwargs: | ||
raise ValueError( | ||
"`contraction_type` is not supported anymore for nerfacc >= 0.4.0." | ||
) | ||
|
||
# check the resolution is legal | ||
assert isinstance(resolution, int), "N3Tree only supports uniform resolution!" | ||
|
||
# check the roi_aabb is legal | ||
if isinstance(roi_aabb, (list, tuple)): | ||
roi_aabb = torch.tensor(roi_aabb, dtype=torch.float32) | ||
assert isinstance(roi_aabb, Tensor), f"Invalid type: {roi_aabb}!" | ||
assert roi_aabb.shape[0] == 6, f"Invalid shape: {roi_aabb}!" | ||
roi_aabb = roi_aabb.cpu() | ||
|
||
# to be compatible with the OccupancyGrid | ||
aabbs = torch.stack( | ||
[_enlarge_aabb(roi_aabb, 2**i) for i in range(levels)], dim=0 | ||
) | ||
self.register_buffer("aabbs", aabbs) # [n_aabbs, 6] | ||
|
||
center = (roi_aabb[:3] + roi_aabb[3:]) / 2.0 | ||
radius = (roi_aabb[3:] - roi_aabb[:3]) / 2.0 * 2 ** (levels - 1) | ||
|
||
tree_depth = int(math.log2(resolution)) - 1 | ||
self.tree = svox.N3Tree( | ||
N=2, | ||
data_dim=1, | ||
init_refine=tree_depth, | ||
depth_limit=20, | ||
radius=radius.tolist(), | ||
center=center.tolist(), | ||
) | ||
_aabbs = [_enlarge_aabb(roi_aabb, 2**i) for i in range(levels - 1)] | ||
for aabb in _aabbs[::-1]: | ||
leaf_c = self.tree.corners + self.tree.lengths * 0.5 | ||
sel = ((leaf_c > aabb[:3]) & (leaf_c < aabb[3:])).all(dim=-1) | ||
self.tree[sel].refine() | ||
# print("tree size", len(self.tree), "at resolution", resolution) | ||
|
||
self.thresh = 0.0 | ||
|
||
@torch.no_grad() | ||
def sampling( | ||
self, | ||
# rays | ||
rays_o: Tensor, # [n_rays, 3] | ||
rays_d: Tensor, # [n_rays, 3] | ||
# sigma/alpha function for skipping invisible space | ||
sigma_fn: Optional[Callable] = None, | ||
alpha_fn: Optional[Callable] = None, | ||
near_plane: float = 0.0, | ||
far_plane: float = 1e10, | ||
t_min: Optional[Tensor] = None, # [n_rays] | ||
t_max: Optional[Tensor] = None, # [n_rays] | ||
# rendering options | ||
render_step_size: float = 1e-3, | ||
early_stop_eps: float = 1e-4, | ||
alpha_thre: float = 0.0, | ||
stratified: bool = False, | ||
cone_angle: float = 0.0, | ||
) -> Tuple[Tensor, Tensor, Tensor]: | ||
"""Sampling with spatial skipping. | ||
Note: | ||
This function is not differentiable to any inputs. | ||
Args: | ||
rays_o: Ray origins of shape (n_rays, 3). | ||
rays_d: Normalized ray directions of shape (n_rays, 3). | ||
sigma_fn: Optional. If provided, the marching will skip the invisible space | ||
by evaluating the density along the ray with `sigma_fn`. It should be a | ||
function that takes in samples {t_starts (N,), t_ends (N,), | ||
ray indices (N,)} and returns the post-activation density values (N,). | ||
You should only provide either `sigma_fn` or `alpha_fn`. | ||
alpha_fn: Optional. If provided, the marching will skip the invisible space | ||
by evaluating the density along the ray with `alpha_fn`. It should be a | ||
function that takes in samples {t_starts (N,), t_ends (N,), | ||
ray indices (N,)} and returns the post-activation opacity values (N,). | ||
You should only provide either `sigma_fn` or `alpha_fn`. | ||
near_plane: Optional. Near plane distance. Default: 0.0. | ||
far_plane: Optional. Far plane distance. Default: 1e10. | ||
t_min: Optional. Per-ray minimum distance. Tensor with shape (n_rays). | ||
If profided, the marching will start from maximum of t_min and near_plane. | ||
t_max: Optional. Per-ray maximum distance. Tensor with shape (n_rays). | ||
If profided, the marching will stop by minimum of t_max and far_plane. | ||
render_step_size: Step size for marching. Default: 1e-3. | ||
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4. | ||
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0. | ||
stratified: Whether to use stratified sampling. Default: False. | ||
cone_angle: Cone angle for linearly-increased step size. 0. means | ||
constant step size. Default: 0.0. | ||
Returns: | ||
A tuple of {LongTensor, Tensor, Tensor}: | ||
- **ray_indices**: Ray index of each sample. IntTensor with shape (n_samples). | ||
- **t_starts**: Per-sample start distance. Tensor with shape (n_samples,). | ||
- **t_ends**: Per-sample end distance. Tensor with shape (n_samples,). | ||
Examples: | ||
.. code-block:: python | ||
>>> ray_indices, t_starts, t_ends = grid.sampling( | ||
>>> rays_o, rays_d, render_step_size=1e-3) | ||
>>> t_mid = (t_starts + t_ends) / 2.0 | ||
>>> sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices] | ||
""" | ||
|
||
assert t_min is None and t_max is None, ( | ||
"Do not supported per-ray min max. Please use near_plane and far_plane instead." | ||
) | ||
if stratified: | ||
near_plane += torch.rand(()).item() * render_step_size | ||
|
||
t_starts, t_ends, packed_info, ray_indices = svox.volume_sample( | ||
self.tree, | ||
thresh=self.thresh, | ||
rays=svox.Rays(rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()), | ||
step_size=render_step_size, | ||
cone_angle=cone_angle, | ||
near_plane=near_plane, | ||
far_plane=far_plane, | ||
) | ||
packed_info = packed_info.long() | ||
ray_indices = ray_indices.long() | ||
|
||
# skip invisible space | ||
if (alpha_thre > 0.0 or early_stop_eps > 0.0) and ( | ||
sigma_fn is not None or alpha_fn is not None | ||
): | ||
alpha_thre = min(alpha_thre, self.thresh) | ||
|
||
# Compute visibility of the samples, and filter out invisible samples | ||
if sigma_fn is not None: | ||
if t_starts.shape[0] != 0: | ||
sigmas = sigma_fn(t_starts, t_ends, ray_indices) | ||
else: | ||
sigmas = torch.empty((0,), device=t_starts.device) | ||
assert ( | ||
sigmas.shape == t_starts.shape | ||
), "sigmas must have shape of (N,)! Got {}".format(sigmas.shape) | ||
masks = render_visibility_from_density( | ||
t_starts=t_starts, | ||
t_ends=t_ends, | ||
sigmas=sigmas, | ||
ray_indices=ray_indices, | ||
n_rays=len(rays_o), | ||
early_stop_eps=early_stop_eps, | ||
alpha_thre=alpha_thre, | ||
) | ||
elif alpha_fn is not None: | ||
if t_starts.shape[0] != 0: | ||
alphas = alpha_fn(t_starts, t_ends, ray_indices) | ||
else: | ||
alphas = torch.empty((0,), device=t_starts.device) | ||
assert ( | ||
alphas.shape == t_starts.shape | ||
), "alphas must have shape of (N,)! Got {}".format(alphas.shape) | ||
masks = render_visibility_from_alpha( | ||
alphas=alphas, | ||
ray_indices=ray_indices, | ||
n_rays=len(rays_o), | ||
early_stop_eps=early_stop_eps, | ||
alpha_thre=alpha_thre, | ||
) | ||
ray_indices, t_starts, t_ends = ( | ||
ray_indices[masks], | ||
t_starts[masks], | ||
t_ends[masks], | ||
) | ||
return ray_indices, t_starts, t_ends | ||
|
||
@torch.no_grad() | ||
def update_every_n_steps( | ||
self, | ||
step: int, | ||
occ_eval_fn: Callable, | ||
occ_thre: float = 1e-2, | ||
ema_decay: float = 0.95, | ||
warmup_steps: int = 256, | ||
n: int = 16, | ||
) -> None: | ||
"""Update the estimator every n steps during training. | ||
Args: | ||
step: Current training step. | ||
occ_eval_fn: A function that takes in sample locations :math:`(N, 3)` and | ||
returns the occupancy values :math:`(N, 1)` at those locations. | ||
occ_thre: Threshold used to binarize the occupancy grid. Default: 1e-2. | ||
ema_decay: The decay rate for EMA updates. Default: 0.95. | ||
warmup_steps: Sample all cells during the warmup stage. After the warmup | ||
stage we change the sampling strategy to 1/4 uniformly sampled cells | ||
together with 1/4 occupied cells. Default: 256. | ||
n: Update the grid every n steps. Default: 16. | ||
""" | ||
if not self.training: | ||
raise RuntimeError( | ||
"You should only call this function only during training. " | ||
"Please call _update() directly if you want to update the " | ||
"field during inference." | ||
) | ||
if step % n == 0 and self.training: | ||
self._update( | ||
step=step, | ||
occ_eval_fn=occ_eval_fn, | ||
occ_thre=occ_thre, | ||
ema_decay=ema_decay, | ||
warmup_steps=warmup_steps, | ||
) | ||
|
||
@torch.no_grad() | ||
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]: | ||
"""Samples both n uniform and occupied cells.""" | ||
uniform_indices = torch.randint(len(self.tree), (n,), device=self.device) | ||
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[:, 0] | ||
if n < len(occupied_indices): | ||
selector = torch.randint(len(occupied_indices), (n,), device=self.device) | ||
occupied_indices = occupied_indices[selector] | ||
indices = torch.cat([uniform_indices, occupied_indices], dim=0) | ||
return indices | ||
|
||
@torch.no_grad() | ||
def _update( | ||
self, | ||
step: int, | ||
occ_eval_fn: Callable, | ||
occ_thre: float = 0.01, | ||
ema_decay: float = 0.95, | ||
warmup_steps: int = 256, | ||
) -> None: | ||
"""Update the occ field in the EMA way.""" | ||
if step < warmup_steps: | ||
x = self.tree.sample(1).squeeze(1) | ||
occ = occ_eval_fn(x).squeeze(-1) | ||
sel = (*self.tree._all_leaves().T,) | ||
self.tree.data.data[sel] = torch.maximum(self.tree.data.data[sel] * ema_decay, occ[:, None]) | ||
else: | ||
N = len(self.tree) // 4 | ||
indices = self._sample_uniform_and_occupied_cells(N) | ||
x = self.tree[indices].sample(1).squeeze(1) | ||
occ = occ_eval_fn(x).squeeze(-1) | ||
self.tree[indices] = torch.maximum( | ||
self.tree[indices].values * ema_decay, occ[:, None] | ||
) | ||
self.thresh = min(occ_thre, self.tree[:].values.mean().item()) | ||
|
||
|
||
if __name__ == "__main__": | ||
roi_aabb = [-1.0, -1.0, -1.0, 1.0, 1.0, 1.0] | ||
resolution = 128 | ||
levels = 4 | ||
estimator = N3TreeEstimator(roi_aabb, resolution, levels) | ||
|
||
def occ_eval_fn(x): | ||
return torch.rand(len(x), 1) | ||
|
||
estimator.update_every_n_steps(0, occ_eval_fn, occ_thre=0.5) |