Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GaussianNoise #8381

Merged
merged 6 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ Color
v2.RGB
v2.RandomGrayscale
v2.GaussianBlur
v2.GaussianNoise
v2.RandomInvert
v2.RandomPosterize
v2.RandomSolarize
Expand All @@ -368,6 +369,7 @@ Functionals
v2.functional.grayscale_to_rgb
v2.functional.to_grayscale
v2.functional.gaussian_blur
v2.functional.gaussian_noise
v2.functional.invert
v2.functional.posterize
v2.functional.solarize
Expand Down
78 changes: 76 additions & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):

input = input.as_subclass(torch.Tensor)
with ignore_jit_no_profile_information_warning():
actual = kernel_scripted(input, *args, **kwargs)
expected = kernel(input, *args, **kwargs)
with freeze_rng_state():
actual = kernel_scripted(input, *args, **kwargs)
with freeze_rng_state():
expected = kernel(input, *args, **kwargs)

assert_close(actual, expected, rtol=rtol, atol=atol)

Expand Down Expand Up @@ -3238,6 +3240,78 @@ def test_functional_image_correctness(self, dimensions, kernel_size, sigma, dtyp
torch.testing.assert_close(actual, expected, rtol=0, atol=1)


class TestGaussianNoise:
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_kernel(self, make_input):
check_kernel(
F.gaussian_noise,
make_input(dtype=torch.float32),
# This cannot pass because the noise on a batch in not per-image
check_batched_vs_unbatched=False,
)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_functional(self, make_input):
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.gaussian_noise, torch.Tensor),
(F.gaussian_noise_image, tv_tensors.Image),
(F.gaussian_noise_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_transform(self, make_input):
def adapter(_, input, __):
# This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
# Same for PIL images
for key, value in input.items():
if isinstance(value, torch.Tensor) and not value.is_floating_point():
input[key] = value.to(torch.float32)
if isinstance(value, PIL.Image.Image):
input[key] = F.pil_to_tensor(value).to(torch.float32)
return input

check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter)

def test_bad_input(self):
with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."):
F.gaussian_noise(make_image_pil())
with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"):
F.gaussian_noise(make_image(dtype=torch.uint8))
with pytest.raises(ValueError, match="sigma shouldn't be negative"):
F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1)

def test_clip(self):
img = make_image(dtype=torch.float32)

out = F.gaussian_noise(img, mean=100, clip=False)
assert out.min() > 50

out = F.gaussian_noise(img, mean=100, clip=True)
assert (out == 1).all()

out = F.gaussian_noise(img, mean=-100, clip=False)
assert out.min() < -50

out = F.gaussian_noise(img, mean=-100, clip=True)
assert (out == 0).all()


class TestAutoAugmentTransforms:
# These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling.
# It's typically very hard to test the effect on some parameters without heavy mocking logic.
Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ._misc import (
ConvertImageDtype,
GaussianBlur,
GaussianNoise,
Identity,
Lambda,
LinearTransformation,
Expand Down
25 changes: 25 additions & 0 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,31 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params)


class GaussianNoise(Transform):
"""Add gaussian noise to the image.

The input tensor is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.

The input tensor is also expected to be of float dtype in ``[0, 1]``.
This transform does not support PIL images.

Args:
mean (float): Mean of the sampled normal distribution. Default is 0.
sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1.
clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True.
"""

def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sigma (σ) should instead be called var for "variance". For example, mean is not called mu (μ).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair @Richienb , I can't claim mean and sigma are the best name combination, but unfortunately for consistency this is probably what we'll need to use, as these are the names that are already used by GaussianBlur.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's save this for a future breaking change.

super().__init__()
self.mean = mean
self.sigma = sigma
self.clip = clip

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip)


class ToDtype(Transform):
"""Converts the input to a specific dtype, optionally scaling the values for images or videos.

Expand Down
3 changes: 3 additions & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@
gaussian_blur,
gaussian_blur_image,
gaussian_blur_video,
gaussian_noise,
gaussian_noise_image,
gaussian_noise_video,
normalize,
normalize_image,
normalize_video,
Expand Down
39 changes: 39 additions & 0 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.functional import conv2d, pad as torch_pad

from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

Expand Down Expand Up @@ -181,6 +182,44 @@ def gaussian_blur_video(
return gaussian_blur_image(video, kernel_size, sigma)


def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.GaussianNoise`"""
if torch.jit.is_scripting():
return gaussian_noise_image(inpt, mean=mean, sigma=sigma)

_log_api_usage_once(gaussian_noise)

kernel = _get_kernel(gaussian_noise, type(inpt))
return kernel(inpt, mean=mean, sigma=sigma, clip=clip)


@_register_kernel_internal(gaussian_noise, torch.Tensor)
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
if not image.is_floating_point():
raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}")
if sigma < 0:
raise ValueError(f"sigma shouldn't be negative. Got {sigma}")

noise = mean + torch.randn_like(image) * sigma
out = image + noise
if clip:
out = torch.clamp(out, 0, 1)
return out


@_register_kernel_internal(gaussian_noise, tv_tensors.Video)
def gaussian_noise_video(video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
return gaussian_noise_image(video, mean=mean, sigma=sigma, clip=clip)


@_register_kernel_internal(gaussian_noise, PIL.Image.Image)
def _gaussian_noise_pil(
video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True
) -> PIL.Image.Image:
raise ValueError("Gaussian Noise is not implemented for PIL images.")


def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
"""See :func:`~torchvision.transforms.v2.ToDtype` for details."""
if torch.jit.is_scripting():
Expand Down
Loading