Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RandomBrightness preprocessing layer. #122

Closed
wants to merge 9 commits into from
50 changes: 50 additions & 0 deletions examples/layers/preprocessing/random_brightness_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""random_brightness_demo.py shows how to use the RandomBrightness preprocessing layer.

Operates on the oxford_flowers102 dataset. In this script the flowers
are loaded, then are passed through the preprocessing layers.
Finally, they are shown using matplotlib.
"""
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from keras_cv.layers import preprocessing

IMG_SIZE = (224, 224)
BATCH_SIZE = 64


def resize(image, label, num_classes=10):
image = tf.image.resize(image, IMG_SIZE)
label = tf.one_hot(label, num_classes)
return image, label


def main():
data, ds_info = tfds.load("oxford_flowers102", with_info=True, as_supervised=True)
train_ds = data["train"]

num_classes = ds_info.features["label"].num_classes

train_ds = train_ds.map(lambda x, y: resize(x, y, num_classes=num_classes)).batch(
BATCH_SIZE
)
random_brightness = preprocessing.RandomBrightness(
scale=(-0.5, 0.5),
)
train_ds = train_ds.map(
lambda x, y: (random_brightness(x, training=True), y),
num_parallel_calls=tf.data.AUTOTUNE,
)

for images, labels in train_ds.take(1):
plt.figure(figsize=(8, 8))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
plt.show()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions keras_cv/layers/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@

from keras_cv.layers.preprocessing.cut_mix import CutMix
from keras_cv.layers.preprocessing.mix_up import MixUp
from keras_cv.layers.preprocessing.random_brightness import RandomBrightness
from keras_cv.layers.preprocessing.random_cutout import RandomCutout
139 changes: 139 additions & 0 deletions keras_cv/layers/preprocessing/random_brightness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

_SCALE_VALIDATION_ERROR = (
"The `scale` should be number or a list of two numbers "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The scale argument should be a number (or a list of two numbers) "
"in the range [-1.0, 1.0]. "

"that ranged between [-1.0, 1.0]. "
)


class RandomBrightness(tf.keras.layers.Layer):
"""Randomly adjust brightness for the a RGB image.

This layer will randomly increase/reduce the brightness for the input RGB image.
During inference time, the output will be identical to input. Call the layer with
training=True to adjust brightness of the input.

Note that different brightness adjustment will be apply to each the images in the
batch.

Args:
scale: Float or a list/tuple of 2 floats between -1.0 and 1.0. The scale is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For naming consistency with other preprocessing layers, we should use factor, see e.g. https://keras.io/api/layers/preprocessing_layers/image_augmentation/random_rotation/

used to determine the lower bound and upper bound of the brightness
adjustment. A float value will be choose randomly between the limits.
When -1 is chosen, the output image will be black, and when 1.0 is
chosen, the image will be fully white. When only one float is provided,
eg, 0.2, then -0.2 will be used for lower bound and 0.2 will be used for
upper bound.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also support integers? Interpreted as a range in absolute levels, e.g. 32 would be interpreted as a +-32 shift of the channels

seed: integer, for fixed RNG behavior.

Inputs:
3D (HWC) or 4D (NHWC) tensor, with float or int dtype. The value should be
ranged between [0, 255].

Output:
3D (HWC) or 4D (NHWC) tensor with brightness adjusted based on the `scale`.
The output will have same dtypes as the input image.

Sample usage:
```
random_bright = keras_cv.layers.RandomBrightness(scale=0.2)
# An image with shape [2, 2, 3]
image = [[[1, 2, 3], [4 ,5 ,6]],
[[7, 8, 9], [10, 11, 12]]]
# Assume we randomly select the scale to be 0.1, then it will apply 0.1 * 255 to
# all the channel
output = random_bright(image, training=True)
# output will be int64 with 25.5 added to each channel and round down.
tf.Tensor(
[[[26 27 28]
[29 30 31]]
[[32 33 34]
[35 36 37]]], shape=(2, 2, 3), dtype=int64)
```
"""

def __init__(self, scale, seed=None, **kwargs):
super().__init__(**kwargs)
self._set_scale(scale)
self._seed = seed

def _set_scale(self, scale):
if isinstance(scale, (tuple, list)):
if len(scale) != 2:
raise ValueError(_SCALE_VALIDATION_ERROR + f"Got {scale}")
self._check_scale_range(scale[0])
self._check_scale_range(scale[1])
self._scale = sorted(scale)
elif isinstance(scale, (int, float)):
self._check_scale_range(scale)
scale = abs(scale)
self._scale = [-scale, scale]
else:
raise ValueError(_SCALE_VALIDATION_ERROR + f"Got {scale}")

@staticmethod
def _check_scale_range(input_number):
if input_number > 1.0 or input_number < -1.0:
raise ValueError(_SCALE_VALIDATION_ERROR + f"Got {input_number}")

def call(self, inputs, training=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Augmentation preprocessing layers should default to training=True (default should be to apply the transformation). You can assume that training is a Python bool in what follows.

if training is None:
training = tf.keras.backend.learning_phase()
return tf.__internal__.smart_cond.smart_cond(
training,
true_fn=lambda: self._brightness_adjust(inputs),
false_fn=lambda: inputs,
)

def _brightness_adjust(self, image):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to standardize image or images?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

rank = image.shape.rank
if rank == 3:
rgb_delta_shape = (3,)
elif rank == 4:
# Skip the width and height, but keep the batch and channel.
# This will ensure to have same adjustment for one channel, but different
# across the images.
rgb_delta_shape = [tf.shape(image)[0], 1, 1, 3]
else:
raise ValueError(
f"Expect the input image to be rank 3 or 4. Got {image.shape}"
)
if self._seed is not None:
rgb_delta = tf.random.stateless_uniform(
shape=rgb_delta_shape,
seed=[0, self._seed],
minval=self._scale[0],
maxval=self._scale[1],
)
else:
rgb_delta = tf.random.uniform(
shape=rgb_delta_shape, minval=self._scale[0], maxval=self._scale[1]
)
rgb_delta = rgb_delta * 255.0
input_dtype = image.dtype
image = tf.cast(image, tf.float32)
image += rgb_delta
image = tf.clip_by_value(image, 0.0, 255.0)
return tf.cast(image, input_dtype)

def get_config(self):
config = {
"scale": self._scale,
"seed": self._seed,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
128 changes: 128 additions & 0 deletions keras_cv/layers/preprocessing/random_brightness_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import tensorflow as tf

from keras_cv.layers import preprocessing


class RandomBrightnessTest(tf.test.TestCase):
def test_scale_input_validation(self):
with self.assertRaisesRegexp(ValueError, "ranged between \[-1.0, 1.0\]"):
preprocessing.RandomBrightness(2.0)

with self.assertRaisesRegexp(ValueError, "list of two numbers"):
preprocessing.RandomBrightness([1.0])

with self.assertRaisesRegexp(ValueError, "should be number"):
preprocessing.RandomBrightness("one")

def test_scale_normalize(self):
layer = preprocessing.RandomBrightness(1.0)
self.assertEqual(layer._scale, [-1.0, 1.0])

layer = preprocessing.RandomBrightness((0.5, 0.3))
self.assertEqual(layer._scale, [0.3, 0.5])

layer = preprocessing.RandomBrightness(-0.2)
self.assertEqual(layer._scale, [-0.2, 0.2])

def test_output_value_range(self):
# Always scale up to 255
layer = preprocessing.RandomBrightness([1.0, 1.0])
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=True)
output_min = tf.math.reduce_min(output)
output_max = tf.math.reduce_max(output)
self.assertEqual(output_min, 255)
self.assertEqual(output_max, 255)

# Always scale down to 0
layer = preprocessing.RandomBrightness([-1.0, -1.0])
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=True)
output_min = tf.math.reduce_min(output)
output_max = tf.math.reduce_max(output)
self.assertEqual(output_min, 0)
self.assertEqual(output_max, 0)

def test_output(self):
# Always scale up, but randomly between 0 ~ 255
layer = preprocessing.RandomBrightness([0, 1.0])
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=True)
diff = output - inputs
self.assertGreaterEqual(tf.math.reduce_min(diff), 0)
self.assertGreater(tf.math.reduce_mean(diff), 0)

# Always scale down, but randomly between 0 ~ 255
layer = preprocessing.RandomBrightness([-1.0, 0.0])
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=True)
diff = output - inputs
self.assertLessEqual(tf.math.reduce_max(diff), 0)
self.assertLess(tf.math.reduce_mean(diff), 0)

def test_different_adjustment_within_batch(self):
layer = preprocessing.RandomBrightness([0.2, 0.3])
inputs = np.zeros(shape=(2, 10, 10, 3)) # 2 images with all zeros
output = layer(inputs, training=True)
diff = output - inputs
# Make sure two images gets the same adjustment
self.assertNotAllClose(diff[0], diff[1])
# Make sure all the pixel are the same with the same image
image1 = output[0]
# The reduced mean pixel value among width and height are the same as
# any of the pixel in the image.
self.assertAllClose(
tf.reduce_mean(image1, axis=[0, 1]), image1[0, 0], rtol=1e-5, atol=1e-5
)

def test_inference(self):
layer = preprocessing.RandomBrightness([0, 1.0])
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs)
self.assertAllClose(inputs, output)

output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_dtype(self):
layer = preprocessing.RandomBrightness([0, 1.0])
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=True)
self.assertEqual(output.dtype, tf.int64)

inputs = tf.cast(inputs, tf.float32)
output = layer(inputs, training=True)
self.assertEqual(output.dtype, tf.float32)

def test_seed(self):
layer = preprocessing.RandomBrightness([0, 1.0], seed=1337)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output_1 = layer(inputs, training=True)
output_2 = layer(inputs, training=True)

self.assertAllClose(output_1, output_2)

def test_config(self):
layer = preprocessing.RandomBrightness([0, 1.0], seed=1337)
config = layer.get_config()
self.assertEqual(config["scale"], [0.0, 1.0])
self.assertEqual(config["seed"], 1337)

reconstructed_layer = preprocessing.RandomBrightness.from_config(config)
self.assertEqual(reconstructed_layer._scale, layer._scale)
self.assertEqual(reconstructed_layer._seed, layer._seed)