Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge master #20741

Merged
merged 8 commits into from
Jan 9, 2025
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Keras 3: Deep Learning for Humans

Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, and PyTorch.
Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only).
Effortlessly build and train models for computer vision, natural language processing, audio processing,
timeseries forecasting, recommender systems, etc.

Expand Down Expand Up @@ -73,7 +73,7 @@ python pip_build.py --install
## Configuring your backend

You can export the environment variable `KERAS_BACKEND` or you can edit your local config file at `~/.keras/keras.json`
to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`. Example:
to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`, `"openvino"`. Example:

```
export KERAS_BACKEND="jax"
Expand All @@ -91,6 +91,10 @@ import keras
**Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after
the package has been imported.

**Note:** The OpenVINO backend is an inference-only backend, meaning it is designed only for running model
predictions using `model.predict()` method.
To use `openvino` backend, install the required dependencies from the `requirements-openvino.txt` file.

## Backwards compatibility

Keras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your
Expand Down
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
12 changes: 8 additions & 4 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout):
return global_value


def distribute_data_input(per_process_batch, layout):
def distribute_data_input(per_process_batch, layout, batch_dim_name):
"""Distribute the input data with the corresponding layout.

Note that the inputs here is a local worker batch. Within the local worker,
Expand All @@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout):
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)

mesh_shape = list(layout.mesh.shape.values())
num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1
num_model_replicas_total = layout.mesh.shape[batch_dim_name]

mesh_model_dim_size = 1
for name, dim_size in layout.mesh.shape.items():
if not name == batch_dim_name:
mesh_model_dim_size *= dim_size

num_model_replicas_per_process = num_model_replicas_total / num_processes()
per_process_batch_size = per_process_batch.shape[0]

Expand Down
4 changes: 3 additions & 1 deletion keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def test_distribute_data_input(self):
mesh, jax.sharding.PartitionSpec("batch", None)
)

result = backend_dlib.distribute_data_input(per_process_batch, layout)
result = backend_dlib.distribute_data_input(
per_process_batch, layout, "batch"
)

# Check the shape of the global batch array
self.assertEqual(
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/jax/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def stateful_fn(*args, **kwargs):
self._tf_trackable.non_trainable_variables,
non_trainable_variables,
):
var.assign(new_value)
var.assign(tf.cast(new_value, var.dtype))
return output

stateful_fn.__signature__ = inspect.Signature(
Expand Down
8 changes: 6 additions & 2 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
from functools import partial

import jax
import numpy as np
Expand Down Expand Up @@ -988,15 +989,18 @@ def _get_jax_state(

def _distribute_data(data, layouts=None):
distribution = distribution_lib.distribution()

if distribution is not None:
if layouts is None:
layouts = tree.map_structure(
lambda d: distribution.get_data_layout(d.shape),
data,
)
return tree.map_structure(
jax_distribution_lib.distribute_data_input, data, layouts
jax_dist_data_input = partial(
jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution.batch_dim_name,
)
return tree.map_structure(jax_dist_data_input, data, layouts)

return tree.map_structure(jax.device_put, data)

Expand Down
20 changes: 12 additions & 8 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ class Distribution:
device_mesh: A `DeviceMesh` instance.
"""

def __init__(self, device_mesh):
def __init__(self, device_mesh, batch_dim_name=None):
self._device_mesh = device_mesh
self._batch_dim_name = batch_dim_name

def get_data_layout(self, data_shape):
"""Retrieve the `TensorLayout` for the input data.
Expand Down Expand Up @@ -341,6 +342,10 @@ def scope(self):
def device_mesh(self):
return self._device_mesh

@property
def batch_dim_name(self):
return self._batch_dim_name

def distribute_dataset(self, dataset):
"""Create a distributed dataset instance from the original user dataset.

Expand Down Expand Up @@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
else:
self._initialize_mesh_from_list_devices()

self._batch_dim_name = self.device_mesh.axis_names[0]
# Those following attributes might get convert to public methods.
self._num_process = distribution_lib.num_processes()
self._process_id = distribution_lib.process_id()
Expand All @@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh):
"Expect `mesh` to be an instance of `DeviceMesh`. "
f"Received: mesh={device_mesh} (of type {type(device_mesh)})"
)
super().__init__(device_mesh)
super().__init__(device_mesh, device_mesh.axis_names[0])
if self.device_mesh.devices.ndim != 1:
warnings.warn(
"Expect the input mesh to be 1D, but received "
Expand All @@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices):
axis_names=[DEFAULT_BATCH_DIM_NAME],
devices=devices,
)
super().__init__(device_mesh)
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)

def _initialize_mesh_from_list_devices(self):
devices = np.array(list_devices())
Expand All @@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self):
axis_names=[DEFAULT_BATCH_DIM_NAME],
devices=devices,
)
super().__init__(device_mesh)
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
Expand Down Expand Up @@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
Expand Down Expand Up @@ -631,7 +635,7 @@ def distribute_dataset(self, dataset):
# Note that this might be smaller than one if model replicas are sharded
# across multiple processes.
mesh_batch_dim_index = self.device_mesh.axis_names.index(
self._batch_dim_name
self.batch_dim_name
)
num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]
if num_model_replicas == 1:
Expand Down
6 changes: 3 additions & 3 deletions keras/src/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_create_with_device_mesh(self):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["data"])
self.assertEqual(distribution._batch_dim_name, "data")
self.assertEqual(distribution.batch_dim_name, "data")

self.assertFalse(distribution._is_multi_process)
self.assertEqual(distribution._process_id, 0)
Expand All @@ -197,7 +197,7 @@ def test_create_with_devices(self):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["batch"])
self.assertEqual(distribution._batch_dim_name, "batch")
self.assertEqual(distribution.batch_dim_name, "batch")

@mock.patch.object(
distribution_lib,
Expand All @@ -211,7 +211,7 @@ def test_create_with_list_devices(self, mock_list_devices):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["batch"])
self.assertEqual(distribution._batch_dim_name, "batch")
self.assertEqual(distribution.batch_dim_name, "batch")

def test_get_data_layout(self):
distribution = distribution_lib.DataParallel(
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
10 changes: 4 additions & 6 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ class Layer(BackendLayer, Operation, KerasSaveable):
trainable: Boolean, whether the layer's variables should be trainable.
name: String name of the layer.
dtype: The dtype of the layer's computations and weights. Can also be a
`keras.DTypePolicy`,
which allows the computation and
weight dtype to differ. Defaults to `None`. `None` means to use
`keras.config.dtype_policy()`,
which is a `float32` policy unless set to different value
(via `keras.config.set_dtype_policy()`).
`keras.DTypePolicy`, which allows the computation and weight dtype
to differ. Defaults to `None`. `None` means to use
`keras.config.dtype_policy()`, which is a `float32` policy unless
set to different value (via `keras.config.set_dtype_policy()`).

Attributes:
name: The name of the layer (string).
Expand Down
Loading
Loading