Skip to content

Commit

Permalink
implement batch_normalization (#19543)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkarthee authored Apr 18, 2024
1 parent fb6244b commit 47c032d
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 7 deletions.
3 changes: 3 additions & 0 deletions keras/backend/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
elif backend.backend() == "jax":
BackendVariable = backend.jax.core.Variable
backend_name_scope = backend.common.name_scope.name_scope
elif backend.backend() == "mlx":
BackendVariable = backend.mlx.core.Variable
backend_name_scope = backend.common.name_scope.name_scope
elif backend.backend() == "torch":
BackendVariable = backend.torch.core.Variable
backend_name_scope = backend.common.name_scope.name_scope
Expand Down
4 changes: 2 additions & 2 deletions keras/backend/mlx/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import mlx.core as mx
import numpy as np
import tree
from keras.utils import tree

from keras.backend.common import KerasVariable
from keras.backend.common import standardize_dtype
from keras.backend.common.keras_tensor import KerasTensor
from keras.backend.common.stateless_scope import StatelessScope
from keras.utils.nest import pack_sequence_as
from keras.utils.tree import pack_sequence_as

SUPPORTS_SPARSE_TENSORS = False

Expand Down
19 changes: 16 additions & 3 deletions keras/backend/mlx/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,22 @@ def moments(x, axes, keepdims=False, synchronized=False):
def batch_normalization(
x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3
):
raise NotImplementedError(
"MLX backend doesn't support batch normalization yet."
)
shape = [1] * len(x.shape)
shape[axis] = mean.shape[0]
mean = mx.reshape(mean, shape)
variance = mx.reshape(variance, shape)

inv = mx.rsqrt(variance + epsilon)
if scale is not None:
scale = mx.reshape(scale, shape)
inv = inv * scale

res = -mean * inv
if offset is not None:
offset = mx.reshape(offset, shape)
res = res + offset

return mx.add(x * inv, res)


def ctc_loss(
Expand Down
6 changes: 6 additions & 0 deletions keras/backend/mlx/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,12 @@ def divide(x1, x2):
return mx.divide(x1, x2)


def divide_no_nan(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return mx.where(x2 == 0, 0, mx.divide(x1, x2))


def true_divide(x1, x2):
return divide(x1, x2)

Expand Down
4 changes: 2 additions & 2 deletions keras/backend/mlx/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mlx.core as mx
import numpy as np
import tree
from keras.utils import tree

from keras import backend
from keras import callbacks as callbacks_module
Expand Down Expand Up @@ -141,7 +141,7 @@ def compute_loss_and_updates(
# Note that this is needed for the regularization loss, which need
# the latest value of train/non-trainable variables.
loss = self.compute_loss(
x, y, y_pred, sample_weight, allow_empty=True
x, y, y_pred, sample_weight
)
if losses:
loss += ops.sum(losses)
Expand Down

0 comments on commit 47c032d

Please sign in to comment.