From 08910e26bf1e3973e22ca29fbc8f2b700bbddd0c Mon Sep 17 00:00:00 2001 From: David Landup <60978046+DavidLandup0@users.noreply.github.com> Date: Wed, 2 Oct 2024 12:31:53 +0900 Subject: [PATCH] [Keras Ops] Add Histogram Operation (#20316) * add histogram operation to keras.ops * update docstrings * extract values from torch op --- keras/api/_tf_keras/keras/ops/__init__.py | 1 + keras/api/ops/__init__.py | 1 + keras/src/backend/jax/math.py | 4 + keras/src/backend/numpy/math.py | 4 + keras/src/backend/tensorflow/math.py | 28 ++++++ keras/src/backend/torch/math.py | 5 + keras/src/ops/math.py | 86 +++++++++++++++++ keras/src/ops/math_test.py | 112 ++++++++++++++++++++++ 8 files changed, 241 insertions(+) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 0a4a04bedd3..a1730855e11 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -47,6 +47,7 @@ from keras.src.ops.math import extract_sequences from keras.src.ops.math import fft from keras.src.ops.math import fft2 +from keras.src.ops.math import histogram from keras.src.ops.math import in_top_k from keras.src.ops.math import irfft from keras.src.ops.math import istft diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 0a4a04bedd3..a1730855e11 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -47,6 +47,7 @@ from keras.src.ops.math import extract_sequences from keras.src.ops.math import fft from keras.src.ops.math import fft2 +from keras.src.ops.math import histogram from keras.src.ops.math import in_top_k from keras.src.ops.math import irfft from keras.src.ops.math import istft diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 18ba91862a9..11c96086ced 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -294,3 +294,7 @@ def logdet(x): # `np.log(np.linalg.det(x))`. See # https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html return slogdet(x)[1] + + +def histogram(x, bins, range): + return jnp.histogram(x, bins=bins, range=range) diff --git a/keras/src/backend/numpy/math.py b/keras/src/backend/numpy/math.py index f9448c92b93..a40cd569578 100644 --- a/keras/src/backend/numpy/math.py +++ b/keras/src/backend/numpy/math.py @@ -316,3 +316,7 @@ def logdet(x): # In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See # https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html return slogdet(x)[1] + + +def histogram(x, bins, range): + return np.histogram(x, bins=bins, range=range) diff --git a/keras/src/backend/tensorflow/math.py b/keras/src/backend/tensorflow/math.py index f034cf429e1..3e50fbc9773 100644 --- a/keras/src/backend/tensorflow/math.py +++ b/keras/src/backend/tensorflow/math.py @@ -370,3 +370,31 @@ def norm(x, ord=None, axis=None, keepdims=False): def logdet(x): x = convert_to_tensor(x) return tf.linalg.logdet(x) + + +def histogram(x, bins, range): + """ + Computes a histogram of the data tensor `x` using TensorFlow. + The `tf.histogram_fixed_width()` and `tf.histogram_fixed_width_bins()` + methods yielded slight numerical differences on some edge cases. + """ + + x = tf.convert_to_tensor(x, dtype=x.dtype) + + # Handle the range argument + if range is None: + min_val = tf.reduce_min(x) + max_val = tf.reduce_max(x) + else: + min_val, max_val = range + + x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val)) + bin_edges = tf.linspace(min_val, max_val, bins + 1) + bin_edges_list = bin_edges.numpy().tolist() + bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1]) + + bin_counts = tf.math.bincount( + bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype + ) + + return bin_counts, bin_edges diff --git a/keras/src/backend/torch/math.py b/keras/src/backend/torch/math.py index e2e80e9358c..e05d358e901 100644 --- a/keras/src/backend/torch/math.py +++ b/keras/src/backend/torch/math.py @@ -419,3 +419,8 @@ def norm(x, ord=None, axis=None, keepdims=False): def logdet(x): x = convert_to_tensor(x) return torch.logdet(x) + + +def histogram(x, bins, range): + hist_result = torch.histogram(x, bins=bins, range=range) + return hist_result.hist, hist_result.bin_edges diff --git a/keras/src/ops/math.py b/keras/src/ops/math.py index fd0a41d5177..749142b6a6d 100644 --- a/keras/src/ops/math.py +++ b/keras/src/ops/math.py @@ -971,3 +971,89 @@ def logdet(x): if any_symbolic_tensors((x,)): return Logdet().symbolic_call(x) return backend.math.logdet(x) + + +class Histogram(Operation): + def __init__(self, bins=10, range=None): + super().__init__() + + if not isinstance(bins, int): + raise TypeError("bins must be of type `int`") + if bins < 0: + raise ValueError("`bins` should be a non-negative integer") + + if range: + if len(range) < 2 or not isinstance(range, tuple): + raise ValueError("range must be a tuple of two elements") + + if range[1] < range[0]: + raise ValueError( + "The second element of range must be greater than the first" + ) + + self.bins = bins + self.range = range + + def call(self, x): + x = backend.convert_to_tensor(x) + if len(x.shape) > 1: + raise ValueError("Input tensor must be 1-dimensional") + return backend.math.histogram(x, bins=self.bins, range=self.range) + + def compute_output_spec(self, x): + return ( + KerasTensor(shape=(self.bins,), dtype=x.dtype), + KerasTensor(shape=(self.bins + 1,), dtype=x.dtype), + ) + + +@keras_export("keras.ops.histogram") +def histogram(x, bins=10, range=None): + """Computes a histogram of the data tensor `x`. + + Args: + x: Input tensor. + bins: An integer representing the number of histogram bins. + Defaults to 10. + range: A tuple representing the lower and upper range of the bins. + If not specified, it will use the min and max of `x`. + + Returns: + A tuple containing: + - A tensor representing the counts of elements in each bin. + - A tensor representing the bin edges. + + Example: + + ``` + >>> nput_tensor = np.random.rand(8) + >>> keras.ops.histogram(input_tensor) + (array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32), + array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262, + 0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101, + 0.85892869])) + ``` + + """ + + if not isinstance(bins, int): + raise TypeError("bins must be of type `int`") + if bins < 0: + raise ValueError("`bins` should be a non-negative integer") + + if range: + if len(range) < 2 or not isinstance(range, tuple): + raise ValueError("range must be a tuple of two elements") + + if range[1] < range[0]: + raise ValueError( + "The second element of range must be greater than the first" + ) + + if any_symbolic_tensors((x,)): + return Histogram(bins=bins, range=range).symbolic_call(x) + + x = backend.convert_to_tensor(x) + if len(x.shape) > 1: + raise ValueError("Input tensor must be 1-dimensional") + return backend.math.histogram(x, bins=bins, range=range) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index 09c87514c78..09bcb9503fd 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -1468,3 +1468,115 @@ def test_istft_invalid_window_shape_2D_inputs(self): fft_length, window=incorrect_window, ) + + +class HistogramTest(testing.TestCase): + def test_histogram_default_args(self): + hist_op = kmath.histogram + input_tensor = np.random.rand(8) + + # Expected output + expected_counts, expected_edges = np.histogram(input_tensor) + + counts, edges = hist_op(input_tensor) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_custom_bins(self): + hist_op = kmath.histogram + input_tensor = np.random.rand(8) + bins = 5 + + # Expected output + expected_counts, expected_edges = np.histogram(input_tensor, bins=bins) + + counts, edges = hist_op(input_tensor, bins=bins) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_custom_range(self): + hist_op = kmath.histogram + input_tensor = np.random.rand(10) + range_specified = (2, 8) + + # Expected output + expected_counts, expected_edges = np.histogram( + input_tensor, range=range_specified + ) + + counts, edges = hist_op(input_tensor, range=range_specified) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_symbolic_input(self): + hist_op = kmath.histogram + input_tensor = KerasTensor(shape=(None,), dtype="float32") + + counts, edges = hist_op(input_tensor) + + self.assertEqual(counts.shape, (10,)) + self.assertEqual(edges.shape, (11,)) + + def test_histogram_non_integer_bins_raises_error(self): + hist_op = kmath.histogram + input_tensor = np.random.rand(8) + + with self.assertRaisesRegex( + ValueError, "`bins` should be a non-negative integer" + ): + hist_op(input_tensor, bins=-5) + + def test_histogram_range_validation(self): + hist_op = kmath.histogram + input_tensor = np.random.rand(8) + + with self.assertRaisesRegex( + ValueError, "range must be a tuple of two elements" + ): + hist_op(input_tensor, range=(1,)) + + with self.assertRaisesRegex( + ValueError, + "The second element of range must be greater than the first", + ): + hist_op(input_tensor, range=(5, 1)) + + def test_histogram_large_values(self): + hist_op = kmath.histogram + input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10]) + + counts, edges = hist_op(input_tensor, bins=5) + + expected_counts, expected_edges = np.histogram(input_tensor, bins=5) + + self.assertAllClose(counts, expected_counts) + self.assertAllClose(edges, expected_edges) + + def test_histogram_float_input(self): + hist_op = kmath.histogram + input_tensor = np.random.rand(8) + + counts, edges = hist_op(input_tensor, bins=5) + + expected_counts, expected_edges = np.histogram(input_tensor, bins=5) + + self.assertAllClose(counts, expected_counts) + self.assertAllClose(edges, expected_edges) + + def test_histogram_high_dimensional_input(self): + hist_op = kmath.histogram + input_tensor = np.random.rand(3, 4, 5) + + with self.assertRaisesRegex( + ValueError, "Input tensor must be 1-dimensional" + ): + hist_op(input_tensor)