-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace torchvision.datasets.utils with functionality from torchdata #6060
base: main
Are you sure you want to change the base?
Conversation
@@ -1,22 +1,26 @@ | |||
from .caltech import Caltech101, Caltech256 | |||
from .celeba import CelebA | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only made changes to four datasets to cover all cases, i.e.
- HTTP download (CIFAR)
- GDrive download (Caltech)
- Kaggle download (FER2013)
- Manual download (ImageNet)
The others are commented out to be able to import torchvision
.
@@ -41,11 +37,6 @@ def _info() -> Dict[str, Any]: | |||
return dict(categories=categories, wnids=wnids) | |||
|
|||
|
|||
class ImageNetResource(ManualDownloadResource): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of subclassing just to parametrize I followed the approach in #6052.
@@ -42,7 +42,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: | |||
yield from self._dp | |||
|
|||
@abc.abstractmethod | |||
def _resources(self) -> List[OnlineResource]: | |||
def _resources(self) -> Sequence[OnlineResource]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gets rid of the annoying mypy
warning when returning List[ManualDownloadResource]
Instead of List[OnlineResource]
.
@classmethod | ||
def from_http(cls, url: str, *, file_name: Optional[str] = None, **kwargs: Any) -> OnlineResource: | ||
return cls(url, file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs) | ||
|
||
@classmethod | ||
def from_gdrive(cls, id: str, **kwargs: Any) -> OnlineResource: | ||
return cls(f"https://drive.google.com/uc?export=download&id={id}", **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By using the OnlineReader
datapipe, we no longer need different download functionality for HTTP or GDrive. Thus, I was able to remove the subclasses and replace them with these thin instantiation class methods.
If we switch to iopath
in the future, we can probably remove them as well.
stream = list(dp)[0][1] | ||
|
||
with open(file, "wb") as fh, tqdm() as progress_bar: | ||
for chunk in iter(lambda: stream.read(1024 * 1024), b""): # type: ignore[no-any-return] | ||
# filter out keep-alive new chunks | ||
if not chunk: | ||
continue | ||
|
||
fh.write(chunk) | ||
progress_bar.update(len(chunk)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There a two reasons I didn't do this with datapipes:
- We should probably have a progress bar for downloads, since they usually are quite large and a long running process without output is always suspicious. By importing
tqdm
fromtorch.hub
we get vanillatqdm
if it is installed and a minimal port if it is not. - AFAIK,
torchdata
has no builtin option to read chunks from a stream. We need that, because we otherwise load the whole file into memory before saving it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I agree having a progress bar for download is a good idea. Maybe
torchdata
can have a dependency fortqdm
. - I suppose adding an option to return
r.iter_content()
instead ofr.raw
would work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Apart from
tqdm
,rich
also added support for progress bars recently. That is what newer versions ofpip
are using. Just leaving this here to asses the options before deciding to depend ontqdm
. -
I would avoid having a flag the changes the output types. Having something like a
ChunkReader
datapipe can be useful in general not just for HTTP requests, right?from torchdata.datapipes.iter import IterDataPipe class ChunkReader(IterDataPipe): def __init__(self, datapipe, *, chunk_size=32 * 1024 * 1024): self.datapipe = datapipe self.chunk_size = chunk_size def __iter__(self): for path, stream in self.datapipe: for chunk in iter(lambda: stream.read(self.chunk_size), b""): # filter out keep-alive new chunks from HTTP streams if not chunk: continue yield path, chunk
A utility like that is needed anyway for
Saver
anyway. AFAIK, there is currently no builtin functionality to read data from a stream, correct? I'm guessing that is why theSaver
example usesbytes
as input.
If there is more to discuss, let's migrate this to a proper issue on the torchdata
repository.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will open an issue in torchdata
to get more thoughts from the community and see if there is anything that we may be missing.
IIUC, since _get_response_from_http
returns r.raw
, rather than say r.iter_content()
, it loads the whole response into memory at once. Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will open an issue in
torchdata
to get more thoughts from the community and see if there is anything that we may be missing.
See pytorch/data#439. It seems we missed StreamReader
, which does exactly that. Still, could you open an issue regarding the progress bar?
IIUC, since
_get_response_from_http
returnsr.raw
, rather than sayr.iter_content()
, it loads the whole response into memory at once. Is that correct?
r.raw
just gives you the stream and is somewhat like BinaryIO
, whereas r.iter_content()
reads chunks from the stream and thus is a Iterator[bytes]
. IMO returning r.raw
is the right choice, since attaching a StreamReader
to a HTTPReader
will give us the iter_content()
behavior.
dp = HashChecker(dp, {str(file): self.sha256}, hash_type="sha256") | ||
list(dp) | ||
|
||
return file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we opted to only perform the preprocessing after the download in #5990, we can probably also move it here.
dp = IterableWrapper([self.url]) | ||
dp = OnlineReader(dp) | ||
stream = list(dp)[0][1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that can be a useful functionality. We probably want to add it as a new DataPipe rather than modifying the HttpReader
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OnlineReader
has no support for redirects.
Agree we can add that to that and HttpReader
.
GDriveReader
is missing functionality that we support. For example, we can deal with Virus scan warnings that popped up recently.
We should add that.
For that we need automatic detection of the archive type.
Is it possible to extend Decompressor
to do more of that? Or will you need something different?
stream = list(dp)[0][1] | ||
|
||
with open(file, "wb") as fh, tqdm() as progress_bar: | ||
for chunk in iter(lambda: stream.read(1024 * 1024), b""): # type: ignore[no-any-return] | ||
# filter out keep-alive new chunks | ||
if not chunk: | ||
continue | ||
|
||
fh.write(chunk) | ||
progress_bar.update(len(chunk)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I agree having a progress bar for download is a good idea. Maybe
torchdata
can have a dependency fortqdm
. - I suppose adding an option to return
r.iter_content()
instead ofr.raw
would work?
Nope. As the name implies class ZipArchiveLoader(IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe
@staticmethod
def _load(path, stream):
for info in zipfile.ZipFile(stream):
extracted_path = ...
extracted_stream = ...
yield os.path.join(path, extracted_path), StreamWrapper(extracted_stream)
def __iter__(self):
for path, stream in self.datapipe:
yield from self._load(path, stream)
class TarArchiveLoader(IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe
@staticmethod
def _load(path, stream):
for info in tarfile.open(stream):
extracted_path = ...
extracted_stream = ...
yield os.path.join(path, extracted_path), StreamWrapper(extracted_stream)
def __iter__(self):
for path, stream in self.datapipe:
yield from self._load(path, stream)
class ArchiveLoader(IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe
_LOADERS = {
".zip": ZipArchiveLoader._load,
".tar": TarArchiveLoader._load,
}
def __iter__(self):
for path, stream in self.datapipe:
load = self._LOADERS[os.path.splitext(path)[1]]
yield from load(path, stream) With this we can extract any archive with dp = IterableWrapper([path_to_any_kind_of_archive])
dp = ArchiveLoader(dp)
dp = ChunkReader(dp) # see comment above
dp = Saver(dp, ...)
list(dp) Alternatively, I could also live with a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NivekT @ejguan While working on the next iteration of this, I encountered two issues documented as TODO
below and tracked in pytorch/data#451 and pytorch/data#452.
D = TypeVar("D") | ||
|
||
|
||
class ProgressBar(IterDataPipe[D]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This my naive approach to add a progress bar as a datapipe. It is sufficient for our use case, but needs more work for general use. For example, the implementation relies on the fact that all items in the datapipe should be handles by the same progress bar. This works for us, since we get chunks of a single file. But if we would process multiple files with the pipeline, the progress bar would aggregate everything together.
Related on GDrive and other cloud downloads: pytorch/pytorch#73466 |
As the title implies, we want to get rid of as much functionality as possible from
torchvision.datasets.utils
to avoid duplication between the libraries. This PR adds an initial draft what already can be achieved and what has to be ported totorchdata
before we can depend on it.Apart from my inline comments below, there a number of othe
OnlineReader
has no support for redirects. That is crucial since quite a few datasets have permanent redirects to keep the original URL reported in the paper alive while being able to host elsewhere. We have support for resolving redirects by performing head requests until the returned URL matches the one we started with:vision/torchvision/datasets/utils.py
Lines 88 to 101 in ac56f52
That could be integrated into
OnlineReader
.GDriveReader
is missing functionality that we support. For example, we can deal with Virus scan warnings that popped up recently. In addition improve error handling for GDrive downloads #5704 adds more common error handling that would also be nice for users. There should be no issue porting thistorchdata
The way we are currently loading archives is very similar to loading a folder of files: enumerate all files and return the path-file pairs. For that we need automatic detection of the archive type:
vision/torchvision/prototype/datasets/utils/_resource.py
Lines 72 to 85 in ac56f52
I've requested that quite some time ago on the old private
torchdata
repository and I remember another discussion with @ejguan on a PR, but we never got anywhere. My reasoning is that for our purpose we don't care about the actual type, we just want the path-file tuples. This is aligned with theOnlineReader
for which we also don't care how the API loads data either from HTTP or GDrive. Plus, there is also theDecompressor
datapipe that does the same different compression types.