Skip to content

Commit

Permalink
use fake self namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jan 20, 2023
1 parent 35860ff commit 21946f7
Showing 1 changed file with 15 additions and 36 deletions.
51 changes: 15 additions & 36 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import numbers
import warnings
from types import SimpleNamespace
from typing import Any, cast, Dict, List, Optional, Tuple, Union

import PIL.Image
Expand All @@ -12,7 +13,7 @@
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform

from ._transform import _RandomApplyTransform
from .utils import get_dimensions, has_any, is_simple_tensor, query_chw, query_spatial_size
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -53,24 +54,19 @@ def __init__(

self._log_ratio = torch.log(torch.tensor(self.ratio))

@staticmethod
def _get_params_internal(
img_c: int,
img_h: int,
img_w: int,
scale: Tuple[float, float],
log_ratio: torch.Tensor,
value: Optional[List[float]] = None,
) -> Tuple[int, int, int, int, Optional[torch.Tensor]]:
if value is not None and not (len(value) in (1, img_c)):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)

if self.value is not None and not (len(self.value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)

area = img_h * img_w

log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
Expand All @@ -83,34 +79,18 @@ def _get_params_internal(
if not (h < img_h and w < img_w):
continue

if value is None:
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
v = torch.tensor(self.value)[:, None, None]

i = int(torch.randint(0, img_h - h + 1, size=(1,)))
j = int(torch.randint(0, img_w - w + 1, size=(1,)))
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None

return i, j, h, w, v

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
return dict(
zip(
"ijhwv",
self._get_params_internal(
img_c,
img_h,
img_w,
self.scale,
self._log_ratio,
self.value, # type: ignore[arg-type]
),
)
)
return dict(i=i, j=j, h=h, w=w, v=v)

@staticmethod
def get_params(
Expand All @@ -119,13 +99,12 @@ def get_params(
ratio: Tuple[float, float],
value: Optional[List[float]] = None,
) -> Tuple[int, int, int, int, torch.Tensor]:
img_c, img_h, img_w = get_dimensions(image)
i, j, h, w, v = RandomErasing._get_params_internal(
img_c, img_h, img_w, scale, torch.log(torch.tensor(ratio)), value
)
self = SimpleNamespace(scale=scale, _log_ratio=torch.log(torch.tensor(ratio)), value=value)
params = RandomErasing._get_params(self, [image]) # type: ignore[arg-type]
v = params["v"]
if v is None:
v = image
return i, j, h, w, v
return params["i"], params["j"], params["h"], params["w"], v

def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
Expand Down

0 comments on commit 21946f7

Please sign in to comment.