Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weighted AUROC to omit empty classes #376

Merged
merged 27 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1a471f4
Skipping empty classes in weighted auroc.py
BeyondTheProof Jul 15, 2021
bb86efe
formatting
BeyondTheProof Jul 15, 2021
cfecc54
Added test to test_auroc.py
BeyondTheProof Jul 15, 2021
7969598
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2021
0740a29
Removed comment, reformatted prediction matrix
BeyondTheProof Jul 15, 2021
67ac9b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2021
8f2b5c4
fmt
Borda Jul 16, 2021
2c0447a
Merge branch 'master' into master
mergify[bot] Jul 16, 2021
a6b5a2b
Merge branch 'master' into master
mergify[bot] Jul 16, 2021
bc3e84a
Merge branch 'master' into master
mergify[bot] Jul 16, 2021
064eef2
Merge branch 'master' into master
mergify[bot] Jul 16, 2021
c37d28b
Merge branch 'master' into master
mergify[bot] Jul 19, 2021
cd30d00
debug -- fixing binary case
BeyondTheProof Jul 22, 2021
996ec8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2021
2df5539
Updated tests
BeyondTheProof Jul 23, 2021
200e5c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2021
20f84cd
updated changelog
BeyondTheProof Jul 23, 2021
da050d7
debug
BeyondTheProof Jul 23, 2021
376bfcd
final bug
BeyondTheProof Jul 23, 2021
d0ac692
Update auroc.py
BeyondTheProof Jul 23, 2021
a3cf8d0
removed f-string
BeyondTheProof Jul 23, 2021
f96e4f2
Docstring; removed f-string
BeyondTheProof Jul 23, 2021
d33e2b2
Merge branch 'master' into master
mergify[bot] Jul 24, 2021
28294c7
Merge branch 'master' into master
mergify[bot] Jul 24, 2021
e933050
Merge branch 'master' into master
mergify[bot] Jul 24, 2021
99caaa2
Merge branch 'master' into master
mergify[bot] Jul 24, 2021
8af688d
Merge branch 'master' into master
mergify[bot] Jul 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))


## [0.4.1] - 2021-07-05
Expand Down
28 changes: 28 additions & 0 deletions tests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,31 @@ def test_error_multiclass_no_num_classes():
ValueError, match="Detected input to `multiclass` but you did not provide `num_classes` argument"
):
_ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20, )))


def test_weighted_with_empty_classes():
""" Tests that weighted multiclass AUROC calculation yields the same results if a new
but empty class exists. Tests that the proper warnings and errors are raised
"""
preds = torch.tensor([
[0.90, 0.05, 0.05],
[0.05, 0.90, 0.05],
[0.05, 0.05, 0.90],
[0.85, 0.05, 0.10],
[0.10, 0.10, 0.80],
])
target = torch.tensor([0, 1, 1, 2, 2])
num_classes = 3
_auroc = auroc(preds, target, average="weighted", num_classes=num_classes)

# Add in a class with zero observations at second to last index
preds = torch.cat((preds[:, :num_classes - 1], torch.rand_like(preds[:, 0:1]), preds[:, num_classes - 1:]), axis=1)
# Last class (2) gets moved to 3
target[target == num_classes - 1] = num_classes
with pytest.warns(UserWarning, match='Class 2 had 0 observations, omitted from AUROC calculation'):
_auroc_empty_class = auroc(preds, target, average="weighted", num_classes=num_classes + 1)
assert _auroc == _auroc_empty_class

target = torch.zeros_like(target)
with pytest.raises(ValueError, match='Found 1 non-empty class in `multiclass` AUROC calculation'):
_ = auroc(preds, target, average="weighted", num_classes=num_classes + 1)
20 changes: 18 additions & 2 deletions torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional, Sequence, Tuple

import torch
Expand Down Expand Up @@ -87,8 +88,23 @@ def _auroc_compute(
else:
raise ValueError('Detected input to be `multilabel` but you did not provide `num_classes` argument')
else:
if mode != DataType.BINARY and num_classes is None:
raise ValueError('Detected input to `multiclass` but you did not provide `num_classes` argument')
if mode != DataType.BINARY:
if num_classes is None:
raise ValueError("Detected input to `multiclass` but you did not provide `num_classes` argument")
if average == AverageMethod.WEIGHTED and len(torch.unique(target)) < num_classes:
# If one or more classes has 0 observations, we should exclude them, as its weight will be 0
target_bool_mat = torch.zeros((len(target), num_classes), dtype=bool)
target_bool_mat[torch.arange(len(target)), target.long()] = 1
class_observed = target_bool_mat.sum(axis=0) > 0
for c in range(num_classes):
if not class_observed[c]:
warnings.warn(f'Class {c} had 0 observations, omitted from AUROC calculation', UserWarning)
preds = preds[:, class_observed]
target = target_bool_mat[:, class_observed]
target = torch.where(target)[1]
num_classes = class_observed.sum()
if num_classes == 1:
raise ValueError('Found 1 non-empty class in `multiclass` AUROC calculation')
fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights)

# calculate standard roc auc score
Expand Down