diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 361eeee8917..4119f744e1a 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -204,6 +204,12 @@ def istft( x = _get_complex_tensor_from_tuple(x) dtype = jnp.real(x).dtype + if len(x.shape) < 2: + raise ValueError( + f"Input `x` must have at least 2 dimensions. " + f"Received shape: {x.shape}" + ) + expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) l_pad = (fft_length - sequence_length) // 2 r_pad = fft_length - sequence_length - l_pad diff --git a/keras/src/ops/linalg_test.py b/keras/src/ops/linalg_test.py index e1f0decf64b..a2fd0c61aad 100644 --- a/keras/src/ops/linalg_test.py +++ b/keras/src/ops/linalg_test.py @@ -101,6 +101,15 @@ def test_qr(self): self.assertEqual(q.shape, qref_shape) self.assertEqual(r.shape, rref_shape) + def test_qr_invalid_mode(self): + # backend agnostic error message + x = np.array([[1, 2], [3, 4]]) + invalid_mode = "invalid_mode" + with self.assertRaisesRegex( + ValueError, "Expected one of {'reduced', 'complete'}." + ): + linalg.qr(x, mode=invalid_mode) + def test_solve(self): a = KerasTensor([None, 20, 20]) b = KerasTensor([None, 20, 5]) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index 60db9fc70f6..86e3c70a78e 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -1,5 +1,6 @@ import math +import jax.numpy as jnp import numpy as np import pytest import scipy.signal @@ -1256,3 +1257,90 @@ def test_undefined_fft_length_and_last_dimension(self): expected_shape = real_part.shape[:-1] + (None,) self.assertEqual(output_spec.shape, expected_shape) + + +class TestMathErrors(testing.TestCase): + + @pytest.mark.skipif( + backend.backend() != "jax", reason="Testing Jax errors only" + ) + def test_segment_sum_no_num_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + kmath.segment_sum(data, segment_ids) + + @pytest.mark.skipif( + backend.backend() != "jax", reason="Testing Jax errors only" + ) + def test_segment_max_no_num_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + kmath.segment_max(data, segment_ids) + + def test_stft_invalid_input_type(self): + # backend agnostic error message + x = np.array([1, 2, 3, 4]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + with self.assertRaisesRegex(TypeError, "`float32` or `float64`"): + kmath.stft(x, sequence_length, sequence_stride, fft_length) + + def test_invalid_fft_length(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 4 + sequence_stride = 1 + fft_length = 2 + with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"): + kmath.stft(x, sequence_length, sequence_stride, fft_length) + + def test_stft_invalid_window(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = "invalid_window" + with self.assertRaisesRegex(ValueError, "If a string is passed to"): + kmath.stft( + x, sequence_length, sequence_stride, fft_length, window=window + ) + + def test_stft_invalid_window_shape(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = np.ones((sequence_length + 1)) + with self.assertRaisesRegex(ValueError, "The shape of `window` must"): + kmath.stft( + x, sequence_length, sequence_stride, fft_length, window=window + ) + + def test_istft_invalid_window_shape_2D_inputs(self): + # backend agnostic error message + x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]])) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + incorrect_window = np.ones((sequence_length + 1,)) + with self.assertRaisesRegex( + ValueError, "The shape of `window` must be equal to" + ): + kmath.istft( + x, + sequence_length, + sequence_stride, + fft_length, + window=incorrect_window, + )