From 26752108c67927fd17c3d96f862fabca0b008d26 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 8 Aug 2022 20:00:28 +0200 Subject: [PATCH 01/14] Enable batch support for `windowed_mean|variance` --- .../python/stats/sample_stats.py | 110 +++++++++--------- .../python/stats/sample_stats_test.py | 81 ++++++++++++- 2 files changed, 135 insertions(+), 56 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 1c9e82166f..94c8a87cd2 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -14,10 +14,18 @@ # ============================================================================ """Functions for computing statistics of samples.""" +JAX_MODE = False +NUMPY_MODE = False + # Dependency imports import numpy as np import tensorflow.compat.v2 as tf +if JAX_MODE or NUMPY_MODE: + tnp = np +else: + import tensorflow.experimental.numpy as tnp + from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util @@ -712,7 +720,7 @@ def windowed_variance( Computes variances among data in the Tensor `x` along the given windows: - result[i] = variance(x[low_indices[i]:high_indices[i]+1]) + result[i] = variance(x[low_indices[i]:high_indices[i]]) accurately and efficiently. To wit, if K is the size of `low_indices` and `high_indices`, and `N` is the size of `x` along @@ -727,10 +735,9 @@ def windowed_variance( last half of an MCMC chain. Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has - rank `axis`, and `low_indices` and `high_indices` broadcast to shape - `[M]`. Then each element of `low_indices` and `high_indices` - must be between 0 and N+1, and the shape of the output will be - `Bx + [M] + E`. Batch shape in the indices is not currently supported. + rank `axis`, and `low_indices` and `high_indices` broadcast to `x`. + Then each element of `low_indices` and `high_indices` must be + between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -769,7 +776,7 @@ def windowed_variance( """ with tf.name_scope(name or 'windowed_variance'): x = tf.convert_to_tensor(x) - low_indices, high_indices, low_counts, high_counts = _prepare_window_args( + x, indices, axis = _prepare_window_args( x, low_indices, high_indices, axis) # We have a problem with indexing: the standard convention demands @@ -786,15 +793,11 @@ def windowed_variance( def index_for_cumulative(indices): return tf.maximum(indices - 1, 0) cum_sums = tf.cumsum(x, axis=axis) - low_sums = tf.gather( - cum_sums, index_for_cumulative(low_indices), axis=axis) - high_sums = tf.gather( - cum_sums, index_for_cumulative(high_indices), axis=axis) + sums = tnp.take_along_axis( + cum_sums, index_for_cumulative(indices), axis=axis) cum_variances = cumulative_variance(x, sample_axis=axis) - low_variances = tf.gather( - cum_variances, index_for_cumulative(low_indices), axis=axis) - high_variances = tf.gather( - cum_variances, index_for_cumulative(high_indices), axis=axis) + variances = tnp.take_along_axis( + cum_variances, index_for_cumulative(indices), axis=axis) # This formula is the binary accurate variance merge from [1], # adapted to subtract and batched across the indexed counts, sums, @@ -812,15 +815,18 @@ def index_for_cumulative(indices): # This formula can also be read as implementing the above variance # computation by "unioning" A u B with a notional "negative B" # multiset. - counts = high_counts - low_counts # |A| - discrepancies = ( - _safe_average(high_sums, high_counts) - - _safe_average(low_sums, low_counts))**2 # (mean(A u B) - mean(B))**2 - adjustments = high_counts * (-low_counts) / counts # |A u B| * -|B| / |A| - residuals = (high_variances * high_counts - - low_variances * low_counts + + bounds = ps.cast(indices, sums.dtype) + counts = bounds[1] - bounds[0] # |A| + sum_averages = tf.math.divide_no_nan(sums, bounds) + # (mean(A u B) - mean(B))**2 + discrepancies = tf.square(sum_averages[1] - sum_averages[0]) + # |A u B| * -|B| / |A| + adjustments = tf.math.divide_no_nan(bounds[1] * (-bounds[0]), counts) + variances_scaled = variances * bounds + residuals = (variances_scaled[1] - + variances_scaled[0] + adjustments * discrepancies) - return _safe_average(residuals, counts) + return tf.math.divide_no_nan(residuals, counts) def windowed_mean( @@ -829,7 +835,7 @@ def windowed_mean( Computes means among data in the Tensor `x` along the given windows: - result[i] = mean(x[low_indices[i]:high_indices[i]+1]) + result[i] = mean(x[low_indices[i]:high_indices[i]]) efficiently. To wit, if K is the size of `low_indices` and `high_indices`, and `N` is the size of `x` along the given `axis`, @@ -842,10 +848,9 @@ def windowed_mean( last half of an MCMC chain. Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has - rank `axis`, and `low_indices` and `high_indices` broadcast to shape - `[M]`. Then each element of `low_indices` and `high_indices` - must be between 0 and N+1, and the shape of the output will be - `Bx + [M] + E`. Batch shape in the indices is not currently supported. + rank `axis`, and `low_indices` and `high_indices` broadcast to `x`. + Then each element of `low_indices` and `high_indices` must be + between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -878,18 +883,17 @@ def windowed_mean( """ with tf.name_scope(name or 'windowed_mean'): x = tf.convert_to_tensor(x) - low_indices, high_indices, low_counts, high_counts = _prepare_window_args( - x, low_indices, high_indices, axis) + x, indices, axis = _prepare_window_args(x, low_indices, high_indices, axis) raw_cumsum = tf.cumsum(x, axis=axis) - cum_sums = tf.concat( - [tf.zeros_like(tf.gather(raw_cumsum, [0], axis=axis)), raw_cumsum], - axis=axis) - low_sums = tf.gather(cum_sums, low_indices, axis=axis) - high_sums = tf.gather(cum_sums, high_indices, axis=axis) - - counts = high_counts - low_counts - return _safe_average(high_sums - low_sums, counts) + rank = ps.rank(x) + paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32), + (rank, 2)) + cum_sums = ps.pad(raw_cumsum, paddings) + sums = tnp.take_along_axis(cum_sums, indices, + axis=axis) + counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype) + return tf.math.divide_no_nan(sums[1] - sums[0], counts) def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): @@ -905,24 +909,20 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): # Broadcast indices together. high_indices = high_indices + tf.zeros_like(low_indices) low_indices = low_indices + tf.zeros_like(high_indices) - - # TODO(axch): Support batch low and high indices. That would - # complicate this shape munging (though tf.gather should work - # fine). - - # We want to place `low_counts` and `high_counts` at the `axis` - # position, so we reshape them to shape `[1, 1, ..., 1, N, 1, ..., - # 1]`, where the `N` is at `axis`. The `counts_shp`, below, - # is this shape. - size = ps.size(high_indices) - counts_shp = ps.one_hot( - axis, depth=ps.rank(x), on_value=size, off_value=1) - - low_counts = tf.reshape(tf.cast(low_indices, dtype=x.dtype), - shape=counts_shp) - high_counts = tf.reshape(tf.cast(high_indices, dtype=x.dtype), - shape=counts_shp) - return low_indices, high_indices, low_counts, high_counts + indices = ps.stack([low_indices, high_indices], axis=0) + x = tf.expand_dims(x, axis=0) + axis += 1 + + if ps.rank(indices) != ps.rank(x) and ps.rank(indices) == 2: + # legacy usage, kept for backward compatibility + size = ps.size(indices) // 2 + bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size, + off_value=1) + bc_shape = ps.concat([[2], bc_shape[1:]], axis=0) + indices = ps.reshape(indices, bc_shape) + # `take_along_axis` requires the type to be int32 + indices = ps.cast(indices, dtype=tf.int32) + return x, indices, axis def _safe_average(totals, counts): diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index 235e32a014..da109752cf 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -15,7 +15,7 @@ """Tests for Sample Stats Ops.""" # Dependency imports - +import functools import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf @@ -679,6 +679,85 @@ def test_windowed_mean_corner_cases(self): self.evaluate(sample_stats.windowed_mean(y))) +@test_util.test_all_tf_execution_regimes +class WindowedStatsTest(test_util.TestCase): + def apply_slice_along_axis(self, func, arr, low, high, axis): + """Applies `func` over slices of `arr` along `axis`. Slices intervals are + specified through `low` and `high`. Support broadcasting. + """ + np.testing.assert_equal(low.shape, high.shape) + ni, _, nk = arr.shape[:axis], arr.shape[axis], arr.shape[axis + 1:] + si, j, sk = low.shape[:axis], low.shape[axis], low.shape[axis + 1:] + mk = max(nk, sk) + mi = max(ni, si) + out = np.empty(mi + (j,) + mk) + for ki in np.ndindex(ni): + for kk in np.ndindex(mk): + ak = tuple(np.mod(kk, nk)) + ik = tuple(np.mod(kk, sk)) + ai = tuple(np.mod(ki, ni)) + ii = tuple(np.mod(ki, si)) + a_1d = arr[ai + np.s_[:, ] + ak] + out_1d = out[ki + np.s_[:, ] + kk] + low_1d = low[ii + np.s_[:, ] + ik] + high_1d = high[ii + np.s_[:, ] + ik] + + for r in range(j): + out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]]) + return out + def check_gaussian_windowed(self, shape, indice_shape, axis, + window_func, np_func): + stat_shape = np.array(shape).astype(np.int32) + stat_shape[axis] = 1 + loc = np.arange(np.prod(stat_shape)).reshape(stat_shape) + scale = 0.1 * np.arange(np.prod(stat_shape)).reshape(stat_shape) + rng = test_util.test_np_rng() + x = rng.normal(loc=loc, scale=scale, size=shape) + indice_shape = [2] + list(indice_shape) + indices = rng.randint(shape[axis] + 1, size=indice_shape) + indices = np.sort(indices, axis=0) + low_indices, high_indices = indices[0], indices[1] + a = window_func(x, low_indices=low_indices, + high_indices=high_indices, axis=axis) + b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices, + axis=axis) + b[np.isnan(b)] = 0 # We treat stats computed on empty sets as zeros + self.assertAllClose(a, b) + + def check_windowed(self, func, numpy_func): + check_fn = functools.partial(self.check_gaussian_windowed, + window_func=func, np_func=numpy_func) + check_fn((64, 4, 8), (128, 1, 1), axis=0) + check_fn((64, 4, 8), (32, 1, 1), axis=0) + check_fn((64, 4, 8), (32, 4, 1), axis=0) + check_fn((64, 4, 8), (32, 4, 8), axis=0) + check_fn((64, 4, 8), (64, 64, 1), axis=1) + check_fn((64, 4, 8), (1, 64, 1), axis=1) + check_fn((64, 4, 8), (64, 2, 8), axis=1) + check_fn((64, 4, 8), (64, 4, 64), axis=2) + check_fn((64, 4, 8), (1, 1, 64), axis=2) + check_fn((64, 4, 8), (64, 4, 4), axis=2) + check_fn((64, 4, 8), (1, 1, 4), axis=2) + + with self.assertRaises(Exception): + # Non broadcastable shapes + check_fn((64, 4, 8), (4, 1, 4), axis=2) + + def test_windowed_mean(self): + self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean) + + def test_windowed_mean_graph(self): + func = tf.function(tfp.stats.windowed_mean) + self.check_windowed(func=func, numpy_func=np.mean) + + def test_windowed_variance(self): + self.check_windowed(func=tfp.stats.windowed_variance, numpy_func=np.var) + + def test_windowed_variance_graph(self): + func = tf.function(tfp.stats.windowed_variance) + self.check_windowed(func=func, numpy_func=np.var) + + @test_util.test_all_tf_execution_regimes class LogAverageProbsTest(test_util.TestCase): From d48cdfcd6e1a62734294d02fe55fdd80780b6ad2 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 8 Aug 2022 20:52:56 +0200 Subject: [PATCH 02/14] Remove unused function Add test cases --- tensorflow_probability/python/stats/sample_stats.py | 7 ------- tensorflow_probability/python/stats/sample_stats_test.py | 3 +++ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 94c8a87cd2..2f3aeda03d 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -925,13 +925,6 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): return x, indices, axis -def _safe_average(totals, counts): - # This tf.where protects `totals` from getting a gradient signal - # when `counts` is 0. - safe_totals = tf.where(~tf.equal(counts, 0), totals, 0) - return tf.where(~tf.equal(counts, 0), safe_totals / counts, 0) - - def log_average_probs(logits, sample_axis=0, event_axis=None, keepdims=False, validate_args=False, name=None): """Computes `log(average(to_probs(logits)))` in a numerically stable manner. diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index da109752cf..d9f4c6ed6d 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -731,13 +731,16 @@ def check_windowed(self, func, numpy_func): check_fn((64, 4, 8), (32, 1, 1), axis=0) check_fn((64, 4, 8), (32, 4, 1), axis=0) check_fn((64, 4, 8), (32, 4, 8), axis=0) + check_fn((64, 4, 8), (64, 4, 8), axis=0) check_fn((64, 4, 8), (64, 64, 1), axis=1) check_fn((64, 4, 8), (1, 64, 1), axis=1) check_fn((64, 4, 8), (64, 2, 8), axis=1) + check_fn((64, 4, 8), (64, 4, 8), axis=1) check_fn((64, 4, 8), (64, 4, 64), axis=2) check_fn((64, 4, 8), (1, 1, 64), axis=2) check_fn((64, 4, 8), (64, 4, 4), axis=2) check_fn((64, 4, 8), (1, 1, 4), axis=2) + check_fn((64, 4, 8), (64, 4, 8), axis=2) with self.assertRaises(Exception): # Non broadcastable shapes From e02054300db7e66cc5afe71f2fc6f20eb7cfc860 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 8 Aug 2022 22:27:22 +0200 Subject: [PATCH 03/14] Doc fix Replace `**2` with `tf.square` --- .../python/stats/sample_stats.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 2f3aeda03d..df5381fdaf 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -702,8 +702,8 @@ def cumulative_variance(x, sample_axis=0, name=None): excl_counts = tf.reshape(tf.range(size, dtype=x.dtype), shape=counts_shp) incl_counts = excl_counts + 1 excl_sums = tf.cumsum(x, axis=sample_axis, exclusive=True) - discrepancies = (excl_sums / excl_counts - x)**2 - discrepancies = tf.where(excl_counts == 0, x**2, discrepancies) + discrepancies = tf.math.square(excl_sums / excl_counts - x) + discrepancies = tf.where(excl_counts == 0, tf.math.square(x), discrepancies) adjustments = excl_counts / incl_counts # The zeroth item's residual contribution is 0, because it has no # other items to vary from. The preceding expressions, however, @@ -734,8 +734,10 @@ def windowed_variance( trailing-window estimators from some iterative process, such as the last half of an MCMC chain. - Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has - rank `axis`, and `low_indices` and `high_indices` broadcast to `x`. + Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` + have shape `Bi + [M] + F`, such that: + - `rank(Bx) = rank(Bi) = axis`, + - `Bi + [1] + F` broadcasts to `Bx + [N] + E`. Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. @@ -847,8 +849,10 @@ def windowed_mean( trailing-window estimators from some iterative process, such as the last half of an MCMC chain. - Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has - rank `axis`, and `low_indices` and `high_indices` broadcast to `x`. + Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` + have shape `Bi + [M] + F`, such that: + - `rank(Bx) = rank(Bi) = axis`, + - `Bi + [1] + F` broadcasts to `Bx + [N] + E`. Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. From c28faa5f758ef13344bb1be5aafa11a29e57b79d Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 8 Aug 2022 23:46:53 +0200 Subject: [PATCH 04/14] Allow lower rank indices --- .../python/stats/sample_stats.py | 51 +++++++++++++------ .../python/stats/sample_stats_test.py | 32 ++++++++++++ 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index df5381fdaf..6ae297efea 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -735,12 +735,18 @@ def windowed_variance( last half of an MCMC chain. Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` - have shape `Bi + [M] + F`, such that: - - `rank(Bx) = rank(Bi) = axis`, - - `Bi + [1] + F` broadcasts to `Bx + [N] + E`. + have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`. Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. + The shape of indices must be broadcastable with `x` unless the rank is lower + than the rank of `x`, then the shape is expanded with extra inner dimensions + to match the rank of `x`. + + In the special case where the rank of indices is one, i.e when + `rank(Bi) = rank(F) = 0`, the indices are reshaped to + `[1] * rank(Bx) + [M] + [1] * rank(E)`. + The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` This corresponds to analyzing `x` as though it were streaming, for @@ -850,12 +856,18 @@ def windowed_mean( last half of an MCMC chain. Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` - have shape `Bi + [M] + F`, such that: - - `rank(Bx) = rank(Bi) = axis`, - - `Bi + [1] + F` broadcasts to `Bx + [N] + E`. + have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`. Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. + The shape of indices must be broadcastable with `x` unless the rank is lower + than the rank of `x`, then the shape is expanded with extra inner dimensions + to match the rank of `x`. + + In the special case where the rank of indices is one, i.e when + `rank(Bi) = rank(F) = 0`, the indices are reshaped to + `[1] * rank(Bx) + [M] + [1] * rank(E)`. + The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` This corresponds to analyzing `x` as though it were streaming, for @@ -913,17 +925,26 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): # Broadcast indices together. high_indices = high_indices + tf.zeros_like(low_indices) low_indices = low_indices + tf.zeros_like(high_indices) - indices = ps.stack([low_indices, high_indices], axis=0) + + indices_shape = ps.shape(low_indices) + if ps.rank(low_indices) < ps.rank(x): + if ps.rank(low_indices) == 1: + size = ps.size(low_indices) + bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size, + off_value=1) + else: + # we assume the first dimensions are broadcastable with `x`, + # we add trailing dimensions + extra_dims = ps.rank(x) - ps.rank(low_indices) + bc_shape = ps.concat([indices_shape, [1]*extra_dims], axis=0) + else: + bc_shape = indices_shape + + bc_shape = ps.concat([[2], bc_shape], axis=0) + indices = tf.stack([low_indices, high_indices], axis=0) + indices = ps.reshape(indices, bc_shape) x = tf.expand_dims(x, axis=0) axis += 1 - - if ps.rank(indices) != ps.rank(x) and ps.rank(indices) == 2: - # legacy usage, kept for backward compatibility - size = ps.size(indices) // 2 - bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size, - off_value=1) - bc_shape = ps.concat([[2], bc_shape[1:]], axis=0) - indices = ps.reshape(indices, bc_shape) # `take_along_axis` requires the type to be int32 indices = ps.cast(indices, dtype=tf.int32) return x, indices, axis diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index d9f4c6ed6d..1e83a2ca21 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -681,6 +681,19 @@ def test_windowed_mean_corner_cases(self): @test_util.test_all_tf_execution_regimes class WindowedStatsTest(test_util.TestCase): + + def _maybe_expand_dims_to_make_broadcastable(self, x, shape, axis): + if len(shape) > len(x.shape): + if len(x.shape) == 1: + bc_shape = np.ones(len(shape), dtype=np.int32) + bc_shape[axis] = x.shape[0] + return x.reshape(bc_shape) + else: + extra_dims = len(shape) - len(x.shape) + bc_shape = x.shape + (1,) * extra_dims + return x.reshape(bc_shape) + return x + def apply_slice_along_axis(self, func, arr, low, high, axis): """Applies `func` over slices of `arr` along `axis`. Slices intervals are specified through `low` and `high`. Support broadcasting. @@ -705,6 +718,7 @@ def apply_slice_along_axis(self, func, arr, low, high, axis): for r in range(j): out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]]) return out + def check_gaussian_windowed(self, shape, indice_shape, axis, window_func, np_func): stat_shape = np.array(shape).astype(np.int32) @@ -717,6 +731,10 @@ def check_gaussian_windowed(self, shape, indice_shape, axis, indices = rng.randint(shape[axis] + 1, size=indice_shape) indices = np.sort(indices, axis=0) low_indices, high_indices = indices[0], indices[1] + low_indices = self._maybe_expand_dims_to_make_broadcastable( + low_indices, x.shape, axis) + high_indices = self._maybe_expand_dims_to_make_broadcastable( + high_indices, x.shape, axis) a = window_func(x, low_indices=low_indices, high_indices=high_indices, axis=axis) b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices, @@ -732,20 +750,34 @@ def check_windowed(self, func, numpy_func): check_fn((64, 4, 8), (32, 4, 1), axis=0) check_fn((64, 4, 8), (32, 4, 8), axis=0) check_fn((64, 4, 8), (64, 4, 8), axis=0) + check_fn((64, 4, 8), (128, 1), axis=0) + check_fn((64, 4, 8), (32,), axis=0) + check_fn((64, 4, 8), (32, 4), axis=0) + check_fn((64, 4, 8), (64, 64, 1), axis=1) check_fn((64, 4, 8), (1, 64, 1), axis=1) check_fn((64, 4, 8), (64, 2, 8), axis=1) check_fn((64, 4, 8), (64, 4, 8), axis=1) + check_fn((64, 4, 8), (16,), axis=1) + check_fn((64, 4, 8), (1, 64), axis=1) + check_fn((64, 4, 8), (64, 4, 64), axis=2) check_fn((64, 4, 8), (1, 1, 64), axis=2) check_fn((64, 4, 8), (64, 4, 4), axis=2) check_fn((64, 4, 8), (1, 1, 4), axis=2) check_fn((64, 4, 8), (64, 4, 8), axis=2) + check_fn((64, 4, 8), (16,), axis=2) + check_fn((64, 4, 8), (1, 4), axis=2) + check_fn((64, 4, 8), (64, 4), axis=2) with self.assertRaises(Exception): # Non broadcastable shapes check_fn((64, 4, 8), (4, 1, 4), axis=2) + with self.assertRaises(Exception): + # Non broadcastable shapes + check_fn((64, 4, 8), (2, 4), axis=2) + def test_windowed_mean(self): self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean) From 4446b60dd1b448f72357ba922f205a7331c627c5 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Wed, 10 Aug 2022 10:20:23 +0200 Subject: [PATCH 05/14] Test against tensors with dynamic shapes Some `tensorflow` to `prefer_static` replacement --- .../python/stats/sample_stats.py | 36 +++++++++---------- .../python/stats/sample_stats_test.py | 13 +++++-- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 6ae297efea..174a8bb262 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -22,9 +22,9 @@ import tensorflow.compat.v2 as tf if JAX_MODE or NUMPY_MODE: - tnp = np + numpy_ops = np else: - import tensorflow.experimental.numpy as tnp + from tensorflow.python.ops import numpy_ops from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util @@ -739,13 +739,12 @@ def windowed_variance( Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. - The shape of indices must be broadcastable with `x` unless the rank is lower - than the rank of `x`, then the shape is expanded with extra inner dimensions - to match the rank of `x`. + The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. - In the special case where the rank of indices is one, i.e when - `rank(Bi) = rank(F) = 0`, the indices are reshaped to - `[1] * rank(Bx) + [M] + [1] * rank(E)`. + If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded + with extra inner dimensions to match the rank of `x`. In the special + case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, + the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -801,10 +800,10 @@ def windowed_variance( def index_for_cumulative(indices): return tf.maximum(indices - 1, 0) cum_sums = tf.cumsum(x, axis=axis) - sums = tnp.take_along_axis( + sums = numpy_ops.take_along_axis( cum_sums, index_for_cumulative(indices), axis=axis) cum_variances = cumulative_variance(x, sample_axis=axis) - variances = tnp.take_along_axis( + variances = numpy_ops.take_along_axis( cum_variances, index_for_cumulative(indices), axis=axis) # This formula is the binary accurate variance merge from [1], @@ -860,13 +859,12 @@ def windowed_mean( Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. - The shape of indices must be broadcastable with `x` unless the rank is lower - than the rank of `x`, then the shape is expanded with extra inner dimensions - to match the rank of `x`. + The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. - In the special case where the rank of indices is one, i.e when - `rank(Bi) = rank(F) = 0`, the indices are reshaped to - `[1] * rank(Bx) + [M] + [1] * rank(E)`. + If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded + with extra inner dimensions to match the rank of `x`. In the special + case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, + the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -906,7 +904,7 @@ def windowed_mean( paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32), (rank, 2)) cum_sums = ps.pad(raw_cumsum, paddings) - sums = tnp.take_along_axis(cum_sums, indices, + sums = numpy_ops.take_along_axis(cum_sums, indices, axis=axis) counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype) return tf.math.divide_no_nan(sums[1] - sums[0], counts) @@ -915,7 +913,7 @@ def windowed_mean( def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): """Common argument defaulting logic for windowed statistics.""" if high_indices is None: - high_indices = tf.range(ps.shape(x)[axis]) + 1 + high_indices = ps.range(ps.shape(x)[axis]) + 1 else: high_indices = tf.convert_to_tensor(high_indices) if low_indices is None: @@ -941,7 +939,7 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): bc_shape = indices_shape bc_shape = ps.concat([[2], bc_shape], axis=0) - indices = tf.stack([low_indices, high_indices], axis=0) + indices = ps.stack([low_indices, high_indices], axis=0) indices = ps.reshape(indices, bc_shape) x = tf.expand_dims(x, axis=0) axis += 1 diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index 1e83a2ca21..9e72a1a8a0 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -731,17 +731,26 @@ def check_gaussian_windowed(self, shape, indice_shape, axis, indices = rng.randint(shape[axis] + 1, size=indice_shape) indices = np.sort(indices, axis=0) low_indices, high_indices = indices[0], indices[1] + + tf_low_indices = self._make_dynamic_shape(low_indices) + tf_high_indices = self._make_dynamic_shape(high_indices) + tf_x = self._make_dynamic_shape(x) + + a = window_func(tf_x, low_indices=tf_low_indices, + high_indices=tf_high_indices, axis=axis) + low_indices = self._maybe_expand_dims_to_make_broadcastable( low_indices, x.shape, axis) high_indices = self._maybe_expand_dims_to_make_broadcastable( high_indices, x.shape, axis) - a = window_func(x, low_indices=low_indices, - high_indices=high_indices, axis=axis) b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices, axis=axis) b[np.isnan(b)] = 0 # We treat stats computed on empty sets as zeros self.assertAllClose(a, b) + def _make_dynamic_shape(self, x): + return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape)) + def check_windowed(self, func, numpy_func): check_fn = functools.partial(self.check_gaussian_windowed, window_func=func, np_func=numpy_func) From 56c5c16fdebcc7c7ccb5fc403f078311463b012f Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Wed, 10 Aug 2022 17:00:32 +0200 Subject: [PATCH 06/14] Fix `take_along_axis` import for jax backend --- .../python/stats/sample_stats.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 174a8bb262..0f1bc0e65a 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -21,10 +21,12 @@ import numpy as np import tensorflow.compat.v2 as tf -if JAX_MODE or NUMPY_MODE: - numpy_ops = np +if NUMPY_MODE: + take_along_axis = np.take_along_axis +elif JAX_MODE: + from jax.numpy import take_along_axis else: - from tensorflow.python.ops import numpy_ops + from tensorflow.python.ops.numpy_ops import take_along_axis from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util @@ -800,10 +802,10 @@ def windowed_variance( def index_for_cumulative(indices): return tf.maximum(indices - 1, 0) cum_sums = tf.cumsum(x, axis=axis) - sums = numpy_ops.take_along_axis( + sums = take_along_axis( cum_sums, index_for_cumulative(indices), axis=axis) cum_variances = cumulative_variance(x, sample_axis=axis) - variances = numpy_ops.take_along_axis( + variances = take_along_axis( cum_variances, index_for_cumulative(indices), axis=axis) # This formula is the binary accurate variance merge from [1], @@ -904,8 +906,7 @@ def windowed_mean( paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32), (rank, 2)) cum_sums = ps.pad(raw_cumsum, paddings) - sums = numpy_ops.take_along_axis(cum_sums, indices, - axis=axis) + sums = take_along_axis(cum_sums, indices, axis=axis) counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype) return tf.math.divide_no_nan(sums[1] - sums[0], counts) From 26f4f121a2ac4b2b14f15af38ebce944ba3f0fa8 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 26 Sep 2022 14:31:15 +0000 Subject: [PATCH 07/14] Expose `tensorflow.experimental.numpy` API to numpy and jax backends --- .../python/internal/backend/jax/rewrite.py | 2 ++ .../python/stats/sample_stats.py | 14 ++++---------- .../python/stats/sample_stats_test.py | 8 ++++---- tensorflow_probability/substrates/meta/rewrite.py | 1 + 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/tensorflow_probability/python/internal/backend/jax/rewrite.py b/tensorflow_probability/python/internal/backend/jax/rewrite.py index 68efbd20f5..f7ac15a370 100644 --- a/tensorflow_probability/python/internal/backend/jax/rewrite.py +++ b/tensorflow_probability/python/internal/backend/jax/rewrite.py @@ -45,6 +45,8 @@ def main(argv): if FLAGS.rewrite_numpy_import: contents = contents.replace('\nimport numpy as np', '\nimport numpy as onp; import jax.numpy as np') + contents = contents.replace('\nimport numpy as tnp', + '\nimport jax.numpy as tnp') else: contents = contents.replace('\nimport numpy as np', '\nimport numpy as np; onp = np') diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 0f1bc0e65a..5773574564 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -20,13 +20,7 @@ # Dependency imports import numpy as np import tensorflow.compat.v2 as tf - -if NUMPY_MODE: - take_along_axis = np.take_along_axis -elif JAX_MODE: - from jax.numpy import take_along_axis -else: - from tensorflow.python.ops.numpy_ops import take_along_axis +import tensorflow.experimental.numpy as tnp from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util @@ -802,10 +796,10 @@ def windowed_variance( def index_for_cumulative(indices): return tf.maximum(indices - 1, 0) cum_sums = tf.cumsum(x, axis=axis) - sums = take_along_axis( + sums = tnp.take_along_axis( cum_sums, index_for_cumulative(indices), axis=axis) cum_variances = cumulative_variance(x, sample_axis=axis) - variances = take_along_axis( + variances = tnp.take_along_axis( cum_variances, index_for_cumulative(indices), axis=axis) # This formula is the binary accurate variance merge from [1], @@ -906,7 +900,7 @@ def windowed_mean( paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32), (rank, 2)) cum_sums = ps.pad(raw_cumsum, paddings) - sums = take_along_axis(cum_sums, indices, axis=axis) + sums = tnp.take_along_axis(cum_sums, indices, axis=axis) counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype) return tf.math.divide_no_nan(sums[1] - sums[0], counts) diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index 9e72a1a8a0..de1e4bfb96 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -788,17 +788,17 @@ def check_windowed(self, func, numpy_func): check_fn((64, 4, 8), (2, 4), axis=2) def test_windowed_mean(self): - self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean) + self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean) def test_windowed_mean_graph(self): - func = tf.function(tfp.stats.windowed_mean) + func = tf.function(sample_stats.windowed_mean) self.check_windowed(func=func, numpy_func=np.mean) def test_windowed_variance(self): - self.check_windowed(func=tfp.stats.windowed_variance, numpy_func=np.var) + self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var) def test_windowed_variance_graph(self): - func = tf.function(tfp.stats.windowed_variance) + func = tf.function(sample_stats.windowed_variance) self.check_windowed(func=func, numpy_func=np.var) diff --git a/tensorflow_probability/substrates/meta/rewrite.py b/tensorflow_probability/substrates/meta/rewrite.py index e71a75bf4e..8cbca19144 100644 --- a/tensorflow_probability/substrates/meta/rewrite.py +++ b/tensorflow_probability/substrates/meta/rewrite.py @@ -29,6 +29,7 @@ TF_REPLACEMENTS = { 'import tensorflow ': 'from tensorflow_probability.python.internal.backend import numpy ', + 'import tensorflow.experimental.numpy as tnp': 'import numpy as tnp', 'import tensorflow.compat.v1': 'from tensorflow_probability.python.internal.backend.numpy.compat ' 'import v1', From 169f7f5e3b2ce954e119ac195cb4abe347dbaa34 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 26 Sep 2022 17:50:36 +0000 Subject: [PATCH 08/14] Rewrite `apply_slice_along_axis` using `np.vectorize` --- .../python/stats/sample_stats_test.py | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index de1e4bfb96..ce0d91b357 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -699,24 +699,26 @@ def apply_slice_along_axis(self, func, arr, low, high, axis): specified through `low` and `high`. Support broadcasting. """ np.testing.assert_equal(low.shape, high.shape) - ni, _, nk = arr.shape[:axis], arr.shape[axis], arr.shape[axis + 1:] - si, j, sk = low.shape[:axis], low.shape[axis], low.shape[axis + 1:] - mk = max(nk, sk) - mi = max(ni, si) - out = np.empty(mi + (j,) + mk) - for ki in np.ndindex(ni): - for kk in np.ndindex(mk): - ak = tuple(np.mod(kk, nk)) - ik = tuple(np.mod(kk, sk)) - ai = tuple(np.mod(ki, ni)) - ii = tuple(np.mod(ki, si)) - a_1d = arr[ai + np.s_[:, ] + ak] - out_1d = out[ki + np.s_[:, ] + kk] - low_1d = low[ii + np.s_[:, ] + ik] - high_1d = high[ii + np.s_[:, ] + ik] - - for r in range(j): - out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]]) + + def apply_func(vector, l, h): + return func(vector[l:h]) + + apply_func_1d = np.vectorize(apply_func, signature='(n), (), ()->()') + vectorized_func = np.vectorize(apply_func_1d, + signature='(n), (k), (k)->(m)') + + # Put `axis` at the innermost dimension + dims = list(range(arr.ndim)) + dims[-1] = axis + dims[axis] = arr.ndim - 1 + t_arr = np.transpose(arr, axes=dims) + t_low = np.transpose(low, axes=dims) + t_high = np.transpose(high, axes=dims) + + t_out = vectorized_func(t_arr, t_low, t_high) + + # Replace `axis` at its place + out = np.transpose(t_out, axes=dims) return out def check_gaussian_windowed(self, shape, indice_shape, axis, @@ -797,10 +799,6 @@ def test_windowed_mean_graph(self): def test_windowed_variance(self): self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var) - def test_windowed_variance_graph(self): - func = tf.function(sample_stats.windowed_variance) - self.check_windowed(func=func, numpy_func=np.var) - @test_util.test_all_tf_execution_regimes class LogAverageProbsTest(test_util.TestCase): From c90e9619e0dbb6b7ecd9587e7d72a7f732fc12c1 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 26 Sep 2022 19:38:25 +0000 Subject: [PATCH 09/14] Check for statically known rank Parametrize tests --- .../python/stats/sample_stats.py | 6 + .../python/stats/sample_stats_test.py | 104 ++++++++++-------- 2 files changed, 63 insertions(+), 47 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 5773574564..637b64024d 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -915,6 +915,12 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): low_indices = high_indices // 2 else: low_indices = tf.convert_to_tensor(low_indices) + + indices_rank = tf.get_static_value(ps.rank(low_indices)) + x_rank = tf.get_static_value(ps.rank(x)) + if indices_rank is None or x_rank is None: + raise ValueError("`indices` and `x` ranks must be statically known.") + # Broadcast indices together. high_indices = high_indices + tf.zeros_like(low_indices) low_indices = low_indices + tf.zeros_like(high_indices) diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index ce0d91b357..b7cd7dff38 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -15,10 +15,14 @@ """Tests for Sample Stats Ops.""" # Dependency imports -import functools +import itertools + import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf +from absl.testing import parameterized +from tensorflow.python.framework.errors_impl import InvalidArgumentError + from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.stats import sample_stats @@ -721,7 +725,8 @@ def apply_func(vector, l, h): out = np.transpose(t_out, axes=dims) return out - def check_gaussian_windowed(self, shape, indice_shape, axis, + + def check_gaussian_windowed_func(self, shape, indice_shape, axis, window_func, np_func): stat_shape = np.array(shape).astype(np.int32) stat_shape[axis] = 1 @@ -753,51 +758,56 @@ def check_gaussian_windowed(self, shape, indice_shape, axis, def _make_dynamic_shape(self, x): return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape)) - def check_windowed(self, func, numpy_func): - check_fn = functools.partial(self.check_gaussian_windowed, - window_func=func, np_func=numpy_func) - check_fn((64, 4, 8), (128, 1, 1), axis=0) - check_fn((64, 4, 8), (32, 1, 1), axis=0) - check_fn((64, 4, 8), (32, 4, 1), axis=0) - check_fn((64, 4, 8), (32, 4, 8), axis=0) - check_fn((64, 4, 8), (64, 4, 8), axis=0) - check_fn((64, 4, 8), (128, 1), axis=0) - check_fn((64, 4, 8), (32,), axis=0) - check_fn((64, 4, 8), (32, 4), axis=0) - - check_fn((64, 4, 8), (64, 64, 1), axis=1) - check_fn((64, 4, 8), (1, 64, 1), axis=1) - check_fn((64, 4, 8), (64, 2, 8), axis=1) - check_fn((64, 4, 8), (64, 4, 8), axis=1) - check_fn((64, 4, 8), (16,), axis=1) - check_fn((64, 4, 8), (1, 64), axis=1) - - check_fn((64, 4, 8), (64, 4, 64), axis=2) - check_fn((64, 4, 8), (1, 1, 64), axis=2) - check_fn((64, 4, 8), (64, 4, 4), axis=2) - check_fn((64, 4, 8), (1, 1, 4), axis=2) - check_fn((64, 4, 8), (64, 4, 8), axis=2) - check_fn((64, 4, 8), (16,), axis=2) - check_fn((64, 4, 8), (1, 4), axis=2) - check_fn((64, 4, 8), (64, 4), axis=2) - - with self.assertRaises(Exception): - # Non broadcastable shapes - check_fn((64, 4, 8), (4, 1, 4), axis=2) - - with self.assertRaises(Exception): - # Non broadcastable shapes - check_fn((64, 4, 8), (2, 4), axis=2) - - def test_windowed_mean(self): - self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean) - - def test_windowed_mean_graph(self): - func = tf.function(sample_stats.windowed_mean) - self.check_windowed(func=func, numpy_func=np.mean) - - def test_windowed_variance(self): - self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var) + @parameterized.named_parameters(*[( + f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, + tf_func, np_func) for a, (b, axis), (tf_func, np_func) in + itertools.product([(64, 4, 8), ], + [((128, 1, 1), 0), + ((32, 1, 1), 0), + ((32, 4, 1), 0), + ((32, 4, 8), 0), + ((64, 4, 8), 0), + ((128, 1), 0), + ((32,), 0), + ((32, 4), 0), + + ((64, 64, 1), 1), + ((1, 64, 1), 1), + ((64, 2, 8), 1), + ((64, 4, 8), 1), + ((16,), 1), + ((1, 64), 1), + + ((64, 4, 64), 2), + ((1, 1, 64), 2), + ((64, 4, 4), 2), + ((1, 1, 4), 2), + ((64, 4, 8), 2), + ((16,), 2), + ((1, 4), 2), + ((64, 4), 2)], + [ + (sample_stats.windowed_mean, np.mean), + (sample_stats.windowed_variance, np.var) + ])]) + def test_windowed(self, shape, indice_shape, axis, window_func, np_func): + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, + np_func) + + + @parameterized.named_parameters(*[( + f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, + tf_func, np_func) for a, (b, axis), (tf_func, np_func) in + itertools.product([(64, 4, 8), ], + [((4, 1, 4), 2), ((2, 4), 2)], + [(sample_stats.windowed_mean, np.mean), + (sample_stats.windowed_variance, np.var)])]) + def test_non_broadcastable_shapes(self, shape, indice_shape, axis, + window_func, np_func): + with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError), + '^shape mismatch|Incompatible shapes'): + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, + np_func) @test_util.test_all_tf_execution_regimes From 45dabfa802e9f682a34c7856fa610f5b31794845 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 26 Sep 2022 20:17:48 +0000 Subject: [PATCH 10/14] Documentation --- .../python/stats/sample_stats.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 637b64024d..b3cdeca20e 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -735,11 +735,12 @@ def windowed_variance( Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. - The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. + The shape `Bi + [1] + F` must be implicitly broadcastable with the + shape of `x`, the following implicit broadcasting rules are applied: If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded - with extra inner dimensions to match the rank of `x`. In the special - case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, + with extra inner dimensions to match the rank of `x`. + If rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`. The default windows are @@ -855,11 +856,12 @@ def windowed_mean( Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. - The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. + The shape `Bi + [1] + F` must be implicitly broadcastable with the + shape of `x`, the following implicit broadcasting rules are applied: If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded - with extra inner dimensions to match the rank of `x`. In the special - case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, + with extra inner dimensions to match the rank of `x`. + If rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`. The default windows are From df117ee0e29d3a4ce14e0c05fa77f5cb9e080da4 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 26 Sep 2022 20:50:19 +0000 Subject: [PATCH 11/14] Documentation --- tensorflow_probability/python/stats/sample_stats.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index b3cdeca20e..29ecc4394b 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -733,7 +733,8 @@ def windowed_variance( Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`. Then each element of `low_indices` and `high_indices` must be - between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. + between 0 and N+1, and the shape of the output will be + `broadcast(Bx, Bi) + [M] + broadcast(E, F)`. The shape `Bi + [1] + F` must be implicitly broadcastable with the shape of `x`, the following implicit broadcasting rules are applied: @@ -854,7 +855,8 @@ def windowed_mean( Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`. Then each element of `low_indices` and `high_indices` must be - between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. + between 0 and N+1, and the shape of the output will be + `broadcast(Bx, Bi) + [M] + broadcast(E, F)`. The shape `Bi + [1] + F` must be implicitly broadcastable with the shape of `x`, the following implicit broadcasting rules are applied: From a679e265659d737ea6d5018c4da5b65cef9905c2 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 26 Sep 2022 21:50:56 +0000 Subject: [PATCH 12/14] Style --- .../python/stats/sample_stats_test.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index b7cd7dff38..a274926848 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -725,9 +725,8 @@ def apply_func(vector, l, h): out = np.transpose(t_out, axes=dims) return out - def check_gaussian_windowed_func(self, shape, indice_shape, axis, - window_func, np_func): + window_func, np_func): stat_shape = np.array(shape).astype(np.int32) stat_shape[axis] = 1 loc = np.arange(np.prod(stat_shape)).reshape(stat_shape) @@ -759,9 +758,9 @@ def _make_dynamic_shape(self, x): return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape)) @parameterized.named_parameters(*[( - f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, - tf_func, np_func) for a, (b, axis), (tf_func, np_func) in - itertools.product([(64, 4, 8), ], + f"{np_func.__name__} shape={s} indices_shape={i} axis={axis}", s, i, axis, + tf_func, np_func) for s, (i, axis), (tf_func, np_func) in + itertools.product([(64, 4, 8)], [((128, 1, 1), 0), ((32, 1, 1), 0), ((32, 4, 1), 0), @@ -786,22 +785,19 @@ def _make_dynamic_shape(self, x): ((16,), 2), ((1, 4), 2), ((64, 4), 2)], - [ - (sample_stats.windowed_mean, np.mean), - (sample_stats.windowed_variance, np.var) - ])]) + [(sample_stats.windowed_mean, np.mean), + (sample_stats.windowed_variance, np.var)])]) def test_windowed(self, shape, indice_shape, axis, window_func, np_func): self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, np_func) - @parameterized.named_parameters(*[( - f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, - tf_func, np_func) for a, (b, axis), (tf_func, np_func) in - itertools.product([(64, 4, 8), ], - [((4, 1, 4), 2), ((2, 4), 2)], - [(sample_stats.windowed_mean, np.mean), - (sample_stats.windowed_variance, np.var)])]) + f"{np_func.__name__} shape={s} indices_shape={i} axis={axis}", s, i, axis, + tf_func, np_func) for s, (i, axis), (tf_func, np_func) in + itertools.product([(64, 4, 8)], + [((4, 1, 4), 2), ((2, 4), 2)], + [(sample_stats.windowed_mean, np.mean), + (sample_stats.windowed_variance, np.var)])]) def test_non_broadcastable_shapes(self, shape, indice_shape, axis, window_func, np_func): with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError), From 385cff7efb90dc63cd62cfe2dff923e3e1960072 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 26 Sep 2022 23:58:30 +0000 Subject: [PATCH 13/14] Notation --- .../python/stats/sample_stats.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 29ecc4394b..d430018ece 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -718,9 +718,9 @@ def windowed_variance( result[i] = variance(x[low_indices[i]:high_indices[i]]) - accurately and efficiently. To wit, if K is the size of - `low_indices` and `high_indices`, and `N` is the size of `x` along - the given `axis`, the computation takes O(K + N) work, O(log(N)) + accurately and efficiently. To wit, if `m` is the size of + `low_indices` and `high_indices`, and `n` is the size of `x` along + the given `axis`, the computation takes O(n + m) work, O(log(n)) depth (the length of the longest series of operations that are performed sequentially), and only uses O(1) TensorFlow kernel invocations. The underlying algorithm is an adaptation of the @@ -730,19 +730,19 @@ def windowed_variance( trailing-window estimators from some iterative process, such as the last half of an MCMC chain. - Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` - have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`. + Suppose `x` has shape `Bx + [n] + E`, `low_indices` and `high_indices` + have shape `Bi + [m] + F`, such that `rank(Bx) = rank(Bi) = axis`. Then each element of `low_indices` and `high_indices` must be - between 0 and N+1, and the shape of the output will be - `broadcast(Bx, Bi) + [M] + broadcast(E, F)`. + between 0 and `n+1`, and the shape of the output will be + `broadcast(Bx, Bi) + [m] + broadcast(E, F)`. The shape `Bi + [1] + F` must be implicitly broadcastable with the shape of `x`, the following implicit broadcasting rules are applied: - If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded + If `rank(Bi + [m] + F) < rank(x)`, then the indices are expanded with extra inner dimensions to match the rank of `x`. If rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, - the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`. + the indices are reshaped to `[1] * rank(Bx) + [m] + [1] * rank(E)`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -751,14 +751,14 @@ def windowed_variance( in the variance of the last half of the data at each point. Args: - x: A numeric `Tensor` holding `N` samples along the given `axis`, + x: A numeric `Tensor` holding `n` samples along the given `axis`, whose windowed variances are desired. low_indices: An integer `Tensor` defining the lower boundary (inclusive) of each window. Default: elementwise half of `high_indices`. high_indices: An integer `Tensor` defining the upper boundary (exclusive) of each window. Must be broadcast-compatible with - `low_indices`. Default: `tf.range(1, N+1)`, i.e., N windows + `low_indices`. Default: `tf.range(1, n+1)`, i.e., n windows that each end in the corresponding datum from `x` (inclusive)`. axis: Scalar `Tensor` designating the axis holding samples. This is the axis of `x` along which we take windows, and therefore @@ -842,9 +842,9 @@ def windowed_mean( result[i] = mean(x[low_indices[i]:high_indices[i]]) - efficiently. To wit, if K is the size of `low_indices` and - `high_indices`, and `N` is the size of `x` along the given `axis`, - the computation takes O(K + N) work, O(log(N)) depth (the length of + efficiently. To wit, if `m` is the size of `low_indices` and + `high_indices`, and `n` is the size of `x` along the given `axis`, + the computation takes O(m + n) work, O(log(n)) depth (the length of the longest series of operations that are performed sequentially), and only uses O(1) TensorFlow kernel invocations. @@ -852,19 +852,19 @@ def windowed_mean( trailing-window estimators from some iterative process, such as the last half of an MCMC chain. - Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` - have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`. + Suppose `x` has shape `Bx + [n] + E`, `low_indices` and `high_indices` + have shape `Bi + [m] + F`, such that `rank(Bx) = rank(Bi) = axis`. Then each element of `low_indices` and `high_indices` must be - between 0 and N+1, and the shape of the output will be - `broadcast(Bx, Bi) + [M] + broadcast(E, F)`. + between 0 and `n+1`, and the shape of the output will be + `broadcast(Bx, Bi) + [m] + broadcast(E, F)`. The shape `Bi + [1] + F` must be implicitly broadcastable with the shape of `x`, the following implicit broadcasting rules are applied: - If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded + If `rank(Bi + [m] + F) < rank(x)`, then the indices are expanded with extra inner dimensions to match the rank of `x`. If rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, - the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`. + the indices are reshaped to `[1] * rank(Bx) + [m] + [1] * rank(E)`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -873,14 +873,14 @@ def windowed_mean( in the variance of the last half of the data at each point. Args: - x: A numeric `Tensor` holding `N` samples along the given `axis`, + x: A numeric `Tensor` holding `n` samples along the given `axis`, whose windowed means are desired. low_indices: An integer `Tensor` defining the lower boundary (inclusive) of each window. Default: elementwise half of `high_indices`. high_indices: An integer `Tensor` defining the upper boundary (exclusive) of each window. Must be broadcast-compatible with - `low_indices`. Default: `tf.range(1, N+1)`, i.e., N windows + `low_indices`. Default: `tf.range(1, n+1)`, i.e., n windows that each end in the corresponding datum from `x` (inclusive). axis: Scalar `Tensor` designating the axis holding samples. This is the axis of `x` along which we take windows, and therefore From 92d7143aad048bbb46e120ebeb46724b74907948 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Sun, 9 Oct 2022 10:46:47 +0000 Subject: [PATCH 14/14] Remove extra JAX_MODE and NUMPY_MODE setting --- tensorflow_probability/python/stats/sample_stats.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index d430018ece..fecf5c2547 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -14,9 +14,6 @@ # ============================================================================ """Functions for computing statistics of samples.""" -JAX_MODE = False -NUMPY_MODE = False - # Dependency imports import numpy as np import tensorflow.compat.v2 as tf