Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloading Revamp #3216

Open
wants to merge 141 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
141 commits
Select commit Hold shift + click to select a range
0471543
initial debugging and testing works
AntonioMacaronio Jun 11, 2024
c6dde7d
pwais changes with RayBatchStream to alleviate training
AntonioMacaronio Jun 12, 2024
a09ea0c
Merge branch 'main' into dataloading-revamp
AntonioMacaronio Jun 12, 2024
78453cd
few bugs to iron out with multiprocessing, specifically pickled colla…
AntonioMacaronio Jun 12, 2024
f2bd96f
working version of RayBatchStream
AntonioMacaronio Jun 13, 2024
d8b7430
additional docstrings
AntonioMacaronio Jun 13, 2024
a5425d4
cleanup
AntonioMacaronio Jun 13, 2024
604f734
much more documentation
AntonioMacaronio Jun 13, 2024
0143803
successfully trained AEA-script2_seq2 closed_loop without OOM
AntonioMacaronio Jun 13, 2024
d3527e2
porting over aria dataset-size feature
AntonioMacaronio Jun 13, 2024
25f5f27
added logic to handle eviction of a worker's cached_collated_batch
AntonioMacaronio Jun 14, 2024
3a8b63b
antonio's implementation of stream batches
AntonioMacaronio Jun 15, 2024
536c6ca
training on a dataset with 4000 images works!
AntonioMacaronio Jun 15, 2024
43a0061
some configuration speedups, loops aren't actually needed!
AntonioMacaronio Jun 15, 2024
fa7cf30
quick fix adjustment to aria
AntonioMacaronio Jun 15, 2024
927cb6a
removed unnecessary looping
AntonioMacaronio Jun 16, 2024
814f2c2
much faster training when adding i variable to collate every 5 ray bu…
AntonioMacaronio Jun 25, 2024
247ac3e
cleanup unnecssary variables in Dataloader
AntonioMacaronio Jul 7, 2024
55d0803
further cleanup
AntonioMacaronio Jul 11, 2024
b6979a4
adding caching of compressed images to RAM to reduce disk bottleneck
AntonioMacaronio Jul 20, 2024
81dbf7c
added caching to RAM for masks
AntonioMacaronio Jul 22, 2024
55ca71d
found fast way to collate - many tricks applied
AntonioMacaronio Jul 26, 2024
3b4f091
quick update to aria to test on different datasets
AntonioMacaronio Jul 26, 2024
7de1922
cleaned up the accelerated pil_to_numpy function
AntonioMacaronio Jul 26, 2024
9ceaad1
cleaning up PR
AntonioMacaronio Jul 26, 2024
4147a6a
this commit was used to generate the time metrics and profiling metrics
AntonioMacaronio Jul 26, 2024
5a55b7a
REAL commit used to run tests
AntonioMacaronio Jul 26, 2024
78f02e6
testing with nerfacto-big
AntonioMacaronio Aug 15, 2024
19bc4b5
generated RayBundle collate and converting images from uint8s to floa…
AntonioMacaronio Aug 15, 2024
9245d05
updating nerfacto to support uint8 easily, will need to figure out a …
AntonioMacaronio Aug 20, 2024
3124c14
datamanager updates, both splat and nerf
AntonioMacaronio Aug 20, 2024
afb0612
must use writeable arrays because torch requires them
AntonioMacaronio Aug 20, 2024
288a740
cleaned up base_dataset, added pickle to utils, more code in full_ima…
AntonioMacaronio Aug 22, 2024
2fd0862
lots of process on a parallel FullImageDatamanger
AntonioMacaronio Aug 23, 2024
846e2f3
can train big splats with pre-assertion hack or ROI hack and 0 workers
AntonioMacaronio Aug 24, 2024
8fb0b4d
fixed all undistortion issues with ParallelImageDatamanager
AntonioMacaronio Aug 27, 2024
ce3f83f
adding some downsampling and parallel tests with splatfacto!
AntonioMacaronio Aug 31, 2024
8ab9963
deleted commented code in dataloaders.py and added bugfix to shuffling
AntonioMacaronio Aug 31, 2024
c9e16bf
testing splatfacto-big
AntonioMacaronio Sep 1, 2024
ddac38d
cleaned up base_pipeline.py
AntonioMacaronio Sep 1, 2024
443719a
cleaned up base_pipeline.py ACTUALLY THIS TIME, forgot to save last time
AntonioMacaronio Sep 1, 2024
d16e519
cleaned up a lot of code
AntonioMacaronio Sep 1, 2024
367d512
process_project_aria back to main branch and some cleanup in full_ima…
AntonioMacaronio Sep 1, 2024
d3d99b4
clarifying docstrings
AntonioMacaronio Sep 1, 2024
6f763dc
further PR cleanup
AntonioMacaronio Sep 3, 2024
a5191bd
updating models
AntonioMacaronio Sep 9, 2024
7db70dc
further cleanup
AntonioMacaronio Sep 9, 2024
5c3262b
removed caching of images into bytestrings
AntonioMacaronio Sep 9, 2024
ff2bda1
adding caching of compressed images to RAM, forgot that hardware matters
AntonioMacaronio Sep 9, 2024
f6dd7dd
removing oom methods, adding the ability to add a flag to dataloading
AntonioMacaronio Sep 15, 2024
a6602c7
removed CacheDataloader, moved RayBatchStream to dataloaders.py, new …
AntonioMacaronio Sep 15, 2024
3dc2031
fixing base_piplines, deleting a weird datamanager_configs file that …
AntonioMacaronio Sep 15, 2024
89f3d98
cleaning up next_train
AntonioMacaronio Sep 15, 2024
14e60e5
replaced parallel datamanager with new datamanager
AntonioMacaronio Sep 19, 2024
204dfb2
reverted the original base_datamanager.py, new datamanager replaced p…
AntonioMacaronio Sep 19, 2024
5864bc9
modified VanillaConfig, but VanillaDataManager is the same as before
AntonioMacaronio Sep 19, 2024
6d97de3
cleaning up, 2 datamanagers now - original and new parallel one
AntonioMacaronio Sep 19, 2024
1f34017
able to train with new nerfstudio dataloader now
AntonioMacaronio Sep 19, 2024
99cf86a
side by side datamanagers, moved tons of logic into dataloaders.py an…
AntonioMacaronio Sep 23, 2024
4ebad85
added custom ray processing API to support implementations like LERF,…
AntonioMacaronio Sep 23, 2024
87921be
adding functionality for ns-eval by adding FixedIndicesEvalDataloader…
AntonioMacaronio Sep 24, 2024
b628c7c
adding both ray API and image-view API to datamanagers for custom par…
AntonioMacaronio Sep 27, 2024
d2785d1
updating splatfacto config for 4k tests
AntonioMacaronio Sep 30, 2024
436af9d
updating docstrings to be more descriptive
AntonioMacaronio Sep 30, 2024
dd4daaa
new datamanager API breaks when setup_eval() has multiple workers, no…
AntonioMacaronio Sep 30, 2024
43c66ae
adding custom_view_processor to ImageBatchStream
AntonioMacaronio Sep 30, 2024
ba81e11
merging with main!
AntonioMacaronio Sep 30, 2024
1922566
reverting full_images_datamanager to main branch
AntonioMacaronio Oct 1, 2024
beb74be
removing nn.Module inheritance from Datamanager class
AntonioMacaronio Oct 1, 2024
087cff0
don't need to move datamanger to device anymore since Datamanager is …
AntonioMacaronio Oct 1, 2024
48e6d15
finished integration test with nerfacto
AntonioMacaronio Oct 4, 2024
3f1799b
simplified config variables, integrated the parallelism/disk-data-loa…
AntonioMacaronio Oct 25, 2024
f46aa42
updated the splatfacto config to be simpler with the dataloading and …
AntonioMacaronio Oct 25, 2024
5aa51fb
style checks and some cleanup
AntonioMacaronio Oct 25, 2024
ec3c12a
new splatfacto test, cleaning up nerfacto integration test
AntonioMacaronio Oct 25, 2024
82bc5b2
removing redundant parallel_full_images_datamaanger, as the OG full_i…
AntonioMacaronio Oct 26, 2024
377a56a
Merge branch 'main' into dataloading-revamp
AntonioMacaronio Oct 28, 2024
bbb5473
ruff linting and pyright fixing
AntonioMacaronio Oct 28, 2024
2e64120
further pyright fixing
AntonioMacaronio Oct 28, 2024
e9c2fd6
another pyright fixing
AntonioMacaronio Oct 28, 2024
e4dc9f9
fixing pyright error, camera optimization no longer part of datamanager
AntonioMacaronio Nov 1, 2024
8b0ec8e
fixing one pyright
AntonioMacaronio Nov 22, 2024
6349852
fixing dataloading error when camera is not undistorted with dataloader
AntonioMacaronio Dec 13, 2024
ad6b090
fixing comments and updating style
AntonioMacaronio Dec 21, 2024
8c678ee
undoing a style change i made
AntonioMacaronio Dec 21, 2024
64edabb
undoing another style change i made by accident
AntonioMacaronio Dec 21, 2024
cc63585
Merge branch 'main' into dataloading-revamp
AntonioMacaronio Dec 22, 2024
1e40aad
fixing slow runtime
AntonioMacaronio Dec 25, 2024
0012017
fixing a more general camera undistortion bug
AntonioMacaronio Dec 31, 2024
2fdba24
move images to device properly
kerrj Jan 2, 2025
d8ec706
minor improvements
kerrj Jan 2, 2025
51fc984
Merge branch 'main' into dataloading-revamp
kerrj Jan 3, 2025
36df9b3
add print statement about >500 images, cleanup method configs
kerrj Jan 3, 2025
a3fb46f
make method configs consistent across nerfacto models
kerrj Jan 3, 2025
3e06221
adding description comments
AntonioMacaronio Jan 4, 2025
c6f8094
Merge branch 'main' into dataloading-revamp
kerrj Jan 6, 2025
e72dd78
updating description
AntonioMacaronio Jan 6, 2025
f5024fc
Merge branch 'dataloading-revamp' of https://github.com/AntonioMacaro…
AntonioMacaronio Jan 6, 2025
d2af513
resolving some pyright issues with export.py, explained in PR desc
AntonioMacaronio Jan 6, 2025
1a02133
fixing pyright issues in base_pipeline.py
AntonioMacaronio Jan 6, 2025
b7bcb13
ran pyright on exporter and base_pipeline.py without issues
AntonioMacaronio Jan 6, 2025
603a5db
adding a git ignore to a clearly checked pyright issue
AntonioMacaronio Jan 6, 2025
eedda79
typo
kerrj Jan 7, 2025
3c5ab8e
merge
kerrj Jan 7, 2025
2f90812
fixing most ns-dev-test cases
AntonioMacaronio Jan 7, 2025
3a82351
Merge branch 'dataloading-revamp' of https://github.com/AntonioMacaro…
AntonioMacaronio Jan 7, 2025
4091694
cleanup, passing final ns-dev-test
AntonioMacaronio Jan 7, 2025
e7c99e4
oops, accidentally pushed the deletion of a docstring, undoing that
AntonioMacaronio Jan 7, 2025
bd6d1ae
another cleanup
AntonioMacaronio Jan 7, 2025
deb4d7f
some fixes to eval pipeline
kerrj Jan 7, 2025
a5f62aa
lint
kerrj Jan 7, 2025
97629a7
Merge branch 'main' into dataloading-revamp
kerrj Jan 8, 2025
e13525e
add asserts for spawn
kerrj Jan 8, 2025
94afc0b
Merge branch 'dataloading-revamp' of https://github.com/AntonioMacaro…
kerrj Jan 8, 2025
c316a7b
lint
kerrj Jan 8, 2025
b8da37d
cleaning up import statements in parallel_datamanager.py
AntonioMacaronio Jan 9, 2025
3fafbc2
Merge branch 'dataloading-revamp' of https://github.com/AntonioMacaro…
AntonioMacaronio Jan 9, 2025
e4a7661
adding new developer documentation if users would like to migrate the…
AntonioMacaronio Jan 9, 2025
1b37fc4
removing unnecessary to_device no-op
AntonioMacaronio Jan 11, 2025
e1764a9
further updates to documentation
AntonioMacaronio Jan 11, 2025
a280ebc
Merge branch 'main' into dataloading-revamp
kerrj Jan 13, 2025
70d965e
lint
kerrj Jan 13, 2025
f4b0f28
more docs
kerrj Jan 13, 2025
7617a79
docs
kerrj Jan 13, 2025
b0fc764
remove comment
kerrj Jan 13, 2025
37f4ca4
add docs, fix depth dataset with parallel datamanager, fix mask sampl…
kerrj Jan 13, 2025
6270e6b
remove profiling
kerrj Jan 13, 2025
9972d21
more profile removal
kerrj Jan 13, 2025
14246c6
custom_view_processor->custom_image_processor
kerrj Jan 13, 2025
a1af58a
doc clarification
kerrj Jan 13, 2025
974200c
Merge branch 'main' of https://github.com/AntonioMacaronio/nerfstudio…
AntonioMacaronio Jan 14, 2025
c2e8a8b
Merge branch 'dataloading-revamp' of https://github.com/AntonioMacaro…
AntonioMacaronio Jan 15, 2025
ac1d4a8
datamanager doc nit
brentyi Jan 16, 2025
3092805
whitespace
brentyi Jan 16, 2025
8d42154
nits
brentyi Jan 16, 2025
cce5e47
Merge branch 'dataloading-revamp' of https://github.com/AntonioMacaro…
AntonioMacaronio Jan 16, 2025
01f40b4
remove stuff from __post_init__, tune num workers more, add random of…
kerrj Jan 16, 2025
61ef730
Merge branch 'dataloading-revamp' of https://github.com/AntonioMacaro…
AntonioMacaronio Jan 16, 2025
3e532f5
Merge branch 'main' into dataloading-revamp
kerrj Jan 17, 2025
c038632
Merge branch 'dataloading-revamp' of https://github.com/AntonioMacaro…
AntonioMacaronio Jan 18, 2025
7c4762e
removing unnecessary assertion, updating docstring because DataManage…
AntonioMacaronio Jan 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 107 additions & 3 deletions docs/developer_guides/pipelines/datamanagers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,28 @@

## What is a DataManager?

The DataManager returns RayBundle and RayGT objects. Let's first take a look at the most important abstract methods required by the DataManager.
The DataManager batches and returns two components from an input dataset:

1. A representation of viewpoint (either cameras or rays).
- For splatting methods (`FullImageDataManager`): a `Cameras` object.
- For ray sampling methods (`VanillaDataManager`): a `RayBundle` object.
2. A dictionary of ground truth data.
- For splatting methods (`FullImageDataManager`): dictionary contains complete images.
- For ray sampling methods (`VanillaDataManager`): dictionary contains per-ray information.

Behaviors are defined by implementing the abstract methods required by the DataManager:

```python
class DataManager(nn.Module):
"""Generic data manager's abstract class
"""

@abstractmethod
def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
def next_train(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
"""Returns the next batch of data for train."""

@abstractmethod
def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
def next_eval(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
"""Returns the next batch of data for eval."""

@abstractmethod
Expand Down Expand Up @@ -94,3 +103,98 @@ See the code!
## Creating Your Own

We currently don't have other implementations because most papers follow the VanillaDataManager implementation. However, it should be straightforward to add a VanillaDataManager with logic that progressively adds cameras, for instance, by relying on the step and modifying RayBundle and RayGT generation logic.

## Disk Caching for Large Datasets
As of January 2025, the FullImageDatamanager and ParallelImageDatamanager implementations now support parallelized dataloading and dataloading from disk to avoid Out-Of-Memory errors and support very large datasets. To train a NeRF-based method with a large dataset that's unable to fit in memory, please add the `load_from_disk` flag to your `ns-train` command. For example with nerfacto:
```bash
ns-train nerfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.load-from-disk
```

To train splatfacto with a large dataset that's unable to fit in memory, please set the device of `cache_images` to `"disk"`. For example with splatfacto:
```bash
ns-train splatfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.cache-images disk
```

## Migrating Your DataManager to the new DataManager
Many methods subclass a DataManager and add extra data to it. If you would like your custom datamanager to also support new parallel features, you can migrate any custom dataloading logic to the new `custom_ray_processor()` API. This function takes in a full training batch (either image or ray bundle) and allows the user to modify or add to it. Let's take a look at an example for the LERF method, which was built on Nerfstudio's VanillaDataManager. This API provides an interface to attach new information to the RayBundle (for ray based methods), Cameras object (for splatting based methods), or ground truth dictionary. It runs in a background process if disk caching is enabled, otherwise it runs in the main process.

Naively transfering code to `custom_ray_processor` may still OOM on very large datasets if initialization code requires computing something over the whole dataset. To fully take advantage of parallelization make sure your subclassed datamanager computes new information inside the `custom_ray_processor`, or caches a subset of the whole dataset. This can also still be slow if pre-computation requires GPU-heavy steps on the same GPU used for training.

**Note**: Because the parallel DataManager uses background processes, any member of the DataManager needs to be *picklable* to be used inside `custom_ray_processor`.

```python
class LERFDataManager(VanillaDataManager):
"""Subclass VanillaDataManager to add extra data processing

Args:
config: the DataManagerConfig used to instantiate class
"""

config: LERFDataManagerConfig

def __init__(
self,
config: LERFDataManagerConfig,
device: Union[torch.device, str] = "cpu",
test_mode: Literal["test", "val", "inference"] = "val",
world_size: int = 1,
local_rank: int = 0,
**kwargs,
):
super().__init__(
config=config, device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank, **kwargs
)
# Some code to initialize all the CLIP and DINO feature encoders.
self.image_encoder: BaseImageEncoder = kwargs["image_encoder"]
self.dino_dataloader = ...
self.clip_interpolator = ...

def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the train dataloader.

In this custom DataManager we need to add on the data that LERF needs, namely CLIP and DINO features.
"""
self.train_count += 1
image_batch = next(self.iter_train_image_dataloader)
assert self.train_pixel_sampler is not None
batch = self.train_pixel_sampler.sample(image_batch)
ray_indices = batch["indices"]
ray_bundle = self.train_ray_generator(ray_indices)
batch["clip"], clip_scale = self.clip_interpolator(ray_indices)
batch["dino"] = self.dino_dataloader(ray_indices)
ray_bundle.metadata["clip_scales"] = clip_scale
# assume all cameras have the same focal length and image width
ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item()
ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item()
ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item()
ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item()
return ray_bundle, batch
```

To migrate this custom datamanager to the new datamanager, we'll subclass the new ParallelDataManager and shift the data customization process from `next_train()` to `custom_ray_processor()`.
The function `custom_ray_processor()` is called with a fully populated ray bundle and ground truth batch, just like the subclassed `next_train` in the above code. This code, however, is run in a background process.

```python
class LERFDataManager(ParallelDataManager, Generic[TDataset]):
"""
__init__ stays the same
"""

...

def custom_ray_processor(
self, ray_bundle: RayBundle, batch: Dict
) -> Tuple[RayBundle, Dict]:
"""An API to add latents, metadata, or other further customization to the RayBundle dataloading process that is parallelized."""
ray_indices = batch["indices"]
batch["clip"], clip_scale = self.clip_interpolator(ray_indices)
batch["dino"] = self.dino_dataloader(ray_indices)
ray_bundle.metadata["clip_scales"] = clip_scale

# Assume all cameras have the same focal length and image dimensions.
ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item()
ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item()
ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item()
ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item()
return ray_bundle, batch
```
18 changes: 11 additions & 7 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager
from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig
Expand All @@ -37,6 +37,7 @@
from nerfstudio.data.dataparsers.phototourism_dataparser import PhototourismDataParserConfig
from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig
from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.datasets.depth_dataset import DepthDataset
from nerfstudio.data.datasets.sdf_dataset import SDFDataset
from nerfstudio.data.datasets.semantic_dataset import SemanticDataset
Expand Down Expand Up @@ -91,7 +92,8 @@
max_num_iterations=30000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
_target=ParallelDataManager[InputDataset],
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
Expand Down Expand Up @@ -127,7 +129,8 @@
max_num_iterations=100000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
_target=ParallelDataManager[InputDataset],
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=8192,
eval_num_rays_per_batch=4096,
Expand Down Expand Up @@ -171,7 +174,8 @@
max_num_iterations=100000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
_target=ParallelDataManager[InputDataset],
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=16384,
eval_num_rays_per_batch=4096,
Expand Down Expand Up @@ -220,7 +224,7 @@
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=VanillaDataManagerConfig(
_target=VanillaDataManager[DepthDataset],
_target=ParallelDataManager[DepthDataset],
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
Expand Down Expand Up @@ -302,7 +306,7 @@
method_configs["mipnerf"] = TrainerConfig(
method_name="mipnerf",
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=1024),
datamanager=VanillaDataManagerConfig(dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=1024),
model=VanillaModelConfig(
_target=MipNerfModel,
loss_coefficients={"rgb_loss_coarse": 0.1, "rgb_loss_fine": 1.0},
Expand Down Expand Up @@ -375,7 +379,7 @@
max_num_iterations=30000,
mixed_precision=False,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
dataparser=BlenderDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
Expand Down
75 changes: 32 additions & 43 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
Expand All @@ -42,7 +41,6 @@

import torch
import tyro
from torch import nn
from torch.nn import Parameter
from torch.utils.data.distributed import DistributedSampler
from typing_extensions import TypeVar
Expand All @@ -56,44 +54,19 @@
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig
from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader
from nerfstudio.data.utils.dataloaders import (
CacheDataloader,
FixedIndicesEvalDataloader,
RandIndicesEvalDataloader,
variable_res_collate,
)
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.misc import IterableWrapper, get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE


def variable_res_collate(batch: List[Dict]) -> Dict:
"""Default collate function for the cached dataloader.
Args:
batch: Batch of samples from the dataset.
Returns:
Collated batch.
"""
images = []
imgdata_lists = defaultdict(list)
for data in batch:
image = data.pop("image")
images.append(image)
topop = []
for key, val in data.items():
if isinstance(val, torch.Tensor):
# if the value has same height and width as the image, assume that it should be collated accordingly.
if len(val.shape) >= 2 and val.shape[:2] == image.shape[:2]:
imgdata_lists[key].append(val)
topop.append(key)
# now that iteration is complete, the image data items can be removed from the batch
for key in topop:
del data[key]

new_batch = nerfstudio_collate(batch)
new_batch["image"] = images
new_batch.update(imgdata_lists)

return new_batch


@dataclass
class DataManagerConfig(InstantiateConfig):
"""Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers;
Expand All @@ -111,7 +84,7 @@ class DataManagerConfig(InstantiateConfig):
"""Process images on GPU for speed at the expense of memory, if True."""


class DataManager(nn.Module):
class DataManager:
"""Generic data manager's abstract class

This version of the data manager is designed be a monolithic way to load data and latents,
Expand Down Expand Up @@ -164,16 +137,16 @@ class DataManager(nn.Module):
train_sampler: Optional[DistributedSampler] = None
eval_sampler: Optional[DistributedSampler] = None
includes_time: bool = False
test_mode: Literal["test", "val", "inference"] = "val"

def __init__(self):
"""Constructor for the DataManager class.

Subclassed DataManagers will likely need to override this constructor.

If you aren't manually calling the setup_train and setup_eval functions from an overriden
constructor, that you call super().__init__() BEFORE you initialize any
nn.Modules or nn.Parameters, but AFTER you've already set all the attributes you need
for the setup functions."""
If you aren't manually calling the setup_train() and setup_eval() functions from an overridden
constructor, please call super().__init__() in your subclass' __init__() method after
you've already set all the attributes you need for the setup functions."""
super().__init__()
self.train_count = 0
self.eval_count = 0
Expand Down Expand Up @@ -311,13 +284,17 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""Target class to instantiate."""
dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig)
"""Specifies the dataparser used to unpack the data."""
cache_images_type: Literal["uint8", "float32"] = "float32"
"""The image type returned from manager, caching images in uint8 saves memory"""
train_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per training iteration."""
train_num_images_to_sample_from: int = -1
AntonioMacaronio marked this conversation as resolved.
Show resolved Hide resolved
train_num_images_to_sample_from: int = 50
"""Number of images to sample during training iteration."""
train_num_times_to_repeat_images: int = -1
train_num_times_to_repeat_images: int = 10
"""When not training on all images, number of iterations before picking new
images. If -1, never pick new images."""
images. If -1, never pick new images.
Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck.
"""
eval_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per eval iteration."""
eval_num_images_to_sample_from: int = -1
Expand All @@ -331,10 +308,20 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""Specifies the collate function to use for the train and eval dataloaders."""
camera_res_scale_factor: float = 1.0
"""The scale factor for scaling spatial data such as images, mask, semantics
along with relevant information about camera intrinsics
"""
along with relevant information about camera intrinsics"""
patch_size: int = 1
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
load_from_disk: bool = False
"""If True, conserves RAM memory by loading images from disk.
If False, caches all the images as tensors to RAM and loads from RAM."""
dataloader_num_workers: int = 4
"""The number of workers performing the dataloading from either disk/RAM, which
includes collating, pixel sampling, unprojecting, ray generation etc."""
prefetch_factor: int = 10
"""The limit number of batches a worker will start loading once an iterator is created.
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
cache_compressed_images: bool = False
"""If True, cache raw image files as byte strings to RAM."""

# tyro.conf.Suppress prevents us from creating CLI arguments for this field.
camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None)
Expand Down Expand Up @@ -451,13 +438,15 @@ def create_train_dataset(self) -> TDataset:
return self.dataset_type(
dataparser_outputs=self.train_dataparser_outputs,
scale_factor=self.config.camera_res_scale_factor,
cache_compressed_images=self.config.cache_compressed_images,
)

def create_eval_dataset(self) -> TDataset:
"""Sets up the data loaders for evaluation"""
return self.dataset_type(
dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split),
scale_factor=self.config.camera_res_scale_factor,
cache_compressed_images=self.config.cache_compressed_images,
)

def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler:
Expand Down
Loading
Loading