-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Moving to LitData: Refactoring data pipeline #80
Closed
3 of 6 tasks
Comments
This was referenced Sep 5, 2024
This was referenced Sep 19, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The primary bottleneck in our training pipeline is the dataloader performance (currently training time per epoch is very high due to the dataloader - IterDataPipe and 2-3 times slower then TensorFlow used in SLEAP).
We evaluated the performance of LitData, an API designed to optimize data pipelines by enabling multi-machine and distributed processing, along with built-in cloud-scalable solutions. With the support for efficient parallelization and memory handling, LitData accelerates the performance of data-intensive training processes. We benchmarked hte performance of LitData across all data pipelines (Single instance, centroid, Centered-instance and Bottom-up) and achieved nearly on-par performance with TensorFlow in all cases (except for single instance, which remains 1.5x slower than TensorFlow).
This PR details the plan for refactoring our current data pipeline.
MVP:
PR1:
__iter__
methods across the data modules into individual, well-defined functions.PR2:
get_chunks()
method for each model pipeline. This method handles all the data preprocessing functions (except augmentation, resizing/pad_to_stride
and confidence map (or pafs) generation) to extract dictionaries from.slp
file and save them as.bin
files.get_chunks()
.crop_hw
in thelitdata.StreamingDataset.__getitem__()
method.PR3:
litdata.StreamingDataset
class for each model type. Apply augmentation, resizer,pad_to_stride
and generates confidence maps (and part affinity fields for bottom-up model) in thelitdata.StreamingDataset.__getitem__()
method.PR4:
training.model_trainer.ModelTrainer
class. In_create_data_loaders()
, useld.optimize(fn = get_chunks)
to generate the.bin
files. Pass the.bin
dir path to thelitdata.StreamingDataset
class. Ensure the.bin
files are deleted after training.Example
get_chunks() function
Custom StreamingDataset
Next steps:
Cycler
.torch.exp
function?The text was updated successfully, but these errors were encountered: