-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable batch support for windowed_mean|variance
#1600
base: main
Are you sure you want to change the base?
Conversation
@axch I made changes in the code you authored, could you kindly have a look at this PR? |
@nicolaspi thanks for the contribution! I am no longer an active maintainer of TFP, so I'm not really in a position to review your PR in detail (@jburnim please suggest someone?). On a quick look, though, I see a couple potential code style issues:
|
Thanks for your feedback!
We need specifically the
There is two motivations for this case. First, for backward compatibility, it is equivalent to the legacy non batched usage. Second, it is the only case I can think of where the broadcast is unambiguous when
In any case, I modified the unit tests to test against non static shapes.
I made usage of |
I'll take a look at this. |
# Dependency imports | ||
import numpy as np | ||
import tensorflow.compat.v2 as tf | ||
|
||
if NUMPY_MODE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to do something different about take_along_axis.
- (preferred) Somehow rewrite the logic using
tf.gather
/tf.gather_nd
- Expose tf.experimental.numpy.take_along_axis in https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/internal/backend/numpy
As is, this is problematic since we really dislike using JAX_/NUMPY_MODE in library code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review!
- I don't feel comfortable rewriting
take_along_axis
as it would duplicate already existing logics, I feel like it would produce unnecessary maintenance burden. - What about mapping
tensorflow.experimental.numpy
tonumpy
andjax.numpy
backends?
must be between 0 and N+1, and the shape of the output will be | ||
`Bx + [M] + E`. Batch shape in the indices is not currently supported. | ||
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` | ||
have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is F? Why isn't it a scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check my comment below.
|
||
The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. | ||
|
||
If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this paragraph adds anything, it's just an implementation detail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We specify the implicit rules we uses for broadcasting. I updated the formulation.
Then each element of `low_indices` and `high_indices` must be | ||
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. | ||
|
||
The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This contradicts the next paragraph, no?
In general, consider the non-batched version of this:
x shape: [N] + E
idx shape: [M]
output shape: [M] + E
The batching would introduce a batch dimension on the left of those shapes:
x shape: Bx + [N] + E
idx shape: Bi + [M]
output shape: broadcast(Bx, Bi) + [M] + E
Thus, the only broadcasting requirements are that Bx and Bi broadcast. I don't know where F came from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This contradicts the next paragraph, no?
Yes, I reformulated.
The batching would introduce a batch dimension on the left of those shapes:
Thus, the only broadcasting requirements are that Bx and Bi broadcast. I don't know where F came from.
Maybe the term 'batch' is not proper. This contribution adds the possibility to have the more general case where
idx shape is Bi + [M] + F
. F could be seen as 'inner batch dimensions', but here 'batch' carries a different semantic than the standard machine learning one where it is represented by outer dims.
@test_util.test_all_tf_execution_regimes | ||
class WindowedStatsTest(test_util.TestCase): | ||
|
||
def _maybe_expand_dims_to_make_broadcastable(self, x, shape, axis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two functions are as complex as the thing we're testing. Is there any way we can write this via np.vectorize?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored using np.vectorize, but I am not sure it is easier to read.
Add test cases
Some `tensorflow` to `prefer_static` replacement
Parametrize tests
4002d8b
to
c90e961
Compare
Hi @SiegeLordEx, I have assessed your comments, can you have a look? Thanks |
This PR makes functions
windowed_mean
andwindowed_variance
to accept indices with batch dimensions.Example:
Now gives:
Was previously failing with: