-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Advanced Batching Logic with CombinedStreamingDataset #434
Comments
Hey @schopra8 It should be quite simple to implement. In the CombinedStreamingDataset, we need to allow passing the sampling argument and if per stream, get a batch worth of data instead of random sampling across the datasets. Would you want to give it a try and make a PR contribution ? We will help you if you get it started. Best, |
@schopra8 Let me know if you have an interest into making a PR. It is a bit harder for me justify spending time on LitData if a user hasn't started the PR himself. |
@tchaton Definitely interested in helping push out a PR here. Can you give a pointer on where to start? I've been an avid user of the library but haven't dug into much of the internals. |
Am I correct in saying, I need to update the
Is there an easy way to know whether a sample is being drawn for an existing batch (i.e., reuse the same dataset index) or from a new batch (i.e., randomly select a dataset index again)? |
Hey schopra8, Here is how I will do it. When using the StreamingDataLoader, the CombinedStreamingDataset receives the batch size. So we just need to collect the sample from the same dataset while we have accumulated batch size items. Let's have a nicer argument. batching_method="per_stream|stratified" ;) |
Looking at the code, I'm still a bit confused. Where does the iterator, request Thanks in advance @tchaton! |
I took a stab at this nonetheless :) Would love some feedback @tchaton so I can iterate on it.
|
Hey @schopra8 Open a PR to LitData, I will put my reviews there directly. |
@tchaton Awesome! I just opened the PR |
🚀 Feature
When we use
CombinedStreamingDataset
, samples are drawn across all datasets within a single batch. Other packages like MosaicML's StreamingDataset allows you to specify advanced strategies for batching.The current implementation maps to the
random
option from MosaicML. It would great if we could support other batch methods -- especiallyPer Stream
, where in each batch is taken from a single dataset.Motivation
If you're training over K datasets, where each dataset yields samples of different shapes - we currently have to split the
random batch
into microbatches, since the samples from the different datasets cannot be stacked together. These leads to 10-20% slowdown in training, because we're not fully utilizing the GPU on ever forward and backward pass.Pitch
Allow for additional options in batching, e.g. "per stream" -- which dictates how batches are yielded, when using
CombinedStreamingDataset
.Alternatives
Additional context
The text was updated successfully, but these errors were encountered: