Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable batch support for windowed_mean|variance #1600

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

nicolaspi
Copy link

This PR makes functions windowed_mean and windowed_variance to accept indices with batch dimensions.

Example:

x = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype=np.float32)
low_indices = [[0, 0, 0], [1, 0, 0], [2, 2, 0]]
high_indices = [[3, 3, 3], [1, 2, 3], [3, 2, 1]]
tfp.stats.windowed_mean(x, low_indices=low_indices, high_indices=high_indices, axis=1)

Now gives:

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[2. , 2. , 2. ],
       [0. , 1.5, 2. ],
       [3. , 0. , 1. ]], dtype=float32)>

Was previously failing with:

tensorflow.python.framework.errors_impl.InvalidArgumentError: required broadcastable shapes [Op:SelectV2]

@nicolaspi
Copy link
Author

@axch I made changes in the code you authored, could you kindly have a look at this PR?
Thanks

@axch
Copy link
Contributor

axch commented Aug 9, 2022

@nicolaspi thanks for the contribution! I am no longer an active maintainer of TFP, so I'm not really in a position to review your PR in detail (@jburnim please suggest someone?). On a quick look, though, I see a couple potential code style issues:

  • Do we need the dependency on tf.experimental.numpy?
  • Do we need the special case for rank-1 indices? Could we define a more uniform behavior instead?
  • I'm guessing some of the shape munging already has relevant helpers defined elsewhere in TFP, but I don't remember off-hand
  • TFP generally tries to handle static and dynamic TF shapes uniformly using prefer_static and tensorshape_util.

@nicolaspi
Copy link
Author

Thanks for your feedback!

  • Do we need the dependency on tf.experimental.numpy?

We need specifically the take_along_axis function that allow to gather the slices along each batch dimensions. I replaced the 'experimental' import path with from tensorflow.python.ops import numpy_ops.

  • Do we need the special case for rank-1 indices? Could we define a more uniform behavior instead?

There is two motivations for this case. First, for backward compatibility, it is equivalent to the legacy non batched usage. Second, it is the only case I can think of where the broadcast is unambiguous when rank(indices) < rank(x).

  • I'm guessing some of the shape munging already has relevant helpers defined elsewhere in TFP, but I don't remember off-hand

In any case, I modified the unit tests to test against non static shapes.

I made usage of prefer_static whenever possible.

@nicolaspi
Copy link
Author

@jburnim can you please suggest a reviewer?
CC @axch

@SiegeLordEx
Copy link
Member

I'll take a look at this.

@SiegeLordEx SiegeLordEx self-assigned this Sep 20, 2022
@SiegeLordEx SiegeLordEx self-requested a review September 20, 2022 07:41
@SiegeLordEx SiegeLordEx removed their assignment Sep 20, 2022
# Dependency imports
import numpy as np
import tensorflow.compat.v2 as tf

if NUMPY_MODE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to do something different about take_along_axis.

  1. (preferred) Somehow rewrite the logic using tf.gather/tf.gather_nd
  2. Expose tf.experimental.numpy.take_along_axis in https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/internal/backend/numpy

As is, this is problematic since we really dislike using JAX_/NUMPY_MODE in library code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

  1. I don't feel comfortable rewriting take_along_axis as it would duplicate already existing logics, I feel like it would produce unnecessary maintenance burden.
  2. What about mapping tensorflow.experimental.numpy to numpy and jax.numpy backends?

tensorflow_probability/python/stats/sample_stats.py Outdated Show resolved Hide resolved
must be between 0 and N+1, and the shape of the output will be
`Bx + [M] + E`. Batch shape in the indices is not currently supported.
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is F? Why isn't it a scalar?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check my comment below.


The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.

If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this paragraph adds anything, it's just an implementation detail.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We specify the implicit rules we uses for broadcasting. I updated the formulation.

Then each element of `low_indices` and `high_indices` must be
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.

The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This contradicts the next paragraph, no?

In general, consider the non-batched version of this:

x shape: [N] + E
idx shape: [M]
output shape: [M] + E

The batching would introduce a batch dimension on the left of those shapes:

x shape: Bx + [N] + E
idx shape: Bi + [M]
output shape: broadcast(Bx, Bi) + [M] + E

Thus, the only broadcasting requirements are that Bx and Bi broadcast. I don't know where F came from.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This contradicts the next paragraph, no?

Yes, I reformulated.

The batching would introduce a batch dimension on the left of those shapes:
Thus, the only broadcasting requirements are that Bx and Bi broadcast. I don't know where F came from.

Maybe the term 'batch' is not proper. This contribution adds the possibility to have the more general case where
idx shape is Bi + [M] + F. F could be seen as 'inner batch dimensions', but here 'batch' carries a different semantic than the standard machine learning one where it is represented by outer dims.

tensorflow_probability/python/stats/sample_stats_test.py Outdated Show resolved Hide resolved
tensorflow_probability/python/stats/sample_stats_test.py Outdated Show resolved Hide resolved
tensorflow_probability/python/stats/sample_stats_test.py Outdated Show resolved Hide resolved
@test_util.test_all_tf_execution_regimes
class WindowedStatsTest(test_util.TestCase):

def _maybe_expand_dims_to_make_broadcastable(self, x, shape, axis):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two functions are as complex as the thing we're testing. Is there any way we can write this via np.vectorize?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored using np.vectorize, but I am not sure it is easier to read.

@nicolaspi
Copy link
Author

Hi @SiegeLordEx, I have assessed your comments, can you have a look? Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants