Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOMRG] TransformsV2 questions / comments #7092

Closed
wants to merge 15 commits into from
7 changes: 7 additions & 0 deletions test/test_prototype_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def test_to_wrapping():


def test_to_feature_reference():
# is "feature" now "datapoint" - or is it something else?
# A: yes, TODO: update name
Comment on lines +32 to +33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #7117.

tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)

Expand All @@ -46,6 +48,10 @@ def test_clone_wrapping():

assert type(label_clone) is datapoints.Label
assert label_clone.data_ptr() != label.data_ptr()
# Is this expected?
# Does this meta-data-preserving behaviour occur for all meta-data attached to datapoints?
# Can things go wrong? i.e. if label_clone changes its metadata inplace, that of label gets changed too?
# TODO: N and P should discuss this more
assert label_clone.categories is label.categories


Expand All @@ -60,6 +66,7 @@ def test_requires_grad__wrapping():
assert type(label_requires_grad) is datapoints.Label
assert label.requires_grad
assert label_requires_grad.requires_grad
assert label_requires_grad is label


def test_other_op_no_wrapping():
Expand Down
12 changes: 12 additions & 0 deletions torchvision/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
T = TypeVar("T", bound=enum.Enum)



class StrEnumMeta(enum.EnumMeta):
auto = enum.auto

Expand All @@ -16,10 +17,21 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None


# TODO: As agreed with Philip, we can just get rid of this class for now.
# We can just define enums as
# class BoundingBoxFormat(Enum):
# XYXY = "XYXY"
# ...
# and replace from_str with BoundingBoxFormat["XYXY"].
#
# We won't have a super nice error message as with add_suggestion, but this is
# something we can think about when this is more critically needed (i.e. after
# the migration)
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass



def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT # Need to be public? TODO: no
from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel
from ._mask import Mask
Expand Down
29 changes: 28 additions & 1 deletion torchvision/prototype/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,35 @@ class BoundingBoxFormat(StrEnum):
CXCYWH = StrEnum.auto()


# What if... we just removed the format and spatial_size meta-data?
# A: We could, but it comes with trade-offs. For the format, this wouldn't lead
# to much of a difference, except that users would have to convert to XYXY
# before doing anything. All of the current stable ops expect XYXY already so
# it's not much of a change. Worth noting as well that a few BBox transforms
# only have an implementation for the XYXY format, and they convert / re-convert
# internally (see e.g. affine_bounding_box, but there are others)
# Removing spatial_size however would make the dispatcher-level more cluncky for
# users. It wouldn't change much of the tranforms classes as long as they're
# called with their respective image e.g.
# T(image, bbox)
# because the spatial_size can be known from the image param. But in a mid-level
# dispatcher which only accept 1 kind of input like
# dispatcher(bbox)
# there's no way to know the spatial_size unless it's passed as a parameter.
# Users would also need to keep track of it since some transforms actually
# change it:
# bbox, sz = resize(bbox, spatial_size=sz)
# This also means the mid-level dispatchers:
# - need to accept as input anything that was a meta-data 9in this case
# spatial_size
# - need to return them as well; which means they need to return either a single
# image, a single video, or a tuple of (bbox, spatial_size),
# TL;DR: things would get messy for users and for us.

class BoundingBox(Datapoint):
format: BoundingBoxFormat
format: BoundingBoxFormat # TODO: do not use a builtin?
# TODO: This is the size of the image, not the box. Maybe make this explicit in the name?
# Note: if this isn't user-facing, the TODO is not critical at all
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is user facing. In general, the metadata of the datapoints is considered public.

Copy link
Collaborator

@vfdev-5 vfdev-5 Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spatial_size was renamed from image_size once we added support to Videos if I recall correctly. #6736

But I agree it can be unclear if spatial_size refers to bbox or something else...

spatial_size: Tuple[int, int]

@classmethod
Expand Down
36 changes: 35 additions & 1 deletion torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
FillTypeJIT = Union[int, float, List[float], None]


# TODO: provide a few examples of when the Datapoint type is preserved vs when it's not
# test_prototype_datapoints.py is a good starting point
class Datapoint(torch.Tensor):
__F: Optional[ModuleType] = None

Expand All @@ -41,13 +43,17 @@ def __new__(
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(Datapoint)

# Is this still needed, considering we won't be releasing the prototype datasets anytime soon?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #7154.

@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
# this method should be made abstract
# raise NotImplementedError
return tensor.as_subclass(cls)

# Can things go wrong with having to maintain a different set of funcs that
# need special care? What if we forget to handle one properly?
# A: This is probably fine. These are the only ones encountered so far.
_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
Expand Down Expand Up @@ -81,14 +87,19 @@ def __torch_function__(
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
"""
# Is this something we've bubbled up to the core team as a potential feature request?
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.

# Curious in which cases this can be hit?
# A: Still don't really know, but this comes from the parent
# Tensor.__torch_function__ (which we need to re-write here)
if not all(issubclass(cls, t) for t in types):
return NotImplemented

with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict())
# TODO: maybe we can exit the CM here?

wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
Expand All @@ -98,21 +109,31 @@ def __torch_function__(
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
if wrapper and isinstance(args[0], cls):
# TODO: figure out whether
# arbitrary_tensor.to(some_img)
# should be an Image or a Tensor
return wrapper(cls, args[0], output) # type: ignore[no-any-return]

# Does that mean that DisableTorchFunctionSubclass is ignored for `.inpace_()` functions?
# Or maybe I'm misunderstanding what DisableTorchFunctionSubclass is supposed to do.
# TODO: figure out with torch core whether this is a bug or not

# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
# will retain the input type. Thus, we need to unwrap here.
if isinstance(output, cls):
return output.as_subclass(torch.Tensor)

return output

# Is this used?
def _make_repr(self, **kwargs: Any) -> str:
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
# If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
return f"{super().__repr__()[:-1]}, {extra_repr})"

# What cyclic import is this solving? Can it be avoided somehow?
# A: this is because the dispatchers call the methods which in turn call back into the functional namespace
@property
def _F(self) -> ModuleType:
# This implements a lazy import of the functional to get around the cyclic import. This import is deferred
Expand All @@ -125,6 +146,15 @@ def _F(self) -> ModuleType:
Datapoint.__F = functional
return Datapoint.__F


# doing some_tensor.dtype would go through __torch_function__?
# - is this because it is implemented as a @property method?
# A: yes, or something like that
# - why do we want to bypass __torch_function__ in these case?
# A: for optimization. TODO: should this be part of core already?
# Also, what happens if users access another attribute that we haven't
# overridden here, e.g. image.data or image.some_new_attribute_in_pytorch3.0?

# Add properties for common attributes like shape, dtype, device, ndim etc
# this way we return the result without passing into __torch_function__
@property
Expand All @@ -147,6 +177,10 @@ def dtype(self) -> _dtype: # type: ignore[override]
with DisableTorchFunctionSubclass():
return super().dtype

# Are these the "no-op fallbacks"?
# A: yes, fallback from the dispatchers. These exist in anticipation of
# allowing user-defined transforms.
# TODO: figure out design / tradeoffs
def horizontal_flip(self) -> Datapoint:
return self

Expand Down Expand Up @@ -268,4 +302,4 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N


InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
InputTypeJIT = torch.Tensor
InputTypeJIT = torch.Tensor # why alias it?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To have an easier time looking up what the actual type should be. Meaning if you see *JIT as annotation, you can simply look up * to see what the actual type should be. If we would use InputType and torch.Tensor, this relation is gone.

8 changes: 7 additions & 1 deletion torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace:


class Image(Datapoint):
# Where is this used / changes apart from in ConvertColorSpace()?
# A: For now, this is somewhat redundant with number of channels.
# TODO: decide whether we want to keep it?
color_space: ColorSpace

@classmethod
Expand Down Expand Up @@ -117,7 +120,10 @@ def horizontal_flip(self) -> Image:
def vertical_flip(self) -> Image:
output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)


# Do we want to keep these public?
# This is probalby related to internal customer needs. TODO for N: figure that out
# This is also related to allow user-defined subclasses and transforms (in anticipation of)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, resize is the most problematic one, since it actually overrides the (deprecated) tensor method with a completely different behavior.

def resize( # type: ignore[override]
self,
size: List[int],
Expand Down
7 changes: 7 additions & 0 deletions torchvision/prototype/datapoints/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
L = TypeVar("L", bound="_LabelBase")


# Do we have transforms that change the categories?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

# Why do we need the labels to be datapoints?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because some transformations need this information. Examples are

# In these label classes, what is strictly needed vs something that was
# historically designed with the joint dataset / transforms revamp in mind?
# (asking because the dataset revamp is on indefinite pause, so perhaps some
# things are now obsolete?)
Comment on lines +16 to +19
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything related to categories was designed with the datasets v2 in mind. None of that is relevant to the transforms and is just a nice to have for the users.

class _LabelBase(Datapoint):
categories: Optional[Sequence[str]]

Expand Down Expand Up @@ -58,6 +64,7 @@ def to_categories(self) -> Any:
return tree_map(lambda idx: self.categories[idx], self.tolist())


# Remind me where this is used?
class OneHotLabel(_LabelBase):
def __new__(
cls,
Expand Down
31 changes: 31 additions & 0 deletions torchvision/prototype/transforms/Migration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Needs to be done in order to migrate to stable
----------------------------------------------

(Some of the items really just mean "Nicolas needs to understand this better")

* Figure out criticality of JIT compat for classes. Is this absolutely needed, by whom, potential workarounds, etc.
* Write Docs
* Figure out dataset <-> transformsV2 layer (including HF or other external datasets)
* address TODOs below and in code, code review etc.
* pytorch 2.0 compat?
* Video transforms
* Figure out path to user-defined transforms and sub-classes
* Does Mask support Optical Flow masks?
* Handle strides, e.g. https://github.com/pytorch/vision/issues/7090 ? Looks like it's a non-issue?
* Figure out what transformsV2 mean for inference presets
* Figure out logistics of migration into stable area (extra .v2 namespace, stuff like that)



TODOs
-----

- Those in https://github.com/pytorch/vision/pull/7092 and
https://github.com/pytorch/vision/pull/7082 (There is overlap!)
They're not all critical.
- Document (internally, not as user-facing docs) the `self.as_subclass(torch.Tensor)` perf hack

Done
----

* Figure out what to do about get_params() static methods (See https://github.com/pytorch/vision/pull/7092)
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if params["v"] is not None:
if params["v"] is not None: # What is this?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a value tensor or None used to erase the image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically it is the replacement that is put in the "erased" area. In v1, in case we didn't find an area to erase, we return the bounding box of the whole image as well as the image

return 0, 0, img_h, img_w, img

With that we call F.erase unconditionally, which ultimately leads to replacing every value in the original image with itself:

if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img

Since that is quite nonsensical, we opted to also allow None as a return value and use it as a sentinel to do nothing. I think the previous implementation came from a time were JIT didn't support Union (or Optional for that matter) and thus we couldn't return Optional[torch.Tensor].

inpt = F.erase(inpt, **params, inplace=self.inplace)

return inpt
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .utils import is_simple_tensor, query_chw


# Considering the widespread usage of ToTensor, could this be potentially disruptive?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. But this deprecation is not really related to v2. Although we never followed through before, deprecating ToTensor is a long standing "issue". See for example #2060 (comment).

# What is the cost of keeping it around as an alias for Compose(...)?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not much TBH. I can send a draft PR with a design.

class ToTensor(Transform):
_transformed_types = (PIL.Image.Image, np.ndarray)

Expand Down Expand Up @@ -41,6 +43,7 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
f"is deprecated and will be removed in a future release."
)
# the name seems to be ConvertColorSpace, not ConvertImageColorSpace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! This was renamed and we forgot to update the warning.

if num_output_channels == 1:
replacement_msg = (
"transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)"
Expand Down
10 changes: 9 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,21 @@ class Resize(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
# Unrelated to V2, just interested in your thoughts on this:
# It'd be nice to just be able to do Resize(mode="bilinear"), i.e. accept strings.
# Having to import the Enum class can be inconvenient when prototyping,
# and the discrepancy between torchvision and torch core (which accept
# strings) also trips users.
# I don't remember if this was supported in the past and then removed?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug no it wasn't supported. In this PR we moved from PIL resampling int values to enums. It was also proposed to choose between : 1) str values or 2) Enums. I do not quite remember why we went to Enum instead of str. Most probably for the following reasons:

using enums instead of raw values has multiple benefits (less mistakes, clearer API, easier to search usages etc)

# Would there be any technical challenge in supporting it? (e.g.
# torchscript issues?)
Comment on lines +49 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into that.

interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> None:
super().__init__()

self.size = (
self.size = ( # Isn't this check duplicated in _compute_resized_output_size()?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. I fell for that too. See #6514 (comment). TL;DR Resize(224) and Resize((224, 224)) is something different while the latter would be enforced if we used _setup_size unconditionally.

[size]
if isinstance(size, int)
else _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/transforms/_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
__all__ = ["StereoMatching"]


# I assume this file is actually not part of the transformsV2 work - it's just
# that StereoMatching is still protytpe?
# A: yes

class StereoMatching(torch.nn.Module):
def __init__(
self,
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import PIL.Image
import torch
from torch import nn
# I remember that the core team is open to making it public. Is still private?
# A: yes but core is committed to full BC
Comment on lines +7 to +8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.transforms.utils import check_type
from torchvision.utils import _log_api_usage_once
Expand All @@ -14,6 +16,8 @@ class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
# Q: Apart from Tensors (or datapoints) and PIL images, what are the other
# types that one may pass to a transform?
Comment on lines +19 to +20
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the most relevant will be str for example a path for the image file. We need to support the following use case:

transform = ...
dataset = ...
sample = dataset[0]
transform(sample)

Even without datasets v2, sample can contain basically anything. COCODetection is a good example. It returns a PIL.Image.Image as well a list of dicts with floats, ints, bools, and strings.


def __init__(self) -> None:
super().__init__()
Expand All @@ -35,6 +39,8 @@ def forward(self, *inputs: Any) -> Any:

params = self._get_params(flat_inputs)

# TODO: right now, any tensor or datapoint passed to forward() will be transformed
# the rest is bypassed
Comment on lines +42 to +43
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the TODO is here? Yes, this behavior is by design. What other behavior were you expecting? Even if we go through with #7028, it would still be limited to tensors (and in turn datapoints since they are tensors) and PIL images. Any other type would fail at the first dispatcher call.

Exceptions might be the type conversion transforms like ToImageTensor which also accepts numpy arrays.

flat_outputs = [
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs
]
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def _transform(
return F.to_image_pil(inpt, mode=self.mode)


# What is the "new naming scheme"?
# I was going to point out that "ToTensorImage" and "ToPILImage" would look more
# familiar
Comment on lines +65 to +67
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For kernels we chose {op_name}_{datapoint_name}_{backend_name}, e.g.

  • resize_image_tensor
  • resize_image_pil
  • resize_bounding_box

(backend_name only applies to images and is either tensor or pil)

Following that scheme for transforms we have {OpName}{DatapointName} and thus forming ToImagePIL rather than ToPILImage.

Still, since we didn't want to cause any disruption there (at least not in the beginning), we just silently alias the two.

# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL
Loading