Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,7 @@ def _initialize(self, value):
raise NotImplementedError

def _initialize_with_initializer(self, initializer):
value = self._convert_to_tensor(
initializer(self._shape, dtype=self._dtype)
)
self._initialize(value)
raise NotImplementedError

def _convert_to_tensor(self, value, dtype=None):
raise NotImplementedError
Expand Down
51 changes: 40 additions & 11 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,51 @@ def __init__(self, *args, layout=None, **kwargs):
self._layout = layout
super().__init__(*args, **kwargs)

def _initialize(self, value):
# Note that variable.shape is needed by distribution_lib
self._shape = self._validate_shape(value.shape)
def set_tensor_layout(self):
# We can't import the keras/distribution/distribution_lib
# due to circular dependency.
distribution = global_state.get_global_attribute("distribution")
if self._layout is None and distribution is not None:
tensor_layout = distribution.get_variable_layout(self)
from keras.src.distribution import TensorLayout
if self._layout is None:
distribution = global_state.get_global_attribute("distribution")
if distribution is not None:
tensor_layout = distribution.get_variable_layout(self)
from keras.src.distribution import TensorLayout

if isinstance(tensor_layout, TensorLayout):
self._layout = tensor_layout.backend_layout
else:
self._layout = tensor_layout

if isinstance(tensor_layout, TensorLayout):
self._layout = tensor_layout.backend_layout
else:
self._layout = tensor_layout
def _initialize(self, value):
# Note that variable.shape is needed by distribution_lib
self._shape = self._validate_shape(value.shape)
self.set_tensor_layout()
self._direct_assign(value)

def check_distributed_init(self, initializer):
# Check if 'layout' parameter is supported in the initializer call
import inspect

sig = inspect.signature(initializer.__call__)
layout_supported = "layout" in sig.parameters
# Check if PartitionSpec has any non-None values
spec = getattr(self._layout, "spec", None)
partition_spec = spec if spec is not None else ()
is_partitioned = any(dim is not None for dim in partition_spec)
return layout_supported and is_partitioned

def _initialize_with_initializer(self, initializer):
self.set_tensor_layout()
# Use layout-aware initialization for distributed embeddings
if self.check_distributed_init(initializer):
value = self._convert_to_tensor(
initializer(self._shape, dtype=self._dtype, layout=self._layout)
)
else:
value = self._convert_to_tensor(
initializer(self._shape, dtype=self._dtype)
)
self._initialize(value)

def _direct_assign(self, value):
if self._layout is not None:
value = distribution_lib.distribute_variable(value, self._layout)
Expand Down
221 changes: 221 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,110 @@
from keras.src.utils import rng_utils


def _distribute_initializer(
init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None
):
"""
Distribution-aware initializer for JAX backend.
This function will create a Jax random array and
distribute it according to the current layout.
Args:
init_func: A functools.partial-wrapped object that takes the seed
as argument and returns a jax.Array. Must have shape and dtype
already bound via partial.
mean: Mean of distribution (applied to normal/truncated_normal).
stddev: Standard deviation of the distribution.
seed: JAX compatible seed array, if None use the Seed generator.
layout: TensorLayout for the distributed tensor.
Returns:
A distributed jax array.
Raises:
ValueError: If init_func or seed is None.
If init_func.func is not a supported random function.
Supported jax.random func: normal, truncated_normal, uniform
TypeError: If init_func is not a functools.partial object
or seed is not a Jax array.

"""
import warnings
from functools import partial

# Draw seed from the seed generator if seed is not a Jax Array
if seed is None or not isinstance(seed, jax.Array):
jax_compatible_seed = seed_generator.draw_seed(None)
# Convert to JAX PRNG key format (swap counter and seed value)
seed = jax_compatible_seed[::-1]

# Validate all required arguments
if init_func is None or init_func.func.__name__ not in [
"normal",
"truncated_normal",
"uniform",
]:
raise ValueError(
"init_func cannot be None or "
"Unsupported initializer: {init_func.func.__name__}."
"only JAX-compatible random initializers are supported. "
"Supported jax.random funcs: normal, truncated_normal, uniform"
)

# Ensure init_func is a partial
if not isinstance(init_func, partial):
raise TypeError(
f"init_func must be functools.partial object, got {type(init_func)}"
"init_func is a jax.random.* function with shape and "
"dtype bound via partial"
)

# Shard based on tensor layout
if layout is None:
warnings.warn(
f"The layout is {layout}, sharding will default to single device"
)

sharding = None
else:
if not isinstance(layout, jax.sharding.NamedSharding):
from keras.src.distribution import TensorLayout

if isinstance(layout, TensorLayout):
layout = _to_backend_layout(layout)
else:
raise TypeError(
f"layout must be Keras TensorLayout or "
f"jax.sharding.NamedSharding, got {type(layout)}"
)
sharding = layout

# JAX PRNG key handling within JIT:
# The key is passed directly to jax.random.* functions which are
# JIT-compatible and functional. JAX automatically ensures different
# random values per shard when out_shardings is specified.
try:
compiled_init = jax.jit(
lambda seed: init_func(seed),
out_shardings=sharding,
)
sample = compiled_init(seed)

except RuntimeError as e:
warnings.warn(
f"Sharding at initialization failed due to: {e}, "
f"falling back to single device"
)
compiled_init = jax.jit(
lambda seed: init_func(seed),
out_shardings=None,
)
sample = compiled_init(seed)

# Apply mean/stddev only for distributions where it makes sense
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
return sample * stddev + mean
elif init_func.func == jax.random.uniform:
return sample


def list_devices(device_type=None):
"""Return all the available devices based on the device type.

Expand Down Expand Up @@ -260,3 +364,120 @@ def _to_backend_layout(tensor_layout):
partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes)
jax_mesh = tensor_layout.device_mesh.backend_mesh
return jax.sharding.NamedSharding(jax_mesh, partition_spec)


def _distribute_initializer(
init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None
):
"""
Distribution-aware token embedding initializer for JAX backend.

This function will create a Jax random array and
distribute it according to the current token embedding layout.

Args:
init_func: A functools.partial-wrapped object that takes the seed
as argument and returns a jax.Array. Must have shape and dtype
already bound via partial.
mean: Mean of distribution (applied to normal/truncated_normal).
stddev: Standard deviation of the distribution.
seed: Random seed for initialization.
layout: TensorLayout for the distributed tensor.

Returns:
A distributed jax array.

Raises:
ValueError: If init_func or seed is None.
If init_func.func is not a supported random function.
Supported jax.random func: normal, truncated_normal, uniform
TypeError: If init_func is not a functools.partial object.
"""
import warnings
from functools import partial

# Create SeedGenerator to ensure backend variable exists
# For future state tracking for distributed keys, add
# attributes for base/split keys and number of devices sharded.
if isinstance(seed, jax.Array):
seed_gen = seed_generator.SeedGenerator(seed=int(seed[0]))
elif isinstance(seed, int):
seed_gen = seed_generator.SeedGenerator(seed=seed)
elif isinstance(seed, seed_generator.SeedGenerator):
seed_gen = seed
else:
raise ValueError(
f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}"
)

# Extract the state value as JAX array
jax_seed = seed_gen.state.value

# Convert to JAX PRNG key format (swap counter and seed value)
jax_compatible_seed = jax.numpy.array(
[jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32
)

# Validate all required arguments
if init_func is None or init_func.func.__name__ not in [
"normal",
"truncated_normal",
"uniform",
]:
raise ValueError(
"init_func cannot be None or "
"Unsupported initializer: {init_func.func.__name__}."
"only JAX-compatible random initializers are supported. "
"Supported jax.random funcs: normal, truncated_normal, uniform"
)

# Ensure init_func is a partial
if not isinstance(init_func, partial):
raise TypeError(
f"init_func must be functools.partial object, got {type(init_func)}"
"init_func is a jax.random.* function with shape and "
"dtype bound via partial"
)

# Shard based on tensor layout
if layout is None:
warnings.warn(
f"The layout is {layout}, sharding will default to single device"
)
sharding = None
else:
sharding = _to_backend_layout(layout)

# JAX PRNG key handling within JIT:
# The key is passed directly to jax.random.* functions which are
# JIT-compatible and functional. JAX automatically ensures different
# random values per shard when out_shardings is specified.
try:
compiled_init = jax.jit(
lambda jax_compatible_seed: init_func(jax_compatible_seed),
out_shardings=sharding,
)
sample = compiled_init(jax_compatible_seed)
except RuntimeError as e:
warnings.warn(
f"Sharding failed due to: {e}, falling back to single device"
)
compiled_init = jax.jit(
lambda jax_compatible_seed: init_func(jax_compatible_seed),
out_shardings=None,
)
sample = compiled_init(jax_compatible_seed)

# Store the SeedGenerator for state tracking
seed = seed_gen.next()

# Apply mean/stddev only for distributions where it makes sense
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
return sample * stddev + mean
elif init_func.func == jax.random.uniform:
# Uniform doesn't use mean/stddev - warn
if mean != 0.0 or stddev != 1.0:
warnings.warn(
"mean and stddev are ignored for uniform distribution"
)
return sample
Loading
Loading