Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced Batching Logic with CombinedStreamingDataset #434

Open
schopra8 opened this issue Dec 13, 2024 · 9 comments · May be fixed by #438
Open

Advanced Batching Logic with CombinedStreamingDataset #434

schopra8 opened this issue Dec 13, 2024 · 9 comments · May be fixed by #438
Labels
enhancement New feature or request

Comments

@schopra8
Copy link

🚀 Feature

When we use CombinedStreamingDataset, samples are drawn across all datasets within a single batch. Other packages like MosaicML's StreamingDataset allows you to specify advanced strategies for batching.

The current implementation maps to the random option from MosaicML. It would great if we could support other batch methods -- especially Per Stream, where in each batch is taken from a single dataset.

Motivation

If you're training over K datasets, where each dataset yields samples of different shapes - we currently have to split the random batch into microbatches, since the samples from the different datasets cannot be stacked together. These leads to 10-20% slowdown in training, because we're not fully utilizing the GPU on ever forward and backward pass.

Pitch

Allow for additional options in batching, e.g. "per stream" -- which dictates how batches are yielded, when using CombinedStreamingDataset.

Alternatives

Additional context

@schopra8 schopra8 added the enhancement New feature or request label Dec 13, 2024
@tchaton
Copy link
Collaborator

tchaton commented Dec 14, 2024

Hey @schopra8 It should be quite simple to implement.

In the CombinedStreamingDataset, we need to allow passing the sampling argument and if per stream, get a batch worth of data instead of random sampling across the datasets.

Would you want to give it a try and make a PR contribution ? We will help you if you get it started.

Best,
T.C

@tchaton
Copy link
Collaborator

tchaton commented Dec 14, 2024

@schopra8 Let me know if you have an interest into making a PR. It is a bit harder for me justify spending time on LitData if a user hasn't started the PR himself.

@schopra8
Copy link
Author

@tchaton Definitely interested in helping push out a PR here. Can you give a pointer on where to start? I've been an avid user of the library but haven't dug into much of the internals.

@schopra8
Copy link
Author

Am I correct in saying, I need to update the _get_dataset_index method in the _CombinedDatasetIterator object?

  • I can add a new parameter to the CombinedStreamingDataset constructor called construct_batch_from_single_dataset and then pass this to the constructor for _CombinedDatasetIterator.
  • If construct_batch_from_single_dataset = True then the _CombinedDatasetIterator's _get_dataset_index method will return the same dataset index for every sample in a batch.
  • If construct_batch_from_single_dataset = False then the _CombinedDatasetIterator's _get_dataset_index method will return the a randomly selected dataset index for every sample. This is the current behavior.

Is there an easy way to know whether a sample is being drawn for an existing batch (i.e., reuse the same dataset index) or from a new batch (i.e., randomly select a dataset index again)?

@tchaton
Copy link
Collaborator

tchaton commented Dec 17, 2024

Hey schopra8,

Here is how I will do it. When using the StreamingDataLoader, the CombinedStreamingDataset receives the batch size. So we just need to collect the sample from the same dataset while we have accumulated batch size items.

Let's have a nicer argument. batching_method="per_stream|stratified" ;)

@schopra8
Copy link
Author

Looking at the code, I'm still a bit confused.

Where does the iterator, request batch_size samples? I'm struggling to figure out where I can tell the the underlying CombinedDatasetIterator whether to select a new dataset or re-use the existing dataset (i.e., how to know whether a batch is accumulating samples or we're starting a new batch).

Thanks in advance @tchaton!

@schopra8
Copy link
Author

I took a stab at this nonetheless :) Would love some feedback @tchaton so I can iterate on it.

Init Implementation

  • Added a batching_method parameter to CombinedStreamingDataset and _CombinedDatasetIterator. Default is stratified
  • Updated the _get_dataset_index method. If stratified, we invoke a _set_new_dataset_index method to randomly select a dataset index. If per_stream, we re-use the previous dataset index, which is stored within the _CombinedDatasetIterator as self._cur_dataset_index.
  • Within the StreamingDataLoader, we manually invoke self._cur_dataset_index every batch, if we're dealing with a CombinedStreamingDataset. This ensures that the dataset index is updated on batch boundaries (necessary for the per_stream batching option).

@tchaton
Copy link
Collaborator

tchaton commented Dec 21, 2024

Hey @schopra8 Open a PR to LitData, I will put my reviews there directly.

@schopra8 schopra8 linked a pull request Dec 22, 2024 that will close this issue
4 tasks
@schopra8
Copy link
Author

@tchaton Awesome! I just opened the PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants