Skip to content

Commit

Permalink
Add pathlib.Path support for download utils (#8196)
Browse files Browse the repository at this point in the history
Co-authored-by: Ahmad Sharif <ahmads@fb.com>
Co-authored-by: Brizar <1500595+bmmtstb@users.noreply.github.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
4 people authored Jan 9, 2024
1 parent 2afb7fa commit d234307
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
23 changes: 18 additions & 5 deletions test/test_internet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,48 @@
"""

import os
import pathlib
from urllib.error import URLError

import pytest
import torchvision.datasets.utils as utils


class TestDatasetUtils:
def test_download_url(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "http://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, tmpdir)
assert len(os.listdir(tmpdir)) != 0
except URLError:
pytest.skip(f"could not download test file '{url}'")

def test_download_url_retry_http(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_retry_http(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "https://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, tmpdir)
assert len(os.listdir(tmpdir)) != 0
except URLError:
pytest.skip(f"could not download test file '{url}'")

def test_download_url_dont_exist(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_dont_exist(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
with pytest.raises(URLError):
utils.download_url(url, tmpdir)

def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"

id = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
Expand All @@ -44,7 +57,7 @@ def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir):
mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive")
utils.download_url(url, tmpdir, filename, md5)

mocked.assert_called_once_with(id, tmpdir, filename, md5)
mocked.assert_called_once_with(id, os.path.expanduser(tmpdir), filename, md5)


if __name__ == "__main__":
Expand Down
16 changes: 11 additions & 5 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import urllib.request
import warnings
import zipfile
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
from urllib.parse import urlparse

import numpy as np
Expand Down Expand Up @@ -104,7 +104,11 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:


def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
url: str,
root: Union[str, pathlib.Path],
filename: Optional[str] = None,
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None:
"""Download a file from a url and place it in root.
Expand All @@ -118,7 +122,7 @@ def download_url(
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
fpath = os.fspath(os.path.join(root, filename))

os.makedirs(root, exist_ok=True)

Expand Down Expand Up @@ -203,7 +207,9 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
return api_response, content


def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
):
"""Download a Google Drive file from and place it in root.
Args:
Expand All @@ -217,7 +223,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
fpath = os.fspath(os.path.join(root, filename))

os.makedirs(root, exist_ok=True)

Expand Down

0 comments on commit d234307

Please sign in to comment.