Skip to content

Commit b8d0030

Browse files
authored
Merge branch 'main' into ljanflajnfljanfe
2 parents 1e7f83a + edde825 commit b8d0030

File tree

10 files changed

+64
-69
lines changed

10 files changed

+64
-69
lines changed

gallery/plot_transforms_v2_e2e.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
"""
1111

1212
import pathlib
13-
from collections import defaultdict
1413

1514
import PIL.Image
1615

@@ -99,9 +98,7 @@ def load_example_coco_detection_dataset(**kwargs):
9998
transform = transforms.Compose(
10099
[
101100
transforms.RandomPhotometricDistort(),
102-
transforms.RandomZoomOut(
103-
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
104-
),
101+
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
105102
transforms.RandomIoUCrop(),
106103
transforms.RandomHorizontalFlip(),
107104
transforms.ToImageTensor(),

references/segmentation/presets.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections import defaultdict
2-
31
import torch
42

53

@@ -48,7 +46,7 @@ def __init__(
4846
if use_v2:
4947
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
5048
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
51-
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]
49+
transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})]
5250

5351
transforms += [T.RandomCrop(crop_size)]
5452

references/segmentation/v2_extras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
88
def __init__(self, size, fill=0):
99
super().__init__()
1010
self.size = size
11-
self.fill = v2._geometry._setup_fill_arg(fill)
11+
self.fill = v2._utils._setup_fill_arg(fill)
1212

1313
def _get_params(self, sample):
1414
_, height, width = v2.utils.query_chw(sample)
@@ -20,7 +20,7 @@ def _transform(self, inpt, params):
2020
if not params["needs_padding"]:
2121
return inpt
2222

23-
fill = self.fill[type(inpt)]
23+
fill = v2._utils._get_fill(self.fill, type(inpt))
2424
fill = v2._utils._convert_fill_arg(fill)
2525

2626
return v2.functional.pad(inpt, padding=params["padding"], fill=fill)

test/test_transforms_v2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import random
44
import textwrap
55
import warnings
6-
from collections import defaultdict
76

87
import numpy as np
98

@@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
14751474
elif data_augmentation == "ssd":
14761475
t = [
14771476
transforms.RandomPhotometricDistort(p=1),
1478-
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1),
1477+
transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1),
14791478
transforms.RandomIoUCrop(),
14801479
transforms.RandomHorizontalFlip(p=1),
14811480
to_tensor,

test/test_transforms_v2_consistency.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import inspect
55
import random
66
import re
7-
from collections import defaultdict
87
from pathlib import Path
98

109
import numpy as np
@@ -30,6 +29,7 @@
3029

3130
from torchvision.transforms import functional as legacy_F
3231
from torchvision.transforms.v2 import functional as prototype_F
32+
from torchvision.transforms.v2._utils import _get_fill
3333
from torchvision.transforms.v2.functional import to_image_pil
3434
from torchvision.transforms.v2.utils import query_size
3535

@@ -1181,7 +1181,7 @@ def _transform(self, inpt, params):
11811181
if not params["needs_padding"]:
11821182
return inpt
11831183

1184-
fill = self.fill[type(inpt)]
1184+
fill = _get_fill(self.fill, type(inpt))
11851185
return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
11861186

11871187

@@ -1243,7 +1243,7 @@ def check(self, t, t_ref, data_kwargs=None):
12431243
seg_transforms.RandomCrop(size=480),
12441244
v2_transforms.Compose(
12451245
[
1246-
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
1246+
PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
12471247
v2_transforms.RandomCrop(size=480),
12481248
]
12491249
),

torchvision/prototype/transforms/_geometry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from torchvision import datapoints
77
from torchvision.prototype.datapoints import Label, OneHotLabel
88
from torchvision.transforms.v2 import functional as F, Transform
9-
from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size
9+
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
1010
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size
1111

1212

1313
class FixedSizeCrop(Transform):
1414
def __init__(
1515
self,
1616
size: Union[int, Sequence[int]],
17-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
17+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
1818
padding_mode: str = "constant",
1919
) -> None:
2020
super().__init__()
@@ -119,7 +119,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
119119
)
120120

121121
if params["needs_pad"]:
122-
fill = self._fill[type(inpt)]
122+
fill = _get_fill(self._fill, type(inpt))
123123
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
124124

125125
return inpt

torchvision/prototype/transforms/_misc.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
1+
import functools
12
import warnings
2-
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
3+
from collections import defaultdict
4+
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
35

46
import torch
57

68
from torchvision import datapoints
79
from torchvision.transforms.v2 import Transform
810

9-
from torchvision.transforms.v2._utils import _get_defaultdict
1011
from torchvision.transforms.v2.utils import is_simple_tensor
1112

1213

14+
T = TypeVar("T")
15+
16+
17+
def _default_arg(value: T) -> T:
18+
return value
19+
20+
21+
def _get_defaultdict(default: T) -> Dict[Any, T]:
22+
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
23+
# If it were possible, we could replace this with `defaultdict(lambda: default)`
24+
return defaultdict(functools.partial(_default_arg, default))
25+
26+
1327
class PermuteDimensions(Transform):
1428
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
1529

torchvision/transforms/v2/_auto_augment.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.transforms.v2.functional._geometry import _check_interpolation
1212
from torchvision.transforms.v2.functional._meta import get_size
1313

14-
from ._utils import _setup_fill_arg
14+
from ._utils import _get_fill, _setup_fill_arg
1515
from .utils import check_type, is_simple_tensor
1616

1717

@@ -20,7 +20,7 @@ def __init__(
2020
self,
2121
*,
2222
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
23-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
23+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
2424
) -> None:
2525
super().__init__()
2626
self.interpolation = _check_interpolation(interpolation)
@@ -80,9 +80,9 @@ def _apply_image_or_video_transform(
8080
transform_id: str,
8181
magnitude: float,
8282
interpolation: Union[InterpolationMode, int],
83-
fill: Dict[Type, datapoints._FillTypeJIT],
83+
fill: Dict[Union[Type, str], datapoints._FillTypeJIT],
8484
) -> Union[datapoints._ImageType, datapoints._VideoType]:
85-
fill_ = fill[type(image)]
85+
fill_ = _get_fill(fill, type(image))
8686

8787
if transform_id == "Identity":
8888
return image
@@ -214,7 +214,7 @@ def __init__(
214214
self,
215215
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
216216
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
217-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
217+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
218218
) -> None:
219219
super().__init__(interpolation=interpolation, fill=fill)
220220
self.policy = policy
@@ -394,7 +394,7 @@ def __init__(
394394
magnitude: int = 9,
395395
num_magnitude_bins: int = 31,
396396
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
397-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
397+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
398398
) -> None:
399399
super().__init__(interpolation=interpolation, fill=fill)
400400
self.num_ops = num_ops
@@ -467,7 +467,7 @@ def __init__(
467467
self,
468468
num_magnitude_bins: int = 31,
469469
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
470-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
470+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
471471
):
472472
super().__init__(interpolation=interpolation, fill=fill)
473473
self.num_magnitude_bins = num_magnitude_bins
@@ -550,7 +550,7 @@ def __init__(
550550
alpha: float = 1.0,
551551
all_ops: bool = True,
552552
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
553-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
553+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
554554
) -> None:
555555
super().__init__(interpolation=interpolation, fill=fill)
556556
self._PARAMETER_MAX = 10

torchvision/transforms/v2/_geometry.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_check_padding_arg,
1818
_check_padding_mode_arg,
1919
_check_sequence_input,
20+
_get_fill,
2021
_setup_angle,
2122
_setup_fill_arg,
2223
_setup_float_or_seq,
@@ -487,7 +488,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
487488
def __init__(
488489
self,
489490
padding: Union[int, Sequence[int]],
490-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
491+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
491492
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
492493
) -> None:
493494
super().__init__()
@@ -504,7 +505,7 @@ def __init__(
504505
self.padding_mode = padding_mode
505506

506507
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
507-
fill = self._fill[type(inpt)]
508+
fill = _get_fill(self._fill, type(inpt))
508509
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
509510

510511

@@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):
542543

543544
def __init__(
544545
self,
545-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
546+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
546547
side_range: Sequence[float] = (1.0, 4.0),
547548
p: float = 0.5,
548549
) -> None:
@@ -574,7 +575,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
574575
return dict(padding=padding)
575576

576577
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
577-
fill = self._fill[type(inpt)]
578+
fill = _get_fill(self._fill, type(inpt))
578579
return F.pad(inpt, **params, fill=fill)
579580

580581

@@ -620,7 +621,7 @@ def __init__(
620621
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
621622
expand: bool = False,
622623
center: Optional[List[float]] = None,
623-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
624+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
624625
) -> None:
625626
super().__init__()
626627
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
@@ -640,7 +641,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
640641
return dict(angle=angle)
641642

642643
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
643-
fill = self._fill[type(inpt)]
644+
fill = _get_fill(self._fill, type(inpt))
644645
return F.rotate(
645646
inpt,
646647
**params,
@@ -702,7 +703,7 @@ def __init__(
702703
scale: Optional[Sequence[float]] = None,
703704
shear: Optional[Union[int, float, Sequence[float]]] = None,
704705
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
705-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
706+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
706707
center: Optional[List[float]] = None,
707708
) -> None:
708709
super().__init__()
@@ -762,7 +763,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
762763
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
763764

764765
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
765-
fill = self._fill[type(inpt)]
766+
fill = _get_fill(self._fill, type(inpt))
766767
return F.affine(
767768
inpt,
768769
**params,
@@ -840,7 +841,7 @@ def __init__(
840841
size: Union[int, Sequence[int]],
841842
padding: Optional[Union[int, Sequence[int]]] = None,
842843
pad_if_needed: bool = False,
843-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
844+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
844845
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
845846
) -> None:
846847
super().__init__()
@@ -918,7 +919,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
918919

919920
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
920921
if params["needs_pad"]:
921-
fill = self._fill[type(inpt)]
922+
fill = _get_fill(self._fill, type(inpt))
922923
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
923924

924925
if params["needs_crop"]:
@@ -959,7 +960,7 @@ def __init__(
959960
distortion_scale: float = 0.5,
960961
p: float = 0.5,
961962
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
962-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
963+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
963964
) -> None:
964965
super().__init__(p=p)
965966

@@ -1002,7 +1003,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
10021003
return dict(coefficients=perspective_coeffs)
10031004

10041005
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1005-
fill = self._fill[type(inpt)]
1006+
fill = _get_fill(self._fill, type(inpt))
10061007
return F.perspective(
10071008
inpt,
10081009
None,
@@ -1061,7 +1062,7 @@ def __init__(
10611062
alpha: Union[float, Sequence[float]] = 50.0,
10621063
sigma: Union[float, Sequence[float]] = 5.0,
10631064
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1064-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
1065+
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
10651066
) -> None:
10661067
super().__init__()
10671068
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
@@ -1095,7 +1096,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
10951096
return dict(displacement=displacement)
10961097

10971098
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1098-
fill = self._fill[type(inpt)]
1099+
fill = _get_fill(self._fill, type(inpt))
10991100
return F.elastic(
11001101
inpt,
11011102
**params,

0 commit comments

Comments
 (0)