Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace torchvision.datasets.utils with functionality from torchdata #6060

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented May 20, 2022

As the title implies, we want to get rid of as much functionality as possible from torchvision.datasets.utils to avoid duplication between the libraries. This PR adds an initial draft what already can be achieved and what has to be ported to torchdata before we can depend on it.

Apart from my inline comments below, there a number of othe

  • OnlineReader has no support for redirects. That is crucial since quite a few datasets have permanent redirects to keep the original URL reported in the paper alive while being able to host elsewhere. We have support for resolving redirects by performing head requests until the returned URL matches the one we started with:

    def _get_redirect_url(url: str, max_hops: int = 3) -> str:
    initial_url = url
    headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
    for _ in range(max_hops + 1):
    with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
    if response.url == url or response.url is None:
    return url
    url = response.url
    else:
    raise RecursionError(
    f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
    )

    That could be integrated into OnlineReader.

  • GDriveReader is missing functionality that we support. For example, we can deal with Virus scan warnings that popped up recently. In addition improve error handling for GDrive downloads #5704 adds more common error handling that would also be nice for users. There should be no issue porting this torchdata

  • The way we are currently loading archives is very similar to loading a folder of files: enumerate all files and return the path-file pairs. For that we need automatic detection of the archive type:

    _ARCHIVE_LOADERS = {
    ".tar": TarArchiveLoader,
    ".zip": ZipArchiveLoader,
    ".rar": RarArchiveLoader,
    }
    def _guess_archive_loader(
    self, path: pathlib.Path
    ) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]:
    try:
    _, archive_type, _ = _detect_file_type(path.name)
    except RuntimeError:
    return None
    return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type]

    I've requested that quite some time ago on the old private torchdata repository and I remember another discussion with @ejguan on a PR, but we never got anywhere. My reasoning is that for our purpose we don't care about the actual type, we just want the path-file tuples. This is aligned with the OnlineReader for which we also don't care how the API loads data either from HTTP or GDrive. Plus, there is also the Decompressor datapipe that does the same different compression types.

@@ -1,22 +1,26 @@
from .caltech import Caltech101, Caltech256
from .celeba import CelebA

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only made changes to four datasets to cover all cases, i.e.

  • HTTP download (CIFAR)
  • GDrive download (Caltech)
  • Kaggle download (FER2013)
  • Manual download (ImageNet)

The others are commented out to be able to import torchvision.

@@ -41,11 +37,6 @@ def _info() -> Dict[str, Any]:
return dict(categories=categories, wnids=wnids)


class ImageNetResource(ManualDownloadResource):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of subclassing just to parametrize I followed the approach in #6052.

@@ -42,7 +42,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
yield from self._dp

@abc.abstractmethod
def _resources(self) -> List[OnlineResource]:
def _resources(self) -> Sequence[OnlineResource]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets rid of the annoying mypy warning when returning List[ManualDownloadResource] Instead of List[OnlineResource].

Comment on lines +57 to +63
@classmethod
def from_http(cls, url: str, *, file_name: Optional[str] = None, **kwargs: Any) -> OnlineResource:
return cls(url, file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs)

@classmethod
def from_gdrive(cls, id: str, **kwargs: Any) -> OnlineResource:
return cls(f"https://drive.google.com/uc?export=download&id={id}", **kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using the OnlineReader datapipe, we no longer need different download functionality for HTTP or GDrive. Thus, I was able to remove the subclasses and replace them with these thin instantiation class methods.

If we switch to iopath in the future, we can probably remove them as well.

Comment on lines 73 to 82
stream = list(dp)[0][1]

with open(file, "wb") as fh, tqdm() as progress_bar:
for chunk in iter(lambda: stream.read(1024 * 1024), b""): # type: ignore[no-any-return]
# filter out keep-alive new chunks
if not chunk:
continue

fh.write(chunk)
progress_bar.update(len(chunk))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There a two reasons I didn't do this with datapipes:

  1. We should probably have a progress bar for downloads, since they usually are quite large and a long running process without output is always suspicious. By importing tqdm from torch.hub we get vanilla tqdm if it is installed and a minimal port if it is not.
  2. AFAIK, torchdata has no builtin option to read chunks from a stream. We need that, because we otherwise load the whole file into memory before saving it.

cc @ejguan @NivekT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I agree having a progress bar for download is a good idea. Maybe torchdata can have a dependency for tqdm.
  2. I suppose adding an option to return r.iter_content() instead of r.raw would work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Apart from tqdm, rich also added support for progress bars recently. That is what newer versions of pip are using. Just leaving this here to asses the options before deciding to depend on tqdm.

  2. I would avoid having a flag the changes the output types. Having something like a ChunkReader datapipe can be useful in general not just for HTTP requests, right?

    from torchdata.datapipes.iter import IterDataPipe
    
    
    class ChunkReader(IterDataPipe):
        def __init__(self, datapipe, *, chunk_size=32 * 1024 * 1024):
            self.datapipe = datapipe
            self.chunk_size = chunk_size
    
        def __iter__(self):
            for path, stream in self.datapipe:
                for chunk in iter(lambda: stream.read(self.chunk_size), b""):
                    # filter out keep-alive new chunks from HTTP streams
                    if not chunk:
                        continue
    
                    yield path, chunk

    A utility like that is needed anyway for Saver anyway. AFAIK, there is currently no builtin functionality to read data from a stream, correct? I'm guessing that is why the Saver example uses bytes as input.

If there is more to discuss, let's migrate this to a proper issue on the torchdata repository.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will open an issue in torchdata to get more thoughts from the community and see if there is anything that we may be missing.

IIUC, since _get_response_from_http returns r.raw, rather than say r.iter_content(), it loads the whole response into memory at once. Is that correct?

https://github.com/pytorch/data/blob/12cfaf8899b1337981cd4edf9deef127f925f1bd/torchdata/datapipes/iter/load/online.py#L46

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will open an issue in torchdata to get more thoughts from the community and see if there is anything that we may be missing.

See pytorch/data#439. It seems we missed StreamReader, which does exactly that. Still, could you open an issue regarding the progress bar?

IIUC, since _get_response_from_http returns r.raw, rather than say r.iter_content(), it loads the whole response into memory at once. Is that correct?

r.raw just gives you the stream and is somewhat like BinaryIO, whereas r.iter_content() reads chunks from the stream and thus is a Iterator[bytes]. IMO returning r.raw is the right choice, since attaching a StreamReader to a HTTPReader will give us the iter_content() behavior.

torchvision/prototype/datasets/utils/_resource.py Outdated Show resolved Hide resolved
dp = HashChecker(dp, {str(file): self.sha256}, hash_type="sha256")
list(dp)

return file
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we opted to only perform the preprocessing after the download in #5990, we can probably also move it here.

Comment on lines 71 to 73
dp = IterableWrapper([self.url])
dp = OnlineReader(dp)
stream = list(dp)[0][1]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have primitive support for mirrors in our current utilities. One possibility would be to accept multiple URLs here and keep going through them until we get a proper response. AFAIK nothing like that is currently possible with torchdata

cc @ejguan @NivekT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that can be a useful functionality. We probably want to add it as a new DataPipe rather than modifying the HttpReader.

Copy link
Contributor

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OnlineReader has no support for redirects.

Agree we can add that to that and HttpReader.

GDriveReader is missing functionality that we support. For example, we can deal with Virus scan warnings that popped up recently.

We should add that.

For that we need automatic detection of the archive type.

Is it possible to extend Decompressor to do more of that? Or will you need something different?

Comment on lines 73 to 82
stream = list(dp)[0][1]

with open(file, "wb") as fh, tqdm() as progress_bar:
for chunk in iter(lambda: stream.read(1024 * 1024), b""): # type: ignore[no-any-return]
# filter out keep-alive new chunks
if not chunk:
continue

fh.write(chunk)
progress_bar.update(len(chunk))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I agree having a progress bar for download is a good idea. Maybe torchdata can have a dependency for tqdm.
  2. I suppose adding an option to return r.iter_content() instead of r.raw would work?

@pmeier
Copy link
Collaborator Author

pmeier commented May 23, 2022

For that we need automatic detection of the archive type.

Is it possible to extend Decompressor to do more of that? Or will you need something different?

Nope. As the name implies Decompressor just decompresses files, but does not "extract" archives. We need something similar to what OnlineReader is to HttpReader and GDriveReader, i.e. an ArchiveLoader. Something along the lines:

class ZipArchiveLoader(IterDataPipe):
    def __init__(self, datapipe):
        self.datapipe = datapipe

    @staticmethod
    def _load(path, stream):
        for info in zipfile.ZipFile(stream):
            extracted_path = ...
            extracted_stream = ...
            yield os.path.join(path, extracted_path), StreamWrapper(extracted_stream)

    def __iter__(self):
        for path, stream in self.datapipe:
            yield from self._load(path, stream)


class TarArchiveLoader(IterDataPipe):
    def __init__(self, datapipe):
        self.datapipe = datapipe

    @staticmethod
    def _load(path, stream):
        for info in tarfile.open(stream):
            extracted_path = ...
            extracted_stream = ...
            yield os.path.join(path, extracted_path), StreamWrapper(extracted_stream)

    def __iter__(self):
        for path, stream in self.datapipe:
            yield from self._load(path, stream)


class ArchiveLoader(IterDataPipe):
    def __init__(self, datapipe):
        self.datapipe = datapipe

    _LOADERS = {
        ".zip": ZipArchiveLoader._load,
        ".tar": TarArchiveLoader._load,
    }

    def __iter__(self):
        for path, stream in self.datapipe:
            load = self._LOADERS[os.path.splitext(path)[1]]
            yield from load(path, stream)

With this we can extract any archive with

dp = IterableWrapper([path_to_any_kind_of_archive])
dp = ArchiveLoader(dp)
dp = ChunkReader(dp)  # see comment above
dp = Saver(dp, ...)
list(dp)

Alternatively, I could also live with a select_archive_loader functionality that I can pass in the path of an arbitrary archive and get back the correct loader.

Copy link
Collaborator Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NivekT @ejguan While working on the next iteration of this, I encountered two issues documented as TODO below and tracked in pytorch/data#451 and pytorch/data#452.

D = TypeVar("D")


class ProgressBar(IterDataPipe[D]):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This my naive approach to add a progress bar as a datapipe. It is sufficient for our use case, but needs more work for general use. For example, the implementation relies on the fact that all items in the datapipe should be handles by the same progress bar. This works for us, since we get chunks of a single file. But if we would process multiple files with the pipeline, the progress bar would aggregate everything together.

@vadimkantorov
Copy link

Related on GDrive and other cloud downloads: pytorch/pytorch#73466

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants