-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NOMRG] TransformsV2 questions / comments #7092
Changes from 9 commits
40d7c83
9edfd42
e09d85b
7fbab6c
c7accb6
7aa77b7
a2bce37
a20e581
03ef398
d4b1a3e
91cf82e
c2db00a
a0f2b80
385af4b
0af3f37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,8 +15,35 @@ class BoundingBoxFormat(StrEnum): | |
CXCYWH = StrEnum.auto() | ||
|
||
|
||
# What if... we just removed the format and spatial_size meta-data? | ||
# A: We could, but it comes with trade-offs. For the format, this wouldn't lead | ||
# to much of a difference, except that users would have to convert to XYXY | ||
# before doing anything. All of the current stable ops expect XYXY already so | ||
# it's not much of a change. Worth noting as well that a few BBox transforms | ||
# only have an implementation for the XYXY format, and they convert / re-convert | ||
# internally (see e.g. affine_bounding_box, but there are others) | ||
# Removing spatial_size however would make the dispatcher-level more cluncky for | ||
# users. It wouldn't change much of the tranforms classes as long as they're | ||
# called with their respective image e.g. | ||
# T(image, bbox) | ||
# because the spatial_size can be known from the image param. But in a mid-level | ||
# dispatcher which only accept 1 kind of input like | ||
# dispatcher(bbox) | ||
# there's no way to know the spatial_size unless it's passed as a parameter. | ||
# Users would also need to keep track of it since some transforms actually | ||
# change it: | ||
# bbox, sz = resize(bbox, spatial_size=sz) | ||
# This also means the mid-level dispatchers: | ||
# - need to accept as input anything that was a meta-data 9in this case | ||
# spatial_size | ||
# - need to return them as well; which means they need to return either a single | ||
# image, a single video, or a tuple of (bbox, spatial_size), | ||
# TL;DR: things would get messy for users and for us. | ||
|
||
class BoundingBox(Datapoint): | ||
format: BoundingBoxFormat | ||
format: BoundingBoxFormat # TODO: do not use a builtin? | ||
# TODO: This is the size of the image, not the box. Maybe make this explicit in the name? | ||
# Note: if this isn't user-facing, the TODO is not critical at all | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is user facing. In general, the metadata of the datapoints is considered public. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
But I agree it can be unclear if |
||
spatial_size: Tuple[int, int] | ||
|
||
@classmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,8 @@ | |
FillTypeJIT = Union[int, float, List[float], None] | ||
|
||
|
||
# TODO: provide a few examples of when the Datapoint type is preserved vs when it's not | ||
# test_prototype_datapoints.py is a good starting point | ||
class Datapoint(torch.Tensor): | ||
__F: Optional[ModuleType] = None | ||
|
||
|
@@ -41,13 +43,17 @@ def __new__( | |
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) | ||
return tensor.as_subclass(Datapoint) | ||
|
||
# Is this still needed, considering we won't be releasing the prototype datasets anytime soon? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #7154. |
||
@classmethod | ||
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: | ||
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved, | ||
# this method should be made abstract | ||
# raise NotImplementedError | ||
return tensor.as_subclass(cls) | ||
|
||
# Can things go wrong with having to maintain a different set of funcs that | ||
# need special care? What if we forget to handle one properly? | ||
# A: This is probably fine. These are the only ones encountered so far. | ||
_NO_WRAPPING_EXCEPTIONS = { | ||
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), | ||
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), | ||
|
@@ -81,14 +87,19 @@ def __torch_function__( | |
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are | ||
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` | ||
""" | ||
# Is this something we've bubbled up to the core team as a potential feature request? | ||
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we | ||
# need to reimplement the functionality. | ||
|
||
# Curious in which cases this can be hit? | ||
# A: Still don't really know, but this comes from the parent | ||
# Tensor.__torch_function__ (which we need to re-write here) | ||
if not all(issubclass(cls, t) for t in types): | ||
return NotImplemented | ||
|
||
with DisableTorchFunctionSubclass(): | ||
output = func(*args, **kwargs or dict()) | ||
# TODO: maybe we can exit the CM here? | ||
|
||
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) | ||
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be | ||
|
@@ -98,21 +109,31 @@ def __torch_function__( | |
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would | ||
# be wrapped into a `datapoints.Image`. | ||
if wrapper and isinstance(args[0], cls): | ||
# TODO: figure out whether | ||
# arbitrary_tensor.to(some_img) | ||
# should be an Image or a Tensor | ||
return wrapper(cls, args[0], output) # type: ignore[no-any-return] | ||
|
||
# Does that mean that DisableTorchFunctionSubclass is ignored for `.inpace_()` functions? | ||
# Or maybe I'm misunderstanding what DisableTorchFunctionSubclass is supposed to do. | ||
# TODO: figure out with torch core whether this is a bug or not | ||
|
||
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, | ||
# will retain the input type. Thus, we need to unwrap here. | ||
if isinstance(output, cls): | ||
return output.as_subclass(torch.Tensor) | ||
|
||
return output | ||
|
||
# Is this used? | ||
def _make_repr(self, **kwargs: Any) -> str: | ||
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. | ||
# If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class. | ||
extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items()) | ||
return f"{super().__repr__()[:-1]}, {extra_repr})" | ||
|
||
# What cyclic import is this solving? Can it be avoided somehow? | ||
# A: this is because the dispatchers call the methods which in turn call back into the functional namespace | ||
@property | ||
def _F(self) -> ModuleType: | ||
# This implements a lazy import of the functional to get around the cyclic import. This import is deferred | ||
|
@@ -125,6 +146,15 @@ def _F(self) -> ModuleType: | |
Datapoint.__F = functional | ||
return Datapoint.__F | ||
|
||
|
||
# doing some_tensor.dtype would go through __torch_function__? | ||
# - is this because it is implemented as a @property method? | ||
# A: yes, or something like that | ||
# - why do we want to bypass __torch_function__ in these case? | ||
# A: for optimization. TODO: should this be part of core already? | ||
# Also, what happens if users access another attribute that we haven't | ||
# overridden here, e.g. image.data or image.some_new_attribute_in_pytorch3.0? | ||
|
||
# Add properties for common attributes like shape, dtype, device, ndim etc | ||
# this way we return the result without passing into __torch_function__ | ||
@property | ||
|
@@ -147,6 +177,10 @@ def dtype(self) -> _dtype: # type: ignore[override] | |
with DisableTorchFunctionSubclass(): | ||
return super().dtype | ||
|
||
# Are these the "no-op fallbacks"? | ||
# A: yes, fallback from the dispatchers. These exist in anticipation of | ||
# allowing user-defined transforms. | ||
# TODO: figure out design / tradeoffs | ||
def horizontal_flip(self) -> Datapoint: | ||
return self | ||
|
||
|
@@ -268,4 +302,4 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N | |
|
||
|
||
InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] | ||
InputTypeJIT = torch.Tensor | ||
InputTypeJIT = torch.Tensor # why alias it? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To have an easier time looking up what the actual type should be. Meaning if you see |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,9 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: | |
|
||
|
||
class Image(Datapoint): | ||
# Where is this used / changes apart from in ConvertColorSpace()? | ||
# A: For now, this is somewhat redundant with number of channels. | ||
# TODO: decide whether we want to keep it? | ||
color_space: ColorSpace | ||
|
||
@classmethod | ||
|
@@ -117,7 +120,10 @@ def horizontal_flip(self) -> Image: | |
def vertical_flip(self) -> Image: | ||
output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor)) | ||
return Image.wrap_like(self, output) | ||
|
||
|
||
# Do we want to keep these public? | ||
# This is probalby related to internal customer needs. TODO for N: figure that out | ||
# This is also related to allow user-defined subclasses and transforms (in anticipation of) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition, |
||
def resize( # type: ignore[override] | ||
self, | ||
size: List[int], | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -11,6 +11,12 @@ | |||||||
L = TypeVar("L", bound="_LabelBase") | ||||||||
|
||||||||
|
||||||||
# Do we have transforms that change the categories? | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No |
||||||||
# Why do we need the labels to be datapoints? | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because some transformations need this information. Examples are
|
||||||||
# In these label classes, what is strictly needed vs something that was | ||||||||
# historically designed with the joint dataset / transforms revamp in mind? | ||||||||
# (asking because the dataset revamp is on indefinite pause, so perhaps some | ||||||||
# things are now obsolete?) | ||||||||
Comment on lines
+16
to
+19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everything related to |
||||||||
class _LabelBase(Datapoint): | ||||||||
categories: Optional[Sequence[str]] | ||||||||
|
||||||||
|
@@ -58,6 +64,7 @@ def to_categories(self) -> Any: | |||||||
return tree_map(lambda idx: self.categories[idx], self.tolist()) | ||||||||
|
||||||||
|
||||||||
# Remind me where this is used? | ||||||||
class OneHotLabel(_LabelBase): | ||||||||
def __new__( | ||||||||
cls, | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
Needs to be done in order to migrate to stable | ||
---------------------------------------------- | ||
|
||
(Some of the items really just mean "Nicolas needs to understand this better") | ||
|
||
* Figure out criticality of JIT compat for classes. Is this absolutely needed, by whom, potential workarounds, etc. | ||
* Write Docs | ||
* Figure out dataset <-> transformsV2 layer (including HF or other external datasets) | ||
* address TODOs below and in code, code review etc. | ||
* pytorch 2.0 compat? | ||
* Video transforms | ||
* Figure out path to user-defined transforms and sub-classes | ||
* Does Mask support Optical Flow masks? | ||
* Handle strides, e.g. https://github.com/pytorch/vision/issues/7090 ? Looks like it's a non-issue? | ||
* Figure out what transformsV2 mean for inference presets | ||
* Figure out logistics of migration into stable area (extra .v2 namespace, stuff like that) | ||
|
||
|
||
|
||
TODOs | ||
----- | ||
|
||
- Those in https://github.com/pytorch/vision/pull/7092 and | ||
https://github.com/pytorch/vision/pull/7082 (There is overlap!) | ||
They're not all critical. | ||
- Document (internally, not as user-facing docs) the `self.as_subclass(torch.Tensor)` perf hack | ||
|
||
Done | ||
---- | ||
|
||
* Figure out what to do about get_params() static methods (See https://github.com/pytorch/vision/pull/7092) |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -94,7 +94,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | |||||||||||||
def _transform( | ||||||||||||||
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] | ||||||||||||||
) -> Union[datapoints.ImageType, datapoints.VideoType]: | ||||||||||||||
if params["v"] is not None: | ||||||||||||||
if params["v"] is not None: # What is this? | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a value tensor or None used to erase the image There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically it is the replacement that is put in the "erased" area. In v1, in case we didn't find an area to erase, we return the bounding box of the whole image as well as the image vision/torchvision/transforms/transforms.py Line 1702 in 01d138d
With that we call vision/torchvision/transforms/functional_tensor.py Lines 928 to 932 in 01d138d
Since that is quite nonsensical, we opted to also allow |
||||||||||||||
inpt = F.erase(inpt, **params, inplace=self.inplace) | ||||||||||||||
|
||||||||||||||
return inpt | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
from .utils import is_simple_tensor, query_chw | ||
|
||
|
||
# Considering the widespread usage of ToTensor, could this be potentially disruptive? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup. But this deprecation is not really related to v2. Although we never followed through before, deprecating |
||
# What is the cost of keeping it around as an alias for Compose(...)? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not much TBH. I can send a draft PR with a design. |
||
class ToTensor(Transform): | ||
_transformed_types = (PIL.Image.Image, np.ndarray) | ||
|
||
|
@@ -41,6 +43,7 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: | |
f"The transform `Grayscale(num_output_channels={num_output_channels})` " | ||
f"is deprecated and will be removed in a future release." | ||
) | ||
# the name seems to be ConvertColorSpace, not ConvertImageColorSpace | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! This was renamed and we forgot to update the warning. |
||
if num_output_channels == 1: | ||
replacement_msg = ( | ||
"transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,13 +40,21 @@ class Resize(Transform): | |
def __init__( | ||
self, | ||
size: Union[int, Sequence[int]], | ||
# Unrelated to V2, just interested in your thoughts on this: | ||
# It'd be nice to just be able to do Resize(mode="bilinear"), i.e. accept strings. | ||
# Having to import the Enum class can be inconvenient when prototyping, | ||
# and the discrepancy between torchvision and torch core (which accept | ||
# strings) also trips users. | ||
# I don't remember if this was supported in the past and then removed? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @vfdev-5 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NicolasHug no it wasn't supported. In this PR we moved from PIL resampling int values to enums. It was also proposed to choose between : 1)
|
||
# Would there be any technical challenge in supporting it? (e.g. | ||
# torchscript issues?) | ||
Comment on lines
+49
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll look into that. |
||
interpolation: InterpolationMode = InterpolationMode.BILINEAR, | ||
max_size: Optional[int] = None, | ||
antialias: Optional[bool] = None, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.size = ( | ||
self.size = ( # Isn't this check duplicated in _compute_resized_output_size()? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope. I fell for that too. See #6514 (comment). TL;DR |
||
[size] | ||
if isinstance(size, int) | ||
else _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
import PIL.Image | ||
import torch | ||
from torch import nn | ||
# I remember that the core team is open to making it public. Is still private? | ||
# A: yes but core is committed to full BC | ||
Comment on lines
+7
to
+8
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See also pytorch/pytorch#65761 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @srossross |
||
from torch.utils._pytree import tree_flatten, tree_unflatten | ||
from torchvision.prototype.transforms.utils import check_type | ||
from torchvision.utils import _log_api_usage_once | ||
|
@@ -14,6 +16,8 @@ class Transform(nn.Module): | |
# Class attribute defining transformed types. Other types are passed-through without any transformation | ||
# We support both Types and callables that are able to do further checks on the type of the input. | ||
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image) | ||
# Q: Apart from Tensors (or datapoints) and PIL images, what are the other | ||
# types that one may pass to a transform? | ||
Comment on lines
+19
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, the most relevant will be transform = ...
dataset = ...
sample = dataset[0]
transform(sample) Even without datasets v2, |
||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
|
@@ -35,6 +39,8 @@ def forward(self, *inputs: Any) -> Any: | |
|
||
params = self._get_params(flat_inputs) | ||
|
||
# TODO: right now, any tensor or datapoint passed to forward() will be transformed | ||
# the rest is bypassed | ||
Comment on lines
+42
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what the TODO is here? Yes, this behavior is by design. What other behavior were you expecting? Even if we go through with #7028, it would still be limited to tensors (and in turn datapoints since they are tensors) and PIL images. Any other type would fail at the first dispatcher call. Exceptions might be the type conversion transforms like |
||
flat_outputs = [ | ||
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs | ||
] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,9 @@ def _transform( | |
return F.to_image_pil(inpt, mode=self.mode) | ||
|
||
|
||
# What is the "new naming scheme"? | ||
# I was going to point out that "ToTensorImage" and "ToPILImage" would look more | ||
# familiar | ||
Comment on lines
+65
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For kernels we chose
( Following that scheme for transforms we have Still, since we didn't want to cause any disruption there (at least not in the beginning), we just silently alias the two. |
||
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is | ||
# prevalent and well understood. Thus, we just alias it without deprecating the old name. | ||
ToPILImage = ToImagePIL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #7117.