-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Metrics] Add multiclass auroc #4236
Changes from all commits
983b795
62137f4
b3947e0
70d4d0d
c36cd9b
3bf01c7
6bfca06
1ee1457
e3076d5
19c4b73
60c3c98
30ccb62
7f82b3b
f681b24
ee4078d
8b1b7f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -817,13 +817,14 @@ def new_func(*args, **kwargs) -> torch.Tensor: | |
|
||
def multiclass_auc_decorator(reorder: bool = True) -> Callable: | ||
def wrapper(func_to_decorate: Callable) -> Callable: | ||
@wraps(func_to_decorate) | ||
def new_func(*args, **kwargs) -> torch.Tensor: | ||
results = [] | ||
for class_result in func_to_decorate(*args, **kwargs): | ||
x, y = class_result[:2] | ||
results.append(auc(x, y, reorder=reorder)) | ||
|
||
return torch.cat(results) | ||
return torch.stack(results) | ||
|
||
return new_func | ||
|
||
|
@@ -858,7 +859,7 @@ def auroc( | |
if any(target > 1): | ||
raise ValueError('AUROC metric is meant for binary classification, but' | ||
' target tensor contains value different from 0 and 1.' | ||
' Multiclass is currently not supported.') | ||
' Use `multiclass_auroc` for multi class classification.') | ||
|
||
@auc_decorator(reorder=True) | ||
def _auroc(pred, target, sample_weight, pos_label): | ||
|
@@ -867,6 +868,62 @@ def _auroc(pred, target, sample_weight, pos_label): | |
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) | ||
|
||
|
||
def multiclass_auroc( | ||
pred: torch.Tensor, | ||
target: torch.Tensor, | ||
sample_weight: Optional[Sequence] = None, | ||
num_classes: Optional[int] = None, | ||
) -> torch.Tensor: | ||
""" | ||
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass | ||
prediction scores | ||
|
||
Args: | ||
pred: estimated probabilities, with shape [N, C] | ||
target: ground-truth labels, with shape [N,] | ||
sample_weight: sample weights | ||
num_classes: number of classes (default: None, computes automatically from data) | ||
|
||
Return: | ||
Tensor containing ROCAUC score | ||
|
||
Example: | ||
|
||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], | ||
... [0.05, 0.85, 0.05, 0.05], | ||
... [0.05, 0.05, 0.85, 0.05], | ||
... [0.05, 0.05, 0.05, 0.85]]) | ||
>>> target = torch.tensor([0, 1, 3, 2]) | ||
>>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE | ||
tensor(0.6667) | ||
""" | ||
if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): | ||
raise ValueError( | ||
"Multiclass AUROC metric expects the target scores to be" | ||
" probabilities, i.e. they should sum up to 1.0 over classes") | ||
|
||
if torch.unique(target).size(0) != pred.size(1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't it be
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or we could have a get_num_classes utils too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well the metric is undefined when E.g., for If a target label is not present in the As for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense ! |
||
raise ValueError( | ||
f"Number of classes found in in 'target' ({torch.unique(target).size(0)})" | ||
f" does not equal the number of columns in 'pred' ({pred.size(1)})." | ||
" Multiclass AUROC is not defined when all of the classes do not" | ||
" occur in the target labels.") | ||
|
||
if num_classes is not None and num_classes != pred.size(1): | ||
raise ValueError( | ||
f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal" | ||
f" the number of classes passed in 'num_classes' ({num_classes}).") | ||
|
||
@multiclass_auc_decorator(reorder=False) | ||
def _multiclass_auroc(pred, target, sample_weight, num_classes): | ||
return multiclass_roc(pred, target, sample_weight, num_classes) | ||
|
||
class_aurocs = _multiclass_auroc(pred=pred, target=target, | ||
sample_weight=sample_weight, | ||
num_classes=num_classes) | ||
return torch.mean(class_aurocs) | ||
Comment on lines
+917
to
+924
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've implemented this using the
|
||
|
||
|
||
def average_precision( | ||
pred: torch.Tensor, | ||
target: torch.Tensor, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, we should check pred.size(0) == target.size(0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add that of course but this is not a check that is done in any other metric implementation. So if it's done here it should probably be done everywhere. If that's desired, I could add a helper to
classification.py
Would that work? Then this helper could be used in each metric instead of copy pasting the
if
clause and the exception.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we are slowly unifying the functional and class based interface, we are doing more checks for shape, so this will come in a future PR :]