Skip to content

Commit

Permalink
updated docstring, added checks for valid input
Browse files Browse the repository at this point in the history
  • Loading branch information
bmitzkus committed Oct 30, 2019
1 parent ab4159b commit 4acffd8
Showing 1 changed file with 42 additions and 17 deletions.
59 changes: 42 additions & 17 deletions imagecorruptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,56 @@
corruption_tuple}


def corrupt(x, severity=1, corruption_name=None, corruption_number=-1):
"""
:param x: image to corrupt; a 224x224x3 numpy array in [0, 255]
:param severity: strength with which to corrupt x; an integer in [0, 5]
:param corruption_name: specifies which corruption function to call;
must be one of 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression',
'speckle_noise', 'gaussian_blur', 'spatter', 'saturate';
the last four are validation functions
:param corruption_number: the position of the corruption_name in the above list;
an integer in [0, 18]; useful for easy looping; 15, 16, 17, 18 are validation corruption numbers
:return: the image x corrupted by a corruption function at the given severity; same shape as input
def corrupt(image, severity=1, corruption_name=None, corruption_number=-1):
"""This function returns a corrupted version of the given image.
Args:
image (numpy.ndarray): image to corrupt; a numpy array in [0, 255], expected datatype is np.uint8
expected shape is either (height x width x channels) or (height x width), channels must be 1 or 3
severity (int): strength with which to corrupt the image; an integer in [1, 5]
corruption_name (str): specifies which corruption function to call, must be one of
'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression',
'speckle_noise', 'gaussian_blur', 'spatter', 'saturate';
the last four are validation corruptions
corruption_number (int): the position of the corruption_name in the above list; an integer in [0, 18];
useful for easy looping; 15, 16, 17, 18 are validation corruption numbers
Returns:
numpy.ndarray: the image corrupted by a corruption function at the given severity; same shape as input
"""

if corruption_name:
x_corrupted = corruption_dict[corruption_name](Image.fromarray(x),
if not isinstance(image, np.ndarray):
raise AttributeError('Expecting type(image) to be numpy.ndarray')
if not (image.dtype.type is np.uint8):
raise AttributeError('Expecting image.dtype.type to be numpy.uint8')

if not (image.ndim in [2,3]):
raise AttributeError('Expecting image.shape to be either (width x height) or (width x height x channels)')
if image.ndim == 2:
image = np.stack((image,)*3, axis=-1)

height, width, channels = image.shape

if not (channels in [1,3]):
raise AttributeError('Expecting image to have either 1 or 3 channels (last dimension)')

if channels == 1:
image = np.stack((np.squeeze(image),)*3, axis=-1)

if not severity in [1,2,3,4,5]:
raise AttributeError('Severity must be an integer in [1, 5]')

if not (corruption_name is None):
image_corrupted = corruption_dict[corruption_name](Image.fromarray(image),
severity)
elif corruption_number != -1:
x_corrupted = corruption_tuple[corruption_number](Image.fromarray(x),
image_corrupted = corruption_tuple[corruption_number](Image.fromarray(image),
severity)
else:
raise ValueError("Either corruption_name or corruption_number must be passed")

return np.uint8(x_corrupted)
return np.uint8(image_corrupted)

def get_corruption_names(subset='common'):
if subset == 'common':
Expand Down

0 comments on commit 4acffd8

Please sign in to comment.