diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index d036b5db1de..5fdf08cab24 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -29,6 +29,8 @@ def test_to_wrapping(): def test_to_feature_reference(): + # is "feature" now "datapoint" - or is it something else? + # A: yes, TODO: update name tensor = torch.tensor([0, 1, 0], dtype=torch.int64) label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32) @@ -46,6 +48,10 @@ def test_clone_wrapping(): assert type(label_clone) is datapoints.Label assert label_clone.data_ptr() != label.data_ptr() + # Is this expected? + # Does this meta-data-preserving behaviour occur for all meta-data attached to datapoints? + # Can things go wrong? i.e. if label_clone changes its metadata inplace, that of label gets changed too? + # TODO: N and P should discuss this more assert label_clone.categories is label.categories @@ -60,6 +66,7 @@ def test_requires_grad__wrapping(): assert type(label_requires_grad) is datapoints.Label assert label.requires_grad assert label_requires_grad.requires_grad + assert label_requires_grad is label def test_other_op_no_wrapping(): diff --git a/torchvision/_utils.py b/torchvision/_utils.py index b739ef0966e..382429f6bce 100644 --- a/torchvision/_utils.py +++ b/torchvision/_utils.py @@ -4,6 +4,7 @@ T = TypeVar("T", bound=enum.Enum) + class StrEnumMeta(enum.EnumMeta): auto = enum.auto @@ -16,10 +17,21 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc] raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None +# TODO: As agreed with Philip, we can just get rid of this class for now. +# We can just define enums as +# class BoundingBoxFormat(Enum): +# XYXY = "XYXY" +# ... +# and replace from_str with BoundingBoxFormat["XYXY"]. +# +# We won't have a super nice error message as with add_suggestion, but this is +# something we can think about when this is more critically needed (i.e. after +# the migration) class StrEnum(enum.Enum, metaclass=StrEnumMeta): pass + def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: if not seq: return "" diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py index 92f345e20bd..f077101c3b4 100644 --- a/torchvision/prototype/datapoints/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -1,5 +1,5 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat -from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT +from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT # Need to be public? TODO: no from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._label import Label, OneHotLabel from ._mask import Mask diff --git a/torchvision/prototype/datapoints/_bounding_box.py b/torchvision/prototype/datapoints/_bounding_box.py index 398770cbf6a..9cf314d7576 100644 --- a/torchvision/prototype/datapoints/_bounding_box.py +++ b/torchvision/prototype/datapoints/_bounding_box.py @@ -15,8 +15,35 @@ class BoundingBoxFormat(StrEnum): CXCYWH = StrEnum.auto() +# What if... we just removed the format and spatial_size meta-data? +# A: We could, but it comes with trade-offs. For the format, this wouldn't lead +# to much of a difference, except that users would have to convert to XYXY +# before doing anything. All of the current stable ops expect XYXY already so +# it's not much of a change. Worth noting as well that a few BBox transforms +# only have an implementation for the XYXY format, and they convert / re-convert +# internally (see e.g. affine_bounding_box, but there are others) +# Removing spatial_size however would make the dispatcher-level more cluncky for +# users. It wouldn't change much of the tranforms classes as long as they're +# called with their respective image e.g. +# T(image, bbox) +# because the spatial_size can be known from the image param. But in a mid-level +# dispatcher which only accept 1 kind of input like +# dispatcher(bbox) +# there's no way to know the spatial_size unless it's passed as a parameter. +# Users would also need to keep track of it since some transforms actually +# change it: +# bbox, sz = resize(bbox, spatial_size=sz) +# This also means the mid-level dispatchers: +# - need to accept as input anything that was a meta-data 9in this case +# spatial_size +# - need to return them as well; which means they need to return either a single +# image, a single video, or a tuple of (bbox, spatial_size), +# TL;DR: things would get messy for users and for us. + class BoundingBox(Datapoint): - format: BoundingBoxFormat + format: BoundingBoxFormat # TODO: do not use a builtin? + # TODO: This is the size of the image, not the box. Maybe make this explicit in the name? + # Note: if this isn't user-facing, the TODO is not critical at all spatial_size: Tuple[int, int] @classmethod diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index 659d4e958cc..91bf1d53c63 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -15,6 +15,8 @@ FillTypeJIT = Union[int, float, List[float], None] +# TODO: provide a few examples of when the Datapoint type is preserved vs when it's not +# test_prototype_datapoints.py is a good starting point class Datapoint(torch.Tensor): __F: Optional[ModuleType] = None @@ -41,6 +43,7 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return tensor.as_subclass(Datapoint) + # Is this still needed, considering we won't be releasing the prototype datasets anytime soon? @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: # FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved, @@ -48,6 +51,9 @@ def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: # raise NotImplementedError return tensor.as_subclass(cls) + # Can things go wrong with having to maintain a different set of funcs that + # need special care? What if we forget to handle one properly? + # A: This is probably fine. These are the only ones encountered so far. _NO_WRAPPING_EXCEPTIONS = { torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), @@ -81,14 +87,19 @@ def __torch_function__( For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` """ + # Is this something we've bubbled up to the core team as a potential feature request? # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we # need to reimplement the functionality. + # Curious in which cases this can be hit? + # A: Still don't really know, but this comes from the parent + # Tensor.__torch_function__ (which we need to re-write here) if not all(issubclass(cls, t) for t in types): return NotImplemented with DisableTorchFunctionSubclass(): output = func(*args, **kwargs or dict()) + # TODO: maybe we can exit the CM here? wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be @@ -98,8 +109,15 @@ def __torch_function__( # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # be wrapped into a `datapoints.Image`. if wrapper and isinstance(args[0], cls): + # TODO: figure out whether + # arbitrary_tensor.to(some_img) + # should be an Image or a Tensor return wrapper(cls, args[0], output) # type: ignore[no-any-return] + # Does that mean that DisableTorchFunctionSubclass is ignored for `.inpace_()` functions? + # Or maybe I'm misunderstanding what DisableTorchFunctionSubclass is supposed to do. + # TODO: figure out with torch core whether this is a bug or not + # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, # will retain the input type. Thus, we need to unwrap here. if isinstance(output, cls): @@ -107,12 +125,15 @@ def __torch_function__( return output + # Is this used? def _make_repr(self, **kwargs: Any) -> str: # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class. extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items()) return f"{super().__repr__()[:-1]}, {extra_repr})" + # What cyclic import is this solving? Can it be avoided somehow? + # A: this is because the dispatchers call the methods which in turn call back into the functional namespace @property def _F(self) -> ModuleType: # This implements a lazy import of the functional to get around the cyclic import. This import is deferred @@ -125,6 +146,15 @@ def _F(self) -> ModuleType: Datapoint.__F = functional return Datapoint.__F + + # doing some_tensor.dtype would go through __torch_function__? + # - is this because it is implemented as a @property method? + # A: yes, or something like that + # - why do we want to bypass __torch_function__ in these case? + # A: for optimization. TODO: should this be part of core already? + # Also, what happens if users access another attribute that we haven't + # overridden here, e.g. image.data or image.some_new_attribute_in_pytorch3.0? + # Add properties for common attributes like shape, dtype, device, ndim etc # this way we return the result without passing into __torch_function__ @property @@ -147,6 +177,10 @@ def dtype(self) -> _dtype: # type: ignore[override] with DisableTorchFunctionSubclass(): return super().dtype + # Are these the "no-op fallbacks"? + # A: yes, fallback from the dispatchers. These exist in anticipation of + # allowing user-defined transforms. + # TODO: figure out design / tradeoffs def horizontal_flip(self) -> Datapoint: return self @@ -268,4 +302,4 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] -InputTypeJIT = torch.Tensor +InputTypeJIT = torch.Tensor # why alias it? diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index fc20691100f..5f8e38f1fef 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -58,6 +58,9 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: class Image(Datapoint): + # Where is this used / changes apart from in ConvertColorSpace()? + # A: For now, this is somewhat redundant with number of channels. + # TODO: decide whether we want to keep it? color_space: ColorSpace @classmethod @@ -117,7 +120,10 @@ def horizontal_flip(self) -> Image: def vertical_flip(self) -> Image: output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) - + + # Do we want to keep these public? + # This is probalby related to internal customer needs. TODO for N: figure that out + # This is also related to allow user-defined subclasses and transforms (in anticipation of) def resize( # type: ignore[override] self, size: List[int], diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index 54915493390..8fce6d735a7 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -11,6 +11,12 @@ L = TypeVar("L", bound="_LabelBase") +# Do we have transforms that change the categories? +# Why do we need the labels to be datapoints? +# In these label classes, what is strictly needed vs something that was +# historically designed with the joint dataset / transforms revamp in mind? +# (asking because the dataset revamp is on indefinite pause, so perhaps some +# things are now obsolete?) class _LabelBase(Datapoint): categories: Optional[Sequence[str]] @@ -58,6 +64,7 @@ def to_categories(self) -> Any: return tree_map(lambda idx: self.categories[idx], self.tolist()) +# Remind me where this is used? class OneHotLabel(_LabelBase): def __new__( cls, diff --git a/torchvision/prototype/transforms/Migration.md b/torchvision/prototype/transforms/Migration.md new file mode 100644 index 00000000000..a9dda03f927 --- /dev/null +++ b/torchvision/prototype/transforms/Migration.md @@ -0,0 +1,59 @@ +Needs to be done in order to migrate to Beta +-------------------------------------------- + +(Some of the items really just mean "Nicolas needs to understand this better") + +* (P V N)Figure out logistics of migration (extra .v2 namespace, programmatic "opt-in", + stuff like that): tracked in https://github.com/pytorch/vision/issues/7097 +* (P and N)Figure out dataset <-> transformsV2 layer (including HF or other external + datasets): tracked in https://github.com/pytorch/vision/pull/6663 +* (N) Figure out internal video partners and what they actually need. Some of the + Video transforms like `uniform_temporal_subsample()` are outliers (awkward + support, doesn't fit well into the current API). Same for `PermuteDimensions` + and `TransposeDimension` which break underlying assumptions about dimension + order. +* Address critical TODOs below and in code, code review etc. +* Write Docs +* Philip: (EDIT: submit dummy PR)Polish tests - make sure they are at least functionally equivalent to the v1 + tests. This requires individually checking them. +* (P V N) Bikeshed on a good name and location for wrap_dataset_for_transforms_v2 +* (P V N ) make `Image(some_pil_image)` work - https://github.com/pytorch/vision/pull/6663#discussion_r1093121007 + +Needs to be done before migrating to stable +------------------------------------------- + +* Address rest of TODOs below and in code, code review etc. +* Look into pytorch 2.0 compat? (**Should this be bumped up??**) +* Figure out path to user-defined transforms and sub-classes +* Add support for Optical Flow tranforms (e.g. vlip needs special handling for + flow masks) +* Handle strides, e.g. https://github.com/pytorch/vision/issues/7090 ? Looks like it's a non-issue? +* Figure out what transformsV2 mean for inference presets + + +TODOs +----- + +- Those in https://github.com/pytorch/vision/pull/7092 and + https://github.com/pytorch/vision/pull/7082 (There is overlap!) + They're not all critical. +- Document (internally, not as user-facing docs) the `self.as_subclass(torch.Tensor)` perf hack + +Done +---- + +* Figure out what to do about get_params() static methods (See https://github.com/pytorch/vision/pull/7092). + A: we want them back - tracked in https://github.com/pytorch/vision/pull/7153 +* Avoid inconsistent output type: Let Normalize() and RandomPhotometricDistort + return datapoints instead of tensors + (https://github.com/pytorch/vision/pull/7113) +* Figure out criticality of JIT compat for classes. Is this absolutely needed, + by whom, potential workarounds, etc. + * Done: Philip found elegant way to support JIT as long as the v1 transforms + are still around: https://github.com/pytorch/vision/pull/7135 +* Figure out whether `__torch_dispatch__` is preferable to `__torch_function__`. + * After chat with Alban, there's no reason to use `__torch_dispatch__`. + Everything should work as expected with `__torch_function__`, including + AutoGrad. +* Simplify interface and Image meta-data: Remove color_space metadata and + ConvertColorSpace() transform (https://github.com/pytorch/vision/pull/7120) \ No newline at end of file diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3160770a09d..806ca40bbf7 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -94,7 +94,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] ) -> Union[datapoints.ImageType, datapoints.VideoType]: - if params["v"] is not None: + if params["v"] is not None: # What is this? inpt = F.erase(inpt, **params, inplace=self.inplace) return inpt diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 3247a8051a3..4f1cd9bf175 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -14,6 +14,8 @@ from .utils import is_simple_tensor, query_chw +# Considering the widespread usage of ToTensor, could this be potentially disruptive? +# What is the cost of keeping it around as an alias for Compose(...)? class ToTensor(Transform): _transformed_types = (PIL.Image.Image, np.ndarray) @@ -41,6 +43,7 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: f"The transform `Grayscale(num_output_channels={num_output_channels})` " f"is deprecated and will be removed in a future release." ) + # the name seems to be ConvertColorSpace, not ConvertImageColorSpace if num_output_channels == 1: replacement_msg = ( "transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)" diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 1cbf02d5ae2..581703fdf79 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -40,13 +40,21 @@ class Resize(Transform): def __init__( self, size: Union[int, Sequence[int]], + # Unrelated to V2, just interested in your thoughts on this: + # It'd be nice to just be able to do Resize(mode="bilinear"), i.e. accept strings. + # Having to import the Enum class can be inconvenient when prototyping, + # and the discrepancy between torchvision and torch core (which accept + # strings) also trips users. + # I don't remember if this was supported in the past and then removed? + # Would there be any technical challenge in supporting it? (e.g. + # torchscript issues?) interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> None: super().__init__() - self.size = ( + self.size = ( # Isn't this check duplicated in _compute_resized_output_size()? [size] if isinstance(size, int) else _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index a6980f3e135..42763687d4f 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -14,6 +14,10 @@ __all__ = ["StereoMatching"] +# I assume this file is actually not part of the transformsV2 work - it's just +# that StereoMatching is still protytpe? +# A: yes + class StereoMatching(torch.nn.Module): def __init__( self, diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 43224cabd38..b4e56a99b6f 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -4,6 +4,8 @@ import PIL.Image import torch from torch import nn +# I remember that the core team is open to making it public. Is still private? +# A: yes but core is committed to full BC from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype.transforms.utils import check_type from torchvision.utils import _log_api_usage_once @@ -14,6 +16,8 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image) + # Q: Apart from Tensors (or datapoints) and PIL images, what are the other + # types that one may pass to a transform? def __init__(self) -> None: super().__init__() @@ -35,6 +39,8 @@ def forward(self, *inputs: Any) -> Any: params = self._get_params(flat_inputs) + # TODO: right now, any tensor or datapoint passed to forward() will be transformed + # the rest is bypassed flat_outputs = [ self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs ] diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 01908650fb4..dd62cac3051 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -62,6 +62,9 @@ def _transform( return F.to_image_pil(inpt, mode=self.mode) +# What is the "new naming scheme"? +# I was going to point out that "ToTensorImage" and "ToPILImage" would look more +# familiar # We changed the name to align them with the new naming scheme. Still, `ToPILImage` is # prevalent and well understood. Thus, we just alias it without deprecating the old name. ToPILImage = ToImagePIL diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index cbf8992300e..7e1234e313b 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -22,7 +22,7 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") if isinstance(arg, float): - arg = [float(arg), float(arg)] + arg = [float(arg), float(arg)] # Nit: arg is already a float if isinstance(arg, (list, tuple)) and len(arg) == 1: arg = [arg[0], arg[0]] return arg @@ -68,6 +68,9 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT: return fill +# It is fairly obscure what this does or why this is needed - perhaps a quick +# note summarizing the "fill" situation could be useful? (maybe it exists and +# I'm not seeing it!) def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]: _check_fill_arg(fill) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index ec2da6ee518..efee0cb6fa2 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,8 @@ # TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators +# Do we want all these low-level kernels to be public? +# Is it mostly because they're the only ones that are fully jit-compatible, or are there other reasons? + from torchvision.transforms import InterpolationMode # usort: skip from ._meta import ( clamp_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ba417a0ce84..1360ad7ead1 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -140,6 +140,8 @@ def _compute_resized_output_size( return __compute_resized_output_size(spatial_size, size=size, max_size=max_size) +# IIUC this is mostly a rewrite of torchvision.transforms.functional_tensor.resize()? +# Was this new version made for perf improvements, or some other reason? def resize_image_tensor( image: torch.Tensor, size: List[int], @@ -182,6 +184,8 @@ def resize_image_tensor( return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) +# This looks like a shallow wrapper around the original _FP.resize() +# Are these wrapper common or is this a one-of? @torch.jit.unused def resize_image_pil( image: PIL.Image.Image, @@ -232,6 +236,7 @@ def resize_video( return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) +# Is the if/elif/else/raise structure the same in (almost?) all dispatchers? def resize( inpt: datapoints.InputTypeJIT, size: List[int], @@ -244,11 +249,16 @@ def resize( if isinstance(inpt, torch.Tensor) and ( torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) ): + # TODO: Figure out whether this cond could just be: + # if torch.jit.is_scripting() or is_simple_tensor(inpt): return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) elif isinstance(inpt, datapoints._datapoint.Datapoint): + # IIUC we're calling the method on the datapoint object which in turn + # will go back in this file and call one of the low-level kernels above. + # According to offline chat this is to enable extensibility in the future? return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) elif isinstance(inpt, PIL.Image.Image): - if antialias is not None and not antialias: + if antialias is not None and not antialias: # just check `is False` ? warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) else: diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 28de0536978..6184f183047 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -190,6 +190,7 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: return xyxy +# TODO: Maybe make this available as a class transform as well? def convert_format_bounding_box( bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False ) -> torch.Tensor: @@ -437,6 +438,10 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) - return convert_dtype_image_tensor(video, dtype) +# TODO: this doesn't just change the dtype, it also changes the value range. +# This name relies on the implicit assumption that the value range is determined +# by the dtype. Maybe think of a more descriptive name if we can (once and for +# all) def convert_dtype( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float ) -> torch.Tensor: diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index bc9408d0e2c..002fd918d3f 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -69,6 +69,10 @@ def normalize( f"but got {type(inpt)} instead." ) + # Saw this from [Feedback] thread (https://github.com/pytorch/vision/issues/6753#issuecomment-1308295943) + # Not sure I understand the need to return a Tensor instead of and Image + # Is the inconsistency "worth" it? How can we make sure this isn't too unexpected? + # Image or Video type should not be retained after normalization due to unknown data range # Thus we return Tensor for input Image return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) diff --git a/torchvision/prototype/transforms/utils.py b/torchvision/prototype/transforms/utils.py index 9ab2ed2602b..b9751543cd3 100644 --- a/torchvision/prototype/transforms/utils.py +++ b/torchvision/prototype/transforms/utils.py @@ -38,6 +38,12 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: return c, h, w +# Are we deprecating get_image_size()? +# A: yes, it's currently reporting W,H +# +# Also this Seems like a subset of query_chw(): +# - can we just have query_chw()? +# - if not, can we share code between the 2? def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]: sizes = { tuple(get_spatial_size(inpt)) @@ -50,18 +56,25 @@ def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]: if not sizes: raise TypeError("No image, video, mask or bounding box was found in the sample") elif len(sizes) > 1: + # (This comment applies to the other query_() utils): + # Does this happen? + # What can go terribly wrong if we don't raise an error? + # Should we just document that this returns the size of the first inpt to avoid an error? raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}") h, w = sizes.pop() return h, w +# when can types_or_checks be a callable? def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: for type_or_check in types_or_checks: + # That messed with my brain for a bit - add parenthesis? if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): return True return False +# Are these public because they are needed for users to implement custom transforms? def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: for inpt in flat_inputs: if check_type(inpt, types_or_checks):