Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix istft and add class TestMathErrors in ops/math_test.py #19594

Merged
2 changes: 1 addition & 1 deletion keras/api/_tf_keras/keras/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from keras.src.losses import get
from keras.src.losses import serialize
from keras.src.losses.loss import Loss
from keras.src.losses.losses import CTC
from keras.src.losses.losses import BinaryCrossentropy
from keras.src.losses.losses import BinaryFocalCrossentropy
from keras.src.losses.losses import CTC
from keras.src.losses.losses import CategoricalCrossentropy
from keras.src.losses.losses import CategoricalFocalCrossentropy
from keras.src.losses.losses import CategoricalHinge
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def istft(
x = _get_complex_tensor_from_tuple(x)
dtype = jnp.real(x).dtype

if len(x.shape) < 2:
raise ValueError(
f"Input `x` must have at least 2 dimensions. "
f"Received shape: {x.shape}"
)

expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)
l_pad = (fft_length - sequence_length) // 2
r_pad = fft_length - sequence_length - l_pad
Expand Down
9 changes: 9 additions & 0 deletions keras/src/ops/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def test_qr(self):
self.assertEqual(q.shape, qref_shape)
self.assertEqual(r.shape, rref_shape)

def test_qr_invalid_mode(self):
# backend agnostic error message
x = np.array([[1, 2], [3, 4]])
invalid_mode = "invalid_mode"
with self.assertRaisesRegex(
ValueError, "Expected one of {'reduced', 'complete'}."
):
linalg.qr(x, mode=invalid_mode)

def test_solve(self):
a = KerasTensor([None, 20, 20])
b = KerasTensor([None, 20, 5])
Expand Down
88 changes: 88 additions & 0 deletions keras/src/ops/math_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math

import jax.numpy as jnp
import numpy as np
import pytest
import scipy.signal
Expand Down Expand Up @@ -1256,3 +1257,90 @@ def test_undefined_fft_length_and_last_dimension(self):
expected_shape = real_part.shape[:-1] + (None,)

self.assertEqual(output_spec.shape, expected_shape)


class TestMathErrors(testing.TestCase):

@pytest.mark.skipif(
backend.backend() != "jax", reason="Testing Jax errors only"
)
def test_segment_sum_no_num_segments(self):
data = jnp.array([1, 2, 3, 4])
segment_ids = jnp.array([0, 0, 1, 1])
with self.assertRaisesRegex(
ValueError,
"Argument `num_segments` must be set when using the JAX backend.",
):
kmath.segment_sum(data, segment_ids)

@pytest.mark.skipif(
backend.backend() != "jax", reason="Testing Jax errors only"
)
def test_segment_max_no_num_segments(self):
data = jnp.array([1, 2, 3, 4])
segment_ids = jnp.array([0, 0, 1, 1])
with self.assertRaisesRegex(
ValueError,
"Argument `num_segments` must be set when using the JAX backend.",
):
kmath.segment_max(data, segment_ids)

def test_stft_invalid_input_type(self):
# backend agnostic error message
x = np.array([1, 2, 3, 4])
sequence_length = 2
sequence_stride = 1
fft_length = 4
with self.assertRaisesRegex(TypeError, "`float32` or `float64`"):
kmath.stft(x, sequence_length, sequence_stride, fft_length)

def test_invalid_fft_length(self):
# backend agnostic error message
x = np.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 4
sequence_stride = 1
fft_length = 2
with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"):
kmath.stft(x, sequence_length, sequence_stride, fft_length)

def test_stft_invalid_window(self):
# backend agnostic error message
x = np.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 2
sequence_stride = 1
fft_length = 4
window = "invalid_window"
with self.assertRaisesRegex(ValueError, "If a string is passed to"):
kmath.stft(
x, sequence_length, sequence_stride, fft_length, window=window
)

def test_stft_invalid_window_shape(self):
# backend agnostic error message
x = np.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 2
sequence_stride = 1
fft_length = 4
window = np.ones((sequence_length + 1))
with self.assertRaisesRegex(ValueError, "The shape of `window` must"):
kmath.stft(
x, sequence_length, sequence_stride, fft_length, window=window
)

def test_istft_invalid_window_shape_2D_inputs(self):
# backend agnostic error message
x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]]))
sequence_length = 2
sequence_stride = 1
fft_length = 4
incorrect_window = np.ones((sequence_length + 1,))
with self.assertRaisesRegex(
ValueError, "The shape of `window` must be equal to"
):
kmath.istft(
x,
sequence_length,
sequence_stride,
fft_length,
window=incorrect_window,
)