Skip to content

Commit

Permalink
Fix train_mlp_nerf and save the model at the end of training (#177)
Browse files Browse the repository at this point in the history
* Fix train_mlp_nerf

* Fix black and isort
  • Loading branch information
97littleleaf11 committed Mar 27, 2023
1 parent 17e28de commit 9a1c0ab
Showing 1 changed file with 79 additions and 52 deletions.
131 changes: 79 additions & 52 deletions examples/train_mlp_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
import torch.nn.functional as F
import tqdm
from radiance_fields.mlp import VanillaNeRFRadianceField
from utils import render_image, set_random_seed
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
render_image,
set_random_seed,
)

from nerfacc import ContractionType, OccupancyGrid

Expand All @@ -34,23 +39,17 @@
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--model_path",
type=str,
default=None,
help="the path of the pretrained model",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
Expand All @@ -74,11 +73,47 @@

render_n_samples = 1024

# setup the scene bounding box.
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}

if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader

print("Using unbounded rendering")
target_sample_batch_size = 1 << 16
train_dataset_kwargs["color_bkgd_aug"] = "random"
train_dataset_kwargs["factor"] = 4
test_dataset_kwargs["factor"] = 4
grid_resolution = 128

elif args.scene in NERF_SYNTHETIC_SCENES:
from datasets.nerf_synthetic import SubjectLoader

target_sample_batch_size = 1 << 16
grid_resolution = 128

train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
device=device,
**train_dataset_kwargs,
)

test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)

if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
Expand Down Expand Up @@ -110,44 +145,22 @@
gamma=0.33,
)

# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader

target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader

target_sample_batch_size = 1 << 16
grid_resolution = 128

train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)

test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
**test_dataset_kwargs,
)

occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)

if args.model_path is not None:
checkpoint = torch.load(args.model_path)
radiance_field.load_state_dict(checkpoint["radiance_field_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
occupancy_grid.load_state_dict(checkpoint["occupancy_grid_state_dict"])
step = checkpoint["step"]
else:
step = 0

# training
step = 0
tic = time.time()
Expand Down Expand Up @@ -204,14 +217,28 @@
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"psnr={psnr:.2f}"
)

if step > 0 and step % max_steps == 0:
model_save_path = str(pathlib.Path.cwd() / f"mlp_nerf_{step}")
torch.save(
{
"step": step,
"radiance_field_state_dict": radiance_field.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"occupancy_grid_state_dict": occupancy_grid.state_dict(),
},
model_save_path,
)

# evaluation
radiance_field.eval()

Expand All @@ -230,8 +257,8 @@
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
Expand All @@ -246,7 +273,7 @@
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# f"rgb_test_{i}.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
Expand Down

0 comments on commit 9a1c0ab

Please sign in to comment.