Skip to content

Commit

Permalink
increase test time speed, reduce memory
Browse files Browse the repository at this point in the history
  • Loading branch information
kwea123 committed May 6, 2020
1 parent da37b5b commit 44db824
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 66 deletions.
5 changes: 3 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_opts():
help='number of additional fine samples')
parser.add_argument('--use_disp', default=False, action="store_true",
help='use disparity depth sampling')
parser.add_argument('--chunk', type=int, default=32*1024,
parser.add_argument('--chunk', type=int, default=32*1024*4,
help='chunk size to split the input to avoid OOM')

parser.add_argument('--ckpt_path', type=str, required=True,
Expand Down Expand Up @@ -73,7 +73,8 @@ def batched_inference(models, embeddings,
0,
N_importance,
chunk,
dataset.white_back)
dataset.white_back,
test_time=True)

for k, v in rendered_ray_chunks.items():
results[k] += [v]
Expand Down
21 changes: 16 additions & 5 deletions models/nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,28 @@ def __init__(self,
nn.Linear(W//2, 3),
nn.Sigmoid())

def forward(self, x):
def forward(self, x, sigma_only=False):
"""
Encodes input (xyz+dir) to rgb+sigma (not ready to render yet).
For rendering this ray, please see rendering.py
Inputs:
x: (B, self.in_channels_xyz+self.in_channels_dir)
x: (B, self.in_channels_xyz(+self.in_channels_dir))
the embedded vector of position and direction
sigma_only: whether to infer sigma only. If True,
x is of shape (B, self.in_channels_xyz)
Outputs:
out: (B, 4), rgb and sigma
if sigma_ony:
sigma: (B, 1) sigma
else:
out: (B, 4), rgb and sigma
"""
input_xyz, input_dir = \
torch.split(x, [self.in_channels_xyz, self.in_channels_dir], dim=-1)
if not sigma_only:
input_xyz, input_dir = \
torch.split(x, [self.in_channels_xyz, self.in_channels_dir], dim=-1)
else:
input_xyz = x

xyz_ = input_xyz
for i in range(self.D):
Expand All @@ -102,6 +110,9 @@ def forward(self, x):
xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)

sigma = self.sigma(xyz_)
if sigma_only:
return sigma

xyz_encoding_final = self.xyz_encoding_final(xyz_)

dir_encoding_input = torch.cat([xyz_encoding_final, input_dir], -1)
Expand Down
85 changes: 54 additions & 31 deletions models/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def render_rays(models,
noise_std=1,
N_importance=0,
chunk=1024*32,
white_back=False
white_back=False,
test_time=False
):
"""
Render rays by computing the output of @model applied on @rays
Expand All @@ -80,54 +81,66 @@ def render_rays(models,
N_importance: number of fine samples per ray
chunk: the chunk size in batched inference
white_back: whether the background is white (dataset dependent)
test_time: whether it is test (inference only) or not. If True, it will not do inference
on coarse rgb to save time
Outputs:
result: dictionary containing final rgb and depth maps for coarse and fine models
"""

def inference(model, embedding_xyz, embedding_dir, xyz_, dir_, z_vals):
def inference(model, embedding_xyz, xyz_, dir_, dir_embedded, z_vals, weights_only=False):
"""
Helper function that performs model inference.
Inputs:
model: NeRF model (coarse or fine)
embedding_xyz: embedding module for xyz
embedding_dir: embedding module for dir
xyz_: (N_rays, N_samples_, 3) sampled positions
N_samples_ is the number of sampled points in each ray;
= N_samples for coarse model
= N_samples+N_importance for fine model
dir_: (N_rays, 3) normalized directions
dir_: (N_rays, 3) ray directions
dir_embedded: (N_rays, embed_dir_channels) embedded directions
z_vals: (N_rays, N_samples_) depths of the sampled positions
weights_only: do inference on sigma only or not
Outputs:
rgb_final: (N_rays, 3) the final rgb image
depth_final: (N_rays) depth map
weights: (N_rays, N_samples_): weights fo each sample
if weights_only:
weights: (N_rays, N_samples_): weights of each sample
else:
rgb_final: (N_rays, 3) the final rgb image
depth_final: (N_rays) depth map
weights: (N_rays, N_samples_): weights of each sample
"""
N_samples_ = xyz_.shape[1]
# Embed positions and directions
# Embed directions
xyz_ = xyz_.view(-1, 3) # (N_rays*N_samples_, 3)
xyz_embedded = embedding_xyz(xyz_) # (N_rays*N_samples_, embed_xyz_channels)
dir_embedded = embedding_dir(dir_) # (N_rays, embed_dir_channels)
dir_embedded = torch.repeat_interleave(dir_embedded, repeats=N_samples_, dim=0)
# (N_rays*N_samples_, embed_dir_channels)
if not weights_only:
dir_embedded = torch.repeat_interleave(dir_embedded, repeats=N_samples_, dim=0)
# (N_rays*N_samples_, embed_dir_channels)

# Perform model inference to get rgb and raw sigma
B = xyz_.shape[0]
out_chunks = []
for i in range(0, B, chunk):
xyzdir_embedded = torch.cat([xyz_embedded[i:i+chunk],
dir_embedded[i:i+chunk]], 1)
out_chunks += [model(xyzdir_embedded)]

rgbsigma = torch.cat(out_chunks, 0)
rgbsigma = rgbsigma.view(N_rays, N_samples_, 4)
# Embed positions by chunk
xyz_embedded = embedding_xyz(xyz_[i:i+chunk])
if not weights_only:
xyzdir_embedded = torch.cat([xyz_embedded,
dir_embedded[i:i+chunk]], 1)
else:
xyzdir_embedded = xyz_embedded
out_chunks += [model(xyzdir_embedded, sigma_only=weights_only)]

out = torch.cat(out_chunks, 0)
if weights_only:
sigmas = out.view(N_rays, N_samples_)
else:
rgbsigma = out.view(N_rays, N_samples_, 4)
rgbs = rgbsigma[..., :3] # (N_rays, N_samples_, 3)
sigmas = rgbsigma[..., 3] # (N_rays, N_samples_)

# Convert these values using volume rendering (Section 4)
rgbs = rgbsigma[..., :3] # (N_rays, N_samples_, 3)
sigmas = rgbsigma[..., 3] # (N_rays, N_samples_)
deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples_-1)
delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # (N_rays, 1) the last delta is infinity
deltas = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_)
Expand All @@ -146,6 +159,8 @@ def inference(model, embedding_xyz, embedding_dir, xyz_, dir_, z_vals):
alphas * torch.cumprod(alphas_shifted, -1)[:, :-1] # (N_rays, N_samples_)
weights_sum = weights.sum(1) # (N_rays), the accumulated opacity along the rays
# equals "1 - (1-a1)(1-a2)...(1-an)" mathematically
if weights_only:
return weights

# compute final weighted outputs
rgb_final = torch.sum(weights.unsqueeze(-1)*rgbs, -2) # (N_rays, 3)
Expand All @@ -167,6 +182,9 @@ def inference(model, embedding_xyz, embedding_dir, xyz_, dir_, z_vals):
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)

# Embed direction
dir_embedded = embedding_dir(rays_d) # (N_rays, embed_dir_channels)

# Sample depth points
z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples)
if not use_disp: # use linear sampling in depth space
Expand All @@ -188,14 +206,19 @@ def inference(model, embedding_xyz, embedding_dir, xyz_, dir_, z_vals):
xyz_coarse_sampled = rays_o.unsqueeze(1) + \
rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)

rgb_coarse, depth_coarse, weights_coarse = \
inference(model_coarse, embedding_xyz, embedding_dir,
xyz_coarse_sampled, rays_d, z_vals)

result = {'rgb_coarse': rgb_coarse,
'depth_coarse': depth_coarse,
'opacity_coarse': weights_coarse.sum(1)
}
if test_time:
weights_coarse = \
inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d,
dir_embedded, z_vals, weights_only=True)
result = {'opacity_coarse': weights_coarse.sum(1)}
else:
rgb_coarse, depth_coarse, weights_coarse = \
inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d,
dir_embedded, z_vals, weights_only=False)
result = {'rgb_coarse': rgb_coarse,
'depth_coarse': depth_coarse,
'opacity_coarse': weights_coarse.sum(1)
}

if N_importance > 0: # sample points for fine model
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) # (N_rays, N_samples-1) interval mid points
Expand All @@ -212,8 +235,8 @@ def inference(model, embedding_xyz, embedding_dir, xyz_, dir_, z_vals):

model_fine = models[1]
rgb_fine, depth_fine, weights_fine = \
inference(model_fine, embedding_xyz, embedding_dir,
xyz_fine_sampled, rays_d, z_vals)
inference(model_fine, embedding_xyz, xyz_fine_sampled, rays_d,
dir_embedded, z_vals, weights_only=False)

result['rgb_fine'] = rgb_fine
result['depth_fine'] = depth_fine
Expand Down
69 changes: 41 additions & 28 deletions test.ipynb

Large diffs are not rendered by default.

0 comments on commit 44db824

Please sign in to comment.