Skip to content

Commit

Permalink
[fbsync] refactor Datapoint dispatch mechanism (#7747)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>

Reviewed By: matteobettini

Differential Revision: D48642281

fbshipit-source-id: 33a1dcba4bbc254a26ae091452a61609bb80f663
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Aug 25, 2023
1 parent db56d55 commit 8bd2151
Show file tree
Hide file tree
Showing 24 changed files with 1,215 additions and 1,428 deletions.
4 changes: 4 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,10 @@ def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))


def make_video_tensor(*args, **kwargs):
return make_video(*args, **kwargs).as_subclass(torch.Tensor)


def make_video_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
*,
Expand Down
6 changes: 4 additions & 2 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision.datapoints._datapoint import Datapoint
from torchvision import datapoints
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
Expand All @@ -588,7 +588,9 @@ def test_transforms_v2_wrapper(self, config):
assert len(wrapped_dataset) == info["num_examples"]

wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
assert tree_any(
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
Expand Down
10 changes: 5 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,12 +1344,12 @@ def test_antialias_warning():
transforms.RandomResize(10, 20)(tensor_img)

with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20))
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20))

with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resize((20, 20))
F.resize(datapoints.Video(tensor_video), (20, 20))
with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20))
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20))

with warnings.catch_warnings():
warnings.simplefilter("error")
Expand All @@ -1363,8 +1363,8 @@ def test_antialias_warning():
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img)

datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
Expand Down
52 changes: 19 additions & 33 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import math
import os
import re

from typing import get_type_hints
from unittest import mock

import numpy as np
import PIL.Image
import pytest

import torch

from common_utils import (
Expand All @@ -27,6 +25,7 @@
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.utils import is_simple_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
Expand Down Expand Up @@ -424,12 +423,18 @@ def test_pil_output_type(self, info, args_kwargs):
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load()

method_name = info.id
method = getattr(datapoint, method_name)
datapoint_type = type(datapoint)
spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
input_type = type(datapoint)

wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]

info.dispatcher(datapoint, *other_args, **kwargs)
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]

spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
info.dispatcher(datapoint, *other_args, **kwargs)

spy.assert_called_once()

Expand Down Expand Up @@ -462,9 +467,12 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
kernel_params = list(kernel_signature.parameters.values())[1:]

# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicit passed to the kernel.
datapoint_type_metadata = datapoint_type.__annotations__.keys()
kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
# explicitly passed to the kernel.
input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"},
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]

dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
Expand All @@ -481,28 +489,6 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi

assert dispatcher_param == kernel_param

@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_datapoint_signatures_consistency(self, info):
try:
datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")

dispatcher_signature = inspect.signature(info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

datapoint_signature = inspect.signature(datapoint_method)
datapoint_params = list(datapoint_signature.parameters.values())[1:]

# Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
# defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# natively concrete dispatcher annotations.
datapoint_annotations = get_type_hints(datapoint_method)
for param in datapoint_params:
param._annotation = datapoint_annotations[param.name]

assert dispatcher_params == datapoint_params

@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_unkown_type(self, info):
unkown_input = object()
Expand Down
Loading

0 comments on commit 8bd2151

Please sign in to comment.