Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pytorch/vision into compile-kernel-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jan 17, 2024
2 parents 671dfd1 + 1de7a74 commit 5db34b4
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 38 deletions.
35 changes: 29 additions & 6 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ def test_get_redirect_url_max_hops_exceeded(self, mocker):
assert mock.call_count == 1
assert mock.call_args[0][0].full_url == url

def test_check_md5(self):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_check_md5(self, use_pathlib):
fpath = TEST_FILE
if use_pathlib:
fpath = pathlib.Path(fpath)
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = ""
assert utils.check_md5(fpath, correct_md5)
Expand Down Expand Up @@ -116,7 +119,8 @@ def test_detect_file_type_incompatible(self, file):
utils._detect_file_type(file)

@pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
def test_decompress(self, extension, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress(self, extension, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
Expand All @@ -128,6 +132,8 @@ def create_compressed(root, content="this is the content"):
return compressed, file, content

compressed, file, content = create_compressed(tmpdir)
if use_pathlib:
compressed = pathlib.Path(compressed)

utils._decompress(compressed)

Expand All @@ -140,7 +146,8 @@ def test_decompress_no_compression(self):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress_remove_finished(self, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
Expand All @@ -151,10 +158,20 @@ def create_compressed(root, content="this is the content"):
return compressed, file, content

compressed, file, content = create_compressed(tmpdir)
print(f"{type(compressed)=}")
if use_pathlib:
compressed = pathlib.Path(compressed)
tmpdir = pathlib.Path(tmpdir)

utils.extract_archive(compressed, tmpdir, remove_finished=True)
extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)

assert not os.path.exists(compressed)
if use_pathlib:
assert isinstance(extracted_dir, pathlib.Path)
assert isinstance(compressed, pathlib.Path)
else:
assert isinstance(extracted_dir, str)
assert isinstance(compressed, str)

@pytest.mark.parametrize("extension", [".gz", ".xz"])
@pytest.mark.parametrize("remove_finished", [True, False])
Expand All @@ -167,7 +184,8 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m

mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_zip(self, tmpdir, use_pathlib):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
Expand All @@ -177,6 +195,8 @@ def create_archive(root, content="this is the content"):

return archive, file, content

if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir)

utils.extract_archive(archive, tmpdir)
Expand All @@ -189,7 +209,8 @@ def create_archive(root, content="this is the content"):
@pytest.mark.parametrize(
"extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
)
def test_extract_tar(self, extension, mode, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
Expand All @@ -203,6 +224,8 @@ def create_archive(root, extension, mode, content="this is the content"):

return archive, dst, content

if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir, extension, mode)

utils.extract_archive(archive, tmpdir)
Expand Down
11 changes: 2 additions & 9 deletions torchvision/csrc/io/decoder/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,8 @@ int Stream::openCodec(std::vector<DecoderMetadata>* metadata, int num_threads) {
codecCtx_->thread_count = num_threads;
} else {
// otherwise set sensible defaults
// with the special case for the different MPEG4 codecs
// that don't have threading context functions
if (codecCtx_->codec->capabilities & AV_CODEC_CAP_INTRA_ONLY) {
codecCtx_->thread_type = FF_THREAD_FRAME;
codecCtx_->thread_count = 2;
} else {
codecCtx_->thread_count = 8;
codecCtx_->thread_type = FF_THREAD_SLICE;
}
codecCtx_->thread_count = 8;
codecCtx_->thread_type = FF_THREAD_SLICE;
}

int ret;
Expand Down
69 changes: 46 additions & 23 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def _save_response_content(
content: Iterator[bytes],
destination: str,
destination: Union[str, pathlib.Path],
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
Expand All @@ -43,12 +43,12 @@ def _save_response_content(
pbar.update(len(chunk))


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)


def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
Expand All @@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
return md5.hexdigest()


def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)


def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
Expand Down Expand Up @@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
def download_url(
url: str,
root: Union[str, pathlib.Path],
filename: Optional[str] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None:
Expand Down Expand Up @@ -159,7 +159,7 @@ def download_url(
raise RuntimeError("File not found or corrupted.")


def list_dir(root: str, prefix: bool = False) -> List[str]:
def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:
"""List all directories at a given root
Args:
Expand All @@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:
return directories


def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root
Args:
Expand Down Expand Up @@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple


def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
file_id: str,
root: Union[str, pathlib.Path],
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
):
"""Download a Google Drive file from and place it in root.
Expand Down Expand Up @@ -278,7 +281,9 @@ def download_file_from_google_drive(
)


def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_tar(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)

Expand All @@ -289,14 +294,16 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No
}


def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_zip(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)


_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
Expand All @@ -312,7 +319,7 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
}


def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.
Args:
Expand Down Expand Up @@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def _decompress(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> pathlib.Path:
r"""Decompress a file.
The compression is automatically detected from the file name.
Expand All @@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
Expand All @@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
if remove_finished:
os.remove(from_path)

return to_path
return pathlib.Path(to_path)


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def extract_archive(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> Union[str, pathlib.Path]:
"""Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
Expand All @@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
Returns:
(str): Path to the directory the file was extracted to.
"""

def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
if isinstance(from_path, str):
return os.fspath(ret_path)
else:
return ret_path

if to_path is None:
to_path = os.path.dirname(from_path)

suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
ret_path = _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)
return path_or_str(ret_path)

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]
Expand All @@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
if remove_finished:
os.remove(from_path)

return to_path
return path_or_str(pathlib.Path(to_path))


def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
download_root: Union[str, pathlib.Path],
extract_root: Optional[Union[str, pathlib.Path]] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
Expand Down Expand Up @@ -479,7 +502,7 @@ def verify_str_arg(
return value


def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
Args:
Expand Down

0 comments on commit 5db34b4

Please sign in to comment.