@@ -153,7 +153,7 @@ def __init__(
153
153
"""
154
154
seed1 , seed2 = random_seed .get_seed (seed )
155
155
# If op level seed is not set, use whatever graph level seed is returned
156
- np .random .seed (seed1 if seed is None else seed2 )
156
+ self . _rng = np .random .default_rng (seed1 if seed is None else seed2 )
157
157
dtype = dtypes .as_dtype (dtype ).base_dtype
158
158
if dtype not in (dtypes .uint8 , dtypes .float32 ):
159
159
raise TypeError ("Invalid image dtype %r, expected uint8 or float32" % dtype )
@@ -211,7 +211,7 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
211
211
# Shuffle for the first epoch
212
212
if self ._epochs_completed == 0 and start == 0 and shuffle :
213
213
perm0 = np .arange (self ._num_examples )
214
- np . random .shuffle (perm0 )
214
+ self . _rng .shuffle (perm0 )
215
215
self ._images = self .images [perm0 ]
216
216
self ._labels = self .labels [perm0 ]
217
217
# Go to the next epoch
@@ -225,7 +225,7 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
225
225
# Shuffle the data
226
226
if shuffle :
227
227
perm = np .arange (self ._num_examples )
228
- np . random .shuffle (perm )
228
+ self . _rng .shuffle (perm )
229
229
self ._images = self .images [perm ]
230
230
self ._labels = self .labels [perm ]
231
231
# Start next epoch
0 commit comments