-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Brier metrics to motion forecasting evaluation module #44
Conversation
|
||
Raises: | ||
ValueError: If the number of forecasted trajectories and probabilities don't match. | ||
ValueError: If normalize=False and `forecast_probabilities` contains values outside of the range [0, 1]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since these are "probabilities", we should raise the out of range of error regardless of normalize flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was flip-flopping on whether to call this arg weights
, likelihoods
, or probabilities
, but settled on probabilities
because I though it was most intuitive.
That being said, I can see use-cases where users might want to directly pass in weights and have them normalized, so it might be nice to perform the range check for sanity afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we call them weights
in that case? It isn't evident from the function name or docstring that weights are acceptable.
@@ -56,3 +55,85 @@ def compute_is_missed_prediction( | |||
fde = compute_fde(forecasted_trajectories, gt_trajectory) | |||
is_missed_prediction = fde > miss_threshold_m # type: ignore | |||
return is_missed_prediction | |||
|
|||
|
|||
def compute_brier_ade( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that all the metric functions here work on a single sample. For a batch, we might have to call these functions for individual samples. Wouldn't that be slower because no batch computation will be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point - we can certainly convert all these metric functions into batched equivalents in a follow-up PR. :-)
# Validate that all forecast probabilities are in the range [0, 1] | ||
if np.logical_or(forecast_probabilities < 0.0, forecast_probabilities > 1.0).any(): | ||
raise ValueError("At least one forecast probability falls outside the range [0, 1].") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two functions differ by just 1 line. Maybe move most of the stuff to a common function.
# Compute FDE with Brier score component | ||
fde_vector = compute_fde(forecasted_trajectories, gt_trajectory) | ||
brier_score = np.square((1 - forecast_probabilities)) | ||
print(fde_vector) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove print?
uniform_probabilities_k6: NDArrayFloat = np.ones((6,)) / 6 | ||
confident_probabilities_k6: NDArrayFloat = np.array([0.9, 0.02, 0.02, 0.02, 0.02, 0.02]) | ||
non_normalized_probabilities_k6: NDArrayFloat = confident_probabilities_k6 * 100 | ||
wrong_shape_probabilities_k6: NDArrayFloat = np.ones((5,)) / 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] One unit test for out of range probs
491cd70
to
e4da62a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @jagjeet-singh! Updated the PR to address feedback.
|
||
Raises: | ||
ValueError: If the number of forecasted trajectories and probabilities don't match. | ||
ValueError: If normalize=False and `forecast_probabilities` contains values outside of the range [0, 1]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was flip-flopping on whether to call this arg weights
, likelihoods
, or probabilities
, but settled on probabilities
because I though it was most intuitive.
That being said, I can see use-cases where users might want to directly pass in weights and have them normalized, so it might be nice to perform the range check for sanity afterwards.
@@ -56,3 +55,85 @@ def compute_is_missed_prediction( | |||
fde = compute_fde(forecasted_trajectories, gt_trajectory) | |||
is_missed_prediction = fde > miss_threshold_m # type: ignore | |||
return is_missed_prediction | |||
|
|||
|
|||
def compute_brier_ade( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point - we can certainly convert all these metric functions into batched equivalents in a follow-up PR. :-)
ce28356
to
eb6d202
Compare
e4da62a
to
3384cb1
Compare
PR Summary
This PR adds Brier score-based variants of ADE and FDE to the motion forecasting evaluation module.
These metrics are implemented in an identical way to their counterparts in the AV1 repo and will be used as scoring metrics in the AV2 MF challenge.
Testing
All functions added in this PR have been unit tested.
In order to ensure this PR works as intended, it is:
Compliance with Standards
As the author, I certify that this PR conforms to the following standards: