@@ -272,24 +272,45 @@ def _distribute_initializer(
272272 Raises:
273273 ValueError: If init_func or seed is None.
274274 If init_func.func is not a supported random function.
275+ Supported jax.random func: normal, truncated_normal, uniform
275276 TypeError: If init_func is not a functools.partial object.
276277 """
277278 import warnings
278279 from functools import partial
280+
281+ # Create SeedGenerator to ensure backend variable exists
282+ # For future state tracking for distributed keys, add
283+ # attributes for base/split keys and number of devices sharded.
284+ if isinstance (seed , jax .Array ):
285+ seed_gen = seed_generator .SeedGenerator (seed = int (seed [0 ]))
286+ elif isinstance (seed , int ):
287+ seed_gen = seed_generator .SeedGenerator (seed = seed )
288+ elif isinstance (seed , seed_generator .SeedGenerator ):
289+ seed_gen = seed
290+ else :
291+ raise ValueError (f"seed must be int, JAX array, or SeedGenerator, got { type (seed )} " )
279292
280- # Validate all required arguments
281- if seed is None :
282- raise ValueError ("seed cannot be None. Use keras.random.SeedGenerator." )
293+ # Extract the state value as JAX array
294+ jax_seed = seed_gen .state .value
295+
296+ # Convert to JAX PRNG key format (swap counter and seed value)
297+ jax_compatible_seed = jax .numpy .array (
298+ [jax_seed [1 ], jax_seed [0 ]], dtype = jax .numpy .uint32
299+ )
283300
284- if init_func is None :
301+ # Validate all required arguments
302+ if init_func is None or init_func .func .__name__ not in ['normal' , 'truncated_normal' , 'uniform' ]:
285303 raise ValueError (
286- "init_func cannot be None. Shape and dtype info are required."
304+ "init_func cannot be None or Unsupported initializer: {init_func.func.__name__}."
305+ "only JAX-compatible random initializers are supported. "
306+ "Supported jax.random funcs: normal, truncated_normal, uniform"
287307 )
288308
289309 # Ensure init_func is a partial
290310 if not isinstance (init_func , partial ):
291311 raise TypeError (
292312 f"init_func must be functools.partial object, got { type (init_func )} "
313+ "init_func is a jax.random.* function with shape and dtype bound via partial"
293314 )
294315
295316 # Shard based on tensor layout
@@ -301,12 +322,24 @@ def _distribute_initializer(
301322 else :
302323 sharding = _to_backend_layout (layout )
303324
304- # The init_func has static arguments baked in as per initializer.
305- compiled_init = jax .jit (
306- lambda seed : init_func (seed ), out_shardings = sharding
307- )
325+ # JAX PRNG key handling within JIT:
326+ # The key is passed directly to jax.random.* functions which are
327+ # JIT-compatible and functional. JAX automatically ensures different
328+ # random values per shard when out_shardings is specified.
329+ try :
330+ compiled_init = jax .jit (
331+ lambda jax_compatible_seed : init_func (jax_compatible_seed ), out_shardings = sharding
332+ )
333+ sample = compiled_init (jax_compatible_seed )
334+ except RuntimeError as e :
335+ warnings .warn (f"Sharding failed due to: { e } , falling back to single device" )
336+ compiled_init = jax .jit (
337+ lambda jax_compatible_seed : init_func (jax_compatible_seed ), out_shardings = None
338+ )
339+ sample = compiled_init (jax_compatible_seed )
308340
309- sample = compiled_init (seed )
341+ # Store the SeedGenerator for state tracking
342+ seed = seed_gen .next ()
310343
311344 # Apply mean/stddev only for distributions where it makes sense
312345 if init_func .func in (jax .random .normal , jax .random .truncated_normal ):
@@ -317,9 +350,4 @@ def _distribute_initializer(
317350 warnings .warn (
318351 "mean and stddev are ignored for uniform distribution"
319352 )
320- return sample
321- else :
322- raise ValueError (
323- f"Unsupported initializer: { init_func .func .__name__ } . "
324- f"Supported: normal, truncated_normal, uniform"
325- )
353+ return sample
0 commit comments