Skip to content

Commit

Permalink
Revert "Replace RandomGenerator with SeedGenerator (keras-team#2150
Browse files Browse the repository at this point in the history
…)"

This reverts commit 365a675.
  • Loading branch information
sampathweb authored and divyashreepathihalli committed Nov 18, 2023
1 parent dba9049 commit b416e26
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 8 deletions.
116 changes: 116 additions & 0 deletions keras_cv/backend/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,119 @@
from keras.random import * # noqa: F403, F401
else:
from keras_core.random import * # noqa: F403, F401
class SeedGenerator:
def __init__(self, seed=None, **kwargs):
if keras_3():
self._seed_generator = keras.random.SeedGenerator(
seed=seed, **kwargs
)
else:
self._current_seed = [0, seed]

def next(self, ordered=True):
if keras_3():
return self._seed_generator.next(ordered=ordered)
else:
self._current_seed[0] += 1
return self._current_seed[:]


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed

kwargs = {}
if dtype:
kwargs["dtype"] = dtype
if keras_3():
return keras.random.normal(
shape,
mean=mean,
stddev=stddev,
seed=init_seed,
**kwargs,
)
else:
import tensorflow as tf

return tf.random.normal(
shape,
mean=mean,
stddev=stddev,
seed=init_seed,
**kwargs,
)


def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
if keras_3():
return keras.random.uniform(
shape,
minval=minval,
maxval=maxval,
seed=init_seed,
**kwargs,
)
else:
import tensorflow as tf

return tf.random.uniform(
shape,
minval=minval,
maxval=maxval,
seed=init_seed,
**kwargs,
)


def shuffle(x, axis=0, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed

if keras_3():
return keras.random.shuffle(x=x, axis=axis, seed=init_seed)
else:
import tensorflow as tf

return tf.random.shuffle(x=x, axis=axis, seed=init_seed)


def categorical(logits, num_samples, dtype=None, seed=None):
if isinstance(seed, SeedGenerator):
seed = seed.next()
init_seed = seed[0] + seed[1]
else:
init_seed = seed
kwargs = {}
if dtype:
kwargs["dtype"] = dtype
if keras_3():
return keras.random.categorical(
logits=logits,
num_samples=num_samples,
seed=init_seed,
**kwargs,
)
else:
import tensorflow as tf

return tf.random.categorical(
logits=logits,
num_samples=num_samples,
seed=init_seed,
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import scope
from keras_cv.backend import keras, scope
from keras_cv.backend.config import multi_backend
from keras_cv.backend.random import RandomGenerator
from keras_cv.utils import preprocessing

# In order to support both unbatched and batched inputs, the horizontal
Expand Down Expand Up @@ -132,7 +130,7 @@ def augment_image(self, image, transformation):

def __init__(self, seed=None, **kwargs):
force_generator = kwargs.pop("force_generator", False)
self._random_generator = RandomGenerator(
self._random_generator = keras_backend.RandomGenerator(
seed=seed, force_generator=force_generator
)
super().__init__(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import scope
from keras_cv.backend import keras, scope
from keras_cv.backend.config import multi_backend
from keras_cv.backend.random import RandomGenerator
from keras_cv.utils import preprocessing

H_AXIS = -3
Expand Down Expand Up @@ -111,7 +109,7 @@ def __init__(self):

def __init__(self, seed=None, **kwargs):
force_generator = kwargs.pop("force_generator", False)
self._random_generator = RandomGenerator(
self._random_generator = keras_backend.RandomGenerator(
seed=seed, force_generator=force_generator
)
super().__init__(**kwargs)
Expand Down

0 comments on commit b416e26

Please sign in to comment.