diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 174a8bb262..0f1bc0e65a 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -21,10 +21,12 @@ import numpy as np import tensorflow.compat.v2 as tf -if JAX_MODE or NUMPY_MODE: - numpy_ops = np +if NUMPY_MODE: + take_along_axis = np.take_along_axis +elif JAX_MODE: + from jax.numpy import take_along_axis else: - from tensorflow.python.ops import numpy_ops + from tensorflow.python.ops.numpy_ops import take_along_axis from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util @@ -800,10 +802,10 @@ def windowed_variance( def index_for_cumulative(indices): return tf.maximum(indices - 1, 0) cum_sums = tf.cumsum(x, axis=axis) - sums = numpy_ops.take_along_axis( + sums = take_along_axis( cum_sums, index_for_cumulative(indices), axis=axis) cum_variances = cumulative_variance(x, sample_axis=axis) - variances = numpy_ops.take_along_axis( + variances = take_along_axis( cum_variances, index_for_cumulative(indices), axis=axis) # This formula is the binary accurate variance merge from [1], @@ -904,8 +906,7 @@ def windowed_mean( paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32), (rank, 2)) cum_sums = ps.pad(raw_cumsum, paddings) - sums = numpy_ops.take_along_axis(cum_sums, indices, - axis=axis) + sums = take_along_axis(cum_sums, indices, axis=axis) counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype) return tf.math.divide_no_nan(sums[1] - sums[0], counts)