Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace RandomGenerator with SeedGenerator #2150

Merged
merged 17 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmarks/vectorized_jittered_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import JitteredResize
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -258,8 +259,8 @@ def test_consistency_with_old_impl(self):

# makes offsets fixed to (0.5, 0.5)
with unittest.mock.patch.object(
layer._random_generator,
"random_uniform",
random,
"uniform",
return_value=tf.convert_to_tensor([[0.5, 0.5]]),
):
output = layer(image)
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/vectorized_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import Mosaic
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -101,7 +102,7 @@ def _batch_augment(self, inputs):
minval=0,
maxval=batch_size,
dtype=tf.int32,
seed=self._random_generator.make_legacy_seed(),
seed=random.make_seed(seed=self._seed_generator),
)
# concatenate the batches with permutation order to get all 4 images of
# the mosaic
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/vectorized_random_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import RandomCrop
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -72,7 +73,9 @@ def get_random_transformation(self, image=None, **kwargs):
h_diff = image_shape[H_AXIS] - self.height
w_diff = image_shape[W_AXIS] - self.width
dtype = image_shape.dtype
rands = self._random_generator.random_uniform([2], 0, dtype.max, dtype)
rands = random.uniform(
[2], 0, dtype.max, dtype, seed=self._seed_generator
)
h_start = rands[0] % (h_diff + 1)
w_start = rands[1] % (w_diff + 1)
return {"top": h_start, "left": w_start}
Expand Down
13 changes: 7 additions & 6 deletions benchmarks/vectorized_random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import RandomFlip
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -102,11 +103,11 @@ def get_random_transformation(self, **kwargs):
flip_vertical = False
if self.horizontal:
flip_horizontal = (
self._random_generator.random_uniform(shape=[]) > 0.5
random.uniform(shape=[], seed=self._seed_generator) > 0.5
)
if self.vertical:
flip_vertical = (
self._random_generator.random_uniform(shape=[]) > 0.5
random.uniform(shape=[], seed=self._seed_generator) > 0.5
)
return {
"flip_horizontal": tf.cast(flip_horizontal, dtype=tf.bool),
Expand Down Expand Up @@ -236,14 +237,14 @@ def test_consistency_with_old_impl(self):
)

with unittest.mock.patch.object(
layer._random_generator,
"random_uniform",
random,
"uniform",
return_value=tf.convert_to_tensor([[0.6]]),
):
output = layer(image)
with unittest.mock.patch.object(
old_layer._random_generator,
"random_uniform",
random,
"uniform",
return_value=tf.convert_to_tensor(0.6),
):
old_output = old_layer(image)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/vectorized_random_hue.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, factor, value_range, seed=None, **kwargs):
self.seed = seed

def get_random_transformation(self, **kwargs):
invert = preprocessing_utils.random_inversion(self._random_generator)
invert = preprocessing_utils.random_inversion(self._seed_generator)
# We must scale self.factor() to the range [-0.5, 0.5]. This is because
# the tf.image operation performs rotation on the hue saturation value
# orientation. This can be thought of as an angle in the range
Expand Down
8 changes: 6 additions & 2 deletions benchmarks/vectorized_random_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import random
from keras_cv.layers import RandomRotation
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -122,8 +123,11 @@ def __init__(
def get_random_transformation(self, **kwargs):
min_angle = self.lower * 2.0 * np.pi
max_angle = self.upper * 2.0 * np.pi
angle = self._random_generator.random_uniform(
shape=[1], minval=min_angle, maxval=max_angle
angle = random.uniform(
shape=[1],
minval=min_angle,
maxval=max_angle,
seed=self._seed_generator,
)
return {"angle": angle}

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/vectorized_random_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_shear_amount(self, constraint):
if constraint is None:
return None

invert = preprocessing.random_inversion(self._random_generator)
invert = preprocessing.random_inversion(self._seed_generator)
return invert * constraint()

def augment_image(self, image, transformation=None, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions benchmarks/vectorized_random_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras import backend
from tensorflow import keras

from keras_cv.backend import random
from keras_cv.layers import RandomTranslation
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -217,17 +218,19 @@ def augment_image(self, image, transformation, **kwargs):

def get_random_transformation(self, image=None, **kwargs):
batch_size = 1
height_translation = self._random_generator.random_uniform(
height_translation = random.uniform(
shape=[batch_size, 1],
minval=self.height_lower,
maxval=self.height_upper,
dtype=tf.float32,
seed=self._seed_generator,
)
width_translation = self._random_generator.random_uniform(
width_translation = random.uniform(
shape=[batch_size, 1],
minval=self.width_lower,
maxval=self.width_upper,
dtype=tf.float32,
seed=self._seed_generator,
)
return {
"height_translation": height_translation,
Expand Down
7 changes: 5 additions & 2 deletions benchmarks/vectorized_random_zoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras import backend
from tensorflow import keras

from keras_cv.backend import random
from keras_cv.layers import RandomZoom
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -143,16 +144,18 @@ def __init__(
self.seed = seed

def get_random_transformation(self, image=None, **kwargs):
height_zoom = self._random_generator.random_uniform(
height_zoom = random.uniform(
shape=[1, 1],
minval=1.0 + self.height_lower,
maxval=1.0 + self.height_upper,
seed=self._seed_generator,
)
if self.width_factor is not None:
width_zoom = self._random_generator.random_uniform(
width_zoom = random.uniform(
shape=[1, 1],
minval=1.0 + self.width_lower,
maxval=1.0 + self.width_upper,
seed=self._seed_generator,
)
else:
width_zoom = height_zoom
Expand Down
7 changes: 5 additions & 2 deletions benchmarks/vectorized_randomly_zoomed_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensorflow import keras

from keras_cv import core
from keras_cv.backend import random
from keras_cv.layers import RandomlyZoomedCrop
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
Expand Down Expand Up @@ -109,18 +110,20 @@ def get_random_transformation(

new_width = crop_size[1] * tf.sqrt(aspect_ratio)

height_offset = self._random_generator.random_uniform(
height_offset = random.uniform(
(),
minval=tf.minimum(0.0, original_height - new_height),
maxval=tf.maximum(0.0, original_height - new_height),
dtype=tf.float32,
seed=self._seed_generator,
)

width_offset = self._random_generator.random_uniform(
width_offset = random.uniform(
(),
minval=tf.minimum(0.0, original_width - new_width),
maxval=tf.maximum(0.0, original_width - new_width),
dtype=tf.float32,
seed=self._seed_generator,
)

new_height = new_height / original_height
Expand Down
36 changes: 15 additions & 21 deletions keras_cv/backend/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,30 @@ def __init__(self, seed=None, **kwargs):
seed=seed, **kwargs
)
else:
self._current_seed = [0, seed]
self._current_seed = [seed, 0]

def next(self, ordered=True):
if keras_3():
return self._seed_generator.next(ordered=ordered)
else:
self._current_seed[0] += 1
self._current_seed[1] += 1
return self._current_seed[:]


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
def make_seed(seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
seed_0, seed_1 = seed.next()
if seed_0 is None:
init_seed = seed_1
else:
init_seed = seed_0 + seed_1
else:
init_seed = seed
return init_seed


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
init_seed = make_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -68,11 +75,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):


def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
init_seed = make_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand All @@ -97,12 +100,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):


def shuffle(x, axis=0, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed

init_seed = make_seed(seed)
if keras_3():
return keras.random.shuffle(x=x, axis=axis, seed=init_seed)
else:
Expand All @@ -112,11 +110,7 @@ def shuffle(x, axis=0, seed=None):


def categorical(logits, num_samples, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
init_seed = make_seed(seed)
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
Expand Down
30 changes: 20 additions & 10 deletions keras_cv/layers/preprocessing/aug_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras_cv import layers
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -106,32 +107,41 @@ def _sample_from_dirichlet(self, alpha):
gamma_sample = tf.random.gamma(
shape=(),
alpha=alpha,
seed=self._random_generator.make_legacy_seed(),
seed=random.make_seed(seed=self._seed_generator),
)
return gamma_sample / tf.reduce_sum(
gamma_sample, axis=-1, keepdims=True
)

def _sample_from_beta(self, alpha, beta):
sample_alpha = tf.random.gamma(
(), alpha=alpha, seed=self._random_generator.make_legacy_seed()
(),
alpha=alpha,
seed=random.make_seed(seed=self._seed_generator),
)
sample_beta = tf.random.gamma(
(), alpha=beta, seed=self._random_generator.make_legacy_seed()
(),
alpha=beta,
seed=random.make_seed(seed=self._seed_generator),
)
return sample_alpha / (sample_alpha + sample_beta)

def _sample_depth(self):
return self._random_generator.random_uniform(
return random.uniform(
shape=(),
minval=self.chain_depth[0],
maxval=self.chain_depth[1] + 1,
dtype=tf.int32,
seed=self._seed_generator,
)

def _loop_on_depth(self, depth_level, image_aug):
op_index = self._random_generator.random_uniform(
shape=(), minval=0, maxval=8, dtype=tf.int32
op_index = random.uniform(
shape=(),
minval=0,
maxval=8,
dtype=tf.int32,
seed=self._seed_generator,
)
image_aug = self._apply_op(image_aug, op_index)
depth_level += 1
Expand Down Expand Up @@ -204,7 +214,7 @@ def _solarize(self, image):

def _shear_x(self, image):
x = tf.cast(self.severity_factor() * 0.3, tf.float32)
x *= preprocessing.random_inversion(self._random_generator)
x *= preprocessing.random_inversion(self._seed_generator)
transform_x = layers.RandomShear._format_transform(
[1.0, x, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
)
Expand All @@ -214,7 +224,7 @@ def _shear_x(self, image):

def _shear_y(self, image):
y = tf.cast(self.severity_factor() * 0.3, tf.float32)
y *= preprocessing.random_inversion(self._random_generator)
y *= preprocessing.random_inversion(self._seed_generator)
transform_x = self._format_random_shear_transform(
[1.0, 0.0, 0.0, y, 1.0, 0.0, 0.0, 0.0]
)
Expand All @@ -231,7 +241,7 @@ def _translate_x(self, image):
shape = tf.cast(tf.shape(image), tf.float32)
x = tf.cast(self.severity_factor() * shape[1] / 3, tf.float32)
x = tf.expand_dims(tf.expand_dims(x, axis=0), axis=0)
x *= preprocessing.random_inversion(self._random_generator)
x *= preprocessing.random_inversion(self._seed_generator)
x = tf.cast(x, tf.int32)

translations = tf.cast(
Expand All @@ -246,7 +256,7 @@ def _translate_y(self, image):
shape = tf.cast(tf.shape(image), tf.float32)
y = tf.cast(self.severity_factor() * shape[0] / 3, tf.float32)
y = tf.expand_dims(tf.expand_dims(y, axis=0), axis=0)
y *= preprocessing.random_inversion(self._random_generator)
y *= preprocessing.random_inversion(self._seed_generator)
y = tf.cast(y, tf.int32)

translations = tf.cast(
Expand Down
Loading