Skip to content

Commit

Permalink
make fill defaultdict an implementation detail (#7258)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Feb 16, 2023
1 parent b7892d3 commit efd6bc0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 38 deletions.
5 changes: 3 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(
self.crop_height = size[0]
self.crop_width = size[1]

self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)

self.padding_mode = padding_mode

Expand Down Expand Up @@ -118,7 +119,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)

if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)

return inpt
43 changes: 23 additions & 20 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()

if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")

return params

Expand All @@ -276,11 +274,12 @@ def __init__(
if not isinstance(padding, int):
padding = list(padding)
self.padding = padding
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]


Expand All @@ -293,7 +292,8 @@ def __init__(
) -> None:
super().__init__(p=p)

self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)

_check_sequence_input(side_range, "side_range", req_sizes=(2,))

Expand All @@ -318,7 +318,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(padding=padding)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.pad(inpt, **params, fill=fill)


Expand All @@ -338,7 +338,8 @@ def __init__(
self.interpolation = _check_interpolation(interpolation)
self.expand = expand

self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)

if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
Expand All @@ -350,7 +351,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(angle=angle)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.rotate(
inpt,
**params,
Expand Down Expand Up @@ -395,7 +396,8 @@ def __init__(
self.shear = shear

self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)

if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
Expand Down Expand Up @@ -430,7 +432,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(angle=angle, translate=translate, scale=scale, shear=shear)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.affine(
inpt,
**params,
Expand All @@ -447,9 +449,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()

if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")

padding = self.padding
if padding is not None:
Expand Down Expand Up @@ -478,7 +478,8 @@ def __init__(

self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
self.pad_if_needed = pad_if_needed
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -541,7 +542,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)

if params["needs_crop"]:
Expand All @@ -567,7 +568,8 @@ def __init__(

self.distortion_scale = distortion_scale
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
Expand Down Expand Up @@ -600,7 +602,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(coefficients=perspective_coeffs)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.perspective(
inpt,
None,
Expand All @@ -626,7 +628,8 @@ def __init__(
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)

self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs))
Expand All @@ -652,7 +655,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(displacement=displacement)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.elastic(
inpt,
**params,
Expand Down
19 changes: 3 additions & 16 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,30 +108,17 @@ def __init_subclass__(cls) -> None:

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# v2 transform instance. It extracts all available public attributes that are specific to that transform and
# not `nn.Module` in general.
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs = nn.Module().__dict__.keys()
params = {
return {
attr: value
for attr, value in self.__dict__.items()
if not attr.startswith("_") and attr not in common_attrs
}

# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if "fill" in params:
fill_type_defaultdict = params.pop("fill")
params["fill"] = fill_type_defaultdict[torch.Tensor]

return params

def __prepare_scriptable__(self) -> nn.Module:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
Expand Down

0 comments on commit efd6bc0

Please sign in to comment.