|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | + |
| 9 | +from typing import Optional, Tuple |
| 10 | + |
| 11 | +import torch |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +class MLPWithInputSkips(torch.nn.Module): |
| 17 | + """ |
| 18 | + Implements the multi-layer perceptron architecture of the Neural Radiance Field. |
| 19 | +
|
| 20 | + As such, `MLPWithInputSkips` is a multi layer perceptron consisting |
| 21 | + of a sequence of linear layers with ReLU activations. |
| 22 | +
|
| 23 | + Additionally, for a set of predefined layers `input_skips`, the forward pass |
| 24 | + appends a skip tensor `z` to the output of the preceding layer. |
| 25 | +
|
| 26 | + Note that this follows the architecture described in the Supplementary |
| 27 | + Material (Fig. 7) of [1]. |
| 28 | +
|
| 29 | + References: |
| 30 | + [1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik |
| 31 | + and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng: |
| 32 | + NeRF: Representing Scenes as Neural Radiance Fields for View |
| 33 | + Synthesis, ECCV2020 |
| 34 | + """ |
| 35 | + |
| 36 | + def _make_affine_layer(self, input_dim, hidden_dim): |
| 37 | + l1 = torch.nn.Linear(input_dim, hidden_dim * 2) |
| 38 | + l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2) |
| 39 | + _xavier_init(l1) |
| 40 | + _xavier_init(l2) |
| 41 | + return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2) |
| 42 | + |
| 43 | + def _apply_affine_layer(self, layer, x, z): |
| 44 | + mu_log_std = layer(z) |
| 45 | + mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1) |
| 46 | + std = torch.nn.functional.softplus(log_std) |
| 47 | + return (x - mu) * std |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + n_layers: int = 8, |
| 52 | + input_dim: int = 39, |
| 53 | + output_dim: int = 256, |
| 54 | + skip_dim: int = 39, |
| 55 | + hidden_dim: int = 256, |
| 56 | + input_skips: Tuple[int, ...] = (5,), |
| 57 | + skip_affine_trans: bool = False, |
| 58 | + no_last_relu=False, |
| 59 | + ): |
| 60 | + """ |
| 61 | + Args: |
| 62 | + n_layers: The number of linear layers of the MLP. |
| 63 | + input_dim: The number of channels of the input tensor. |
| 64 | + output_dim: The number of channels of the output. |
| 65 | + skip_dim: The number of channels of the tensor `z` appended when |
| 66 | + evaluating the skip layers. |
| 67 | + hidden_dim: The number of hidden units of the MLP. |
| 68 | + input_skips: The list of layer indices at which we append the skip |
| 69 | + tensor `z`. |
| 70 | + """ |
| 71 | + super().__init__() |
| 72 | + layers = [] |
| 73 | + skip_affine_layers = [] |
| 74 | + for layeri in range(n_layers): |
| 75 | + dimin = hidden_dim if layeri > 0 else input_dim |
| 76 | + dimout = hidden_dim if layeri + 1 < n_layers else output_dim |
| 77 | + |
| 78 | + if layeri > 0 and layeri in input_skips: |
| 79 | + if skip_affine_trans: |
| 80 | + skip_affine_layers.append( |
| 81 | + self._make_affine_layer(skip_dim, hidden_dim) |
| 82 | + ) |
| 83 | + else: |
| 84 | + dimin = hidden_dim + skip_dim |
| 85 | + |
| 86 | + linear = torch.nn.Linear(dimin, dimout) |
| 87 | + _xavier_init(linear) |
| 88 | + layers.append( |
| 89 | + torch.nn.Sequential(linear, torch.nn.ReLU(True)) |
| 90 | + if not no_last_relu or layeri + 1 < n_layers |
| 91 | + else linear |
| 92 | + ) |
| 93 | + self.mlp = torch.nn.ModuleList(layers) |
| 94 | + if skip_affine_trans: |
| 95 | + self.skip_affines = torch.nn.ModuleList(skip_affine_layers) |
| 96 | + self._input_skips = set(input_skips) |
| 97 | + self._skip_affine_trans = skip_affine_trans |
| 98 | + |
| 99 | + def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None): |
| 100 | + """ |
| 101 | + Args: |
| 102 | + x: The input tensor of shape `(..., input_dim)`. |
| 103 | + z: The input skip tensor of shape `(..., skip_dim)` which is appended |
| 104 | + to layers whose indices are specified by `input_skips`. |
| 105 | + Returns: |
| 106 | + y: The output tensor of shape `(..., output_dim)`. |
| 107 | + """ |
| 108 | + y = x |
| 109 | + if z is None: |
| 110 | + # if the skip tensor is None, we use `x` instead. |
| 111 | + z = x |
| 112 | + skipi = 0 |
| 113 | + for li, layer in enumerate(self.mlp): |
| 114 | + if li in self._input_skips: |
| 115 | + if self._skip_affine_trans: |
| 116 | + y = self._apply_affine_layer(self.skip_affines[skipi], y, z) |
| 117 | + else: |
| 118 | + y = torch.cat((y, z), dim=-1) |
| 119 | + skipi += 1 |
| 120 | + y = layer(y) |
| 121 | + return y |
| 122 | + |
| 123 | + |
| 124 | +class TransformerWithInputSkips(torch.nn.Module): |
| 125 | + def __init__( |
| 126 | + self, |
| 127 | + n_layers: int = 8, |
| 128 | + input_dim: int = 39, |
| 129 | + output_dim: int = 256, |
| 130 | + skip_dim: int = 39, |
| 131 | + hidden_dim: int = 64, |
| 132 | + input_skips: Tuple[int, ...] = (5,), |
| 133 | + dim_down_factor: float = 1, |
| 134 | + ): |
| 135 | + """ |
| 136 | + Args: |
| 137 | + n_layers: The number of linear layers of the MLP. |
| 138 | + input_dim: The number of channels of the input tensor. |
| 139 | + output_dim: The number of channels of the output. |
| 140 | + skip_dim: The number of channels of the tensor `z` appended when |
| 141 | + evaluating the skip layers. |
| 142 | + hidden_dim: The number of hidden units of the MLP. |
| 143 | + input_skips: The list of layer indices at which we append the skip |
| 144 | + tensor `z`. |
| 145 | + """ |
| 146 | + super().__init__() |
| 147 | + |
| 148 | + self.first = torch.nn.Linear(input_dim, hidden_dim) |
| 149 | + _xavier_init(self.first) |
| 150 | + |
| 151 | + self.skip_linear = torch.nn.ModuleList() |
| 152 | + |
| 153 | + layers_pool, layers_ray = [], [] |
| 154 | + dimout = 0 |
| 155 | + for layeri in range(n_layers): |
| 156 | + dimin = int(round(hidden_dim / (dim_down_factor**layeri))) |
| 157 | + dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1)))) |
| 158 | + logger.info(f"Tr: {dimin} -> {dimout}") |
| 159 | + for _i, l in enumerate((layers_pool, layers_ray)): |
| 160 | + l.append( |
| 161 | + TransformerEncoderLayer( |
| 162 | + d_model=[dimin, dimout][_i], |
| 163 | + nhead=4, |
| 164 | + dim_feedforward=hidden_dim, |
| 165 | + dropout=0.0, |
| 166 | + d_model_out=dimout, |
| 167 | + ) |
| 168 | + ) |
| 169 | + |
| 170 | + if layeri in input_skips: |
| 171 | + self.skip_linear.append(torch.nn.Linear(input_dim, dimin)) |
| 172 | + |
| 173 | + self.last = torch.nn.Linear(dimout, output_dim) |
| 174 | + _xavier_init(self.last) |
| 175 | + |
| 176 | + # pyre-fixme[8]: Attribute has type `Tuple[ModuleList, ModuleList]`; used as |
| 177 | + # `ModuleList`. |
| 178 | + self.layers_pool, self.layers_ray = ( |
| 179 | + torch.nn.ModuleList(layers_pool), |
| 180 | + torch.nn.ModuleList(layers_ray), |
| 181 | + ) |
| 182 | + self._input_skips = set(input_skips) |
| 183 | + |
| 184 | + def forward( |
| 185 | + self, |
| 186 | + x: torch.Tensor, |
| 187 | + z: Optional[torch.Tensor] = None, |
| 188 | + ): |
| 189 | + """ |
| 190 | + Args: |
| 191 | + x: The input tensor of shape |
| 192 | + `(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`. |
| 193 | + z: The input skip tensor of shape |
| 194 | + `(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)` |
| 195 | + which is appended to layers whose indices are specified by `input_skips`. |
| 196 | + Returns: |
| 197 | + y: The output tensor of shape |
| 198 | + `(minibatch, 1, ..., n_ray_pts, input_dim)`. |
| 199 | + """ |
| 200 | + |
| 201 | + if z is None: |
| 202 | + # if the skip tensor is None, we use `x` instead. |
| 203 | + z = x |
| 204 | + |
| 205 | + y = self.first(x) |
| 206 | + |
| 207 | + B, n_pool, n_rays, n_pts, dim = y.shape |
| 208 | + |
| 209 | + # y_p in n_pool, n_pts, B x n_rays x dim |
| 210 | + y_p = y.permute(1, 3, 0, 2, 4) |
| 211 | + |
| 212 | + skipi = 0 |
| 213 | + dimh = dim |
| 214 | + for li, (layer_pool, layer_ray) in enumerate( |
| 215 | + zip(self.layers_pool, self.layers_ray) |
| 216 | + ): |
| 217 | + y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh) |
| 218 | + if li in self._input_skips: |
| 219 | + z_skip = self.skip_linear[skipi](z) |
| 220 | + y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape( |
| 221 | + n_pool, n_pts * B * n_rays, dimh |
| 222 | + ) |
| 223 | + skipi += 1 |
| 224 | + # n_pool x B*n_rays*n_pts x dim |
| 225 | + y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None) |
| 226 | + dimh = y_pool_attn.shape[-1] |
| 227 | + |
| 228 | + y_ray_attn = ( |
| 229 | + y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh) |
| 230 | + .permute(1, 0, 2, 3) |
| 231 | + .reshape(n_pts, n_pool * B * n_rays, dimh) |
| 232 | + ) |
| 233 | + # n_pts x n_pool*B*n_rays x dim |
| 234 | + y_ray_attn, ray_attn = layer_ray( |
| 235 | + y_ray_attn, |
| 236 | + src_key_padding_mask=None, |
| 237 | + ) |
| 238 | + |
| 239 | + y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3) |
| 240 | + |
| 241 | + y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4) |
| 242 | + |
| 243 | + W = torch.softmax(y[..., :1], dim=1) |
| 244 | + y = (y * W).sum(dim=1) |
| 245 | + y = self.last(y) |
| 246 | + |
| 247 | + return y |
| 248 | + |
| 249 | + |
| 250 | +class TransformerEncoderLayer(torch.nn.Module): |
| 251 | + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. |
| 252 | + This standard encoder layer is based on the paper "Attention Is All You Need". |
| 253 | + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, |
| 254 | + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in |
| 255 | + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement |
| 256 | + in a different way during application. |
| 257 | +
|
| 258 | + Args: |
| 259 | + d_model: the number of expected features in the input (required). |
| 260 | + nhead: the number of heads in the multiheadattention models (required). |
| 261 | + dim_feedforward: the dimension of the feedforward network model (default=2048). |
| 262 | + dropout: the dropout value (default=0.1). |
| 263 | + activation: the activation function of intermediate layer, relu or gelu (default=relu). |
| 264 | +
|
| 265 | + Examples:: |
| 266 | + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) |
| 267 | + >>> src = torch.rand(10, 32, 512) |
| 268 | + >>> out = encoder_layer(src) |
| 269 | + """ |
| 270 | + |
| 271 | + def __init__( |
| 272 | + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1 |
| 273 | + ): |
| 274 | + super(TransformerEncoderLayer, self).__init__() |
| 275 | + self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| 276 | + # Implementation of Feedforward model |
| 277 | + self.linear1 = torch.nn.Linear(d_model, dim_feedforward) |
| 278 | + self.dropout = torch.nn.Dropout(dropout) |
| 279 | + d_model_out = d_model if d_model_out <= 0 else d_model_out |
| 280 | + self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out) |
| 281 | + self.norm1 = torch.nn.LayerNorm(d_model) |
| 282 | + self.norm2 = torch.nn.LayerNorm(d_model_out) |
| 283 | + self.dropout1 = torch.nn.Dropout(dropout) |
| 284 | + self.dropout2 = torch.nn.Dropout(dropout) |
| 285 | + |
| 286 | + self.activation = torch.nn.functional.relu |
| 287 | + |
| 288 | + def forward(self, src, src_mask=None, src_key_padding_mask=None): |
| 289 | + r"""Pass the input through the encoder layer. |
| 290 | +
|
| 291 | + Args: |
| 292 | + src: the sequence to the encoder layer (required). |
| 293 | + src_mask: the mask for the src sequence (optional). |
| 294 | + src_key_padding_mask: the mask for the src keys per batch (optional). |
| 295 | +
|
| 296 | + Shape: |
| 297 | + see the docs in Transformer class. |
| 298 | + """ |
| 299 | + src2, attn = self.self_attn( |
| 300 | + src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask |
| 301 | + ) |
| 302 | + src = src + self.dropout1(src2) |
| 303 | + src = self.norm1(src) |
| 304 | + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
| 305 | + d_out = src2.shape[-1] |
| 306 | + src = src[..., :d_out] + self.dropout2(src2)[..., :d_out] |
| 307 | + src = self.norm2(src) |
| 308 | + return src, attn |
| 309 | + |
| 310 | + |
| 311 | +def _xavier_init(linear) -> None: |
| 312 | + """ |
| 313 | + Performs the Xavier weight initialization of the linear layer `linear`. |
| 314 | + """ |
| 315 | + torch.nn.init.xavier_uniform_(linear.weight.data) |
0 commit comments