Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
from keras.src.ops.numpy import dot as dot
from keras.src.ops.numpy import einsum as einsum
from keras.src.ops.numpy import empty as empty
from keras.src.ops.numpy import empty_like as empty_like
from keras.src.ops.numpy import equal as equal
from keras.src.ops.numpy import exp as exp
from keras.src.ops.numpy import exp2 as exp2
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from keras.src.ops.numpy import dot as dot
from keras.src.ops.numpy import einsum as einsum
from keras.src.ops.numpy import empty as empty
from keras.src.ops.numpy import empty_like as empty_like
from keras.src.ops.numpy import equal as equal
from keras.src.ops.numpy import exp as exp
from keras.src.ops.numpy import exp2 as exp2
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
from keras.src.ops.numpy import dot as dot
from keras.src.ops.numpy import einsum as einsum
from keras.src.ops.numpy import empty as empty
from keras.src.ops.numpy import empty_like as empty_like
from keras.src.ops.numpy import equal as equal
from keras.src.ops.numpy import exp as exp
from keras.src.ops.numpy import exp2 as exp2
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from keras.src.ops.numpy import dot as dot
from keras.src.ops.numpy import einsum as einsum
from keras.src.ops.numpy import empty as empty
from keras.src.ops.numpy import empty_like as empty_like
from keras.src.ops.numpy import equal as equal
from keras.src.ops.numpy import exp as exp
from keras.src.ops.numpy import exp2 as exp2
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,10 @@ def empty(shape, dtype=None):
return jnp.empty(shape, dtype=dtype)


def empty_like(x, dtype=None):
return jnp.empty_like(x, dtype=dtype)


def equal(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,10 @@ def empty(shape, dtype=None):
return np.empty(shape, dtype=dtype)


def empty_like(x, dtype=None):
return np.empty_like(x, dtype=dtype)


def equal(x1, x2):
return np.equal(x1, x2)

Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ NumpyDtypeTest::test_argpartition
NumpyDtypeTest::test_array
NumpyDtypeTest::test_bartlett
NumpyDtypeTest::test_blackman
NumpyDtypeTest::test_empty_like
NumpyDtypeTest::test_gcd
NumpyDtypeTest::test_hamming
NumpyDtypeTest::test_hanning
Expand Down Expand Up @@ -113,6 +114,7 @@ NumpyTwoInputOpsCorrectnessTest::test_vdot
NumpyOneInputOpsDynamicShapeTest::test_angle
NumpyOneInputOpsDynamicShapeTest::test_bartlett
NumpyOneInputOpsDynamicShapeTest::test_blackman
NumpyOneInputOpsDynamicShapeTest::test_empty_like
NumpyOneInputOpsDynamicShapeTest::test_cbrt
NumpyOneInputOpsDynamicShapeTest::test_corrcoef
NumpyOneInputOpsDynamicShapeTest::test_hamming
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,12 @@ def empty(shape, dtype=None):
return OpenVINOKerasTensor(empty_tensor)


def empty_like(x, dtype=None):
raise NotImplementedError(
"`empty_like` is not supported with openvino backend"
)


def equal(x1, x2):
element_type = None
if isinstance(x1, OpenVINOKerasTensor):
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,10 @@ def empty(shape, dtype=None):
return tf.zeros(shape, dtype=dtype)


def empty_like(x, dtype=None):
return tf.zeros_like(x, dtype=dtype)


def equal(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,12 @@ def empty(shape, dtype=None):
return torch.empty(size=shape, dtype=dtype, device=get_device())


def empty_like(x, dtype=None):
x = convert_to_tensor(x)
dtype = to_torch_dtype(dtype or x.dtype)
return torch.empty_like(x, dtype=dtype, device=get_device())


def equal(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
return torch.eq(x1, x2)
Expand Down
42 changes: 42 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3114,6 +3114,48 @@ def empty(shape, dtype=None):
return backend.numpy.empty(shape, dtype=dtype)


class EmptyLike(Operation):
def __init__(self, dtype=None, *, name=None):
super().__init__(name=name)
self.dtype = None if dtype is None else backend.standardize_dtype(dtype)

def call(self, x):
return backend.numpy.empty_like(x, dtype=self.dtype)

def compute_output_spec(self, x):
dtype = (
backend.standardize_dtype(x.dtype)
if self.dtype is None
else self.dtype
)
return KerasTensor(x.shape, dtype=dtype)


@keras_export(["keras.ops.empty_like", "keras.ops.numpy.empty_like"])
def empty_like(x, dtype=None):
"""Return a new uninitialized tensor with the same shape and dtype as `x`.

Args:
x: Input tensor to mimic shape and dtype.
dtype: Optional data type. If None, uses `x.dtype`.

Returns:
A tensor with the same shape and dtype as `x`, with arbitrary contents.

Example:
>>> from keras import ops
>>> x = ops.ones((2, 3), dtype="float32")
>>> y = ops.empty_like(x)
>>> y.shape
(2, 3)
>>> y.dtype
dtype('float32')
"""
if any_symbolic_tensors((x,)):
return EmptyLike(dtype=dtype).symbolic_call(x)
return backend.numpy.empty_like(x, dtype=dtype)


class Equal(Operation):
def call(self, x1, x2):
return backend.numpy.equal(x1, x2)
Expand Down
26 changes: 26 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,11 @@ def test_dot(self):
y = KerasTensor((5,))
self.assertEqual(knp.dot(x, y).shape, ())

def test_empty_like(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.empty_like(x).shape, (None, 3))
self.assertEqual(knp.empty_like(x).dtype, x.dtype)

def test_exp(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.exp(x).shape, (None, 3))
Expand Down Expand Up @@ -2139,6 +2144,11 @@ def test_dot(self):
y = KerasTensor((2, 3))
knp.dot(x, y)

def test_empty_like(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.empty_like(x).shape, (2, 3))
self.assertEqual(knp.empty_like(x).dtype, x.dtype)

def test_exp(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.exp(x).shape, (2, 3))
Expand Down Expand Up @@ -7328,6 +7338,22 @@ def test_empty(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_empty_like(self, dtype):
import jax.numpy as jnp

x = jnp.empty([2, 3, 4], dtype=dtype)
expected_dtype = standardize_dtype(jnp.empty_like(x, dtype=dtype).dtype)

self.assertEqual(
standardize_dtype(knp.empty_like(x, dtype=dtype).dtype),
expected_dtype,
)
self.assertEqual(
standardize_dtype(knp.EmptyLike().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
Expand Down