Skip to content

Commit d33cb4e

Browse files
Address review feedback: improve error messages and add PRNG key handling comments
1 parent ac871f1 commit d33cb4e

File tree

1 file changed

+44
-16
lines changed

1 file changed

+44
-16
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,24 +272,45 @@ def _distribute_initializer(
272272
Raises:
273273
ValueError: If init_func or seed is None.
274274
If init_func.func is not a supported random function.
275+
Supported jax.random func: normal, truncated_normal, uniform
275276
TypeError: If init_func is not a functools.partial object.
276277
"""
277278
import warnings
278279
from functools import partial
280+
281+
# Create SeedGenerator to ensure backend variable exists
282+
# For future state tracking for distributed keys, add
283+
# attributes for base/split keys and number of devices sharded.
284+
if isinstance(seed, jax.Array):
285+
seed_gen = seed_generator.SeedGenerator(seed=int(seed[0]))
286+
elif isinstance(seed, int):
287+
seed_gen = seed_generator.SeedGenerator(seed=seed)
288+
elif isinstance(seed, seed_generator.SeedGenerator):
289+
seed_gen = seed
290+
else:
291+
raise ValueError(f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}")
279292

280-
# Validate all required arguments
281-
if seed is None:
282-
raise ValueError("seed cannot be None. Use keras.random.SeedGenerator.")
293+
# Extract the state value as JAX array
294+
jax_seed = seed_gen.state.value
295+
296+
# Convert to JAX PRNG key format (swap counter and seed value)
297+
jax_compatible_seed = jax.numpy.array(
298+
[jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32
299+
)
283300

284-
if init_func is None:
301+
# Validate all required arguments
302+
if init_func is None or init_func.func.__name__ not in ['normal', 'truncated_normal', 'uniform']:
285303
raise ValueError(
286-
"init_func cannot be None. Shape and dtype info are required."
304+
"init_func cannot be None or Unsupported initializer: {init_func.func.__name__}."
305+
"only JAX-compatible random initializers are supported. "
306+
"Supported jax.random funcs: normal, truncated_normal, uniform"
287307
)
288308

289309
# Ensure init_func is a partial
290310
if not isinstance(init_func, partial):
291311
raise TypeError(
292312
f"init_func must be functools.partial object, got {type(init_func)}"
313+
"init_func is a jax.random.* function with shape and dtype bound via partial"
293314
)
294315

295316
# Shard based on tensor layout
@@ -301,12 +322,24 @@ def _distribute_initializer(
301322
else:
302323
sharding = _to_backend_layout(layout)
303324

304-
# The init_func has static arguments baked in as per initializer.
305-
compiled_init = jax.jit(
306-
lambda seed: init_func(seed), out_shardings=sharding
307-
)
325+
# JAX PRNG key handling within JIT:
326+
# The key is passed directly to jax.random.* functions which are
327+
# JIT-compatible and functional. JAX automatically ensures different
328+
# random values per shard when out_shardings is specified.
329+
try:
330+
compiled_init = jax.jit(
331+
lambda jax_compatible_seed: init_func(jax_compatible_seed), out_shardings=sharding
332+
)
333+
sample = compiled_init(jax_compatible_seed)
334+
except RuntimeError as e:
335+
warnings.warn(f"Sharding failed due to: {e}, falling back to single device")
336+
compiled_init = jax.jit(
337+
lambda jax_compatible_seed: init_func(jax_compatible_seed), out_shardings=None
338+
)
339+
sample = compiled_init(jax_compatible_seed)
308340

309-
sample = compiled_init(seed)
341+
# Store the SeedGenerator for state tracking
342+
seed = seed_gen.next()
310343

311344
# Apply mean/stddev only for distributions where it makes sense
312345
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
@@ -317,9 +350,4 @@ def _distribute_initializer(
317350
warnings.warn(
318351
"mean and stddev are ignored for uniform distribution"
319352
)
320-
return sample
321-
else:
322-
raise ValueError(
323-
f"Unsupported initializer: {init_func.func.__name__}. "
324-
f"Supported: normal, truncated_normal, uniform"
325-
)
353+
return sample

0 commit comments

Comments
 (0)