-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transforms V2 proposal: Enabling reproducible workflows via local RNGs #7027
Comments
Also, an important thing is propagation of passed generators to all components of transform pipeline? (there's an additional complexity that not all transforms need these generators, but those that can accept them may need it to be propagated) In my own code, I implemented my own containers and transform classes partly because of this. This is possible, but at least there should be reusable staticmethod ways for sampling the transform arguments that can accept rng. |
@vadimkantorov I updated my post to include a description of how one would pass |
@rsokl Yes, I implemented something similar in my own code. For that, all transforms must accept a |
Although somewhat niche and low priority, #3001 (comment) also shows an example of why good RNG support is needed. In case a user wants to use the same random parameters at different points in time, there are currently only two solutions:
Thus, in such a situation it would be really beneficial to just reset a generator and pass it again. |
generalizing and accepting rng/generator optionally to those get_params / sample_params could be a first step towards easier reproducibility (if not done yet...) |
Passing the generator in from torch import Generator
rng = Generator.manual_seed(0)
trans = T.Compose(
[
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480),
]
)
img, bboxes, labels = trans(img, bboxes, labels, generator=rng) All 3 transforms will be using the same What if I want to "freeze" the RNG of one transform in Compose, while preserving maximal entropy for the rest of the transforms? I can't re-seed the In the limit, the only benefit I can see from allowing generators in I'm curious if either @rsokl or @vadimkantorov have found this to be a limitation in practice? (Instead of passing generators in |
Oh, I think we can rule-out |
I spent a bit more time thinking about it and I implemented a toy solution to check what happens when For now I can't think of a clean way to handle all this without requiring users to understand inner-details of the DataLoader, so that's a bummer. But I'm curious what you all think. # %%
import torch
from torch.utils.data import DataLoader
class MyTransform(torch.nn.Module):
def __init__(self, rng):
super().__init__()
self.rng = rng
def forward(self):
return torch.randint(0, 1000, size=(1,), generator=self.rng).item()
class Dataset:
def __init__(self, transform):
self.transform = transform
def __getitem__(self, _):
return self.transform() # no input to the transform, we don't care.
def __len__(self):
return 1000
rng = torch.Generator()
t = MyTransform(rng)
ds = Dataset(t)
# %%
# Dataset only, so far so good
for x, _ in zip(ds, range(4)):
print(x)
# 710
# 284
# 837
# 820
# %%
# Things break with DataLoder(num_workers > 0).
# The generator is duplicated across workers when we fork.
# Oopsies.
# Note: this is actually documented! https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading
dl = DataLoader(ds, num_workers=2)
for x, _ in zip(dl, range(10)):
print(x)
# tensor([299])
# tensor([299])
# tensor([754])
# tensor([754])
# tensor([334])
# tensor([334])
# tensor([739])
# tensor([739])
# tensor([609])
# tensor([609])
# %%
# Only way to make it work is to set a per-worker seed: https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers
# And BTW the only reason things "work" by default is becuase torch does that
# for us already https://github.com/pytorch/pytorch/blob/1a661639f77a172df5d1ccd6987049292c6f3440/torch/utils/data/_utils/worker.py#L223-L225
def worker_init_fn(worker_id):
rng.manual_seed(worker_id)
prev_state = rng.get_state() # surpise surprise, see cell below
dl = DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)
for x, _ in zip(dl, range(10)):
print(x)
# tensor([44])
# tensor([845])
# tensor([239])
# tensor([139])
# tensor([933])
# tensor([124])
# tensor([760])
# tensor([368])
# tensor([963])
# tensor([263])
# %%
# Oh and on top of that, the RNG from the main process is never consumed!
# So we get the exact same RNG across epochs.
# EDIT: As Philip pointed out, this is actually already the case even when the global RNG is used everywhere
# Things only work OK because the global RNG gets consumed elsewhere e.g. by the RandomSampler. Ew.
assert (rng.get_state() == prev_state).all() |
In my own practice, I implemented my own Compose/functions for passing down RNG when I need it (as I'm using mainly the pure tensor functions). Supporting it as a forward argument is a worthy thing, even if it's not passed down by default Compose. Your solution of passing it in the constructor is also not bad! |
In my own code when I need to control the RNG / transforms (sometimes when I had to apply dependent versions of transforms to several images in the batch), I usually implemented my own custom dataset code and my own custom samplers At least to simplify the life for advanced usecases, but it's much better to let an explicit option (be it with RNG in the field or also with forward, can even easily support both - either take the forward arg-provided RNG or the RNG from the field) |
Another fun fact: I just realized that |
A problem is that we may still want different generators for different worker threads, so the generator should not just be cloned, but also seeded with the worker-id or somehow depend on example-id. And also, the match of worker-id and example-id might not be guaranteed even if it was all satisfied. I would propose, that the most useful it's to support passing rng to forwards/get_params. for simplifying advanced cases (e.g. the advanced user may just in Dataset's |
|
@NicolasHug This prototype #7445 by @pmeier is actually similar to what I rolled for my own code |
I am familiar with #7445. I have explained in a few of my comments above why passing RNGs to Let me know if there is anything I can clarify. Otherwise, let's please reduce the noise on this issue and focus on the more urgent matter described in #7027 (comment). I am in a tight loop with torch core to get it hopefully resolved. If you're still keen on debating the forward vs init issue, let's please do so on another issue. |
I was not sure if you were responding to my arguments only or also to the prototype in this linked PR, and I linked it because I missed it up in the thread and because it's more complete and concrete than my words, so I linked it. I respectfully don't consider your technical arguments correct in this case. As I outlined, being able to pass down the aug transforms and rng to the worker from the main thread will not likely solve reproducibility and kills the point of augmentations, as for RNG-based augs to be truly deterministic but meaningful/random enough at the same time they need to be correlated with things like epoch-id and example-id to control for thread scheduling randomness. So I consider the argument of "technically impossible" incorrect and the impossibility-of-pickleness not very relevant for the actual goal of reproducibility (if course it would be nice if Generators can be pickled, but doesn't seem very relevant because of these reasons). You seem to have ignored this argument. This means that the goal of "not adding any changes to Datasets/DataLoaders and still achieving meaninigful reproducibility" is indeed inachievable, but passing down the generator is also not a good solution! The field method to be meaninigful would still require users to reset these RNG objects in the whole aug pipeline in Dataset's But, as I see now, the decision is taken and evaluated the PR in question as well, I of course recognize that, and I see no point arguing anymore in this or any other issues in this repo (for that matter). I will reduce my noise/feedback/conversation/issues in this repo to zero from now on, sorry for the bother. |
Your message above is mixing up 2 orthogonal issues:
They're orthogonal. The first one is solved, and I am working with torch core to address the second. Respectfully, I don't think we have a shared understanding on this topic. Let's leave it at that please. |
🚀 The feature
(This was originally pitched in this long feedback thread. It was recommended that I open a separate issue).
Enable the new transforms API to support the use of local generators to control RNG via the modified API:
Thus transforms that implement
_get_params
would replace calls likewith
A transform like
Compose
would have to be modified as well. Currently, it supports a sequence of callables that are assumed to accept a single positional argument. It could be assumed that only instances ofTransform
involve stochasticity and will be passed the random generator. In this case,Compose
would look like:It would be straightforward to document this behavior to users – that only instances of
Transform
are passed the generator – so that they know how to opt-in to having the generator be passed to their custom transforms. And, again, this would be compatible with the oldnn.Module
transforms.An example of this in practice would be:
Another nice thing about this is that specific fail cases that occur during training/testing can be reproduced in an isolated way;
_get_params(dummy_img, generator=rng)
can be used to iterate the generator's state to "replay" a sequence of transformations without have to redo all of the compute. Whereas this would not work if the model and the transforms both affect and derive from global state.Motivation, pitch
In recent years, NumPy has completely revised their PRNG API to avoid global random state (here is a great post on good practices with NumPy's generators). JAX avoids mutable RNG objects altogether. PyTorch provides
torch.Generator
to users to to make randomness local and "non-spooky", but many libraries prevent users from utilizing this capability.I am proposing that
Transform
enable users to optionally pass in aGenerator
to the forward pass so that torchvision transform pipelines can be made to be isolated from global entropy and thus support more reproducible workflows. This reproducibility is especially useful in the context of performing testing & evaluation – the specific sequence of data transformations performed should be able to be isolated from whether or not a model is using dropout in its forward pass.Alternatives
No response
Additional context
@pmeier already provided (positive) feedback on this proposal here
cc @vfdev-5 @datumbox @bjuncek @pmeier
The text was updated successfully, but these errors were encountered: