Skip to content

Commit

Permalink
Merge branch 'main' into please_dont_modify_this_branch_unless_you_ar…
Browse files Browse the repository at this point in the history
…e_just_merging_with_main__
  • Loading branch information
NicolasHug authored Sep 23, 2024
2 parents 69e6dea + 6d7851b commit 99406ef
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ versions.
| `torch` | `torchvision` | Python |
| ------------------ | ------------------ | ------------------- |
| `main` / `nightly` | `main` / `nightly` | `>=3.9`, `<=3.12` |
| `2.5` | `0.20` | `>=3.9`, `<=3.12` |
| `2.4` | `0.19` | `>=3.8`, `<=3.12` |
| `2.3` | `0.18` | `>=3.8`, `<=3.12` |
| `2.2` | `0.17` | `>=3.8`, `<=3.11` |
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without AVIF support")

if USE_NVJPEG and torch.cuda.is_available():
if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA):
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

if nvjpeg_found:
Expand All @@ -376,6 +376,8 @@ def make_image_extension():
Extension = CUDAExtension
else:
warnings.warn("Building torchvision without NVJPEG support")
elif USE_NVJPEG:
warnings.warn("Building torchvision without NVJPEG support")

return Extension(
name="torchvision.image",
Expand Down
10 changes: 10 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import re
import sys
from contextlib import nullcontext
from pathlib import Path

import numpy as np
Expand All @@ -13,6 +14,7 @@
import torchvision.transforms.v2.functional as F
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision._internally_replaced_utils import IN_FBCODE
from torchvision.io.image import (
_decode_avif,
_decode_heic,
Expand Down Expand Up @@ -1076,5 +1078,13 @@ def test_mode_str():
assert decode_image(path, mode="RGBA").shape[0] == 4


def test_avif_heic_fbcode():
cm = nullcontext() if IN_FBCODE else pytest.raises(ImportError, match="cannot import")
with cm:
from torchvision.io import decode_heic # noqa
with cm:
from torchvision.io import decode_avif # noqa


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions torchvision/_internally_replaced_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
_USE_SHARDED_DATASETS = False
IN_FBCODE = False


def _download_file_from_remote_location(fpath: str, url: str) -> None:
Expand Down
7 changes: 7 additions & 0 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,10 @@
"Video",
"VideoReader",
]

from .._internally_replaced_utils import IN_FBCODE

if IN_FBCODE:
from .image import _decode_avif as decode_avif, _decode_heic as decode_heic

__all__ += ["decode_avif", "decode_heic"]
4 changes: 2 additions & 2 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,9 @@ def _parse_colors(
f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}."
)
elif not isinstance(colors, (tuple, str)):
raise ValueError("`colors` must be a tuple or a string, or a list thereof, but got {colors}.")
raise ValueError(f"`colors` must be a tuple or a string, or a list thereof, but got {colors}.")
elif isinstance(colors, tuple) and len(colors) != 3:
raise ValueError("If passed as tuple, colors should be an RGB triplet, but got {colors}.")
raise ValueError(f"If passed as tuple, colors should be an RGB triplet, but got {colors}.")
else: # colors specifies a single color for all objects
colors = [colors] * num_objects

Expand Down

0 comments on commit 99406ef

Please sign in to comment.