Skip to content

Commit

Permalink
Add GaussianNoise transforms (#8381)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
Richienb and NicolasHug authored May 31, 2024
1 parent b0f9f7b commit 6e18cea
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ Color
v2.RGB
v2.RandomGrayscale
v2.GaussianBlur
v2.GaussianNoise
v2.RandomInvert
v2.RandomPosterize
v2.RandomSolarize
Expand All @@ -368,6 +369,7 @@ Functionals
v2.functional.grayscale_to_rgb
v2.functional.to_grayscale
v2.functional.gaussian_blur
v2.functional.gaussian_noise
v2.functional.invert
v2.functional.posterize
v2.functional.solarize
Expand Down
78 changes: 76 additions & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):

input = input.as_subclass(torch.Tensor)
with ignore_jit_no_profile_information_warning():
actual = kernel_scripted(input, *args, **kwargs)
expected = kernel(input, *args, **kwargs)
with freeze_rng_state():
actual = kernel_scripted(input, *args, **kwargs)
with freeze_rng_state():
expected = kernel(input, *args, **kwargs)

assert_close(actual, expected, rtol=rtol, atol=atol)

Expand Down Expand Up @@ -3238,6 +3240,78 @@ def test_functional_image_correctness(self, dimensions, kernel_size, sigma, dtyp
torch.testing.assert_close(actual, expected, rtol=0, atol=1)


class TestGaussianNoise:
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_kernel(self, make_input):
check_kernel(
F.gaussian_noise,
make_input(dtype=torch.float32),
# This cannot pass because the noise on a batch in not per-image
check_batched_vs_unbatched=False,
)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_functional(self, make_input):
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.gaussian_noise, torch.Tensor),
(F.gaussian_noise_image, tv_tensors.Image),
(F.gaussian_noise_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_transform(self, make_input):
def adapter(_, input, __):
# This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
# Same for PIL images
for key, value in input.items():
if isinstance(value, torch.Tensor) and not value.is_floating_point():
input[key] = value.to(torch.float32)
if isinstance(value, PIL.Image.Image):
input[key] = F.pil_to_tensor(value).to(torch.float32)
return input

check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter)

def test_bad_input(self):
with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."):
F.gaussian_noise(make_image_pil())
with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"):
F.gaussian_noise(make_image(dtype=torch.uint8))
with pytest.raises(ValueError, match="sigma shouldn't be negative"):
F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1)

def test_clip(self):
img = make_image(dtype=torch.float32)

out = F.gaussian_noise(img, mean=100, clip=False)
assert out.min() > 50

out = F.gaussian_noise(img, mean=100, clip=True)
assert (out == 1).all()

out = F.gaussian_noise(img, mean=-100, clip=False)
assert out.min() < -50

out = F.gaussian_noise(img, mean=-100, clip=True)
assert (out == 0).all()


class TestAutoAugmentTransforms:
# These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling.
# It's typically very hard to test the effect on some parameters without heavy mocking logic.
Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ._misc import (
ConvertImageDtype,
GaussianBlur,
GaussianNoise,
Identity,
Lambda,
LinearTransformation,
Expand Down
27 changes: 27 additions & 0 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,33 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params)


class GaussianNoise(Transform):
"""Add gaussian noise to images or videos.
The input tensor is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
Each image or frame in a batch will be transformed independently i.e. the
noise added to each image will be different.
The input tensor is also expected to be of float dtype in ``[0, 1]``.
This transform does not support PIL images.
Args:
mean (float): Mean of the sampled normal distribution. Default is 0.
sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1.
clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True.
"""

def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None:
super().__init__()
self.mean = mean
self.sigma = sigma
self.clip = clip

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip)


class ToDtype(Transform):
"""Converts the input to a specific dtype, optionally scaling the values for images or videos.
Expand Down
3 changes: 3 additions & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@
gaussian_blur,
gaussian_blur_image,
gaussian_blur_video,
gaussian_noise,
gaussian_noise_image,
gaussian_noise_video,
normalize,
normalize_image,
normalize_video,
Expand Down
38 changes: 38 additions & 0 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,44 @@ def gaussian_blur_video(
return gaussian_blur_image(video, kernel_size, sigma)


def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.GaussianNoise`"""
if torch.jit.is_scripting():
return gaussian_noise_image(inpt, mean=mean, sigma=sigma)

_log_api_usage_once(gaussian_noise)

kernel = _get_kernel(gaussian_noise, type(inpt))
return kernel(inpt, mean=mean, sigma=sigma, clip=clip)


@_register_kernel_internal(gaussian_noise, torch.Tensor)
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
if not image.is_floating_point():
raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}")
if sigma < 0:
raise ValueError(f"sigma shouldn't be negative. Got {sigma}")

noise = mean + torch.randn_like(image) * sigma
out = image + noise
if clip:
out = torch.clamp(out, 0, 1)
return out


@_register_kernel_internal(gaussian_noise, tv_tensors.Video)
def gaussian_noise_video(video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
return gaussian_noise_image(video, mean=mean, sigma=sigma, clip=clip)


@_register_kernel_internal(gaussian_noise, PIL.Image.Image)
def _gaussian_noise_pil(
video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True
) -> PIL.Image.Image:
raise ValueError("Gaussian Noise is not implemented for PIL images.")


def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
"""See :func:`~torchvision.transforms.v2.ToDtype` for details."""
if torch.jit.is_scripting():
Expand Down

0 comments on commit 6e18cea

Please sign in to comment.