Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add an einsum_distloss implementation #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,49 @@ def original_distloss(w, m, interval):
return loss_uni + loss_bi
```

**Updates**: Provide a python-interface implementation for distortion loss, no cuda kernels needed.

```python
def einsum_distloss(w, m, interval):
'''
Einsum realization of distortion loss.
There are B rays each with N sampled points.
w: Float tensor in shape [B,N]. Volume rendering weights of each point.
m: Float tensor in shape [N]. Midpoint distance to camera of each point.
interval: Scalar or float tensor in shape [B,N]. The query interval of each point.
Note:
The first term of distortion could be experssed as `(w @ mm @ w.T).diagonal()`, which
could be further accelerated by einsum function `torch.einsum('bq, qp, bp->b', w, mm, w)`
'''
mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [N,N]
loss = torch.einsum('bq, qp, bp->b', w, mm, w)
loss += (w*w*interval).sum(-1)/3.
return loss.mean()
```
## Testing
### Numerical equivalent
Run `python test.py`. All our implementation is numerical equivalent to the `O(N^2)` `original_distloss`.

### Speed and memeory benchmark
Run `python test_time_mem.py`. We use a batch of `B=8192` rays. Below is the results on my `Tesla V100` GPU. (We don't have `2080Ti` to test)
- Peak GPU memory (MB)
| \# of pts `N` | 32 | 64 | 128 | 256 | 384 | 512 | 1024|
|:------------:|:--:|:--:|:---:|:---:|:---:|:---:|:---:|
|`original_distloss` |102|396|1560|6192|OOM|OOM|OOM|
|`eff_distloss_native` |12|24|48|96|144|192|384|
|`eff_distloss` |14|28|56|112|168|224|448|
|`flatten_eff_distloss`|13|26|52|104|156|208|416|
|`einsum_distloss` |9|18|36|72|109|145|292|
- Run time accumulated over 100 runs (sec)
| \# of pts `N` | 32 | 64 | 128 | 256 | 384 | 512 |1024 |
|:------------:|:--:|:--:|:---:|:---:|:---:|:---:|:---:|
|`original_distloss` |0.4|0.6|3.3|14.9|OOM|OOM|OOM|
|`eff_distloss_native` |0.2|0.2|0.2|0.4|0.4|0.5|0.8|
|`eff_distloss` |0.2|0.2|0.2|0.3|0.5|0.6|0.9|
|`flatten_eff_distloss`|0.2|0.2|0.2|0.3|0.5|0.5|0.8|
|`einsum_distloss` |0.1|0.1|0.1|0.2|0.3|0.4|0.7|
-----

Unfortunately, the straightforward implementation results in `O(N^2)` space time complexity for N sampled points on a ray. In this package, we provide our `O(N)` realization presnted in the DVGOv2 report.

Please cite mip-nerf-360 if you find this repo helpful. We will be happy if you also cite DVGOv2.
Expand Down
16 changes: 13 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w):
ret = 'PASS' if torch.isclose(ans_backward, my_backward_grad).all() else 'FAIL'
print(f'Test {func_name} backward:', ret)

def einsum_distloss(w, m, interval):
mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs()
loss = torch.einsum('bq, qp, bp->b', w, mm, w)
loss += (w*w*interval).sum(-1)/3.
return loss.mean()

if __name__ == '__main__':
# B rays N points
Expand All @@ -35,8 +40,8 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w):
w = w / w.sum(-1, keepdim=True)
w = w.clone().requires_grad_()
s = torch.linspace(0, 1, N+1).cuda()
m = (s[1:] + s[:-1]) * 0.5
m = m[None].repeat(B,1)
m_ = (s[1:] + s[:-1]) * 0.5
m = m_[None].repeat(B,1)
interval = 1/N

# Compute forward & backward answer
Expand Down Expand Up @@ -72,7 +77,12 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w):
'eff_distloss array interval',
eff_distloss(w, m, interval),
ans_forward, ans_backward, w)


# check einsum implementation
interval = 1/N
check('einsum_distloss',
einsum_distloss(w, m_, interval),
ans_forward, ans_backward, w)

# irregular shape, scalar interval
interval = 1/N
Expand Down
38 changes: 30 additions & 8 deletions test_time_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,31 @@ def original_distloss(w, m, interval):
return loss_uni + loss_bi


def einsum_distloss(w, m, interval):
'''
Einsum realization of distortion loss.
There are B rays each with N sampled points.
w: Float tensor in shape [B,N]. Volume rendering weights of each point.
m: Float tensor in shape [N]. Midpoint distance to camera of each point.
interval: Scalar or float tensor in shape [B,N]. The query interval of each point.
Note:
The first term of distortion could be experssed as `(w @ mm @ w.T).diagonal()`, which
could be further accelerated by einsum function `torch.einsum('bq, qp, bp->b', w, mm, w)`
'''
mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [N,N]
loss = torch.einsum('bq, qp, bp->b', w, mm, w)
loss += (w*w*interval).sum(-1)/3.
return loss.mean()

def gen_example(B, N):
w = torch.rand(B, N).cuda()
w = w / w.sum(-1, keepdim=True)
w = w.clone().requires_grad_()
s = torch.linspace(0, 1, N+1).cuda()
m = (s[1:] + s[:-1]) * 0.5
m = m[None].repeat(B,1)
m_ = (s[1:] + s[:-1]) * 0.5
m = m_[None].repeat(B,1)
interval = 1/N
return w, m, interval
return w, m, interval, m_


def spec(f, NTIMES, *args):
Expand Down Expand Up @@ -85,16 +101,16 @@ def spec(f, NTIMES, *args):
B = 8192
NTIMES = 100

for N in [32, 64, 128, 256, 384, 512]:
for N in [32, 64, 128, 256, 384, 512, 1024]:
print(f' B={B}; N={N} '.center(50, '='))
w, m, interval = gen_example(B, N)
w, m, interval, m_ = gen_example(B, N)
ray_id = torch.arange(len(w))[:,None].repeat(1,N).cuda()

try:
print(' original_distloss '.center(50, '.'))
spec(original_distloss, NTIMES, w, m, interval)
print(' original_distloss '.center(50, '.'))
spec(original_distloss, NTIMES, w, m, interval)
except RuntimeError as e:
print(e)
print(e)

try:
print(' eff_distloss_native '.center(50, '.'))
Expand All @@ -113,3 +129,9 @@ def spec(f, NTIMES, *args):
spec(flatten_eff_distloss, NTIMES, w.flatten(), m.flatten(), interval, ray_id.flatten())
except RuntimeError as e:
print(e)

try:
print('einsum_distloss'.center(50, '.'))
spec(einsum_distloss, NTIMES, w, m_, interval)
except RuntimeError as e:
print(e)