Skip to content

Commit f9a26a2

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Move Harmonic embedding to core pytorch3d
Summary: Moved `HarmonicEmbedding` function in core PyTorch3D. In the next diff will update the NeRF project. Reviewed By: bottler Differential Revision: D32833808 fbshipit-source-id: 0a12ccd1627c0ce024463c796544c91eb8d4d122
1 parent d67662d commit f9a26a2

File tree

4 files changed

+179
-1
lines changed

4 files changed

+179
-1
lines changed

pytorch3d/renderer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
VolumeSampler,
3838
ray_bundle_to_ray_points,
3939
ray_bundle_variables_to_ray_points,
40+
HarmonicEmbedding,
4041
)
4142
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
4243
from .materials import Materials

pytorch3d/renderer/implicit/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .harmonic_embedding import HarmonicEmbedding
78
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
89
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
910
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
@@ -13,5 +14,4 @@
1314
ray_bundle_variables_to_ray_points,
1415
)
1516

16-
1717
__all__ = [k for k in globals().keys() if not k.startswith("_")]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
10+
class HarmonicEmbedding(torch.nn.Module):
11+
def __init__(
12+
self,
13+
n_harmonic_functions: int = 6,
14+
omega_0: float = 1.0,
15+
logspace: bool = True,
16+
append_input: bool = True,
17+
) -> None:
18+
"""
19+
Given an input tensor `x` of shape [minibatch, ... , dim],
20+
the harmonic embedding layer converts each feature
21+
(i.e. vector along the last dimension) in `x`
22+
into a series of harmonic features `embedding`,
23+
where for each i in range(dim) the following are present
24+
in embedding[...]:
25+
```
26+
[
27+
sin(f_1*x[..., i]),
28+
sin(f_2*x[..., i]),
29+
...
30+
sin(f_N * x[..., i]),
31+
cos(f_1*x[..., i]),
32+
cos(f_2*x[..., i]),
33+
...
34+
cos(f_N * x[..., i]),
35+
x[..., i], # only present if append_input is True.
36+
]
37+
```
38+
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
39+
denoting the i-th frequency of the harmonic embedding.
40+
41+
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
42+
powers of 2:
43+
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
44+
45+
If `logspace==False`, frequencies are linearly spaced between
46+
`1.0` and `2**(n_harmonic_functions-1)`:
47+
`f_1, ..., f_N = torch.linspace(
48+
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
49+
)`
50+
51+
Note that `x` is also premultiplied by the base frequency `omega_0`
52+
before evaluating the harmonic functions.
53+
54+
Args:
55+
n_harmonic_functions: int, number of harmonic
56+
features
57+
omega_0: float, base frequency
58+
logspace: bool, Whether to space the frequencies in
59+
logspace or linear space
60+
append_input: bool, whether to concat the original
61+
input to the harmonic embedding. If true the
62+
output is of the form (x, embed.sin(), embed.cos()
63+
64+
"""
65+
super().__init__()
66+
67+
if logspace:
68+
frequencies = 2.0 ** torch.arange(
69+
n_harmonic_functions,
70+
dtype=torch.float32,
71+
)
72+
else:
73+
frequencies = torch.linspace(
74+
1.0,
75+
2.0 ** (n_harmonic_functions - 1),
76+
n_harmonic_functions,
77+
dtype=torch.float32,
78+
)
79+
80+
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
81+
self.append_input = append_input
82+
83+
def forward(self, x: torch.Tensor) -> torch.Tensor:
84+
"""
85+
Args:
86+
x: tensor of shape [..., dim]
87+
Returns:
88+
embedding: a harmonic embedding of `x`
89+
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
90+
"""
91+
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
92+
embed = torch.cat(
93+
(embed.sin(), embed.cos(), x)
94+
if self.append_input
95+
else (embed.sin(), embed.cos()),
96+
dim=-1,
97+
)
98+
return embed
99+
100+
@staticmethod
101+
def get_output_dim_static(
102+
input_dims: int,
103+
n_harmonic_functions: int,
104+
append_input: bool,
105+
) -> int:
106+
"""
107+
Utility to help predict the shape of the output of `forward`.
108+
109+
Args:
110+
input_dims: length of the last dimension of the input tensor
111+
n_harmonic_functions: number of embedding frequencies
112+
append_input: whether or not to concat the original
113+
input to the harmonic embedding
114+
Returns:
115+
int: the length of the last dimension of the output tensor
116+
"""
117+
return input_dims * (2 * n_harmonic_functions + int(append_input))
118+
119+
def get_output_dim(self, input_dims: int = 3) -> int:
120+
"""
121+
Same as above. The default for input_dims is 3 for 3D applications
122+
which use harmonic embedding for positional encoding,
123+
so the input might be xyz.
124+
"""
125+
return self.get_output_dim_static(
126+
input_dims, len(self._frequencies), self.append_input
127+
)

tests/test_harmonic_embedding.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from common_testing import TestCaseMixin
11+
from pytorch3d.renderer.implicit import HarmonicEmbedding
12+
13+
14+
class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase):
15+
def setUp(self) -> None:
16+
super().setUp()
17+
torch.manual_seed(1)
18+
19+
def test_correct_output_dim(self):
20+
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
21+
# input_dims * (2 * n_harmonic_functions + int(append_input))
22+
output_dim = 3 * (2 * 2 + int(False))
23+
self.assertEqual(
24+
output_dim,
25+
embed_fun.get_output_dim_static(
26+
input_dims=3, n_harmonic_functions=2, append_input=False
27+
),
28+
)
29+
self.assertEqual(output_dim, embed_fun.get_output_dim())
30+
31+
def test_correct_frequency_range(self):
32+
embed_fun_log = HarmonicEmbedding(n_harmonic_functions=3)
33+
embed_fun_lin = HarmonicEmbedding(n_harmonic_functions=3, logspace=False)
34+
self.assertClose(embed_fun_log._frequencies, torch.FloatTensor((1.0, 2.0, 4.0)))
35+
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
36+
37+
def test_correct_embed_out(self):
38+
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
39+
x = torch.randn((1, 5))
40+
D = 5 * 4
41+
embed_out = embed_fun(x)
42+
self.assertEqual(embed_out.shape, (1, D))
43+
# Sum the squares of the respective frequencies
44+
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
45+
self.assertClose(sum_squares, torch.ones((D // 2)))
46+
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True)
47+
embed_out = embed_fun(x)
48+
self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5)))
49+
# Last plane in output is the input
50+
self.assertClose(embed_out[..., -5:], x)

0 commit comments

Comments
 (0)