From 9a1c0ab56806a46162d4d088a490fe862b6c3e4e Mon Sep 17 00:00:00 2001 From: Jingchen Ye <11172084+97littleleaf11@users.noreply.github.com> Date: Mon, 27 Mar 2023 13:19:22 +0800 Subject: [PATCH] Fix train_mlp_nerf and save the model at the end of training (#177) * Fix train_mlp_nerf * Fix black and isort --- examples/train_mlp_nerf.py | 131 ++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 52 deletions(-) diff --git a/examples/train_mlp_nerf.py b/examples/train_mlp_nerf.py index 5aad78d7..ee7ea089 100644 --- a/examples/train_mlp_nerf.py +++ b/examples/train_mlp_nerf.py @@ -13,7 +13,12 @@ import torch.nn.functional as F import tqdm from radiance_fields.mlp import VanillaNeRFRadianceField -from utils import render_image, set_random_seed +from utils import ( + MIPNERF360_UNBOUNDED_SCENES, + NERF_SYNTHETIC_SCENES, + render_image, + set_random_seed, +) from nerfacc import ContractionType, OccupancyGrid @@ -34,23 +39,17 @@ choices=["train", "trainval"], help="which train split to use", ) +parser.add_argument( + "--model_path", + type=str, + default=None, + help="the path of the pretrained model", +) parser.add_argument( "--scene", type=str, default="lego", - choices=[ - # nerf synthetic - "chair", - "drums", - "ficus", - "hotdog", - "lego", - "materials", - "mic", - "ship", - # mipnerf360 unbounded - "garden", - ], + choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES, help="which scene to use", ) parser.add_argument( @@ -74,11 +73,47 @@ render_n_samples = 1024 -# setup the scene bounding box. +# setup the dataset +train_dataset_kwargs = {} +test_dataset_kwargs = {} + +if args.scene in MIPNERF360_UNBOUNDED_SCENES: + from datasets.nerf_360_v2 import SubjectLoader + + print("Using unbounded rendering") + target_sample_batch_size = 1 << 16 + train_dataset_kwargs["color_bkgd_aug"] = "random" + train_dataset_kwargs["factor"] = 4 + test_dataset_kwargs["factor"] = 4 + grid_resolution = 128 + +elif args.scene in NERF_SYNTHETIC_SCENES: + from datasets.nerf_synthetic import SubjectLoader + + target_sample_batch_size = 1 << 16 + grid_resolution = 128 + +train_dataset = SubjectLoader( + subject_id=args.scene, + root_fp=args.data_root, + split=args.train_split, + num_rays=target_sample_batch_size // render_n_samples, + device=device, + **train_dataset_kwargs, +) + +test_dataset = SubjectLoader( + subject_id=args.scene, + root_fp=args.data_root, + split="test", + num_rays=None, + device=device, + **test_dataset_kwargs, +) + if args.unbounded: print("Using unbounded rendering") contraction_type = ContractionType.UN_BOUNDED_SPHERE - # contraction_type = ContractionType.UN_BOUNDED_TANH scene_aabb = None near_plane = 0.2 far_plane = 1e4 @@ -110,44 +145,22 @@ gamma=0.33, ) -# setup the dataset -train_dataset_kwargs = {} -test_dataset_kwargs = {} -if args.scene == "garden": - from datasets.nerf_360_v2 import SubjectLoader - - target_sample_batch_size = 1 << 16 - train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} - test_dataset_kwargs = {"factor": 4} - grid_resolution = 128 -else: - from datasets.nerf_synthetic import SubjectLoader - - target_sample_batch_size = 1 << 16 - grid_resolution = 128 - -train_dataset = SubjectLoader( - subject_id=args.scene, - root_fp=args.data_root, - split=args.train_split, - num_rays=target_sample_batch_size // render_n_samples, - **train_dataset_kwargs, -) - -test_dataset = SubjectLoader( - subject_id=args.scene, - root_fp=args.data_root, - split="test", - num_rays=None, - **test_dataset_kwargs, -) - occupancy_grid = OccupancyGrid( roi_aabb=args.aabb, resolution=grid_resolution, contraction_type=contraction_type, ).to(device) +if args.model_path is not None: + checkpoint = torch.load(args.model_path) + radiance_field.load_state_dict(checkpoint["radiance_field_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + occupancy_grid.load_state_dict(checkpoint["occupancy_grid_state_dict"]) + step = checkpoint["step"] +else: + step = 0 + # training step = 0 tic = time.time() @@ -204,14 +217,28 @@ if step % 5000 == 0: elapsed_time = time.time() - tic loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) + psnr = -10.0 * torch.log(loss) / np.log(10.0) print( f"elapsed_time={elapsed_time:.2f}s | step={step} | " f"loss={loss:.5f} | " f"alive_ray_mask={alive_ray_mask.long().sum():d} | " - f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" + f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | " + f"psnr={psnr:.2f}" ) if step > 0 and step % max_steps == 0: + model_save_path = str(pathlib.Path.cwd() / f"mlp_nerf_{step}") + torch.save( + { + "step": step, + "radiance_field_state_dict": radiance_field.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "occupancy_grid_state_dict": occupancy_grid.state_dict(), + }, + model_save_path, + ) + # evaluation radiance_field.eval() @@ -230,8 +257,8 @@ rays, scene_aabb, # rendering options - near_plane=None, - far_plane=None, + near_plane=near_plane, + far_plane=far_plane, render_step_size=render_step_size, render_bkgd=render_bkgd, cone_angle=args.cone_angle, @@ -246,7 +273,7 @@ # ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8), # ) # imageio.imwrite( - # "rgb_test.png", + # f"rgb_test_{i}.png", # (rgb.cpu().numpy() * 255).astype(np.uint8), # ) # break