Skip to content

Commit 898ba5c

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
Moved MLP and Transformer
Summary: Moved the MLP and transformer from nerf to a new file to be reused. Reviewed By: bottler Differential Revision: D38828150 fbshipit-source-id: 8ff77b18b3aeeda398d90758a7bcb2482edce66f
1 parent edee25a commit 898ba5c

File tree

2 files changed

+321
-302
lines changed

2 files changed

+321
-302
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
from typing import Optional, Tuple
10+
11+
import torch
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class MLPWithInputSkips(torch.nn.Module):
17+
"""
18+
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
19+
20+
As such, `MLPWithInputSkips` is a multi layer perceptron consisting
21+
of a sequence of linear layers with ReLU activations.
22+
23+
Additionally, for a set of predefined layers `input_skips`, the forward pass
24+
appends a skip tensor `z` to the output of the preceding layer.
25+
26+
Note that this follows the architecture described in the Supplementary
27+
Material (Fig. 7) of [1].
28+
29+
References:
30+
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
31+
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
32+
NeRF: Representing Scenes as Neural Radiance Fields for View
33+
Synthesis, ECCV2020
34+
"""
35+
36+
def _make_affine_layer(self, input_dim, hidden_dim):
37+
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
38+
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
39+
_xavier_init(l1)
40+
_xavier_init(l2)
41+
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
42+
43+
def _apply_affine_layer(self, layer, x, z):
44+
mu_log_std = layer(z)
45+
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
46+
std = torch.nn.functional.softplus(log_std)
47+
return (x - mu) * std
48+
49+
def __init__(
50+
self,
51+
n_layers: int = 8,
52+
input_dim: int = 39,
53+
output_dim: int = 256,
54+
skip_dim: int = 39,
55+
hidden_dim: int = 256,
56+
input_skips: Tuple[int, ...] = (5,),
57+
skip_affine_trans: bool = False,
58+
no_last_relu=False,
59+
):
60+
"""
61+
Args:
62+
n_layers: The number of linear layers of the MLP.
63+
input_dim: The number of channels of the input tensor.
64+
output_dim: The number of channels of the output.
65+
skip_dim: The number of channels of the tensor `z` appended when
66+
evaluating the skip layers.
67+
hidden_dim: The number of hidden units of the MLP.
68+
input_skips: The list of layer indices at which we append the skip
69+
tensor `z`.
70+
"""
71+
super().__init__()
72+
layers = []
73+
skip_affine_layers = []
74+
for layeri in range(n_layers):
75+
dimin = hidden_dim if layeri > 0 else input_dim
76+
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
77+
78+
if layeri > 0 and layeri in input_skips:
79+
if skip_affine_trans:
80+
skip_affine_layers.append(
81+
self._make_affine_layer(skip_dim, hidden_dim)
82+
)
83+
else:
84+
dimin = hidden_dim + skip_dim
85+
86+
linear = torch.nn.Linear(dimin, dimout)
87+
_xavier_init(linear)
88+
layers.append(
89+
torch.nn.Sequential(linear, torch.nn.ReLU(True))
90+
if not no_last_relu or layeri + 1 < n_layers
91+
else linear
92+
)
93+
self.mlp = torch.nn.ModuleList(layers)
94+
if skip_affine_trans:
95+
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
96+
self._input_skips = set(input_skips)
97+
self._skip_affine_trans = skip_affine_trans
98+
99+
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
100+
"""
101+
Args:
102+
x: The input tensor of shape `(..., input_dim)`.
103+
z: The input skip tensor of shape `(..., skip_dim)` which is appended
104+
to layers whose indices are specified by `input_skips`.
105+
Returns:
106+
y: The output tensor of shape `(..., output_dim)`.
107+
"""
108+
y = x
109+
if z is None:
110+
# if the skip tensor is None, we use `x` instead.
111+
z = x
112+
skipi = 0
113+
for li, layer in enumerate(self.mlp):
114+
if li in self._input_skips:
115+
if self._skip_affine_trans:
116+
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
117+
else:
118+
y = torch.cat((y, z), dim=-1)
119+
skipi += 1
120+
y = layer(y)
121+
return y
122+
123+
124+
class TransformerWithInputSkips(torch.nn.Module):
125+
def __init__(
126+
self,
127+
n_layers: int = 8,
128+
input_dim: int = 39,
129+
output_dim: int = 256,
130+
skip_dim: int = 39,
131+
hidden_dim: int = 64,
132+
input_skips: Tuple[int, ...] = (5,),
133+
dim_down_factor: float = 1,
134+
):
135+
"""
136+
Args:
137+
n_layers: The number of linear layers of the MLP.
138+
input_dim: The number of channels of the input tensor.
139+
output_dim: The number of channels of the output.
140+
skip_dim: The number of channels of the tensor `z` appended when
141+
evaluating the skip layers.
142+
hidden_dim: The number of hidden units of the MLP.
143+
input_skips: The list of layer indices at which we append the skip
144+
tensor `z`.
145+
"""
146+
super().__init__()
147+
148+
self.first = torch.nn.Linear(input_dim, hidden_dim)
149+
_xavier_init(self.first)
150+
151+
self.skip_linear = torch.nn.ModuleList()
152+
153+
layers_pool, layers_ray = [], []
154+
dimout = 0
155+
for layeri in range(n_layers):
156+
dimin = int(round(hidden_dim / (dim_down_factor**layeri)))
157+
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
158+
logger.info(f"Tr: {dimin} -> {dimout}")
159+
for _i, l in enumerate((layers_pool, layers_ray)):
160+
l.append(
161+
TransformerEncoderLayer(
162+
d_model=[dimin, dimout][_i],
163+
nhead=4,
164+
dim_feedforward=hidden_dim,
165+
dropout=0.0,
166+
d_model_out=dimout,
167+
)
168+
)
169+
170+
if layeri in input_skips:
171+
self.skip_linear.append(torch.nn.Linear(input_dim, dimin))
172+
173+
self.last = torch.nn.Linear(dimout, output_dim)
174+
_xavier_init(self.last)
175+
176+
# pyre-fixme[8]: Attribute has type `Tuple[ModuleList, ModuleList]`; used as
177+
# `ModuleList`.
178+
self.layers_pool, self.layers_ray = (
179+
torch.nn.ModuleList(layers_pool),
180+
torch.nn.ModuleList(layers_ray),
181+
)
182+
self._input_skips = set(input_skips)
183+
184+
def forward(
185+
self,
186+
x: torch.Tensor,
187+
z: Optional[torch.Tensor] = None,
188+
):
189+
"""
190+
Args:
191+
x: The input tensor of shape
192+
`(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`.
193+
z: The input skip tensor of shape
194+
`(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)`
195+
which is appended to layers whose indices are specified by `input_skips`.
196+
Returns:
197+
y: The output tensor of shape
198+
`(minibatch, 1, ..., n_ray_pts, input_dim)`.
199+
"""
200+
201+
if z is None:
202+
# if the skip tensor is None, we use `x` instead.
203+
z = x
204+
205+
y = self.first(x)
206+
207+
B, n_pool, n_rays, n_pts, dim = y.shape
208+
209+
# y_p in n_pool, n_pts, B x n_rays x dim
210+
y_p = y.permute(1, 3, 0, 2, 4)
211+
212+
skipi = 0
213+
dimh = dim
214+
for li, (layer_pool, layer_ray) in enumerate(
215+
zip(self.layers_pool, self.layers_ray)
216+
):
217+
y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh)
218+
if li in self._input_skips:
219+
z_skip = self.skip_linear[skipi](z)
220+
y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape(
221+
n_pool, n_pts * B * n_rays, dimh
222+
)
223+
skipi += 1
224+
# n_pool x B*n_rays*n_pts x dim
225+
y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None)
226+
dimh = y_pool_attn.shape[-1]
227+
228+
y_ray_attn = (
229+
y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh)
230+
.permute(1, 0, 2, 3)
231+
.reshape(n_pts, n_pool * B * n_rays, dimh)
232+
)
233+
# n_pts x n_pool*B*n_rays x dim
234+
y_ray_attn, ray_attn = layer_ray(
235+
y_ray_attn,
236+
src_key_padding_mask=None,
237+
)
238+
239+
y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3)
240+
241+
y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4)
242+
243+
W = torch.softmax(y[..., :1], dim=1)
244+
y = (y * W).sum(dim=1)
245+
y = self.last(y)
246+
247+
return y
248+
249+
250+
class TransformerEncoderLayer(torch.nn.Module):
251+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
252+
This standard encoder layer is based on the paper "Attention Is All You Need".
253+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
254+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
255+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
256+
in a different way during application.
257+
258+
Args:
259+
d_model: the number of expected features in the input (required).
260+
nhead: the number of heads in the multiheadattention models (required).
261+
dim_feedforward: the dimension of the feedforward network model (default=2048).
262+
dropout: the dropout value (default=0.1).
263+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
264+
265+
Examples::
266+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
267+
>>> src = torch.rand(10, 32, 512)
268+
>>> out = encoder_layer(src)
269+
"""
270+
271+
def __init__(
272+
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1
273+
):
274+
super(TransformerEncoderLayer, self).__init__()
275+
self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)
276+
# Implementation of Feedforward model
277+
self.linear1 = torch.nn.Linear(d_model, dim_feedforward)
278+
self.dropout = torch.nn.Dropout(dropout)
279+
d_model_out = d_model if d_model_out <= 0 else d_model_out
280+
self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out)
281+
self.norm1 = torch.nn.LayerNorm(d_model)
282+
self.norm2 = torch.nn.LayerNorm(d_model_out)
283+
self.dropout1 = torch.nn.Dropout(dropout)
284+
self.dropout2 = torch.nn.Dropout(dropout)
285+
286+
self.activation = torch.nn.functional.relu
287+
288+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
289+
r"""Pass the input through the encoder layer.
290+
291+
Args:
292+
src: the sequence to the encoder layer (required).
293+
src_mask: the mask for the src sequence (optional).
294+
src_key_padding_mask: the mask for the src keys per batch (optional).
295+
296+
Shape:
297+
see the docs in Transformer class.
298+
"""
299+
src2, attn = self.self_attn(
300+
src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
301+
)
302+
src = src + self.dropout1(src2)
303+
src = self.norm1(src)
304+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
305+
d_out = src2.shape[-1]
306+
src = src[..., :d_out] + self.dropout2(src2)[..., :d_out]
307+
src = self.norm2(src)
308+
return src, attn
309+
310+
311+
def _xavier_init(linear) -> None:
312+
"""
313+
Performs the Xavier weight initialization of the linear layer `linear`.
314+
"""
315+
torch.nn.init.xavier_uniform_(linear.weight.data)

0 commit comments

Comments
 (0)