Skip to content

Commit bfde12b

Browse files
authored
Implement empty_like function in keras.ops (#21840)
* Add empty_like function * Update empty_like function for openvino * correct function based on gemini advice * Update the code based on review
1 parent 4d30a7f commit bfde12b

File tree

12 files changed

+98
-0
lines changed

12 files changed

+98
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
from keras.src.ops.numpy import dot as dot
180180
from keras.src.ops.numpy import einsum as einsum
181181
from keras.src.ops.numpy import empty as empty
182+
from keras.src.ops.numpy import empty_like as empty_like
182183
from keras.src.ops.numpy import equal as equal
183184
from keras.src.ops.numpy import exp as exp
184185
from keras.src.ops.numpy import exp2 as exp2

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from keras.src.ops.numpy import dot as dot
6666
from keras.src.ops.numpy import einsum as einsum
6767
from keras.src.ops.numpy import empty as empty
68+
from keras.src.ops.numpy import empty_like as empty_like
6869
from keras.src.ops.numpy import equal as equal
6970
from keras.src.ops.numpy import exp as exp
7071
from keras.src.ops.numpy import exp2 as exp2

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
from keras.src.ops.numpy import dot as dot
180180
from keras.src.ops.numpy import einsum as einsum
181181
from keras.src.ops.numpy import empty as empty
182+
from keras.src.ops.numpy import empty_like as empty_like
182183
from keras.src.ops.numpy import equal as equal
183184
from keras.src.ops.numpy import exp as exp
184185
from keras.src.ops.numpy import exp2 as exp2

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from keras.src.ops.numpy import dot as dot
6666
from keras.src.ops.numpy import einsum as einsum
6767
from keras.src.ops.numpy import empty as empty
68+
from keras.src.ops.numpy import empty_like as empty_like
6869
from keras.src.ops.numpy import equal as equal
6970
from keras.src.ops.numpy import exp as exp
7071
from keras.src.ops.numpy import exp2 as exp2

keras/src/backend/jax/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,10 @@ def empty(shape, dtype=None):
678678
return jnp.empty(shape, dtype=dtype)
679679

680680

681+
def empty_like(x, dtype=None):
682+
return jnp.empty_like(x, dtype=dtype)
683+
684+
681685
def equal(x1, x2):
682686
x1 = convert_to_tensor(x1)
683687
x2 = convert_to_tensor(x2)

keras/src/backend/numpy/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,10 @@ def empty(shape, dtype=None):
612612
return np.empty(shape, dtype=dtype)
613613

614614

615+
def empty_like(x, dtype=None):
616+
return np.empty_like(x, dtype=dtype)
617+
618+
615619
def equal(x1, x2):
616620
return np.equal(x1, x2)
617621

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ NumpyDtypeTest::test_argpartition
55
NumpyDtypeTest::test_array
66
NumpyDtypeTest::test_bartlett
77
NumpyDtypeTest::test_blackman
8+
NumpyDtypeTest::test_empty_like
89
NumpyDtypeTest::test_gcd
910
NumpyDtypeTest::test_hamming
1011
NumpyDtypeTest::test_hanning
@@ -113,6 +114,7 @@ NumpyTwoInputOpsCorrectnessTest::test_vdot
113114
NumpyOneInputOpsDynamicShapeTest::test_angle
114115
NumpyOneInputOpsDynamicShapeTest::test_bartlett
115116
NumpyOneInputOpsDynamicShapeTest::test_blackman
117+
NumpyOneInputOpsDynamicShapeTest::test_empty_like
116118
NumpyOneInputOpsDynamicShapeTest::test_cbrt
117119
NumpyOneInputOpsDynamicShapeTest::test_corrcoef
118120
NumpyOneInputOpsDynamicShapeTest::test_hamming

keras/src/backend/openvino/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,12 @@ def empty(shape, dtype=None):
896896
return OpenVINOKerasTensor(empty_tensor)
897897

898898

899+
def empty_like(x, dtype=None):
900+
raise NotImplementedError(
901+
"`empty_like` is not supported with openvino backend"
902+
)
903+
904+
899905
def equal(x1, x2):
900906
element_type = None
901907
if isinstance(x1, OpenVINOKerasTensor):

keras/src/backend/tensorflow/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,10 @@ def empty(shape, dtype=None):
14891489
return tf.zeros(shape, dtype=dtype)
14901490

14911491

1492+
def empty_like(x, dtype=None):
1493+
return tf.zeros_like(x, dtype=dtype)
1494+
1495+
14921496
def equal(x1, x2):
14931497
x1 = convert_to_tensor(x1)
14941498
x2 = convert_to_tensor(x2)

keras/src/backend/torch/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,12 @@ def empty(shape, dtype=None):
770770
return torch.empty(size=shape, dtype=dtype, device=get_device())
771771

772772

773+
def empty_like(x, dtype=None):
774+
x = convert_to_tensor(x)
775+
dtype = to_torch_dtype(dtype or x.dtype)
776+
return torch.empty_like(x, dtype=dtype, device=get_device())
777+
778+
773779
def equal(x1, x2):
774780
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
775781
return torch.eq(x1, x2)

0 commit comments

Comments
 (0)