diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 33a2cc3c5160..6ef3c9059d0c 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -414,10 +414,7 @@ def _initialize(self, value): raise NotImplementedError def _initialize_with_initializer(self, initializer): - value = self._convert_to_tensor( - initializer(self._shape, dtype=self._dtype) - ) - self._initialize(value) + raise NotImplementedError def _convert_to_tensor(self, value, dtype=None): raise NotImplementedError diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 7dc5a98fb8d5..f32ef3694251 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -30,22 +30,51 @@ def __init__(self, *args, layout=None, **kwargs): self._layout = layout super().__init__(*args, **kwargs) - def _initialize(self, value): - # Note that variable.shape is needed by distribution_lib - self._shape = self._validate_shape(value.shape) + def set_tensor_layout(self): # We can't import the keras/distribution/distribution_lib # due to circular dependency. - distribution = global_state.get_global_attribute("distribution") - if self._layout is None and distribution is not None: - tensor_layout = distribution.get_variable_layout(self) - from keras.src.distribution import TensorLayout + if self._layout is None: + distribution = global_state.get_global_attribute("distribution") + if distribution is not None: + tensor_layout = distribution.get_variable_layout(self) + from keras.src.distribution import TensorLayout + + if isinstance(tensor_layout, TensorLayout): + self._layout = tensor_layout.backend_layout + else: + self._layout = tensor_layout - if isinstance(tensor_layout, TensorLayout): - self._layout = tensor_layout.backend_layout - else: - self._layout = tensor_layout + def _initialize(self, value): + # Note that variable.shape is needed by distribution_lib + self._shape = self._validate_shape(value.shape) + self.set_tensor_layout() self._direct_assign(value) + def check_distributed_init(self, initializer): + # Check if 'layout' parameter is supported in the initializer call + import inspect + + sig = inspect.signature(initializer.__call__) + layout_supported = "layout" in sig.parameters + # Check if PartitionSpec has any non-None values + spec = getattr(self._layout, "spec", None) + partition_spec = spec if spec is not None else () + is_partitioned = any(dim is not None for dim in partition_spec) + return layout_supported and is_partitioned + + def _initialize_with_initializer(self, initializer): + self.set_tensor_layout() + # Use layout-aware initialization for distributed embeddings + if self.check_distributed_init(initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype, layout=self._layout) + ) + else: + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + def _direct_assign(self, value): if self._layout is not None: value = distribution_lib.distribute_variable(value, self._layout) diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 1407c008910e..e6782a13aa5b 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -9,6 +9,110 @@ from keras.src.utils import rng_utils +def _distribute_initializer( + init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None +): + """ + Distribution-aware initializer for JAX backend. + This function will create a Jax random array and + distribute it according to the current layout. + Args: + init_func: A functools.partial-wrapped object that takes the seed + as argument and returns a jax.Array. Must have shape and dtype + already bound via partial. + mean: Mean of distribution (applied to normal/truncated_normal). + stddev: Standard deviation of the distribution. + seed: JAX compatible seed array, if None use the Seed generator. + layout: TensorLayout for the distributed tensor. + Returns: + A distributed jax array. + Raises: + ValueError: If init_func or seed is None. + If init_func.func is not a supported random function. + Supported jax.random func: normal, truncated_normal, uniform + TypeError: If init_func is not a functools.partial object + or seed is not a Jax array. + + """ + import warnings + from functools import partial + + # Draw seed from the seed generator if seed is not a Jax Array + if seed is None or not isinstance(seed, jax.Array): + jax_compatible_seed = seed_generator.draw_seed(None) + # Convert to JAX PRNG key format (swap counter and seed value) + seed = jax_compatible_seed[::-1] + + # Validate all required arguments + if init_func is None or init_func.func.__name__ not in [ + "normal", + "truncated_normal", + "uniform", + ]: + raise ValueError( + "init_func cannot be None or " + "Unsupported initializer: {init_func.func.__name__}." + "only JAX-compatible random initializers are supported. " + "Supported jax.random funcs: normal, truncated_normal, uniform" + ) + + # Ensure init_func is a partial + if not isinstance(init_func, partial): + raise TypeError( + f"init_func must be functools.partial object, got {type(init_func)}" + "init_func is a jax.random.* function with shape and " + "dtype bound via partial" + ) + + # Shard based on tensor layout + if layout is None: + warnings.warn( + f"The layout is {layout}, sharding will default to single device" + ) + + sharding = None + else: + if not isinstance(layout, jax.sharding.NamedSharding): + from keras.src.distribution import TensorLayout + + if isinstance(layout, TensorLayout): + layout = _to_backend_layout(layout) + else: + raise TypeError( + f"layout must be Keras TensorLayout or " + f"jax.sharding.NamedSharding, got {type(layout)}" + ) + sharding = layout + + # JAX PRNG key handling within JIT: + # The key is passed directly to jax.random.* functions which are + # JIT-compatible and functional. JAX automatically ensures different + # random values per shard when out_shardings is specified. + try: + compiled_init = jax.jit( + lambda seed: init_func(seed), + out_shardings=sharding, + ) + sample = compiled_init(seed) + + except RuntimeError as e: + warnings.warn( + f"Sharding at initialization failed due to: {e}, " + f"falling back to single device" + ) + compiled_init = jax.jit( + lambda seed: init_func(seed), + out_shardings=None, + ) + sample = compiled_init(seed) + + # Apply mean/stddev only for distributions where it makes sense + if init_func.func in (jax.random.normal, jax.random.truncated_normal): + return sample * stddev + mean + elif init_func.func == jax.random.uniform: + return sample + + def list_devices(device_type=None): """Return all the available devices based on the device type. @@ -260,3 +364,120 @@ def _to_backend_layout(tensor_layout): partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes) jax_mesh = tensor_layout.device_mesh.backend_mesh return jax.sharding.NamedSharding(jax_mesh, partition_spec) + + +def _distribute_initializer( + init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None +): + """ + Distribution-aware token embedding initializer for JAX backend. + + This function will create a Jax random array and + distribute it according to the current token embedding layout. + + Args: + init_func: A functools.partial-wrapped object that takes the seed + as argument and returns a jax.Array. Must have shape and dtype + already bound via partial. + mean: Mean of distribution (applied to normal/truncated_normal). + stddev: Standard deviation of the distribution. + seed: Random seed for initialization. + layout: TensorLayout for the distributed tensor. + + Returns: + A distributed jax array. + + Raises: + ValueError: If init_func or seed is None. + If init_func.func is not a supported random function. + Supported jax.random func: normal, truncated_normal, uniform + TypeError: If init_func is not a functools.partial object. + """ + import warnings + from functools import partial + + # Create SeedGenerator to ensure backend variable exists + # For future state tracking for distributed keys, add + # attributes for base/split keys and number of devices sharded. + if isinstance(seed, jax.Array): + seed_gen = seed_generator.SeedGenerator(seed=int(seed[0])) + elif isinstance(seed, int): + seed_gen = seed_generator.SeedGenerator(seed=seed) + elif isinstance(seed, seed_generator.SeedGenerator): + seed_gen = seed + else: + raise ValueError( + f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}" + ) + + # Extract the state value as JAX array + jax_seed = seed_gen.state.value + + # Convert to JAX PRNG key format (swap counter and seed value) + jax_compatible_seed = jax.numpy.array( + [jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32 + ) + + # Validate all required arguments + if init_func is None or init_func.func.__name__ not in [ + "normal", + "truncated_normal", + "uniform", + ]: + raise ValueError( + "init_func cannot be None or " + "Unsupported initializer: {init_func.func.__name__}." + "only JAX-compatible random initializers are supported. " + "Supported jax.random funcs: normal, truncated_normal, uniform" + ) + + # Ensure init_func is a partial + if not isinstance(init_func, partial): + raise TypeError( + f"init_func must be functools.partial object, got {type(init_func)}" + "init_func is a jax.random.* function with shape and " + "dtype bound via partial" + ) + + # Shard based on tensor layout + if layout is None: + warnings.warn( + f"The layout is {layout}, sharding will default to single device" + ) + sharding = None + else: + sharding = _to_backend_layout(layout) + + # JAX PRNG key handling within JIT: + # The key is passed directly to jax.random.* functions which are + # JIT-compatible and functional. JAX automatically ensures different + # random values per shard when out_shardings is specified. + try: + compiled_init = jax.jit( + lambda jax_compatible_seed: init_func(jax_compatible_seed), + out_shardings=sharding, + ) + sample = compiled_init(jax_compatible_seed) + except RuntimeError as e: + warnings.warn( + f"Sharding failed due to: {e}, falling back to single device" + ) + compiled_init = jax.jit( + lambda jax_compatible_seed: init_func(jax_compatible_seed), + out_shardings=None, + ) + sample = compiled_init(jax_compatible_seed) + + # Store the SeedGenerator for state tracking + seed = seed_gen.next() + + # Apply mean/stddev only for distributions where it makes sense + if init_func.func in (jax.random.normal, jax.random.truncated_normal): + return sample * stddev + mean + elif init_func.func == jax.random.uniform: + # Uniform doesn't use mean/stddev - warn + if mean != 0.0 or stddev != 1.0: + warnings.warn( + "mean and stddev are ignored for uniform distribution" + ) + return sample diff --git a/keras/src/backend/jax/random.py b/keras/src/backend/jax/random.py index 79901696339f..2c4de424292b 100644 --- a/keras/src/backend/jax/random.py +++ b/keras/src/backend/jax/random.py @@ -1,3 +1,5 @@ +from functools import partial + import jax from keras.src.backend.config import floatx @@ -7,25 +9,61 @@ def jax_draw_seed(seed): + # Convert to JAX PRNG key format (swap counter and seed value) if isinstance(seed, jax.Array): - return seed + return seed[::-1] else: - return draw_seed(seed) + seed_array = draw_seed(seed) + return seed_array[::-1] -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = jax_draw_seed(seed) - sample = jax.random.normal(seed, shape=shape, dtype=dtype) - return sample * stddev + mean + if layout is not None: + from keras.src.backend import distribution_lib + + init_func = partial( + jax.random.normal, + shape=shape, + dtype=dtype, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=mean, + stddev=stddev, + seed=seed, + layout=layout, + ) + else: + sample = jax.random.normal(seed, shape=shape, dtype=dtype) + return sample * stddev + mean -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = jax_draw_seed(seed) - return jax.random.uniform( - seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval - ) + if layout is not None: + from keras.src.backend import distribution_lib + + init_func = partial( + jax.random.uniform, + shape=shape, + dtype=dtype, + minval=minval, + maxval=maxval, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=None, + stddev=None, + seed=seed, + layout=layout, + ) + else: + return jax.random.uniform( + seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval + ) def categorical(logits, num_samples, dtype="int32", seed=None): @@ -46,13 +84,33 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): ) -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = dtype or floatx() seed = jax_draw_seed(seed) - sample = jax.random.truncated_normal( - seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype - ) - return sample * stddev + mean + if layout is not None: + from keras.src.backend import distribution_lib + + init_func = partial( + jax.random.truncated_normal, + shape=shape, + dtype=dtype, + lower=-2.0, + upper=2.0, + ) + return distribution_lib._distribute_initializer( + init_func=init_func, + mean=mean, + stddev=stddev, + seed=seed, + layout=layout, + ) + else: + sample = jax.random.truncated_normal( + seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype + ) + return sample * stddev + mean def _get_concrete_noise_shape(inputs, noise_shape): diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 16b2303e5e43..da0b5b4d153c 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -23,6 +23,12 @@ class Variable(KerasVariable): def _initialize(self, value): self._value = value + def _initialize_with_initializer(self, initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + def _direct_assign(self, value): self._value = np.array(value, dtype=self._dtype) diff --git a/keras/src/backend/numpy/random.py b/keras/src/backend/numpy/random.py index f8fd65aa38ba..28aa5c64f243 100644 --- a/keras/src/backend/numpy/random.py +++ b/keras/src/backend/numpy/random.py @@ -7,14 +7,14 @@ from keras.src.random.seed_generator import make_default_seed -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed) return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype) -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed) @@ -40,7 +40,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): return output -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 93f9f5819c8b..a52eed6d9d38 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -572,6 +572,12 @@ def _initialize(self, value): ) self._value = OpenVINOKerasTensor(value_const.output(0)) + def _initialize_with_initializer(self, initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + def _direct_assign(self, value): self._value = value diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py index 38de21294677..9d40a93181b5 100644 --- a/keras/src/backend/openvino/random.py +++ b/keras/src/backend/openvino/random.py @@ -12,7 +12,7 @@ from keras.src.random.seed_generator import make_default_seed -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed.data) @@ -20,7 +20,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): return OpenVINOKerasTensor(ov_opset.constant(normal_const).output(0)) -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed_val = draw_seed(seed) if isinstance(seed_val, OpenVINOKerasTensor): @@ -96,7 +96,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): ) -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = dtype or floatx() seed = draw_seed(seed) rng = np.random.default_rng(seed.data) diff --git a/keras/src/backend/tensorflow/random.py b/keras/src/backend/tensorflow/random.py index e807b0de9aab..4b935faf8027 100644 --- a/keras/src/backend/tensorflow/random.py +++ b/keras/src/backend/tensorflow/random.py @@ -20,7 +20,7 @@ def _cast_seed(seed): return seed -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_normal( @@ -28,7 +28,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): ) -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_uniform( @@ -61,7 +61,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): return tf.cast(output, dtype) -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = dtype or floatx() seed = _cast_seed(draw_seed(seed)) return tf.random.stateless_truncated_normal( diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 877dc6909ea1..a10c11dab959 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -109,6 +109,12 @@ def _initialize(self, value): requires_grad=self.trainable, ).to(get_device()) + def _initialize_with_initializer(self, initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + def _direct_assign(self, value): with torch.no_grad(): self.value.copy_(value) diff --git a/keras/src/backend/torch/random.py b/keras/src/backend/torch/random.py index e080731952e6..1413c1e795b3 100644 --- a/keras/src/backend/torch/random.py +++ b/keras/src/backend/torch/random.py @@ -25,7 +25,7 @@ def torch_seed_generator(seed): return generator -def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() dtype = to_torch_dtype(dtype) # Do not use generator during symbolic execution. @@ -64,7 +64,7 @@ def categorical(logits, num_samples, dtype="int32", seed=None): ).type(dtype) -def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None, layout=None): dtype = dtype or floatx() dtype = to_torch_dtype(dtype) requested_shape = shape @@ -108,7 +108,9 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): ) -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): +def truncated_normal( + shape, mean=0.0, stddev=1.0, dtype=None, seed=None, layout=None +): dtype = to_torch_dtype(dtype) # Take a larger standard normal dist, discard values outside 2 * stddev # Offset by mean and stddev diff --git a/keras/src/initializers/random_initializers.py b/keras/src/initializers/random_initializers.py index ad1123e2a18f..0e8ef11d4868 100644 --- a/keras/src/initializers/random_initializers.py +++ b/keras/src/initializers/random_initializers.py @@ -2,6 +2,7 @@ from keras.src import ops from keras.src.api_export import keras_export +from keras.src.backend import backend from keras.src.backend import random from keras.src.initializers.initializer import Initializer from keras.src.saving import serialization_lib @@ -10,7 +11,9 @@ class RandomInitializer(Initializer): def __init__(self, seed=None): self._init_seed = seed - if seed is None: + if seed is None and backend() == "jax": + seed = int(random.draw_seed(None)[0]) + elif seed is None: seed = random.make_default_seed() elif isinstance(seed, dict): seed = serialization_lib.deserialize_keras_object(seed) @@ -68,13 +71,14 @@ def __init__(self, mean=0.0, stddev=0.05, seed=None): self.stddev = stddev super().__init__(seed=seed) - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, layout=None): return random.normal( shape=shape, mean=self.mean, stddev=self.stddev, seed=self.seed, dtype=dtype, + layout=layout, ) def get_config(self): @@ -127,13 +131,14 @@ def __init__(self, mean=0.0, stddev=0.05, seed=None): self.stddev = stddev super().__init__(seed=seed) - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, layout=None): return random.truncated_normal( shape=shape, mean=self.mean, stddev=self.stddev, seed=self.seed, dtype=dtype, + layout=layout, ) def get_config(self): @@ -183,13 +188,14 @@ def __init__(self, minval=-0.05, maxval=0.05, seed=None): self.maxval = maxval super().__init__(seed=seed) - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, layout=None): return random.uniform( shape=shape, minval=self.minval, maxval=self.maxval, seed=self.seed, dtype=dtype, + layout=layout, ) def get_config(self): @@ -282,7 +288,7 @@ def __init__( self.distribution = distribution super().__init__(seed=seed) - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, layout=None): scale = self.scale fan_in, fan_out = compute_fans(shape) if self.mode == "fan_in": @@ -291,20 +297,36 @@ def __call__(self, shape, dtype=None): scale /= max(1.0, fan_out) else: scale /= max(1.0, (fan_in + fan_out) / 2.0) + if self.distribution == "truncated_normal": stddev = math.sqrt(scale) / 0.87962566103423978 return random.truncated_normal( - shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed + shape, + mean=0.0, + stddev=stddev, + dtype=dtype, + seed=self.seed, + layout=layout, ) elif self.distribution == "untruncated_normal": stddev = math.sqrt(scale) return random.normal( - shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed + shape, + mean=0.0, + stddev=stddev, + dtype=dtype, + seed=self.seed, + layout=layout, ) else: limit = math.sqrt(3.0 * scale) return random.uniform( - shape, minval=-limit, maxval=limit, dtype=dtype, seed=self.seed + shape, + minval=-limit, + maxval=limit, + dtype=dtype, + seed=self.seed, + layout=layout, ) def get_config(self): diff --git a/keras/src/random/seed_generator.py b/keras/src/random/seed_generator.py index dd2adbc13bbe..6c664b5587bb 100644 --- a/keras/src/random/seed_generator.py +++ b/keras/src/random/seed_generator.py @@ -134,9 +134,13 @@ def global_seed_generator(): "```" ) gen = global_state.get_global_attribute("global_seed_generator") - if gen is None: + global_seed = global_state.get_global_attribute("global_random_seed") + if gen is None and global_seed is None: gen = SeedGenerator() global_state.set_global_attribute("global_seed_generator", gen) + elif gen is None and global_seed is not None: + gen = SeedGenerator(global_seed) + global_state.set_global_attribute("global_seed_generator", gen) return gen