Skip to content

Commit

Permalink
Fix take_along_axis import for jax backend
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaspi committed Aug 10, 2022
1 parent 8d20563 commit 4002d8b
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import numpy as np
import tensorflow.compat.v2 as tf

if JAX_MODE or NUMPY_MODE:
numpy_ops = np
if NUMPY_MODE:
take_along_axis = np.take_along_axis
elif JAX_MODE:
from jax.numpy import take_along_axis
else:
from tensorflow.python.ops import numpy_ops
from tensorflow.python.ops.numpy_ops import take_along_axis

from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import distribution_util
Expand Down Expand Up @@ -800,10 +802,10 @@ def windowed_variance(
def index_for_cumulative(indices):
return tf.maximum(indices - 1, 0)
cum_sums = tf.cumsum(x, axis=axis)
sums = numpy_ops.take_along_axis(
sums = take_along_axis(
cum_sums, index_for_cumulative(indices), axis=axis)
cum_variances = cumulative_variance(x, sample_axis=axis)
variances = numpy_ops.take_along_axis(
variances = take_along_axis(
cum_variances, index_for_cumulative(indices), axis=axis)

# This formula is the binary accurate variance merge from [1],
Expand Down Expand Up @@ -904,8 +906,7 @@ def windowed_mean(
paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
(rank, 2))
cum_sums = ps.pad(raw_cumsum, paddings)
sums = numpy_ops.take_along_axis(cum_sums, indices,
axis=axis)
sums = take_along_axis(cum_sums, indices, axis=axis)
counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
return tf.math.divide_no_nan(sums[1] - sums[0], counts)

Expand Down

0 comments on commit 4002d8b

Please sign in to comment.