-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Notes on shuffling, sharding, and batchsize #302
Comments
Thank you for adding the detailed notes.
I totally understand this problem. This is one thing we want to solve during this half. For each
Dataset produces 8 batches with size 18. IterDataPipe produces 96 batches with size 1 at the last iteration. This should be some discrepancy over sampling mechanism in DataLoader. Will take a look |
I went through the example again and found the reason about this discrepancy.
cc: @VitalyFedyunin |
Do you think adding a |
@ejguan I think ensuring that each image is processed only once per training epoch is a very common and reasonable assumption that we should maintain. Note that there techniques such as Repeated Augmentation that rely on agreeing on using a single subset of data across processes and then generate augmentations on top of them. Though there might be a way to go around it with your new API/approach, this technique shows that using different seeds across processes has a material effect on accuracies and can't be assumed to be necessarily similar.
@NicolasHug This is a very reasonable guess especially since you proven that many batches get a sample size of 1. This is very likely to make the mini-batch statistics extremely noisy and break training. There is a way to confirm that this is the problem. You can replace BN with SyncBatchNorm to synchronize the statistics across the different batches. It's going to be slower (so it's not a viable solution to your problem) but it's going to give you evidence on whether the BN is to blame here. Not sure how useful this would be for you though as I agree with you that at any case we should consider aligning the behaviour with the old DataLoader. |
@datumbox Thanks for your insights on the effect of random seed. I agree the Here is my thinking about the random seed. Whatever random operations before sharding should use the same random seed to make sure the whole datasets are the same across process.
|
@ejguan Thanks for getting back to me so promptly. I would expect that shard1 and shard2 have different augmentation strategies similar to what the current DataLoader is doing. As Nicolas said earlier, having the same augmentations systematically across batches can lead to introduction of biases and thus less accurate results. I think many of the assumptions of the current DataLoader have become common assumptions on research and moving away from them can lead unexpected results. |
@datumbox Hi! I'm an external user and was just following this thread as I am also very interested in sharding/shuffling behavior with DataPipes. I noticed your comment above and just wanted to clarify one thing: The assumption of "each image is processed only once per training epoch" is not guaranteed true when DDP is used along with You can see the codepath in
Devices 2 and 3 have to repeat some samples on Step 2 in order for PyTorch DDP to work. So within a single epoch, a total of 12 samples have been processed, not 10. Note that if only 1 or 2 GPUs had been used, then exactly 10 samples would have been processed. Anyways, just wanted to share, the difference in behavior is pretty small, especially when shuffe=True, the 2 repeated samples are drawn randomly each epoch. But basically I think it would be OK for DataPipes to also repeat some samples per-device in DDP mode as we never had that guarantee to begin with. |
@abhi-mosaic Thanks for the input. Though I agree with the technical details you provided, it's important to note how the two points being made here differ:
So the assumption of having each image is processed only once per training epoch, is in practice true at the moment, with exception the corner-case of batch padding listed above and that's what I'm advocating we should maintain. If we don't keep the seed the same, we risk being unable to reproduce past results using the new API, introducing large systematic biases on training and being unable to implement techniques similar to Repeated Augmentation properly.
Agreed and that's why I think we should make sure that the new solution behaves similarly and predictably to the old implementation. |
@ejguan I think your suggestion of keeping the seed constant across workers up until the sharding is reasonable. As @datumbox re-itereated, we need different seeds after that. Let's discuss this more in depth in today's meeting, along with a robust user-friendly API, following up on #352 (comment) |
The user friendly API would be welcome 😄 My understanding was that I'm trying to implement a quite simple datapipe (which involves only shuffling and mapping) and it doesn't really work. from torch.utils.data.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
def dummy_dp(n=10):
dp = IterableWrapper(range(n))
dp = dp.shuffle()
return dp
dp = dummy_dp(10)
dl = DataLoader(dp, batch_size=None, shuffle=False, num_workers=1, pin_memory=False, drop_last=False)
print([e for e in dl]) would give
I would really appreciate a simple working example of the above. This is, with shuffling in the datapipe only and num_workers>1. Or an explanation of what to do before DataLoaderV2. Thank you 😄 |
Answering my own question, setting What is non intuitive to me is the false separation between the datapipe and the dataloader. The dataloader still has control over the shuffling! I leave this only as an anecdote, you know way better than me the reasons behind this choice |
Find here one example showing how things are not yet solved 😢 from typing import Iterator, Any
from torch.utils.data.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe
import torch.distributed as dist
class SharderDataPipe(torch.utils.data.datapipes.iter.grouping.ShardingFilterIterDataPipe):
def __init__(self, source_datapipe: IterDataPipe) -> None:
super().__init__(source_datapipe)
self.rank = 0
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.apply_sharding(self.world_size, self.rank)
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_id = self.rank
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_id + worker_info.id * num_workers
num_workers *= worker_info.num_workers
self.apply_sharding(num_workers, worker_id)
yield from super().__iter__()
def dummy_dp_0(n, num_workers):
dp = IterableWrapper(range(n))
dp = dp.shuffle()
dp = SharderDataPipe(dp)
return dp
def dummy_dp_1(n, num_workers):
dp = IterableWrapper(range(n))
dp = SharderDataPipe(dp)
dp = dp.shuffle()
return dp
def dummy_dp_2(n, num_workers):
dp = IterableWrapper(range(n))
dp = dp.shuffle(buffer_size=n)
dp = SharderDataPipe(dp)
dp = dp.shuffle()
return dp
num_workers = 2
for dp in [dummy_dp_0(n=10, num_workers=num_workers),
dummy_dp_1(n=10, num_workers=num_workers),
dummy_dp_2(n=10, num_workers=num_workers)]:
torch.manual_seed(0)
dl = DataLoader(dp, batch_size=None, shuffle=None, num_workers=num_workers, pin_memory=False, drop_last=False)
print([e for e in dl]) gives
What seems intuitive to me is that each worker has access to a different local seed and to a global seed. Then in |
@franchesoni
Yeah. Either setting |
As a work around for now, you have to provide For
This is the result if you do And, the ultimate solution should be something like the following (This needs DataLoader2 landed into TorchData): dp = IterableWrapper(range(n))
dp = dp.shuffle()
dp = SharderDataPipe(dp)
dl = DataLoader2(dp, ...)
for epoch in range(2):
dl.set_seed(0)
print(list(dl)) There are a few features need to be done:
|
Thank you for the explanation. If I understand well, in the future, a Do we have today any solution that solves both shuffling and augmentation? A modification of your example
should work well. But where (or when) can I access all these new cool functionalities? |
@franchesoni dp = IterableWrapper(range(n))
dp = dp.shuffle()
dp = SharderDataPipe(dp)
dp = dp.map(lambda x: random_augment(x)) # franchesoni added this line
dp = dp.random_op(rng=random.Random()) # the API is not determined yet.
dl = DataLoader2(dp, ...)
for epoch in range(2):
dl.set_seed(0)
print(list(dl)) LMK if this design works for you. And, for dp = IterableWrapper(range(n))
dp = dp.shuffle().set_seed(0)
dp = SharderDataPipe(dp)
dp = dp.map(lambda x: random_augment(x)) # franchesoni added this line
dp = dp.random_op(rng=random.Random()) # the API is not determined yet.
dl = DataLoader2(dp, ...)
for epoch in range(2):
dl.set_seed(0)
print(list(dl))
Adding isolated RNG and API would be my high-pri next week. |
This might be ambiguous with all the other types of RNGs that are at play here. It's not clear from the API that this only affects the shuffling seed. Perhaps |
The seed for |
But, this requires users to add this |
Is there a way for the DataLoader2 to be aware of a specific shuffling seed? In the case of map-style datasets, all these details are properly handled by the sampler or by the DataLoader, and users don't really need to worry about the RNG at all (except maybe with |
I am not sure about it. Users still need to specify seeds for all global RNGs using def seed_worker(worker_id):
random.seed(worker_id)
...
ds = Dataset() # Dummy Map-style Dataset
for epoch in range(2):
torch.manual_seed(epoch)
dl = DataLoader(ds, shuffle=True, worker_init_fn=seed_worker, num_workers=2) For dp = IterableWrapper(range(n))
dp = dp.shuffle()
dp = SharderDataPipe(dp)
dl = DataLoader2(dp, ...)
for epoch in range(2):
# dl.set_seed(0)
print(list(dl)) # dl will use time as the seed in the main process, and generate the seed for shuffle WDYT? |
I don't believe so. In the torchvision training references for example (which currently use map-style datasets), we don't need to worry about seeds at all. And yet:
Ideally the same simplicity would be available with datapipes.
Are you sure? I believe that the 2 points above are still verified even if we were to call |
This is not the mechanism for Map-style
You mean users don't need to call If not specified, the seed for the If you don't do manual seed to
I am not 100% sure about how TorchVision augmentation works. But, if all random augmentation relies on the global Torch RNG, to control reproducibility, you have to specify If users don't want to preserve reproducibility, users don't need to specify any of |
I agree with everything you wrote @ejguan , but I think we're not entirely talking about the same thing. My main concern here is to provide an API for datapipes that is just as simple as what we currently have for map-style datasets.
Now considering the reproducible use-case, e.g. when users call
In both these use-cases with map-style datasets, things are kept extremely simple and hidden from the user. At no point do they need to worry about a specific seed for shuffling, or a specific seed for transforms. They can just call Perhaps I'm missing some important detail? |
Great summary! And, thanks for bringing this discussion. It triggers me some thinking about BC @NicolasHug
For reproducible use case
You can see the only difference from users perspective is when they want to have an isolated RNG per random op ( LMK if this makes more sense. |
@ejguan and I just had a chat offline where we discussed some of the points above. Here's a summary of our discussion thus far. The points below are either re-hashing, or updating / correcting the ones above. @ejguan please feel free to edit / correct if this isn't accurate. And thanks again for your time and all your work on this!
|
Yeah. It will be one TODO in the proposal. I will make the RNG attached to shuffler can be seeded by |
Sorry for the delay. You figured things out already in a nice way. Given that If I understood well, the last proposal dp = IterableWrapper(range(n))
dp = dp.shuffle().set_seed(0)
dp = SharderDataPipe(dp)
dp = dp.map(lambda x: random_augment(x)) # franchesoni added this line
dp = dp.random_op(rng=random.Random()) # the API is not determined yet.
dl = DataLoader2(dp, ...)
for epoch in range(2):
dl.set_seed(0)
print(list(dl)) changed and it is now n = 10
dp = IterableWrapper(range(n))
dp = dp.shuffle() # one per worker using shared seed
dp = SharderDataPipe(dp) # data subsampling depending on worker num (function defined way above)
dp = dp.map(lambda x: random_augment(x)) # one per worker using worker seed
dl = DataLoader2(dp, ...)
for epoch in range(2): # will print the same correctly shuffled range twice
torch.manual_seed(0)
print(list(dl)) 1- is this right? I came here because my code said assert (
num_workers <= 1
), "this should be 1 until https://github.com/pytorch/data/issues/302 is solved (check if it is)" happy to see it has advanced. Let me know when I can remove this assertion! 😝 |
(I'm writing this down here to have a written trace, but I'm looking forward to discuss this with you all in our upcoming meetings :) )
I spent some time porting the torchvision training recipes to use datapipes, and I noticed that the model I trained on ImageNet with DPs was much less accurate than the one with regular datasets. After a lot of digging I came to the following conclusion:
Details below. Note: for sharding, I used this custom torchvision sharder which takes DDP and dataloader workers into account, + the TakerIterDataPipe below it.
Shuffle before shard
First, some quick results (training a resnext50_32x4d for 5 epochs with 8 GPUs and 12 workers per GPU):
Shuffle before shard: Acc@1 = 47% -- this is on par with the regular indexable dataset version (phew!!)
Shuffle after shard: Acc@1 = 2%
One way to explain this is that if we shuffle after we shard, then only sub-parts of the dataset get shuffled. Namely, each of the 8 * 12 = 96 dataloader workers receive ~1/96th of the dataset, and each of these parts get shuffled. But that means that the shuffling is far from uniform and for datasets in which the layout is
all_samples_from_class1, all_samples_from_class2, ... all_samples_from_classN
, it's possible that some class i is never in the same batch as class j.So it looks like we need to shuffle before we shard. Now, if we shuffle before sharding, we still need to make sure that all of the 96 workers shuffle the dataset with the same RNG. Otherwise we risk sampling a given sample in more than one worker, or not at all. For that to happen, one can set a random seed in
worker_init_fn
, but that causes a second problem: the random transformations of each worker will also be the same, and this will lead to slightly less accurate results; on top of that, all epochs will start with the same seed, so the shuffling is the same across all epochs. I do not know how to solve this problem yet.Note that TF shuffles the dataset before storing it. We might do something similar, but that would still not solve the issue for custom users datasets.
Size of the batches at the end of an epoch
Some quick results (same experiment as above):
with drop_last=True: Acc@1 = 47%
with drop_last=False: Acc@1 = 11%
Near the end of the epoch, the dataloader with DP will produce a lot of batches with size 1 if drop_last is False. See the last batches of an epoch on indices from
[0, len(imagenet))
with a requested batch size of 32: https://pastebin.com/wjS7YC90. In contrast, this does not happen when using an indexable dataset: https://pastebin.com/Rje0U8Dx.I'm not too sure of why this has such a dramatic impact, but it's possible that this has to do with batch-norm, as @fmassa pointed out offline. Using
drop_last
will make sure that the 1-sized batches are eliminated, producing a much better accuracy.I guess the conclusion here is that it's worth unifying the behaviour of the DataLoader both DPs and regular indexable datasets regarding the batch size, because with indexable datasets and drop_last=False we still get ~47% acc.
The text was updated successfully, but these errors were encountered: