|
1 | 1 | import functools |
2 | 2 | from collections.abc import Sequence |
3 | | -from typing import Any, Callable, Optional, Union |
| 3 | +from typing import Any, Callable, Optional, TYPE_CHECKING, Union |
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | from torchvision import tv_tensors |
| 7 | +from torchvision.transforms.functional import InterpolationMode |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + import cvcuda # type: ignore[import-not-found] |
7 | 11 |
|
8 | 12 | _FillType = Union[int, float, Sequence[int], Sequence[float], None] |
9 | 13 | _FillTypeJIT = Optional[list[float]] |
@@ -177,3 +181,45 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: |
177 | 181 | return isinstance(inpt, cvcuda.Tensor) |
178 | 182 | except ImportError: |
179 | 183 | return False |
| 184 | + |
| 185 | + |
| 186 | +_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {} |
| 187 | + |
| 188 | + |
| 189 | +def _populate_interpolation_mode_to_cvcuda_interp(): |
| 190 | + cvcuda = _import_cvcuda() |
| 191 | + |
| 192 | + global _interpolation_mode_to_cvcuda_interp |
| 193 | + |
| 194 | + _interpolation_mode_to_cvcuda_interp = { |
| 195 | + InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR, |
| 196 | + "bilinear": cvcuda.Interp.LINEAR, |
| 197 | + "linear": cvcuda.Interp.LINEAR, |
| 198 | + 2: cvcuda.Interp.LINEAR, |
| 199 | + InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC, |
| 200 | + "bicubic": cvcuda.Interp.CUBIC, |
| 201 | + 3: cvcuda.Interp.CUBIC, |
| 202 | + InterpolationMode.NEAREST: cvcuda.Interp.NEAREST, |
| 203 | + "nearest": cvcuda.Interp.NEAREST, |
| 204 | + 0: cvcuda.Interp.NEAREST, |
| 205 | + InterpolationMode.BOX: cvcuda.Interp.BOX, |
| 206 | + "box": cvcuda.Interp.BOX, |
| 207 | + 4: cvcuda.Interp.BOX, |
| 208 | + InterpolationMode.HAMMING: cvcuda.Interp.HAMMING, |
| 209 | + "hamming": cvcuda.Interp.HAMMING, |
| 210 | + 5: cvcuda.Interp.HAMMING, |
| 211 | + InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS, |
| 212 | + "lanczos": cvcuda.Interp.LANCZOS, |
| 213 | + 1: cvcuda.Interp.LANCZOS, |
| 214 | + } |
| 215 | + |
| 216 | + |
| 217 | +def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp": |
| 218 | + if len(_interpolation_mode_to_cvcuda_interp) == 0: |
| 219 | + _populate_interpolation_mode_to_cvcuda_interp() |
| 220 | + |
| 221 | + interp = _interpolation_mode_to_cvcuda_interp.get(interpolation) |
| 222 | + if interp is None: |
| 223 | + raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") |
| 224 | + |
| 225 | + return interp |
0 commit comments