-
Notifications
You must be signed in to change notification settings - Fork 7
/
embedder.py
91 lines (72 loc) · 3.19 KB
/
embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import numpy as np
class Embedder:
""" nerf's Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires):
embed_kwargs = {
'include_input': True,
'input_dims': 3,
'max_freq_log2': multires-1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
return embed, embedder_obj.out_dim
class FourierFeatureTransform(torch.nn.Module):
"""
An implementation of Gaussian Fourier feature mapping(random positional encoding embedding).
"Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
https://arxiv.org/abs/2006.10739
https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
Given an input of size [batches, num_input_channels, width, height],
returns a tensor of size [batches, mapping_size*2, width, height].
"""
def __init__(self, num_input_channels=3, mapping_size=256, scale=10):
super().__init__()
self._num_input_channels = num_input_channels
self._mapping_size = mapping_size
B = torch.randn((num_input_channels, mapping_size)) * scale
B_sort = sorted(B, key=lambda x: torch.norm(x, p=2))
self._B = torch.stack(B_sort) # for sape
def forward(self, x):
# assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())
batches, channels = x.shape
assert channels == self._num_input_channels, \
"Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)
# Make shape compatible for matmul with _B.
# From [B, C, W, H] to [(B*W*H), C].
# x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
res = x @ self._B.to(x.device)
# From [(B*W*H), C] to [B, W, H, C]
# x = x.view(batches, width, height, self._mapping_size)
# From [B, W, H, C] to [B, C, W, H]
# x = x.permute(0, 3, 1, 2)
res = 2 * np.pi * res
return torch.cat([x,torch.sin(res), torch.cos(res)], dim=1)