Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable batch support for windowed_mean|variance #1600

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions tensorflow_probability/python/internal/backend/jax/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def main(argv):
if FLAGS.rewrite_numpy_import:
contents = contents.replace('\nimport numpy as np',
'\nimport numpy as onp; import jax.numpy as np')
contents = contents.replace('\nimport numpy as tnp',
'\nimport jax.numpy as tnp')
else:
contents = contents.replace('\nimport numpy as np',
'\nimport numpy as np; onp = np')
Expand Down
170 changes: 94 additions & 76 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Dependency imports
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow.experimental.numpy as tnp

from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import distribution_util
Expand Down Expand Up @@ -694,8 +695,8 @@ def cumulative_variance(x, sample_axis=0, name=None):
excl_counts = tf.reshape(tf.range(size, dtype=x.dtype), shape=counts_shp)
incl_counts = excl_counts + 1
excl_sums = tf.cumsum(x, axis=sample_axis, exclusive=True)
discrepancies = (excl_sums / excl_counts - x)**2
discrepancies = tf.where(excl_counts == 0, x**2, discrepancies)
discrepancies = tf.math.square(excl_sums / excl_counts - x)
discrepancies = tf.where(excl_counts == 0, tf.math.square(x), discrepancies)
adjustments = excl_counts / incl_counts
# The zeroth item's residual contribution is 0, because it has no
# other items to vary from. The preceding expressions, however,
Expand All @@ -712,11 +713,11 @@ def windowed_variance(

Computes variances among data in the Tensor `x` along the given windows:

result[i] = variance(x[low_indices[i]:high_indices[i]+1])
result[i] = variance(x[low_indices[i]:high_indices[i]])

accurately and efficiently. To wit, if K is the size of
`low_indices` and `high_indices`, and `N` is the size of `x` along
the given `axis`, the computation takes O(K + N) work, O(log(N))
accurately and efficiently. To wit, if `m` is the size of
`low_indices` and `high_indices`, and `n` is the size of `x` along
the given `axis`, the computation takes O(n + m) work, O(log(n))
depth (the length of the longest series of operations that are
performed sequentially), and only uses O(1) TensorFlow kernel
invocations. The underlying algorithm is an adaptation of the
Expand All @@ -726,11 +727,19 @@ def windowed_variance(
trailing-window estimators from some iterative process, such as the
last half of an MCMC chain.

Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
rank `axis`, and `low_indices` and `high_indices` broadcast to shape
`[M]`. Then each element of `low_indices` and `high_indices`
must be between 0 and N+1, and the shape of the output will be
`Bx + [M] + E`. Batch shape in the indices is not currently supported.
Suppose `x` has shape `Bx + [n] + E`, `low_indices` and `high_indices`
have shape `Bi + [m] + F`, such that `rank(Bx) = rank(Bi) = axis`.
Then each element of `low_indices` and `high_indices` must be
between 0 and `n+1`, and the shape of the output will be
`broadcast(Bx, Bi) + [m] + broadcast(E, F)`.

The shape `Bi + [1] + F` must be implicitly broadcastable with the
shape of `x`, the following implicit broadcasting rules are applied:

If `rank(Bi + [m] + F) < rank(x)`, then the indices are expanded
with extra inner dimensions to match the rank of `x`.
If rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
the indices are reshaped to `[1] * rank(Bx) + [m] + [1] * rank(E)`.

The default windows are
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
Expand All @@ -739,14 +748,14 @@ def windowed_variance(
in the variance of the last half of the data at each point.

Args:
x: A numeric `Tensor` holding `N` samples along the given `axis`,
x: A numeric `Tensor` holding `n` samples along the given `axis`,
whose windowed variances are desired.
low_indices: An integer `Tensor` defining the lower boundary
(inclusive) of each window. Default: elementwise half of
`high_indices`.
high_indices: An integer `Tensor` defining the upper boundary
(exclusive) of each window. Must be broadcast-compatible with
`low_indices`. Default: `tf.range(1, N+1)`, i.e., N windows
`low_indices`. Default: `tf.range(1, n+1)`, i.e., n windows
that each end in the corresponding datum from `x` (inclusive)`.
axis: Scalar `Tensor` designating the axis holding samples. This
is the axis of `x` along which we take windows, and therefore
Expand All @@ -769,7 +778,7 @@ def windowed_variance(
"""
with tf.name_scope(name or 'windowed_variance'):
x = tf.convert_to_tensor(x)
low_indices, high_indices, low_counts, high_counts = _prepare_window_args(
x, indices, axis = _prepare_window_args(
x, low_indices, high_indices, axis)

# We have a problem with indexing: the standard convention demands
Expand All @@ -786,15 +795,11 @@ def windowed_variance(
def index_for_cumulative(indices):
return tf.maximum(indices - 1, 0)
cum_sums = tf.cumsum(x, axis=axis)
low_sums = tf.gather(
cum_sums, index_for_cumulative(low_indices), axis=axis)
high_sums = tf.gather(
cum_sums, index_for_cumulative(high_indices), axis=axis)
sums = tnp.take_along_axis(
cum_sums, index_for_cumulative(indices), axis=axis)
cum_variances = cumulative_variance(x, sample_axis=axis)
low_variances = tf.gather(
cum_variances, index_for_cumulative(low_indices), axis=axis)
high_variances = tf.gather(
cum_variances, index_for_cumulative(high_indices), axis=axis)
variances = tnp.take_along_axis(
cum_variances, index_for_cumulative(indices), axis=axis)

# This formula is the binary accurate variance merge from [1],
# adapted to subtract and batched across the indexed counts, sums,
Expand All @@ -812,15 +817,18 @@ def index_for_cumulative(indices):
# This formula can also be read as implementing the above variance
# computation by "unioning" A u B with a notional "negative B"
# multiset.
counts = high_counts - low_counts # |A|
discrepancies = (
_safe_average(high_sums, high_counts) -
_safe_average(low_sums, low_counts))**2 # (mean(A u B) - mean(B))**2
adjustments = high_counts * (-low_counts) / counts # |A u B| * -|B| / |A|
residuals = (high_variances * high_counts -
low_variances * low_counts +
bounds = ps.cast(indices, sums.dtype)
counts = bounds[1] - bounds[0] # |A|
sum_averages = tf.math.divide_no_nan(sums, bounds)
# (mean(A u B) - mean(B))**2
discrepancies = tf.square(sum_averages[1] - sum_averages[0])
# |A u B| * -|B| / |A|
adjustments = tf.math.divide_no_nan(bounds[1] * (-bounds[0]), counts)
variances_scaled = variances * bounds
residuals = (variances_scaled[1] -
variances_scaled[0] +
adjustments * discrepancies)
return _safe_average(residuals, counts)
return tf.math.divide_no_nan(residuals, counts)


def windowed_mean(
Expand All @@ -829,23 +837,31 @@ def windowed_mean(

Computes means among data in the Tensor `x` along the given windows:

result[i] = mean(x[low_indices[i]:high_indices[i]+1])
result[i] = mean(x[low_indices[i]:high_indices[i]])

efficiently. To wit, if K is the size of `low_indices` and
`high_indices`, and `N` is the size of `x` along the given `axis`,
the computation takes O(K + N) work, O(log(N)) depth (the length of
efficiently. To wit, if `m` is the size of `low_indices` and
`high_indices`, and `n` is the size of `x` along the given `axis`,
the computation takes O(m + n) work, O(log(n)) depth (the length of
the longest series of operations that are performed sequentially),
and only uses O(1) TensorFlow kernel invocations.

This function can be useful for assessing the behavior over time of
trailing-window estimators from some iterative process, such as the
last half of an MCMC chain.

Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
rank `axis`, and `low_indices` and `high_indices` broadcast to shape
`[M]`. Then each element of `low_indices` and `high_indices`
must be between 0 and N+1, and the shape of the output will be
`Bx + [M] + E`. Batch shape in the indices is not currently supported.
Suppose `x` has shape `Bx + [n] + E`, `low_indices` and `high_indices`
have shape `Bi + [m] + F`, such that `rank(Bx) = rank(Bi) = axis`.
Then each element of `low_indices` and `high_indices` must be
between 0 and `n+1`, and the shape of the output will be
`broadcast(Bx, Bi) + [m] + broadcast(E, F)`.

The shape `Bi + [1] + F` must be implicitly broadcastable with the
shape of `x`, the following implicit broadcasting rules are applied:

If `rank(Bi + [m] + F) < rank(x)`, then the indices are expanded
with extra inner dimensions to match the rank of `x`.
If rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
the indices are reshaped to `[1] * rank(Bx) + [m] + [1] * rank(E)`.

The default windows are
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
Expand All @@ -854,14 +870,14 @@ def windowed_mean(
in the variance of the last half of the data at each point.

Args:
x: A numeric `Tensor` holding `N` samples along the given `axis`,
x: A numeric `Tensor` holding `n` samples along the given `axis`,
whose windowed means are desired.
low_indices: An integer `Tensor` defining the lower boundary
(inclusive) of each window. Default: elementwise half of
`high_indices`.
high_indices: An integer `Tensor` defining the upper boundary
(exclusive) of each window. Must be broadcast-compatible with
`low_indices`. Default: `tf.range(1, N+1)`, i.e., N windows
`low_indices`. Default: `tf.range(1, n+1)`, i.e., n windows
that each end in the corresponding datum from `x` (inclusive).
axis: Scalar `Tensor` designating the axis holding samples. This
is the axis of `x` along which we take windows, and therefore
Expand All @@ -878,58 +894,60 @@ def windowed_mean(
"""
with tf.name_scope(name or 'windowed_mean'):
x = tf.convert_to_tensor(x)
low_indices, high_indices, low_counts, high_counts = _prepare_window_args(
x, low_indices, high_indices, axis)
x, indices, axis = _prepare_window_args(x, low_indices, high_indices, axis)

raw_cumsum = tf.cumsum(x, axis=axis)
cum_sums = tf.concat(
[tf.zeros_like(tf.gather(raw_cumsum, [0], axis=axis)), raw_cumsum],
axis=axis)
low_sums = tf.gather(cum_sums, low_indices, axis=axis)
high_sums = tf.gather(cum_sums, high_indices, axis=axis)

counts = high_counts - low_counts
return _safe_average(high_sums - low_sums, counts)
rank = ps.rank(x)
paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
(rank, 2))
cum_sums = ps.pad(raw_cumsum, paddings)
sums = tnp.take_along_axis(cum_sums, indices, axis=axis)
counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
return tf.math.divide_no_nan(sums[1] - sums[0], counts)


def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
"""Common argument defaulting logic for windowed statistics."""
if high_indices is None:
high_indices = tf.range(ps.shape(x)[axis]) + 1
high_indices = ps.range(ps.shape(x)[axis]) + 1
else:
high_indices = tf.convert_to_tensor(high_indices)
if low_indices is None:
low_indices = high_indices // 2
else:
low_indices = tf.convert_to_tensor(low_indices)

indices_rank = tf.get_static_value(ps.rank(low_indices))
x_rank = tf.get_static_value(ps.rank(x))
if indices_rank is None or x_rank is None:
raise ValueError("`indices` and `x` ranks must be statically known.")

# Broadcast indices together.
high_indices = high_indices + tf.zeros_like(low_indices)
low_indices = low_indices + tf.zeros_like(high_indices)

# TODO(axch): Support batch low and high indices. That would
# complicate this shape munging (though tf.gather should work
# fine).

# We want to place `low_counts` and `high_counts` at the `axis`
# position, so we reshape them to shape `[1, 1, ..., 1, N, 1, ...,
# 1]`, where the `N` is at `axis`. The `counts_shp`, below,
# is this shape.
size = ps.size(high_indices)
counts_shp = ps.one_hot(
axis, depth=ps.rank(x), on_value=size, off_value=1)

low_counts = tf.reshape(tf.cast(low_indices, dtype=x.dtype),
shape=counts_shp)
high_counts = tf.reshape(tf.cast(high_indices, dtype=x.dtype),
shape=counts_shp)
return low_indices, high_indices, low_counts, high_counts


def _safe_average(totals, counts):
# This tf.where protects `totals` from getting a gradient signal
# when `counts` is 0.
safe_totals = tf.where(~tf.equal(counts, 0), totals, 0)
return tf.where(~tf.equal(counts, 0), safe_totals / counts, 0)
indices_shape = ps.shape(low_indices)
if ps.rank(low_indices) < ps.rank(x):
nicolaspi marked this conversation as resolved.
Show resolved Hide resolved
if ps.rank(low_indices) == 1:
size = ps.size(low_indices)
bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size,
off_value=1)
else:
# we assume the first dimensions are broadcastable with `x`,
# we add trailing dimensions
extra_dims = ps.rank(x) - ps.rank(low_indices)
bc_shape = ps.concat([indices_shape, [1]*extra_dims], axis=0)
else:
bc_shape = indices_shape

bc_shape = ps.concat([[2], bc_shape], axis=0)
indices = ps.stack([low_indices, high_indices], axis=0)
indices = ps.reshape(indices, bc_shape)
x = tf.expand_dims(x, axis=0)
axis += 1
# `take_along_axis` requires the type to be int32
indices = ps.cast(indices, dtype=tf.int32)
return x, indices, axis


def log_average_probs(logits, sample_axis=0, event_axis=None, keepdims=False,
Expand Down
Loading