Distortion loss is proposed by mip-nerf-360, which encourages volume rendering weights to be compact and sparse and can alleviate floater and background collapse artifact. In our DVGOv2 report, we show that the distortion loss is also helpful to point-based query, which speeds up our training and gives us better quantitative results.
A pytorch pseudo-code for the distortion loss:
def original_distloss(w, m, interval):
'''
Original O(N^2) realization of distortion loss.
There are B rays each with N sampled points.
w: Float tensor in shape [B,N]. Volume rendering weights of each point.
m: Float tensor in shape [B,N]. Midpoint distance to camera of each point.
interval: Scalar or float tensor in shape [B,N]. The query interval of each point.
'''
loss_uni = (1/3) * (interval * w.pow(2)).sum(-1).mean()
ww = w.unsqueeze(-1) * w.unsqueeze(-2) # [B,N,N]
mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [B,N,N]
loss_bi = (ww * mm).sum((-1,-2)).mean()
return loss_uni + loss_bi
Unfortunately, the straightforward implementation results in O(N^2)
space time complexity for N sampled points on a ray. In this package, we provide our O(N)
realization presnted in the DVGOv2 report.
Please cite mip-nerf-360 if you find this repo helpful. We will be happy if you also cite DVGOv2.
@inproceedings{BarronMVSH22,
author = {Jonathan T. Barron and
Ben Mildenhall and
Dor Verbin and
Pratul P. Srinivasan and
Peter Hedman},
title = {Mip-NeRF 360: Unbounded Anti-Aliased Neural Radiance Fields},
booktitle = {CVPR},
year = {2022},
}
@article{SunSC22_2,
author = {Cheng Sun and
Min Sun and
Hwann{-}Tzong Chen},
title = {Improved Direct Voxel Grid Optimization for Radiance Fields Reconstruction},
journal = {arxiv cs.GR 2206.05085},
year = {2022},
}
pip install torch_efficient_distloss
Assumed Pytorch
and numpy
are already installed.
All functions are runs in O(N)
and are numerical equivalent to the distortion loss.
import torch
from torch_efficient_distloss import eff_distloss, eff_distloss_native, flatten_eff_distloss
# A toy example
B = 8192 # number of rays
N = 128 # number of points on a ray
w = torch.rand(B, N).cuda()
w = w / w.sum(-1, keepdim=True)
w = w.clone().requires_grad_()
s = torch.linspace(0, 1, N+1).cuda()
m = (s[1:] + s[:-1]) * 0.5
m = m[None].repeat(B,1)
interval = 1/N
loss = 0.01 * eff_distloss(w, m, interval)
loss.backward()
print('Loss', loss)
print('Gradient', w.grad)
eff_distloss_native
:- Using built-in Pytorch operation to implement the
O(N)
distortion loss. - Input:
w
: Float tensor in shape [B,N]. Volume rendering weights of each point.m
: Float tensor in shape [B,N]. Midpoint distance to camera of each point.interval
: Scalar or float tensor in shape [B,N]. The query interval of each point.
- Using built-in Pytorch operation to implement the
eff_distloss
:- The same as
eff_distloss_native
. Slightly faster and consume slightly more GPU memory.
- The same as
flatten_eff_distloss
:- Support varied number of sampled points on each ray.
- All input tensor should be flatten.
- Should provide an additional flatten Long tensor
ray_id
to specify the ray index of each point.ray_id
should be increasing (i.e.,ray_id[i-1]<=ray_id[i]
) and ranging from0
toN-1
.
- Loss weight around
0.01
to0.001
is recommended.
Run python test.py
. All our implementation is numerical equivalent to the O(N^2)
original_distloss
.
Run python test_time_mem.py
. We use a batch of B=8192
rays. Below is the results on my RTX 2080Ti
GPU.
- Peak GPU memory (MB)
# of pts N
32 64 128 256 384 512 original_distloss
102 396 1560 6192 OOM OOM eff_distloss_native
12 24 48 96 144 192 eff_distloss
14 28 56 112 168 224 flatten_eff_distloss
13 26 52 104 156 208 - Run time accumulated over 100 runs (sec)
# of pts N
32 64 128 256 384 512 original_distloss
0.2 0.8 2.4 17.9 OOM OOM eff_distloss_native
0.1 0.1 0.1 0.2 0.3 0.3 eff_distloss
0.1 0.1 0.1 0.1 0.2 0.2 flatten_eff_distloss
0.1 0.1 0.1 0.2 0.2 0.3