Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloading Revamp #3216

Open
wants to merge 141 commits into
base: main
Choose a base branch
from

Conversation

AntonioMacaronio
Copy link
Contributor

@AntonioMacaronio AntonioMacaronio commented Jun 12, 2024

Problems and Background

  • With a sufficiently large enough dataset, the current parallel_datamanager.py will try to cache the entire dataset into RAM, which will lead to an out-of-memory (OOM) error
  • parallel_datamanager.py only uses one worker to generate ray bundles. Since various subprocesses such as unprojecting during ray generation, or pixel sampling within a custom mask can be a CPU-intensive task, it may be better suited to parallelize this. While parallel_datamanager.py does support multiple workers, each worker caches the entire dataset to RAM and it does not support massive datasets, leading to duplicate copies of the dataset in computer memory. It also implements parallelism from scratch and is not friendly to build off.
  • Additionally, both VanillaDataManager and ParallelDataManager rely on CacheDataloader, which subclasses torch.utils.data.DataLoader, which is a strange coding practice, and actually serves no particular use in the current nerfstudio implementation.
  • Similarly for full_images_datamanager.py: As we can not fit the entire dataset in RAM, the current implementation loads in entire dataset into the FullImageDataloader's cached_train attribute. To do this efficiently, we need multiprocess parallelization to load in images, undistort them, and do this quickly to keep up with GPU's forward and backward passes of the model.

Overview of Changes

  • Replacing CacheDataloader with RayBatchStream, which subclasses torch.utils.data.IterableDataset. The goal of this class is to generate ray bundles directly without caching all images to RAM. This is done by collating a sampled batch of images to sample from. A new ParallelDatamanager class is written which is available side-by-side but can completely replace the original VanillaDatamanager
  • Adding an ImageBatchStream to create a parallel, OOM-resistant version of FullImageDataManager. This can be configured to load from the disk by setting cache_images config variable to disk.
  • A new pil_to_numpy() function is added. This function reads a PIL.Image's data buffer and fills an empty numpy array while reading, hastening the conversion process and removing an extra memory allocation. It is the fastest way to get from a PIL Image to a Pytorch tensor averaging ~2.5ms for a 1080x1920 image (~40% faster)
  • A new flag called cache_compressed_imgs now caches your images to RAM in their compressed form (for example, caching) and relies on parallelized CPU dataloading to efficiently decode them into pytorch tensors to be used in training.
  • Resolving some pyright issues: Within nerfstudio/scripts/exporter.py, the assertions for ExportPointCloud and ExportPoissonMesh were modified because these are only used on NeRFs, so exporting for splats (has its own export method) and RandomCameraDatamanger (outdated) were removed. Similarly, some "# type: ignore" were added to various runtime checked locations that pyright could not detect. This was in base_pipeline.py and trainer.py.

Impact

  • Checkout these comparisons! The left was trained on 200 images of a 4k video, while the right was trained on 2000 images of the same 4k video.

Additional Diagrams
NOTE: THESE TESTS WERE CONDUCTED ON A 24GB 3090Ti
image
image
image
image
image

Copy link
Contributor

@pwais pwais left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice progress! sorry its not fast but i think i know why:

i think the main reason this is slower than expected is because _get_collated_batch() gets called per raybundle and sadly _get_collated_batch() is AFAIK needlessly slow.

  • take note about how the current CachedDataloader avoids doing _get_collated_batch() per raybundle. it would have been nice for the author to have left some notes about how slow _get_collated_batch() is, but evidently that author found it's necessary to not collate images per raybundle .
  • in my impl, I just _get_collated_batch() once on a small set of images an keep that batch cached. the main problem I saw is that _get_collated_batch() on thousands of images seemed to use 2x or 3x as much RAM as actually needed and thus cause many minutes of swapping and stuff

Even if you only call _get_collated_batch() once tho, you might need a bigger prefetch factor and/or more workers depending on the model.

IMO it's worth trying to find a way to get the result of nerfstudio_collate on cameras (I think the cameras do need to be collated because they can be ragged? i could be wrong and they don't need collation) but on images just have the worker read image files / buffers and never call collate on those tensors.

Just to be clear, this is the line where collate on images can go nuts and start taking forever to allocate 200GB or more of RAM for many images in code in main:

storage = elem.storage()._new_shared(numel, device=elem.device)

So! If a worker is just emitting raybundles then the images never need to be in shared tensor memory then eh? Thus should be able to save some RAM and CPU by skipping that line for images. Still need to think about the cost of reading the images themselves, but collate is definitely a troublemaker.

nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
Copy link
Contributor

@pwais pwais left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just took a quick look (can't do a full review right now), so cool to see this coming along!!

Sounds like this change will target the case that uncompressed image tensors can't fit in RAM, but the raw image files (typically jpeg) do fit in RAM. In that case I guess we do want each worker to literally load the file bytes into Python RAM (as implemented) versus let the OS disk cache work, because the idea is that the uncompressed image tensors will otherwise blow out the disk cache.

I think it would be important to test in the end like a case where the user only has limited RAM (say 16GB) and e.g. a 8GB laptop graphics card, in that case I think there are moderate or larger image datasets where the whole thing would OOM when using the current cache impl. In that case, it would be helpful to have some way to disable the cache, or just communicate to the user that they simply have too weak of a machine for the dataset (e.g. just a CONSOLE.print("[bold yellow]Warning ...") in the line where the workers start reading image files into RAM.

nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
nerfstudio/data/utils/data_utils.py Show resolved Hide resolved
nerfstudio/data/utils/dataloaders.py Outdated Show resolved Hide resolved
@pwais
Copy link
Contributor

pwais commented Jan 9, 2025

Very cool to see this moving along!! Congrats @AntonioMacaronio !! What sort of datasets sizes have you tested so far, like 500 4k images for nerfacto / depth-nerfacto as well as splatfacto? I would be curious to test now that it's closer to launch!

@kerrj
Copy link
Collaborator

kerrj commented Jan 13, 2025

#3569 will fix the failing tests

@kerrj
Copy link
Collaborator

kerrj commented Jan 13, 2025

Very cool to see this moving along!! Congrats @AntonioMacaronio !! What sort of datasets sizes have you tested so far, like 500 4k images for nerfacto / depth-nerfacto as well as splatfacto? I would be curious to test now that it's closer to launch!

@pwais Anthony made this comparison of 200 vs 2000 4k images.

Copy link
Collaborator

@brentyi brentyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some minor comments. overall looks reasonable to me!

self.train_num_times_to_repeat_images = (
10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images
)
self.prefetch_factor = self.train_num_times_to_repeat_images if self.use_parallel_dataloader else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to avoid these mutations? it seems simpler if we keep the config objects as simple data structures with minimal logic (if possible)

for setting values to 50 and 10, can we just set these as default values for the config field?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this! will change

Copy link
Contributor Author

@AntonioMacaronio AntonioMacaronio Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after a little more investigation, it may not be possible to avoid this. the default value of -1 does serve a purpose when the user wants to load the entire dataset's images into CPU RAM. When train_num_times_to_repeat_images==-1, the RayBatchStream will assign each worker an even partition and together these partitions will form the entire dataset.

With prefetch_factor, it's default value is None represents when the user doesn't want to use a parallel_dataloader, it has to be None for an error to not occur when initializing pytorch's dataloader

Edit: I think it's okay to this actually put these outside of a post init method, but then the method_configs.py will need to have them set to -1 and None if we want to have the original behavior of caching all images into RAM (which is the fastest and how all nerfstudio users are using it now)


if self.use_parallel_dataloader:
try:
torch.multiprocessing.set_start_method("spawn")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems a little bit odd to set this kind of global flag in the config object's __post_init__... is there a better place? maybe right before wherever the processes or dataloader is created?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I agree, maybe in the __init__ of the datamanager?

# If a user would like to load from disk, we pre-emptively set the number of
# workers and prefetch factor to parallelize the dataloading process.
try:
torch.multiprocessing.set_start_method("spawn")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants