Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gaussian noise transform #6192 #6233

Closed
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9be5e13
adds gaussian noise transform
parth-shastri Jul 3, 2022
42952f6
adds gaussian noise transform
parth-shastri Jul 3, 2022
db9756a
Update torchvision/transforms/transforms.py
parth-shastri Jul 4, 2022
05a52af
Update torchvision/transforms/transforms.py
parth-shastri Jul 4, 2022
c380281
Delete _C.pyd
parth-shastri Jul 4, 2022
431a7e0
added GaussianNoise transform
parth-shastri Jul 4, 2022
5fc7c85
adds the GaussianNoise transform
parth-shastri Jul 4, 2022
179908b
adds GaussianNoise transform
parth-shastri Jul 4, 2022
396abba
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 6, 2022
5d8b0f7
fixes on the lint tests and the plot_transforms
parth-shastri Jul 12, 2022
98e4e98
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 12, 2022
aa9b2e7
test
parth-shastri Jul 13, 2022
223074f
test
parth-shastri Jul 13, 2022
a26ed67
Merge branch 'add-gaussian-noise-transform' of https://github.com/par…
parth-shastri Jul 13, 2022
6d57443
fixes the plot_transforms bug
parth-shastri Jul 13, 2022
1d1fbcd
Update torchvision/transforms/transforms.py
parth-shastri Jul 27, 2022
ff80571
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 27, 2022
533e76f
Update gallery/plot_transforms.py
parth-shastri Jul 27, 2022
b8d98d7
Update test_transforms.py
parth-shastri Jul 27, 2022
35ac3c9
update
parth-shastri Jul 27, 2022
42d49ae
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 27, 2022
74f92b1
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Jul 31, 2022
54234f9
Update test_transforms.py
parth-shastri Jul 31, 2022
28fbd4b
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Aug 3, 2022
6a95453
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Aug 10, 2022
24804f6
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Aug 12, 2022
b1bb81f
lint
parth-shastri Aug 12, 2022
e08c9c1
lint updated
parth-shastri Aug 12, 2022
7fc04aa
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Sep 7, 2022
8a500fe
adds functional transforms, fixed sigma
parth-shastri Sep 7, 2022
8b560fb
updated lint, adds functional transforms, fixed sigma
parth-shastri Sep 7, 2022
ac15585
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Oct 5, 2022
3a85c34
Update torchvision/transforms/functional_tensor.py
parth-shastri Oct 5, 2022
5892695
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Oct 6, 2022
e6b4e45
suggested changes
parth-shastri Oct 6, 2022
58d525d
update
parth-shastri Oct 6, 2022
021ecba
fixed docs
parth-shastri Oct 6, 2022
5956088
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Oct 6, 2022
92d024f
Merge branch 'main' into add-gaussian-noise-transform
datumbox Oct 27, 2022
bfde863
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Nov 9, 2022
dbc3e1a
Merge branch 'main' into add-gaussian-noise-transform
datumbox Nov 10, 2022
d18195b
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Nov 27, 2022
a8d8137
fixes for random calls in functional transforms
parth-shastri Nov 27, 2022
2f7f558
Merge branch 'main' into add-gaussian-noise-transform
parth-shastri Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Transforms on PIL Image and torch.\*Tensor
Resize
TenCrop
GaussianBlur
GaussianNoise
RandomInvert
RandomPosterize
RandomSolarize
Expand Down
9 changes: 9 additions & 0 deletions gallery/plot_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)

####################################
# GaussianNoise
# ~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.GaussianNoise` transform
# perturbs the input image with gaussian noise.
noisy = T.GaussianNoise(mean=0, sigma=(0.1, 2.0))
noisy_imgs = [noisy(orig_img) for _ in range(2)]
plot(noisy_imgs)

####################################
# RandomPerspective
# ~~~~~~~~~~~~~~~~~
Expand Down
19 changes: 19 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,25 @@ def test_gaussian_blur_asserts():
transforms.GaussianBlur(3, "sigma_string")


def test_gaussian_noise():
np_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
img = F.to_pil_image(np_img, "RGB")
out = transforms.GaussianNoise(2.0, (0.1, 2.0))(img)
assert isinstance(out, PIL.Image.Image)
parth-shastri marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(TypeError, match="Tensor is not a torch image"):
out = transforms.GaussianNoise(2.0, (0.1, 2.0))(torch.ones((4)))

with pytest.raises(ValueError, match="Mean should be a positive number"):
transforms.GaussianNoise(-1)

with pytest.raises(ValueError, match="If sigma is a single number, it must be positive."):
transforms.GaussianNoise(2.0, -1)

with pytest.raises(ValueError, match="sigma should be a single number or a list/tuple with length 2."):
transforms.GaussianNoise(2.0, (1, 2, 3))


def test_lambda():
trans = transforms.Lambda(lambda x: x.add(10))
x = torch.randn(10)
Expand Down
83 changes: 83 additions & 0 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"RandomPerspective",
"RandomErasing",
"GaussianBlur",
"GaussianNoise",
parth-shastri marked this conversation as resolved.
Show resolved Hide resolved
"InterpolationMode",
"RandomInvert",
"RandomPosterize",
Expand Down Expand Up @@ -1837,6 +1838,88 @@ def __repr__(self) -> str:
return s


class GaussianNoise(torch.nn.Module):
"""Adds Gaussian noise to the image with specified mean and standard deviation.
If the image is torch Tensor, it is expected
to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
mean (float or sequence): Mean of the sampling gaussian distribution .
sigma (float or tuple of float (min, max)): Standard deviation to be used for
sampling the gaussian noise. If float, sigma is fixed. If it is tuple
of float (min, max), sigma is chosen uniformly at random to lie in the
given range.

Returns:
PIL Image or Tensor: Input image perturbed with Gaussian Noise.

"""

def __init__(self, mean, sigma = (0.1, 2.0)):
super().__init__()
_log_api_usage_once(self)

if mean < 0:
raise ValueError("Mean should be a positive number")

if isinstance(sigma, numbers.Number):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise ValueError("sigma should be a single number or a list/tuple with length 2.")

self.mean = mean
self.sigma = sigma

@staticmethod
def get_params(sigma_min: float, sigma_max: float) -> float:
return torch.empty(1).uniform_(sigma_min, sigma_max).item()

def forward(self, image: Tensor) -> Tensor:
"""
Args:
image (PIL Image or Tensor): image to be perturbed with gaussian noise.

Returns:
PIL Image or Tensor: Image added with gaussian noise.
"""
sigma = self.get_params(self.sigma[0], self.sigma[1])
if not isinstance(image, torch.Tensor):
Copy link
Contributor

@oke-aditya oke-aditya Jul 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm bit hesitant about this implementation. Why not create F.gaussian_noise that would work well on PIL Images and tensors which we can then use here?

See https://github.com/pytorch/vision/blob/main/torchvision/transforms/functional.py#L1338

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, let's create a functional op F.gaussian_noise and at first iteration we can have tensor implementation. If input is PIL, we can convert it to tensor, apply the op and get back PIL image type for the output.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parth-shastri please create a functional op F.gaussian_noise

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parth-shastri please address this comment

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes working on that

parth-shastri marked this conversation as resolved.
Show resolved Hide resolved
if not F._is_pil_image(image):
raise TypeError(f"image should be PIL Image or Tensor. Got {type(image)}")

t_image = F.pil_to_tensor(image)
else:
t_image = image

if not t_image.ndim >= 2:
raise TypeError("Tensor is not a torch image.")

dtype = t_image.dtype

if not t_image.is_floating_point():
t_image = t_image.to(torch.float32)

gaussian_noise = sigma * torch.randn_like(t_image) + self.mean
output = t_image + gaussian_noise

if output.dtype != dtype:
output = output.to(dtype)

if not isinstance(image, torch.Tensor):
output = F.to_pil_image(output)

return output

def __repr__(self) -> str:
s = f"{self.__class__.__name__}(mean={self.mean}, sigma={self.sigma})"
return s


def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
Expand Down