From 4e61ff581adccbcc0acfe00b40d4e4651b44002b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 15 Feb 2023 15:39:29 +0100 Subject: [PATCH 1/2] make fill defaultdict an implementation detail --- torchvision/prototype/transforms/_geometry.py | 48 ++++++++++--------- .../prototype/transforms/_transform.py | 14 +----- 2 files changed, 27 insertions(+), 35 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 69238760be5..513d7987412 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -256,9 +256,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError( - f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." - ) + raise ValueError(f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill}.") return params @@ -277,11 +275,12 @@ def __init__( if not isinstance(padding, int): padding = list(padding) self.padding = padding - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] @@ -294,7 +293,8 @@ def __init__( ) -> None: super().__init__(p=p) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) _check_sequence_input(side_range, "side_range", req_sizes=(2,)) @@ -319,7 +319,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(padding=padding) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.pad(inpt, **params, fill=fill) @@ -339,7 +339,8 @@ def __init__( self.interpolation = _check_interpolation(interpolation) self.expand = expand - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) @@ -351,7 +352,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.rotate( inpt, **params, @@ -396,7 +397,8 @@ def __init__( self.shear = shear self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) @@ -431,7 +433,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.affine( inpt, **params, @@ -448,9 +450,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError( - f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." - ) + raise ValueError(f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill}.") padding = self.padding if padding is not None: @@ -479,7 +479,8 @@ def __init__( self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] self.pad_if_needed = pad_if_needed - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: @@ -542,7 +543,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: @@ -568,7 +569,8 @@ def __init__( self.distortion_scale = distortion_scale self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_spatial_size(flat_inputs) @@ -601,7 +603,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(coefficients=perspective_coeffs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.perspective( inpt, None, @@ -627,7 +629,8 @@ def __init__( self.sigma = _setup_float_or_seq(sigma, "sigma", 2) self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = list(query_spatial_size(flat_inputs)) @@ -653,7 +656,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] return F.elastic( inpt, **params, @@ -838,7 +841,8 @@ def __init__( self.crop_height = size[0] self.crop_width = size[1] - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode @@ -936,7 +940,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) if params["needs_pad"]: - fill = self.fill[type(inpt)] + fill = self._fill[type(inpt)] inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 7f3c03d5e67..e96799f18af 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -114,24 +114,12 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen # if the v2 transform introduced new parameters that are not support by the v1 transform. common_attrs = nn.Module().__dict__.keys() - params = { + return { attr: value for attr, value in self.__dict__.items() if not attr.startswith("_") and attr not in common_attrs } - # transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed - # with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value - # for the different datapoint types. Below we extract the value for tensors and return that together with the - # other params. - # This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and - # `RandomRotation` - if "fill" in params: - fill_type_defaultdict = params.pop("fill") - params["fill"] = fill_type_defaultdict[torch.Tensor] - - return params - def __prepare_scriptable__(self) -> nn.Module: # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return # value is used for scripting over the original object that should have been scripted. Since the v1 transforms From c27381e86ff6aeb59926c8ceae628cb4589a181a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 15 Feb 2023 15:54:37 +0100 Subject: [PATCH 2/2] cleanup --- torchvision/prototype/transforms/_geometry.py | 4 ++-- torchvision/prototype/transforms/_transform.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 513d7987412..48c737c979e 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -256,7 +256,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError(f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill}.") + raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") return params @@ -450,7 +450,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if not (params["fill"] is None or isinstance(params["fill"], (int, float))): - raise ValueError(f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill}.") + raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") padding = self.padding if padding is not None: diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index e96799f18af..277241a0085 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -108,9 +108,8 @@ def __init_subclass__(cls) -> None: def _extract_params_for_v1_transform(self) -> Dict[str, Any]: # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current - # v2 transform instance. It does two things: - # 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general - # 2. If available handle the `fill` attribute for v1 compatibility (see below for details) + # v2 transform instance. It extracts all available public attributes that are specific to that transform and + # not `nn.Module` in general. # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen # if the v2 transform introduced new parameters that are not support by the v1 transform. common_attrs = nn.Module().__dict__.keys()