Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMSE option to MSE code #249

Merged
merged 8 commits into from
May 19, 2021
Merged

Add RMSE option to MSE code #249

merged 8 commits into from
May 19, 2021

Conversation

johannespitz
Copy link
Contributor

@johannespitz johannespitz commented May 14, 2021

What does this PR do?

Fixes #250 .

PR review

I'm open for any kind of feedback. This is a very simple fix and I think it might be helpful for others to have this functionality.

@SkafteNicki
Copy link
Member

Hi @johannespitz,
As RMSE is just sqrt(MSE) the implementation needs to happen in the compute function in instead of the update. Essentially, it should be:

return sum_squared_error / n_obs if squared else torch.sqrt(sum_squared_error / n_obs)

instead of

return sum_squared_error / n_obs

@maximsch2
Copy link
Contributor

Should this be just a CompositeMetric(MSE, torch.sqrt)?

@SkafteNicki SkafteNicki added the enhancement New feature or request label May 16, 2021
@SkafteNicki SkafteNicki added this to the v0.4 milestone May 16, 2021
@johannespitz
Copy link
Contributor Author

johannespitz commented May 16, 2021

Hi,

unfortunately RMSE is not simply the square root of the MSE. Therefore, one can not compose the metric with some arithmetic.

import torch

targets = torch.tensor([1.0, 2, 3, 4, -1, -2])
preditions = torch.tensor([0, 1, 5, 2, 0, -2])

lenght = targets.shape[0]

mse = ((targets - preditions) ** 2).sum()
rmse = (((targets - preditions) ** 2) ** 0.5).sum()

print(f"mse {mse / lenght}")
print(f"rmse {rmse / lenght}")
print(f"mse**0.5 {mse ** 0.5 / lenght}")

The square root needs to be taken before summing up the individual errors.

@maximsch2
Copy link
Contributor

Is that actually true? Wikipedia gives RMSE = sqrt(MSE): https://en.wikipedia.org/wiki/Root-mean-square_deviation

@johannespitz
Copy link
Contributor Author

johannespitz commented May 17, 2021

Oh, my bad. scikit-learn agrees with that definition.
In this case, I agree that there is no need to modify the existing code to implement RSME.
I still believe that the metric I implemented is more useful for anyone working with physical distances, but I understand 100% if you prefer not to clutter your API with such (apparently niche) additions.
Therefore, please feel free to close the PR and the issue.

In case you want to pull the addition I'd suggest we call it the Average Euclidean Distance.
https://www.quora.com/What-is-the-difference-between-RMSE-and-Average-Euclidean-Distance

@SkafteNicki
Copy link
Member

I am actually in favour of adding this argument, as RMSE also is considered a standard machine learning algorithm.
The reason I think we should not just do CompositionalMetric(MSE, torch.sqrt) as @maximsch2 proposes is that CompositionalMetric is really not meant to be used by users (it is not in the documentation) but more as a tool to define metric arithmetics.
@johannespitz would you still be up for making the required changes?

@johannespitz
Copy link
Contributor Author

Sorry, about all those commits. But I have tested it, and now I'm getting the expected result.
However, torchmetrics.MeanSquaredError() ** 0.5 is documented and results in the same output...

@codecov
Copy link

codecov bot commented May 17, 2021

Codecov Report

Merging #249 (362dd2c) into master (dfc0895) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #249      +/-   ##
==========================================
+ Coverage   96.81%   96.83%   +0.01%     
==========================================
  Files          92      184      +92     
  Lines        3012     6026    +3014     
==========================================
+ Hits         2916     5835    +2919     
- Misses         96      191      +95     
Flag Coverage Δ
Linux 78.95% <100.00%> (+<0.01%) ⬆️
Windows 78.95% <100.00%> (+<0.01%) ⬆️
cpu 96.81% <100.00%> (+<0.01%) ⬆️
gpu 96.84% <ø> (?)
macOS 96.81% <100.00%> (+<0.01%) ⬆️
pytest 96.83% <100.00%> (+0.01%) ⬆️
python3.6 95.75% <100.00%> (+<0.01%) ⬆️
python3.8 96.81% <100.00%> (+<0.01%) ⬆️
python3.9 96.71% <100.00%> (+<0.01%) ⬆️
torch1.3.1 95.75% <100.00%> (+<0.01%) ⬆️
torch1.4.0 95.88% <100.00%> (+<0.01%) ⬆️
torch1.8.1 96.71% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
...etrics/functional/regression/mean_squared_error.py 100.00% <100.00%> (ø)
torchmetrics/regression/mean_squared_error.py 100.00% <100.00%> (ø)
...rics/functional/classification/confusion_matrix.py 100.00% <0.00%> (ø)
...chmetrics/functional/classification/cohen_kappa.py 100.00% <0.00%> (ø)
...trics/functional/regression/mean_relative_error.py 100.00% <0.00%> (ø)
__w/3/s/torchmetrics/regression/pearson.py 100.00% <0.00%> (ø)
...etrics/functional/regression/mean_squared_error.py 100.00% <0.00%> (ø)
__w/3/s/torchmetrics/utilities/prints.py 100.00% <0.00%> (ø)
...s/torchmetrics/functional/classification/f_beta.py 100.00% <0.00%> (ø)
__w/3/s/torchmetrics/functional/nlp.py 100.00% <0.00%> (ø)
... and 84 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update dfc0895...362dd2c. Read the comment docs.

@Borda Borda enabled auto-merge (squash) May 17, 2021 16:34
@maximsch2
Copy link
Contributor

lgtm to me as well, but agree that the extra value might be slim compared to just metric ** 0.5

I see your point about root of each observation. We are getting a few metrics like that recently and it might be worth trying to combine/generalize them. The patter is that we want to have

sum(metric(pred_i, target_i))/N
And so far we have metric:

  • MSE: (pred_i - target_i)**2
  • MAE: torch.abs(pred_i-target_i)
  • What you asked: torch.sqrt((pred_i - target_i)**2) - if you look at it, it's actually exactly MAE, we already have it and you can use it! @johannespitz
  • MeanPercentageError: torch.abs((pred_i-target_i)/target_i)

I do wonder if we will exhaust all of those or we should make a generic metric implementation that will compute average "distance" between pred and target for a pluggable distance.

@johannespitz
Copy link
Contributor Author

@maximsch2 I like the idea of a generic class which can handle all metrics that fit into this sum(metric(pred_i, target_i))/N pattern. All the current metrics could just inherit from this class and specify the function which takes in single predictions and targets.

class torchmetrics.MeanSquaredError(AverageMetric):
    def __init__(self):
        super().__init__(
            individual_metric_fn=lambda pred, targ: (pred - targ) ** 2
        )

This would also allow me to implement what I wanted (but didn't even implement as you pointed out above) individual_metric_fn=lambda pred, targ: troch.norm(pred - targ, dim=-1)

@Borda Borda merged commit 233f0eb into Lightning-AI:master May 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add option to compute root_mean_squared_error
4 participants