Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving to LitData: Refactoring data pipeline #80

Open
3 of 6 tasks
gitttt-1234 opened this issue Sep 5, 2024 · 0 comments · May be fixed by #90, #91, #92 or #94
Open
3 of 6 tasks

Moving to LitData: Refactoring data pipeline #80

gitttt-1234 opened this issue Sep 5, 2024 · 0 comments · May be fixed by #90, #91, #92 or #94

Comments

@gitttt-1234
Copy link
Contributor

gitttt-1234 commented Sep 5, 2024

The primary bottleneck in our training pipeline is the dataloader performance (currently training time per epoch is very high due to the dataloader - IterDataPipe and 2-3 times slower then TensorFlow used in SLEAP).

We evaluated the performance of LitData, an API designed to optimize data pipelines by enabling multi-machine and distributed processing, along with built-in cloud-scalable solutions. With the support for efficient parallelization and memory handling, LitData accelerates the performance of data-intensive training processes. We benchmarked hte performance of LitData across all data pipelines (Single instance, centroid, Centered-instance and Bottom-up) and achieved nearly on-par performance with TensorFlow in all cases (except for single instance, which remains 1.5x slower than TensorFlow).

This PR details the plan for refactoring our current data pipeline.

MVP:

PR1:

  • Break down all operations in the __iter__ methods across the data modules into individual, well-defined functions.

PR2:

  • Implement the get_chunks() method for each model pipeline. This method handles all the data preprocessing functions (except augmentation, resizing/ pad_to_stride and confidence map (or pafs) generation) to extract dictionaries from .slp file and save them as .bin files.
    • For centroid model, the centroids are computed inside get_chunks().
    • For centered-instance model, the crops are generated (with crop size as: crop_hw * (np.sqrt(2) - 1)) to account for blacking of edges when applying rotation augmentation. The images are recropped to crop_hw in the litdata.StreamingDataset.__getitem__() method.

PR3:

  • Implement a custom litdata.StreamingDataset class for each model type. Apply augmentation, resizer, pad_to_stride and generates confidence maps (and part affinity fields for bottom-up model) in the litdata.StreamingDataset.__getitem__() method.

PR4:

  • Integrate with training.model_trainer.ModelTrainer class. In _create_data_loaders(), use ld.optimize(fn = get_chunks) to generate the .bin files. Pass the .bin dir path to the litdata.StreamingDataset class. Ensure the .bin files are deleted after training.

Example

get_chunks() function

import litdata as ld

def single_instance_get_chunks(lf: sleap_io.LabeledFrame):
    image, instances = get_img_inst_from_lf(lf) # extract image and instances from labeled frame and convert to `torch.Tensor`s.

    image = normalize(image) # includes converting to/ from RGB from/ to grayscale

    image, instances = resize(image, instances)

    image = pad_to_stride(image, max_stride)

    ex = {
          "image": image, 
          "instances": instances, 
          "orig_size": orig_size, 
          "frame_idx": lf.frame_idx, 
          "video_idx": video_idx
       }
    
    return ex

labels = sio.load_slp("test.pkg.slp")
ld.optimize(
        fn = single_instance_get_chunks,
        inputs = [x for x in labels],
        output_dir="./single_instance_chunks/",
        num_workers=2,
        chunk_size=100
    )

Custom StreamingDataset

class SingleInstanceDataset(ld.StreamingDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __getitem__(self, index):
        ex = super().__getitem__(index)

        image, instances = ex["image"], ex["instances"]

        image, instances = augmentation(image, instances)

        confidence_maps = get_confmaps(instances, image.shape[-2:])

        sample = {
            "image": image,
            "video_idx": ex["video_idx"],
            "frame_idx": ex["frame_idx"],
            "orig_size": ex["orig_size"],
            "instances": instances,
            "confidence_maps": confidence_maps,
        }

        return sample

Next steps:

  • Need to re-implement Cycler.
  • Speed-up torch.exp function?
@gitttt-1234 gitttt-1234 linked a pull request Sep 12, 2024 that will close this issue
@gitttt-1234 gitttt-1234 linked a pull request Sep 12, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment