Skip to content

Commit

Permalink
[Keras Ops] Add Histogram Operation (#20316)
Browse files Browse the repository at this point in the history
* add histogram operation to keras.ops

* update docstrings

* extract values from torch op
  • Loading branch information
DavidLandup0 authored Oct 2, 2024
1 parent 084b7e1 commit 08910e2
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from keras.src.ops.math import extract_sequences
from keras.src.ops.math import fft
from keras.src.ops.math import fft2
from keras.src.ops.math import histogram
from keras.src.ops.math import in_top_k
from keras.src.ops.math import irfft
from keras.src.ops.math import istft
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from keras.src.ops.math import extract_sequences
from keras.src.ops.math import fft
from keras.src.ops.math import fft2
from keras.src.ops.math import histogram
from keras.src.ops.math import in_top_k
from keras.src.ops.math import irfft
from keras.src.ops.math import istft
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,7 @@ def logdet(x):
# `np.log(np.linalg.det(x))`. See
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
return slogdet(x)[1]


def histogram(x, bins, range):
return jnp.histogram(x, bins=bins, range=range)
4 changes: 4 additions & 0 deletions keras/src/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,7 @@ def logdet(x):
# In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
return slogdet(x)[1]


def histogram(x, bins, range):
return np.histogram(x, bins=bins, range=range)
28 changes: 28 additions & 0 deletions keras/src/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,31 @@ def norm(x, ord=None, axis=None, keepdims=False):
def logdet(x):
x = convert_to_tensor(x)
return tf.linalg.logdet(x)


def histogram(x, bins, range):
"""
Computes a histogram of the data tensor `x` using TensorFlow.
The `tf.histogram_fixed_width()` and `tf.histogram_fixed_width_bins()`
methods yielded slight numerical differences on some edge cases.
"""

x = tf.convert_to_tensor(x, dtype=x.dtype)

# Handle the range argument
if range is None:
min_val = tf.reduce_min(x)
max_val = tf.reduce_max(x)
else:
min_val, max_val = range

x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
bin_edges = tf.linspace(min_val, max_val, bins + 1)
bin_edges_list = bin_edges.numpy().tolist()
bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])

bin_counts = tf.math.bincount(
bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype
)

return bin_counts, bin_edges
5 changes: 5 additions & 0 deletions keras/src/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,8 @@ def norm(x, ord=None, axis=None, keepdims=False):
def logdet(x):
x = convert_to_tensor(x)
return torch.logdet(x)


def histogram(x, bins, range):
hist_result = torch.histogram(x, bins=bins, range=range)
return hist_result.hist, hist_result.bin_edges
86 changes: 86 additions & 0 deletions keras/src/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,89 @@ def logdet(x):
if any_symbolic_tensors((x,)):
return Logdet().symbolic_call(x)
return backend.math.logdet(x)


class Histogram(Operation):
def __init__(self, bins=10, range=None):
super().__init__()

if not isinstance(bins, int):
raise TypeError("bins must be of type `int`")
if bins < 0:
raise ValueError("`bins` should be a non-negative integer")

if range:
if len(range) < 2 or not isinstance(range, tuple):
raise ValueError("range must be a tuple of two elements")

if range[1] < range[0]:
raise ValueError(
"The second element of range must be greater than the first"
)

self.bins = bins
self.range = range

def call(self, x):
x = backend.convert_to_tensor(x)
if len(x.shape) > 1:
raise ValueError("Input tensor must be 1-dimensional")
return backend.math.histogram(x, bins=self.bins, range=self.range)

def compute_output_spec(self, x):
return (
KerasTensor(shape=(self.bins,), dtype=x.dtype),
KerasTensor(shape=(self.bins + 1,), dtype=x.dtype),
)


@keras_export("keras.ops.histogram")
def histogram(x, bins=10, range=None):
"""Computes a histogram of the data tensor `x`.
Args:
x: Input tensor.
bins: An integer representing the number of histogram bins.
Defaults to 10.
range: A tuple representing the lower and upper range of the bins.
If not specified, it will use the min and max of `x`.
Returns:
A tuple containing:
- A tensor representing the counts of elements in each bin.
- A tensor representing the bin edges.
Example:
```
>>> nput_tensor = np.random.rand(8)
>>> keras.ops.histogram(input_tensor)
(array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32),
array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262,
0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101,
0.85892869]))
```
"""

if not isinstance(bins, int):
raise TypeError("bins must be of type `int`")
if bins < 0:
raise ValueError("`bins` should be a non-negative integer")

if range:
if len(range) < 2 or not isinstance(range, tuple):
raise ValueError("range must be a tuple of two elements")

if range[1] < range[0]:
raise ValueError(
"The second element of range must be greater than the first"
)

if any_symbolic_tensors((x,)):
return Histogram(bins=bins, range=range).symbolic_call(x)

x = backend.convert_to_tensor(x)
if len(x.shape) > 1:
raise ValueError("Input tensor must be 1-dimensional")
return backend.math.histogram(x, bins=bins, range=range)
112 changes: 112 additions & 0 deletions keras/src/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,3 +1468,115 @@ def test_istft_invalid_window_shape_2D_inputs(self):
fft_length,
window=incorrect_window,
)


class HistogramTest(testing.TestCase):
def test_histogram_default_args(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

# Expected output
expected_counts, expected_edges = np.histogram(input_tensor)

counts, edges = hist_op(input_tensor)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_custom_bins(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)
bins = 5

# Expected output
expected_counts, expected_edges = np.histogram(input_tensor, bins=bins)

counts, edges = hist_op(input_tensor, bins=bins)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_custom_range(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(10)
range_specified = (2, 8)

# Expected output
expected_counts, expected_edges = np.histogram(
input_tensor, range=range_specified
)

counts, edges = hist_op(input_tensor, range=range_specified)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_symbolic_input(self):
hist_op = kmath.histogram
input_tensor = KerasTensor(shape=(None,), dtype="float32")

counts, edges = hist_op(input_tensor)

self.assertEqual(counts.shape, (10,))
self.assertEqual(edges.shape, (11,))

def test_histogram_non_integer_bins_raises_error(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

with self.assertRaisesRegex(
ValueError, "`bins` should be a non-negative integer"
):
hist_op(input_tensor, bins=-5)

def test_histogram_range_validation(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

with self.assertRaisesRegex(
ValueError, "range must be a tuple of two elements"
):
hist_op(input_tensor, range=(1,))

with self.assertRaisesRegex(
ValueError,
"The second element of range must be greater than the first",
):
hist_op(input_tensor, range=(5, 1))

def test_histogram_large_values(self):
hist_op = kmath.histogram
input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10])

counts, edges = hist_op(input_tensor, bins=5)

expected_counts, expected_edges = np.histogram(input_tensor, bins=5)

self.assertAllClose(counts, expected_counts)
self.assertAllClose(edges, expected_edges)

def test_histogram_float_input(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

counts, edges = hist_op(input_tensor, bins=5)

expected_counts, expected_edges = np.histogram(input_tensor, bins=5)

self.assertAllClose(counts, expected_counts)
self.assertAllClose(edges, expected_edges)

def test_histogram_high_dimensional_input(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(3, 4, 5)

with self.assertRaisesRegex(
ValueError, "Input tensor must be 1-dimensional"
):
hist_op(input_tensor)

0 comments on commit 08910e2

Please sign in to comment.