@@ -125,8 +125,9 @@ def jax_sample_fn_generic(op):
125125
126126 def sample_fn (rng , size , dtype , * parameters ):
127127 rng_key = rng ["jax_state" ]
128- sample = jax_op (rng_key , * parameters , shape = size , dtype = dtype )
129- rng ["jax_state" ] = jax .random .split (rng_key , num = 1 )[0 ]
128+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
129+ sample = jax_op (sampling_key , * parameters , shape = size , dtype = dtype )
130+ rng ["jax_state" ] = rng_key
130131 return (rng , sample )
131132
132133 return sample_fn
@@ -151,9 +152,10 @@ def jax_sample_fn_loc_scale(op):
151152
152153 def sample_fn (rng , size , dtype , * parameters ):
153154 rng_key = rng ["jax_state" ]
155+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
154156 loc , scale = parameters
155- sample = loc + jax_op (rng_key , size , dtype ) * scale
156- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
157+ sample = loc + jax_op (sampling_key , size , dtype ) * scale
158+ rng ["jax_state" ] = rng_key
157159 return (rng , sample )
158160
159161 return sample_fn
@@ -168,8 +170,9 @@ def jax_sample_fn_no_dtype(op):
168170
169171 def sample_fn (rng , size , dtype , * parameters ):
170172 rng_key = rng ["jax_state" ]
171- sample = jax_op (rng_key , * parameters , shape = size )
172- rng ["jax_state" ] = jax .random .split (rng_key , num = 1 )[0 ]
173+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
174+ sample = jax_op (sampling_key , * parameters , shape = size )
175+ rng ["jax_state" ] = rng_key
173176 return (rng , sample )
174177
175178 return sample_fn
@@ -189,9 +192,12 @@ def jax_sample_fn_uniform(op):
189192
190193 def sample_fn (rng , size , dtype , * parameters ):
191194 rng_key = rng ["jax_state" ]
195+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
192196 minval , maxval = parameters
193- sample = jax_op (rng_key , shape = size , dtype = dtype , minval = minval , maxval = maxval )
194- rng ["jax_state" ] = jax .random .split (rng_key , num = 1 )[0 ]
197+ sample = jax_op (
198+ sampling_key , shape = size , dtype = dtype , minval = minval , maxval = maxval
199+ )
200+ rng ["jax_state" ] = rng_key
195201 return (rng , sample )
196202
197203 return sample_fn
@@ -211,9 +217,10 @@ def jax_sample_fn_shape_rate(op):
211217
212218 def sample_fn (rng , size , dtype , * parameters ):
213219 rng_key = rng ["jax_state" ]
220+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
214221 (shape , rate ) = parameters
215- sample = jax_op (rng_key , shape , size , dtype ) / rate
216- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
222+ sample = jax_op (sampling_key , shape , size , dtype ) / rate
223+ rng ["jax_state" ] = rng_key
217224 return (rng , sample )
218225
219226 return sample_fn
@@ -225,9 +232,10 @@ def jax_sample_fn_exponential(op):
225232
226233 def sample_fn (rng , size , dtype , * parameters ):
227234 rng_key = rng ["jax_state" ]
235+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
228236 (scale ,) = parameters
229- sample = jax .random .exponential (rng_key , size , dtype ) * scale
230- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
237+ sample = jax .random .exponential (sampling_key , size , dtype ) * scale
238+ rng ["jax_state" ] = rng_key
231239 return (rng , sample )
232240
233241 return sample_fn
@@ -239,13 +247,14 @@ def jax_sample_fn_t(op):
239247
240248 def sample_fn (rng , size , dtype , * parameters ):
241249 rng_key = rng ["jax_state" ]
250+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
242251 (
243252 df ,
244253 loc ,
245254 scale ,
246255 ) = parameters
247- sample = loc + jax .random .t (rng_key , df , size , dtype ) * scale
248- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
256+ sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
257+ rng ["jax_state" ] = rng_key
249258 return (rng , sample )
250259
251260 return sample_fn
@@ -257,9 +266,10 @@ def jax_funcify_choice(op):
257266
258267 def sample_fn (rng , size , dtype , * parameters ):
259268 rng_key = rng ["jax_state" ]
269+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
260270 (a , p , replace ) = parameters
261- smpl_value = jax .random .choice (rng_key , a , size , replace , p )
262- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
271+ smpl_value = jax .random .choice (sampling_key , a , size , replace , p )
272+ rng ["jax_state" ] = rng_key
263273 return (rng , smpl_value )
264274
265275 return sample_fn
@@ -271,9 +281,10 @@ def jax_sample_fn_permutation(op):
271281
272282 def sample_fn (rng , size , dtype , * parameters ):
273283 rng_key = rng ["jax_state" ]
284+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
274285 (x ,) = parameters
275- sample = jax .random .permutation (rng_key , x )
276- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
286+ sample = jax .random .permutation (sampling_key , x )
287+ rng ["jax_state" ] = rng_key
277288 return (rng , sample )
278289
279290 return sample_fn
@@ -285,10 +296,11 @@ def jax_sample_fn_lognormal(op):
285296
286297 def sample_fn (rng , size , dtype , * parameters ):
287298 rng_key = rng ["jax_state" ]
299+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
288300 loc , scale = parameters
289- sample = loc + jax .random .normal (rng_key , size , dtype ) * scale
301+ sample = loc + jax .random .normal (sampling_key , size , dtype ) * scale
290302 sample_exp = jax .numpy .exp (sample )
291- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
303+ rng ["jax_state" ] = rng_key
292304 return (rng , sample_exp )
293305
294306 return sample_fn
0 commit comments