Skip to content

Commit

Permalink
remove MLP updates
Browse files Browse the repository at this point in the history
  • Loading branch information
liruilong940607 committed Apr 25, 2023
1 parent 73ce8c8 commit c37d199
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 115 deletions.
2 changes: 1 addition & 1 deletion benchmarks/kplanes
2 changes: 1 addition & 1 deletion benchmarks/tensorf
Submodule tensorf updated 2 files
+4 −2 script.sh
+286 −119 train.py
2 changes: 1 addition & 1 deletion benchmarks/tineuvox
112 changes: 0 additions & 112 deletions examples/radiance_fields/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,115 +281,3 @@ def forward(self, x, t, condition=None):
torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
)
return self.nerf(x, condition=condition)


class NDRTNeRFRadianceField(nn.Module):

"""Invertble NN from https://arxiv.org/pdf/2206.15258.pdf"""

def __init__(self) -> None:
super().__init__()
self.time_encoder = SinusoidalEncoder(1, 0, 4, True)
self.warp_layers_1 = nn.ModuleList()
self.time_layers_1 = nn.ModuleList()
self.warp_layers_2 = nn.ModuleList()
self.time_layers_2 = nn.ModuleList()
self.posi_encoder_1 = SinusoidalEncoder(2, 0, 4, True)
self.posi_encoder_2 = SinusoidalEncoder(1, 0, 4, True)
for _ in range(3):
self.warp_layers_1.append(
MLP(
input_dim=self.posi_encoder_1.latent_dim + 64,
output_dim=1,
net_depth=2,
net_width=128,
skip_layer=None,
output_init=functools.partial(
torch.nn.init.uniform_, b=1e-4
),
)
)
self.warp_layers_2.append(
MLP(
input_dim=self.posi_encoder_2.latent_dim + 64,
output_dim=1 + 2,
net_depth=1,
net_width=128,
skip_layer=None,
output_init=functools.partial(
torch.nn.init.uniform_, b=1e-4
),
)
)
self.time_layers_1.append(
DenseLayer(
input_dim=self.time_encoder.latent_dim,
output_dim=64,
)
)
self.time_layers_2.append(
DenseLayer(
input_dim=self.time_encoder.latent_dim,
output_dim=64,
)
)

self.nerf = VanillaNeRFRadianceField()

def _warp(self, x, t_enc, i_layer):
uv, w = x[:, :2], x[:, 2:]
dw = self.warp_layers_1[i_layer](
torch.cat(
[self.posi_encoder_1(uv), self.time_layers_1[i_layer](t_enc)],
dim=-1,
)
)
w = w + dw
rt = self.warp_layers_2[i_layer](
torch.cat(
[self.posi_encoder_2(w), self.time_layers_2[i_layer](t_enc)],
dim=-1,
)
)
r = self._euler2rot_2dinv(rt[:, :1])
t = rt[:, 1:]
uv = torch.bmm(r, (uv - t)[..., None]).squeeze(-1)
return torch.cat([uv, w], dim=-1)

def warp(self, x, t):
t_enc = self.time_encoder(t)
x = self._warp(x, t_enc, 0)
x = x[..., [1, 2, 0]]
x = self._warp(x, t_enc, 1)
x = x[..., [2, 0, 1]]
x = self._warp(x, t_enc, 2)
return x

def query_opacity(self, x, timestamps, step_size):
idxs = torch.randint(0, len(timestamps), (x.shape[0],), device=x.device)
t = timestamps[idxs]
density = self.query_density(x, t)
# if the density is small enough those two are the same.
# opacity = 1.0 - torch.exp(-density * step_size)
opacity = density * step_size
return opacity

def query_density(self, x, t):
x = self.warp(x, t)
return self.nerf.query_density(x)

def forward(self, x, t, condition=None):
x = self.warp(x, t)
return self.nerf(x, condition=condition)

def _euler2rot_2dinv(self, euler_angle):
# (B, 1) -> (B, 2, 2)
theta = euler_angle.reshape(-1, 1, 1)
rot = torch.cat(
(
torch.cat((theta.cos(), -theta.sin()), 1),
torch.cat((theta.sin(), theta.cos()), 1),
),
2,
)
return rot

0 comments on commit c37d199

Please sign in to comment.