diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 15a0d67a422..e22715971d6 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -179,6 +179,7 @@ from keras.src.ops.numpy import dot as dot from keras.src.ops.numpy import einsum as einsum from keras.src.ops.numpy import empty as empty +from keras.src.ops.numpy import empty_like as empty_like from keras.src.ops.numpy import equal as equal from keras.src.ops.numpy import exp as exp from keras.src.ops.numpy import exp2 as exp2 diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index 9a1d473cac0..82b6b6dff36 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -65,6 +65,7 @@ from keras.src.ops.numpy import dot as dot from keras.src.ops.numpy import einsum as einsum from keras.src.ops.numpy import empty as empty +from keras.src.ops.numpy import empty_like as empty_like from keras.src.ops.numpy import equal as equal from keras.src.ops.numpy import exp as exp from keras.src.ops.numpy import exp2 as exp2 diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 15a0d67a422..e22715971d6 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -179,6 +179,7 @@ from keras.src.ops.numpy import dot as dot from keras.src.ops.numpy import einsum as einsum from keras.src.ops.numpy import empty as empty +from keras.src.ops.numpy import empty_like as empty_like from keras.src.ops.numpy import equal as equal from keras.src.ops.numpy import exp as exp from keras.src.ops.numpy import exp2 as exp2 diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index 9a1d473cac0..82b6b6dff36 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -65,6 +65,7 @@ from keras.src.ops.numpy import dot as dot from keras.src.ops.numpy import einsum as einsum from keras.src.ops.numpy import empty as empty +from keras.src.ops.numpy import empty_like as empty_like from keras.src.ops.numpy import equal as equal from keras.src.ops.numpy import exp as exp from keras.src.ops.numpy import exp2 as exp2 diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 6162f98f07e..15df1b7696f 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -678,6 +678,10 @@ def empty(shape, dtype=None): return jnp.empty(shape, dtype=dtype) +def empty_like(x, dtype=None): + return jnp.empty_like(x, dtype=dtype) + + def equal(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index e5f4284b3db..eb8b7111ed8 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -612,6 +612,10 @@ def empty(shape, dtype=None): return np.empty(shape, dtype=dtype) +def empty_like(x, dtype=None): + return np.empty_like(x, dtype=dtype) + + def equal(x1, x2): return np.equal(x1, x2) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index e14d190b829..7c5d295a44d 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -5,6 +5,7 @@ NumpyDtypeTest::test_argpartition NumpyDtypeTest::test_array NumpyDtypeTest::test_bartlett NumpyDtypeTest::test_blackman +NumpyDtypeTest::test_empty_like NumpyDtypeTest::test_gcd NumpyDtypeTest::test_hamming NumpyDtypeTest::test_hanning @@ -113,6 +114,7 @@ NumpyTwoInputOpsCorrectnessTest::test_vdot NumpyOneInputOpsDynamicShapeTest::test_angle NumpyOneInputOpsDynamicShapeTest::test_bartlett NumpyOneInputOpsDynamicShapeTest::test_blackman +NumpyOneInputOpsDynamicShapeTest::test_empty_like NumpyOneInputOpsDynamicShapeTest::test_cbrt NumpyOneInputOpsDynamicShapeTest::test_corrcoef NumpyOneInputOpsDynamicShapeTest::test_hamming diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 2ef9dc7bdda..445ddb7b1fd 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -896,6 +896,12 @@ def empty(shape, dtype=None): return OpenVINOKerasTensor(empty_tensor) +def empty_like(x, dtype=None): + raise NotImplementedError( + "`empty_like` is not supported with openvino backend" + ) + + def equal(x1, x2): element_type = None if isinstance(x1, OpenVINOKerasTensor): diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 523389bb024..c1f4e8066e3 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -1489,6 +1489,10 @@ def empty(shape, dtype=None): return tf.zeros(shape, dtype=dtype) +def empty_like(x, dtype=None): + return tf.zeros_like(x, dtype=dtype) + + def equal(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index cfd844f24b6..bd2a5cea2ac 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -770,6 +770,12 @@ def empty(shape, dtype=None): return torch.empty(size=shape, dtype=dtype, device=get_device()) +def empty_like(x, dtype=None): + x = convert_to_tensor(x) + dtype = to_torch_dtype(dtype or x.dtype) + return torch.empty_like(x, dtype=dtype, device=get_device()) + + def equal(x1, x2): x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) return torch.eq(x1, x2) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 3abff5d93b6..5190ff2cd80 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -3114,6 +3114,48 @@ def empty(shape, dtype=None): return backend.numpy.empty(shape, dtype=dtype) +class EmptyLike(Operation): + def __init__(self, dtype=None, *, name=None): + super().__init__(name=name) + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, x): + return backend.numpy.empty_like(x, dtype=self.dtype) + + def compute_output_spec(self, x): + dtype = ( + backend.standardize_dtype(x.dtype) + if self.dtype is None + else self.dtype + ) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.empty_like", "keras.ops.numpy.empty_like"]) +def empty_like(x, dtype=None): + """Return a new uninitialized tensor with the same shape and dtype as `x`. + + Args: + x: Input tensor to mimic shape and dtype. + dtype: Optional data type. If None, uses `x.dtype`. + + Returns: + A tensor with the same shape and dtype as `x`, with arbitrary contents. + + Example: + >>> from keras import ops + >>> x = ops.ones((2, 3), dtype="float32") + >>> y = ops.empty_like(x) + >>> y.shape + (2, 3) + >>> y.dtype + dtype('float32') + """ + if any_symbolic_tensors((x,)): + return EmptyLike(dtype=dtype).symbolic_call(x) + return backend.numpy.empty_like(x, dtype=dtype) + + class Equal(Operation): def call(self, x1, x2): return backend.numpy.equal(x1, x2) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index d6ff3a9456c..42a8c37b49e 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1495,6 +1495,11 @@ def test_dot(self): y = KerasTensor((5,)) self.assertEqual(knp.dot(x, y).shape, ()) + def test_empty_like(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.empty_like(x).shape, (None, 3)) + self.assertEqual(knp.empty_like(x).dtype, x.dtype) + def test_exp(self): x = KerasTensor((None, 3)) self.assertEqual(knp.exp(x).shape, (None, 3)) @@ -2139,6 +2144,11 @@ def test_dot(self): y = KerasTensor((2, 3)) knp.dot(x, y) + def test_empty_like(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.empty_like(x).shape, (2, 3)) + self.assertEqual(knp.empty_like(x).dtype, x.dtype) + def test_exp(self): x = KerasTensor((2, 3)) self.assertEqual(knp.exp(x).shape, (2, 3)) @@ -7328,6 +7338,22 @@ def test_empty(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_empty_like(self, dtype): + import jax.numpy as jnp + + x = jnp.empty([2, 3, 4], dtype=dtype) + expected_dtype = standardize_dtype(jnp.empty_like(x, dtype=dtype).dtype) + + self.assertEqual( + standardize_dtype(knp.empty_like(x, dtype=dtype).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.EmptyLike().symbolic_call(x).dtype), + expected_dtype, + ) + @parameterized.named_parameters( named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) )