From a04d5d174e565cc8f90815c1a8e55cb9ec38923f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=89=91=E5=8C=A3?= Date: Tue, 12 Jul 2022 20:50:33 +0800 Subject: [PATCH 1/2] add an einsum_distloss implementation --- README.md | 43 +++++++++++++++++++++++++++++++++++++++++++ test.py | 16 +++++++++++++--- test_time_mem.py | 38 ++++++++++++++++++++++++++++++-------- 3 files changed, 86 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index cc5519a..0cbf2d3 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,49 @@ def original_distloss(w, m, interval): return loss_uni + loss_bi ``` +**Updates**: Provide a python-interface implementation for distortion loss, no cuda kernels needed. + +```python +def einsum_distloss(w, m, interval): + ''' + Einsum realization of distortion loss. + There are B rays each with N sampled points. + w: Float tensor in shape [B,N]. Volume rendering weights of each point. + m: Float tensor in shape [N]. Midpoint distance to camera of each point. + interval: Scalar or float tensor in shape [B,N]. The query interval of each point. + Note: + The first term of distortion could be experssed as `(w @ mm @ w.T).diagonal()`, which + could be further accelerated by einsum function `torch.einsum('bq, qp, bp->b', w, mm, w)` + ''' + mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [N,N] + loss = torch.einsum('bq, qp, bp->b', w, mm, w) + loss += (w*w*interval).sum(-1)/3. + return loss.mean() +``` +## Testing +### Numerical equivalent +Run `python test.py`. All our implementation is numerical equivalent to the `O(N^2)` `original_distloss`. + +### Speed and memeory benchmark +Run `python test_time_mem.py`. We use a batch of `B=8192` rays. Below is the results on my `Tesla V100` GPU. (We don't have `2080Ti` to test) +- Peak GPU memory (MB) + | \# of pts `N` | 32 | 64 | 128 | 256 | 384 | 512 | 1024| + |:------------:|:--:|:--:|:---:|:---:|:---:|:---:|:---:| + |`original_distloss` |102|396|1560|6192|OOM|OOM|OOM| + |`eff_distloss_native` |12|24|48|96|144|192|384| + |`eff_distloss` |14|28|56|112|168|224|448| + |`flatten_eff_distloss`|13|26|52|104|156|208|416| + |`einsum_distloss` |9|18|36|72|109|145|292| +- Run time accumulated over 100 runs (sec) + | \# of pts `N` | 32 | 64 | 128 | 256 | 384 | 512 |1024 | + |:------------:|:--:|:--:|:---:|:---:|:---:|:---:|:---:| + |`original_distloss` |0.4|0.6|3.3|14.9|OOM|OOM|OOM| + |`eff_distloss_native` |0.2|0.2|0.2|0.4|0.4|0.5|0.8| + |`eff_distloss` |0.2|0.2|0.2|0.3|0.5|0.6|0.9| + |`flatten_eff_distloss`|0.2|0.2|0.2|0.3|0.5|0.5|0.8| + |`einsum_distloss` |0.1|0.1|0.1|0.2|0.3|0.4|0.7| +----- + Unfortunately, the straightforward implementation results in `O(N^2)` space time complexity for N sampled points on a ray. In this package, we provide our `O(N)` realization presnted in the DVGOv2 report. Please cite mip-nerf-360 if you find this repo helpful. We will be happy if you also cite DVGOv2. diff --git a/test.py b/test.py index 0a25b9f..9e85fe4 100644 --- a/test.py +++ b/test.py @@ -26,6 +26,11 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w): ret = 'PASS' if torch.isclose(ans_backward, my_backward_grad).all() else 'FAIL' print(f'Test {func_name} backward:', ret) +def einsum_distloss(w, m, interval): + mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() + loss = torch.einsum('bq, qp, bp->b', w, mm, w) + loss += (w*w*interval).sum(-1)/3. + return loss.mean() if __name__ == '__main__': # B rays N points @@ -35,8 +40,8 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w): w = w / w.sum(-1, keepdim=True) w = w.clone().requires_grad_() s = torch.linspace(0, 1, N+1).cuda() - m = (s[1:] + s[:-1]) * 0.5 - m = m[None].repeat(B,1) + m_ = (s[1:] + s[:-1]) * 0.5 + m = m_[None].repeat(B,1) interval = 1/N # Compute forward & backward answer @@ -72,7 +77,12 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w): 'eff_distloss array interval', eff_distloss(w, m, interval), ans_forward, ans_backward, w) - + + # check einsum implementation + interval = 1/N + check('einsum_distloss', + einsum_distloss(w, m_, interval), + ans_forward, ans_backward, w) # irregular shape, scalar interval interval = 1/N diff --git a/test_time_mem.py b/test_time_mem.py index eaeb6ef..83f3dd6 100644 --- a/test_time_mem.py +++ b/test_time_mem.py @@ -20,15 +20,31 @@ def original_distloss(w, m, interval): return loss_uni + loss_bi +def einsum_distloss(w, m, interval): + ''' + Einsum realization of distortion loss. + There are B rays each with N sampled points. + w: Float tensor in shape [B,N]. Volume rendering weights of each point. + m: Float tensor in shape [N]. Midpoint distance to camera of each point. + interval: Scalar or float tensor in shape [B,N]. The query interval of each point. + Note: + The first term of distortion could be experssed as `(w @ mm @ w.T).diagonal()`, which + could be further accelerated by einsum function `torch.einsum('bq, qp, bp->b', w, mm, w)` + ''' + mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [N,N] + loss = torch.einsum('bq, qp, bp->b', w, mm, w) + loss += (w*w*interval).sum(-1)/3. + return loss.mean() + def gen_example(B, N): w = torch.rand(B, N).cuda() w = w / w.sum(-1, keepdim=True) w = w.clone().requires_grad_() s = torch.linspace(0, 1, N+1).cuda() - m = (s[1:] + s[:-1]) * 0.5 - m = m[None].repeat(B,1) + m_ = (s[1:] + s[:-1]) * 0.5 + m = m_[None].repeat(B,1) interval = 1/N - return w, m, interval + return w, m, interval, m_ def spec(f, NTIMES, *args): @@ -85,16 +101,16 @@ def spec(f, NTIMES, *args): B = 8192 NTIMES = 100 - for N in [32, 64, 128, 256, 384, 512]: + for N in [32, 64, 128, 256, 384, 512, 1024]: print(f' B={B}; N={N} '.center(50, '=')) - w, m, interval = gen_example(B, N) + w, m, interval, m_ = gen_example(B, N) ray_id = torch.arange(len(w))[:,None].repeat(1,N).cuda() try: - print(' original_distloss '.center(50, '.')) - spec(original_distloss, NTIMES, w, m, interval) + print(' original_distloss '.center(50, '.')) + spec(original_distloss, NTIMES, w, m, interval) except RuntimeError as e: - print(e) + print(e) try: print(' eff_distloss_native '.center(50, '.')) @@ -113,3 +129,9 @@ def spec(f, NTIMES, *args): spec(flatten_eff_distloss, NTIMES, w.flatten(), m.flatten(), interval, ray_id.flatten()) except RuntimeError as e: print(e) + + try: + print('einsum_distloss'.center(50, '.')) + spec(einsum_distloss, NTIMES, w, m_, interval) + except RuntimeError as e: + print(e) From a9f451ad1b5d200b4b1b4674bc5814d3dc878ea6 Mon Sep 17 00:00:00 2001 From: Spark001 Date: Tue, 12 Jul 2022 21:02:40 +0800 Subject: [PATCH 2/2] add an einsum_distloss implementation --- README.md | 43 +++++++++++++++++++++++++++++++++++++++++++ test.py | 16 +++++++++++++--- test_time_mem.py | 38 ++++++++++++++++++++++++++++++-------- 3 files changed, 86 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index cc5519a..0cbf2d3 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,49 @@ def original_distloss(w, m, interval): return loss_uni + loss_bi ``` +**Updates**: Provide a python-interface implementation for distortion loss, no cuda kernels needed. + +```python +def einsum_distloss(w, m, interval): + ''' + Einsum realization of distortion loss. + There are B rays each with N sampled points. + w: Float tensor in shape [B,N]. Volume rendering weights of each point. + m: Float tensor in shape [N]. Midpoint distance to camera of each point. + interval: Scalar or float tensor in shape [B,N]. The query interval of each point. + Note: + The first term of distortion could be experssed as `(w @ mm @ w.T).diagonal()`, which + could be further accelerated by einsum function `torch.einsum('bq, qp, bp->b', w, mm, w)` + ''' + mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [N,N] + loss = torch.einsum('bq, qp, bp->b', w, mm, w) + loss += (w*w*interval).sum(-1)/3. + return loss.mean() +``` +## Testing +### Numerical equivalent +Run `python test.py`. All our implementation is numerical equivalent to the `O(N^2)` `original_distloss`. + +### Speed and memeory benchmark +Run `python test_time_mem.py`. We use a batch of `B=8192` rays. Below is the results on my `Tesla V100` GPU. (We don't have `2080Ti` to test) +- Peak GPU memory (MB) + | \# of pts `N` | 32 | 64 | 128 | 256 | 384 | 512 | 1024| + |:------------:|:--:|:--:|:---:|:---:|:---:|:---:|:---:| + |`original_distloss` |102|396|1560|6192|OOM|OOM|OOM| + |`eff_distloss_native` |12|24|48|96|144|192|384| + |`eff_distloss` |14|28|56|112|168|224|448| + |`flatten_eff_distloss`|13|26|52|104|156|208|416| + |`einsum_distloss` |9|18|36|72|109|145|292| +- Run time accumulated over 100 runs (sec) + | \# of pts `N` | 32 | 64 | 128 | 256 | 384 | 512 |1024 | + |:------------:|:--:|:--:|:---:|:---:|:---:|:---:|:---:| + |`original_distloss` |0.4|0.6|3.3|14.9|OOM|OOM|OOM| + |`eff_distloss_native` |0.2|0.2|0.2|0.4|0.4|0.5|0.8| + |`eff_distloss` |0.2|0.2|0.2|0.3|0.5|0.6|0.9| + |`flatten_eff_distloss`|0.2|0.2|0.2|0.3|0.5|0.5|0.8| + |`einsum_distloss` |0.1|0.1|0.1|0.2|0.3|0.4|0.7| +----- + Unfortunately, the straightforward implementation results in `O(N^2)` space time complexity for N sampled points on a ray. In this package, we provide our `O(N)` realization presnted in the DVGOv2 report. Please cite mip-nerf-360 if you find this repo helpful. We will be happy if you also cite DVGOv2. diff --git a/test.py b/test.py index 0a25b9f..9e85fe4 100644 --- a/test.py +++ b/test.py @@ -26,6 +26,11 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w): ret = 'PASS' if torch.isclose(ans_backward, my_backward_grad).all() else 'FAIL' print(f'Test {func_name} backward:', ret) +def einsum_distloss(w, m, interval): + mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() + loss = torch.einsum('bq, qp, bp->b', w, mm, w) + loss += (w*w*interval).sum(-1)/3. + return loss.mean() if __name__ == '__main__': # B rays N points @@ -35,8 +40,8 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w): w = w / w.sum(-1, keepdim=True) w = w.clone().requires_grad_() s = torch.linspace(0, 1, N+1).cuda() - m = (s[1:] + s[:-1]) * 0.5 - m = m[None].repeat(B,1) + m_ = (s[1:] + s[:-1]) * 0.5 + m = m_[None].repeat(B,1) interval = 1/N # Compute forward & backward answer @@ -72,7 +77,12 @@ def check(func_name, my_forward_val, ans_forward, ans_backward, w): 'eff_distloss array interval', eff_distloss(w, m, interval), ans_forward, ans_backward, w) - + + # check einsum implementation + interval = 1/N + check('einsum_distloss', + einsum_distloss(w, m_, interval), + ans_forward, ans_backward, w) # irregular shape, scalar interval interval = 1/N diff --git a/test_time_mem.py b/test_time_mem.py index eaeb6ef..83f3dd6 100644 --- a/test_time_mem.py +++ b/test_time_mem.py @@ -20,15 +20,31 @@ def original_distloss(w, m, interval): return loss_uni + loss_bi +def einsum_distloss(w, m, interval): + ''' + Einsum realization of distortion loss. + There are B rays each with N sampled points. + w: Float tensor in shape [B,N]. Volume rendering weights of each point. + m: Float tensor in shape [N]. Midpoint distance to camera of each point. + interval: Scalar or float tensor in shape [B,N]. The query interval of each point. + Note: + The first term of distortion could be experssed as `(w @ mm @ w.T).diagonal()`, which + could be further accelerated by einsum function `torch.einsum('bq, qp, bp->b', w, mm, w)` + ''' + mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [N,N] + loss = torch.einsum('bq, qp, bp->b', w, mm, w) + loss += (w*w*interval).sum(-1)/3. + return loss.mean() + def gen_example(B, N): w = torch.rand(B, N).cuda() w = w / w.sum(-1, keepdim=True) w = w.clone().requires_grad_() s = torch.linspace(0, 1, N+1).cuda() - m = (s[1:] + s[:-1]) * 0.5 - m = m[None].repeat(B,1) + m_ = (s[1:] + s[:-1]) * 0.5 + m = m_[None].repeat(B,1) interval = 1/N - return w, m, interval + return w, m, interval, m_ def spec(f, NTIMES, *args): @@ -85,16 +101,16 @@ def spec(f, NTIMES, *args): B = 8192 NTIMES = 100 - for N in [32, 64, 128, 256, 384, 512]: + for N in [32, 64, 128, 256, 384, 512, 1024]: print(f' B={B}; N={N} '.center(50, '=')) - w, m, interval = gen_example(B, N) + w, m, interval, m_ = gen_example(B, N) ray_id = torch.arange(len(w))[:,None].repeat(1,N).cuda() try: - print(' original_distloss '.center(50, '.')) - spec(original_distloss, NTIMES, w, m, interval) + print(' original_distloss '.center(50, '.')) + spec(original_distloss, NTIMES, w, m, interval) except RuntimeError as e: - print(e) + print(e) try: print(' eff_distloss_native '.center(50, '.')) @@ -113,3 +129,9 @@ def spec(f, NTIMES, *args): spec(flatten_eff_distloss, NTIMES, w.flatten(), m.flatten(), interval, ray_id.flatten()) except RuntimeError as e: print(e) + + try: + print('einsum_distloss'.center(50, '.')) + spec(einsum_distloss, NTIMES, w, m_, interval) + except RuntimeError as e: + print(e)