-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2023-08-05 nightly release (84db2ac)
- Loading branch information
pytorchbot
committed
Aug 5, 2023
1 parent
ffdb719
commit db4f879
Showing
13 changed files
with
385 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
""" | ||
===================================== | ||
How to write your own Datapoint class | ||
===================================== | ||
This guide is intended for downstream library maintainers. We explain how to | ||
write your own datapoint class, and how to make it compatible with the built-in | ||
Torchvision v2 transforms. Before continuing, make sure you have read | ||
:ref:`sphx_glr_auto_examples_plot_datapoints.py`. | ||
""" | ||
|
||
# %% | ||
import torch | ||
import torchvision | ||
|
||
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that | ||
# some APIs may slightly change in the future | ||
torchvision.disable_beta_transforms_warning() | ||
|
||
from torchvision import datapoints | ||
from torchvision.transforms import v2 | ||
|
||
# %% | ||
# We will create a very simple class that just inherits from the base | ||
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover | ||
# what you need to know to implement your more elaborate uses-cases. If you need | ||
# to create a class that carries meta-data, take a look at how the | ||
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented | ||
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_. | ||
|
||
|
||
class MyDatapoint(datapoints.Datapoint): | ||
pass | ||
|
||
|
||
my_dp = MyDatapoint([1, 2, 3]) | ||
my_dp | ||
|
||
# %% | ||
# Now that we have defined our custom Datapoint class, we want it to be | ||
# compatible with the built-in torchvision transforms, and the functional API. | ||
# For that, we need to implement a kernel which performs the core of the | ||
# transformation, and then "hook" it to the functional that we want to support | ||
# via :func:`~torchvision.transforms.v2.functional.register_kernel`. | ||
# | ||
# We illustrate this process below: we create a kernel for the "horizontal flip" | ||
# operation of our MyDatapoint class, and register it to the functional API. | ||
|
||
from torchvision.transforms.v2 import functional as F | ||
|
||
|
||
@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint) | ||
def hflip_my_datapoint(my_dp, *args, **kwargs): | ||
print("Flipping!") | ||
out = my_dp.flip(-1) | ||
return MyDatapoint.wrap_like(my_dp, out) | ||
|
||
|
||
# %% | ||
# To understand why ``wrap_like`` is used, see | ||
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now, | ||
# we will explain it below in :ref:`param_forwarding`. | ||
# | ||
# .. note:: | ||
# | ||
# In our call to ``register_kernel`` above we used a string | ||
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We | ||
# could also have used the functional *itself*, i.e. | ||
# ``@register_kernel(dispatcher=F.hflip, ...)``. | ||
# | ||
# The functionals that you can be hooked into are the ones in | ||
# ``torchvision.transforms.v2.functional`` and they are documented in | ||
# :ref:`functional_transforms`. | ||
# | ||
# Now that we have registered our kernel, we can call the functional API on a | ||
# ``MyDatapoint`` instance: | ||
|
||
my_dp = MyDatapoint(torch.rand(3, 256, 256)) | ||
_ = F.hflip(my_dp) | ||
|
||
# %% | ||
# And we can also use the | ||
# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally: | ||
t = v2.RandomHorizontalFlip(p=1) | ||
_ = t(my_dp) | ||
|
||
# %% | ||
# .. note:: | ||
# | ||
# We cannot register a kernel for a transform class, we can only register a | ||
# kernel for a **functional**. The reason we can't register a transform | ||
# class is because one transform may internally rely on more than one | ||
# functional, so in general we can't register a single kernel for a given | ||
# class. | ||
# | ||
# .. _param_forwarding: | ||
# | ||
# Parameter forwarding, and ensuring future compatibility of your kernels | ||
# ----------------------------------------------------------------------- | ||
# | ||
# The functional API that you're hooking into is public and therefore | ||
# **backward** compatible: we guarantee that the parameters of these functionals | ||
# won't be removed or renamed without a proper deprecation cycle. However, we | ||
# don't guarantee **forward** compatibility, and we may add new parameters in | ||
# the future. | ||
# | ||
# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter | ||
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you | ||
# already defined and registered your own kernel as | ||
|
||
def hflip_my_datapoint(my_dp): # noqa | ||
print("Flipping!") | ||
out = my_dp.flip(-1) | ||
return MyDatapoint.wrap_like(my_dp, out) | ||
|
||
|
||
# %% | ||
# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to | ||
# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't | ||
# accept it. | ||
# | ||
# For this reason, we recommend to always define your kernels with | ||
# ``*args, **kwargs`` in their signature, as done above. This way, your kernel | ||
# will be able to accept any new parameter that we may add in the future. | ||
# (Technically, adding `**kwargs` only should be enough). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
""" | ||
=================================== | ||
How to write your own v2 transforms | ||
=================================== | ||
This guide explains how to write transforms that are compatible with the | ||
torchvision transforms V2 API. | ||
""" | ||
|
||
# %% | ||
import torch | ||
import torchvision | ||
|
||
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that | ||
# some APIs may slightly change in the future | ||
torchvision.disable_beta_transforms_warning() | ||
|
||
from torchvision import datapoints | ||
from torchvision.transforms import v2 | ||
|
||
|
||
# %% | ||
# Just create a ``nn.Module`` and override the ``forward`` method | ||
# =============================================================== | ||
# | ||
# In most cases, this is all you're going to need, as long as you already know | ||
# the structure of the input that your transform will expect. For example if | ||
# you're just doing image classification, your transform will typically accept a | ||
# single image as input, or a ``(img, label)`` input. So you can just hard-code | ||
# your ``forward`` method to accept just that, e.g. | ||
# | ||
# .. code:: python | ||
# | ||
# class MyCustomTransform(torch.nn.Module): | ||
# def forward(self, img, label): | ||
# # Do some transformations | ||
# return new_img, new_label | ||
# | ||
# .. note:: | ||
# | ||
# This means that if you have a custom transform that is already compatible | ||
# with the V1 transforms (those in ``torchvision.transforms``), it will | ||
# still work with the V2 transforms without any change! | ||
# | ||
# We will illustrate this more completely below with a typical detection case, | ||
# where our samples are just images, bounding boxes and labels: | ||
|
||
class MyCustomTransform(torch.nn.Module): | ||
def forward(self, img, bboxes, label): # we assume inputs are always structured like this | ||
print( | ||
f"I'm transforming an image of shape {img.shape} " | ||
f"with bboxes = {bboxes}\n{label = }" | ||
) | ||
# Do some transformations. Here, we're just passing though the input | ||
return img, bboxes, label | ||
|
||
|
||
transforms = v2.Compose([ | ||
MyCustomTransform(), | ||
v2.RandomResizedCrop((224, 224), antialias=True), | ||
v2.RandomHorizontalFlip(p=1), | ||
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) | ||
]) | ||
|
||
H, W = 256, 256 | ||
img = torch.rand(3, H, W) | ||
bboxes = datapoints.BoundingBoxes( | ||
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]), | ||
format="XYXY", | ||
canvas_size=(H, W) | ||
) | ||
label = 3 | ||
|
||
out_img, out_bboxes, out_label = transforms(img, bboxes, label) | ||
# %% | ||
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") | ||
# %% | ||
# .. note:: | ||
# While working with datapoint classes in your code, make sure to | ||
# familiarize yourself with this section: | ||
# :ref:`datapoint_unwrapping_behaviour` | ||
# | ||
# Supporting arbitrary input structures | ||
# ===================================== | ||
# | ||
# In the section above, we have assumed that you already know the structure of | ||
# your inputs and that you're OK with hard-coding this expected structure in | ||
# your code. If you want your custom transforms to be as flexible as possible, | ||
# this can be a bit limitting. | ||
# | ||
# A key feature of the builtin Torchvision V2 transforms is that they can accept | ||
# arbitrary input structure and return the same structure as output (with | ||
# transformed entries). For example, transforms can accept a single image, or a | ||
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: | ||
|
||
structured_input = { | ||
"img": img, | ||
"annotations": (bboxes, label), | ||
"something_that_will_be_ignored": (1, "hello") | ||
} | ||
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) | ||
|
||
assert isinstance(structured_output, dict) | ||
assert structured_output["something_that_will_be_ignored"] == (1, "hello") | ||
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") | ||
|
||
# %% | ||
# If you want to reproduce this behavior in your own transform, we invite you to | ||
# look at our `code | ||
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_ | ||
# and adapt it to your needs. | ||
# | ||
# In brief, the core logic is to unpack the input into a flat list using `pytree | ||
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and | ||
# then transform only the entries that can be transformed (the decision is made | ||
# based on the **class** of the entries, as all datapoints are | ||
# tensor-subclasses) plus some custom logic that is out of score here - check the | ||
# code for details. The (potentially transformed) entries are then repacked and | ||
# returned, in the same structure as the input. | ||
# | ||
# We do not provide public dev-facing tools to achieve that at this time, but if | ||
# this is something that would be valuable to you, please let us know by opening | ||
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_. |
Oops, something went wrong.