-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[proto] Fixed RandAug and all AA consistency tests #6519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
43ac6db
ad0a553
9732814
768928c
45102d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,27 +69,46 @@ def _apply_image_transform( | |
interpolation: InterpolationMode, | ||
fill: Union[int, float, Sequence[int], Sequence[float]], | ||
) -> Any: | ||
|
||
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 | ||
# So, we have to put fill as None if fill == 0 | ||
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]] | ||
if isinstance(fill, int) and fill == 0: | ||
fill_ = None | ||
else: | ||
fill_ = fill | ||
|
||
if transform_id == "Identity": | ||
return image | ||
elif transform_id == "ShearX": | ||
# magnitude should be arctan(magnitude) | ||
# official autoaug: (1, level, 0, 0, 1, 0) | ||
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290 | ||
# compared to | ||
# torchvision: (1, tan(level), 0, 0, 1, 0) | ||
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 | ||
return F.affine( | ||
image, | ||
angle=0.0, | ||
translate=[0, 0], | ||
scale=1.0, | ||
shear=[math.degrees(magnitude), 0.0], | ||
shear=[math.degrees(math.atan(magnitude)), 0.0], | ||
interpolation=interpolation, | ||
fill=fill, | ||
fill=fill_, | ||
center=[0, 0], | ||
) | ||
elif transform_id == "ShearY": | ||
# magnitude should be arctan(magnitude) | ||
# See above | ||
return F.affine( | ||
image, | ||
angle=0.0, | ||
translate=[0, 0], | ||
scale=1.0, | ||
shear=[0.0, math.degrees(magnitude)], | ||
shear=[0.0, math.degrees(math.atan(magnitude))], | ||
interpolation=interpolation, | ||
fill=fill, | ||
fill=fill_, | ||
center=[0, 0], | ||
) | ||
elif transform_id == "TranslateX": | ||
return F.affine( | ||
|
@@ -99,7 +118,7 @@ def _apply_image_transform( | |
scale=1.0, | ||
shear=[0.0, 0.0], | ||
interpolation=interpolation, | ||
fill=fill, | ||
fill=fill_, | ||
) | ||
elif transform_id == "TranslateY": | ||
return F.affine( | ||
|
@@ -109,10 +128,10 @@ def _apply_image_transform( | |
scale=1.0, | ||
shear=[0.0, 0.0], | ||
interpolation=interpolation, | ||
fill=fill, | ||
fill=fill_, | ||
) | ||
elif transform_id == "Rotate": | ||
return F.rotate(image, angle=magnitude) | ||
return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_) | ||
elif transform_id == "Brightness": | ||
return F.adjust_brightness(image, brightness_factor=1.0 + magnitude) | ||
elif transform_id == "Color": | ||
|
@@ -340,19 +359,17 @@ def forward(self, *inputs: Any) -> Any: | |
sample = inputs if len(inputs) > 1 else inputs[0] | ||
|
||
id, image = self._extract_image(sample) | ||
num_channels, height, width = get_chw(image) | ||
_, height, width = get_chw(image) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, but weird that |
||
|
||
for _ in range(self.num_ops): | ||
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) | ||
|
||
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) | ||
if magnitudes is not None: | ||
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) | ||
magnitude = float(magnitudes[self.magnitude]) | ||
if signed and torch.rand(()) <= 0.5: | ||
magnitude *= -1 | ||
else: | ||
magnitude = 0.0 | ||
|
||
image = self._apply_image_transform( | ||
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill | ||
) | ||
|
@@ -397,7 +414,7 @@ def forward(self, *inputs: Any) -> Any: | |
sample = inputs if len(inputs) > 1 else inputs[0] | ||
|
||
id, image = self._extract_image(sample) | ||
num_channels, height, width = get_chw(image) | ||
_, height, width = get_chw(image) | ||
|
||
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) | ||
|
||
|
@@ -467,7 +484,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: | |
def forward(self, *inputs: Any) -> Any: | ||
sample = inputs if len(inputs) > 1 else inputs[0] | ||
id, orig_image = self._extract_image(sample) | ||
num_channels, height, width = get_chw(orig_image) | ||
_, height, width = get_chw(orig_image) | ||
|
||
if isinstance(orig_image, torch.Tensor): | ||
image = orig_image | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -379,8 +379,12 @@ def affine_segmentation_mask( | |
|
||
|
||
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]: | ||
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need to change this in the stable transforms as well? I would feel a lot more confident in this patch if the stable CI confirms this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stable is using |
||
# So, we can't reassign fill to 0 | ||
# if fill is None: | ||
# fill = 0 | ||
if fill is None: | ||
fill = 0 | ||
return fill | ||
|
||
# This cast does Sequence -> List[float] to please mypy and torch.jit.script | ||
if not isinstance(fill, (int, float)): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -549,8 +549,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L | |
|
||
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice | ||
if fill is not None: | ||
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) | ||
img = torch.cat((img, dummy), dim=1) | ||
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, |
||
img = torch.cat((img, mask), dim=1) | ||
|
||
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we use
assert_equal
, i.e.assert_close(..., rtol=0, atol=0)
for the consistency checks? Any deviation means there is a difference that we don't want. Since we only compare tensor to tensor and PIL to PIL, the usual needed tolerances are probably not needed here.In the framework I've added for consistency tests, I also added
vision/test/test_prototype_transforms_consistency.py
Line 28 in 0563b18
that handles PIL and tensor images simultaneously.