Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metrics] AUROC Metric can't handle 0 observations of a class with multiclass classifier #348

Closed
BeyondTheProof opened this issue Jul 1, 2021 · 15 comments · Fixed by #376
Labels
bug / fix Something isn't working waiting on author
Milestone

Comments

@BeyondTheProof
Copy link
Contributor

I'm attempting to calculate AUROC for a multiclass problem where some classes are very rare, occasionally never seen, and I'm getting the following error: raise ValueError("No positive samples in targets, true positive value should be meaningless")

In the case of 0 observations, I feel the average='weighted' should work, since the contribution to the final AUROC should be 0 regardless. One can think of other scenarios where there are a very high number of classes, some of which will happen to not be seen in some dataset.

Originally posted by @BeyondTheProof in Lightning-AI/pytorch-lightning#2210 (comment)

@BeyondTheProof BeyondTheProof changed the title [metrics] AUROC Metric cant handle 0 observations of a class with multiclass classifier [metrics] AUROC Metric can't handle 0 observations of a class with multiclass classifier Jul 1, 2021
@BeyondTheProof
Copy link
Contributor Author

BeyondTheProof commented Jul 1, 2021

I've found a hack for this by subclassing AUROC:

class FixedAUC(torchmetrics.AUROC):
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        num_classes = preds.shape[1]
        zero_obs_mask = torch.tensor([(target == c).sum() > 0 for c in range(num_classes)])
        preds = preds[:, zero_obs_mask]
        target = target[:, zero_obs_mask]
        self.num_classes = zero_obs_mask.sum().int()
        super().update(torch.softmax(preds, axis=-1).data, target.argmax(axis=-1).int().data)

I will be making a cleaner fix in torch and asking for a PR

@SkafteNicki SkafteNicki transferred this issue from Lightning-AI/pytorch-lightning Jul 2, 2021
@github-actions
Copy link

github-actions bot commented Jul 2, 2021

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Hi @BeyondTheProof, I transferred the issue to the torchmetrics repo.
Can I kindly ask if you are calling the forward method or the update method of the metric when trying to update the metric states?

@BeyondTheProof
Copy link
Contributor Author

Hi @SkafteNicki, thank you for the response and transferring the issue! I am just calling the update method.

@SkafteNicki
Copy link
Member

Could you provide an example of how you are using the metric?

@Borda Borda added bug / fix Something isn't working waiting on author labels Jul 7, 2021
@BeyondTheProof
Copy link
Contributor Author

In my subclassed lightning module, I have metrics that are calculated at the end of each epoch:

self.auroc = torchmetrics.AUROC(num_classes=15, average="weighted", compute_on_step=False)
self.accuracy = pl.metrics.Accuracy(num_classes=15, average="weighted", compute_on_step=False)
self.metrics = [self.auroc, self.accuracy]

When I calculate all my losses, I also calculate some other metrics (if there are any bugs, it is only because I modified it to simplify as much as possible):

for metric, name in zip(self.metrics, ["AUROC", "ACC"]):
  self.log(
      f"{stage}_{name}",
      metric,
      on_step=False,
      on_epoch=True,
      prog_bar=False,
      logger=True,
  )

I hope this helps. If there's anything else you need, please let me know!

@BeyondTheProof
Copy link
Contributor Author

BeyondTheProof commented Jul 9, 2021

I've found a hack for this by subclassing AUROC:

class FixedAUC(torchmetrics.AUROC):
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        num_classes = preds.shape[1]
        zero_obs_mask = torch.tensor([(target == c).sum() > 0 for c in range(num_classes)])
        preds = preds[:, zero_obs_mask]
        target = target[:, zero_obs_mask]
        self.num_classes = zero_obs_mask.sum().int()
        super().update(torch.softmax(preds, axis=-1).data, target.argmax(axis=-1).int().data)

I will be making a cleaner fix in torch and asking for a PR

@SkafteNicki In this subclass, there is an error that arises when some batches have only two classes, others have more. I implemented a much cleaner subclass by just subclassing Metrics:

from torchmetrics.functional.classification.auroc import _auroc_compute


class WeightedAUROC(Metric):
    """
    This is used for when the target is not strictly one of K classes, but a probability
    distribution over all K classes
    """

    def __init__(
        self,
        num_classes: Optional[int] = None,
        pos_label: Optional[int] = None,
        compute_on_step: bool = False,
        dist_sync_on_step: bool = False,
    ) -> None:
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
        )
        self.num_classes = num_classes
        if self.num_classes > 2:
            self.mode = "multiclass"
        else:
            self.mode = "binary"
        self.pos_label = pos_label
        self.average = "weighted"

        self.add_state("preds", default=[], dist_reduce_fx="cat")
        self.add_state("target", default=[], dist_reduce_fx="cat")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert target.ndim == 2, f"WeightedAUC expects a 2D Tensor, got {target.ndim}"
        self.preds.append(torch.softmax(preds, axis=-1).data)
        self.target.append(target.data)

    def compute(self):
        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)
        zero_obs_mask = torch.tensor([(target == c).sum() > 0 for c in range(self.num_classes)])
        preds = preds[:, zero_obs_mask]
        target = target[:, zero_obs_mask].int()
        num_classes = zero_obs_mask.sum().int()
        return _auroc_compute(preds, target, self.mode, num_classes=num_classes, average=self.average)

@KimDaeUng
Copy link

@BeyondTheProof I thinks that it will be better to add the line process_group: Optional[Any] = None,.



class WeightedAUROC(Metric):
    """
    This is used for when the target is not strictly one of K classes, but a probability
    distribution over all K classes
    """

    def __init__(
        self,
        num_classes: Optional[int] = None,
        pos_label: Optional[int] = None,
        compute_on_step: bool = False,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None, # It seems that this line is missing.
    ) -> None:
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
        )
        self.num_classes = num_classes
        if self.num_classes > 2:
            self.mode = "multiclass"
        else:
            self.mode = "binary"
        self.pos_label = pos_label
        self.average = "weighted"

        self.add_state("preds", default=[], dist_reduce_fx="cat")
        self.add_state("target", default=[], dist_reduce_fx="cat")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert target.ndim == 2, f"WeightedAUC expects a 2D Tensor, got {target.ndim}"
        self.preds.append(torch.softmax(preds, axis=-1).data)
        self.target.append(target.data)

    def compute(self):
        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)
        zero_obs_mask = torch.tensor([(target == c).sum() > 0 for c in range(self.num_classes)])
        preds = preds[:, zero_obs_mask]
        target = target[:, zero_obs_mask].int()
        num_classes = zero_obs_mask.sum().int()
        return _auroc_compute(preds, target, self.mode, num_classes=num_classes, average=self.average)

@SkafteNicki
Copy link
Member

Hi @BeyondTheProof, thanks for getting back to me. It indeed seems to be a problem, even though I would consider this a very corner case. Would you be up for sending a fix to our current AUROC implementation?
IMO the only thing missing from your implementation is that the user should be warned if a class was removed.

@BeyondTheProof
Copy link
Contributor Author

Hi @SkafteNicki, good point about the user warning. Also, in the implementation above, the target only allows 2D tensors, i.e., a (obs, classes) binary matrix. I will implement a solution for this as well.

@BeyondTheProof
Copy link
Contributor Author

@BeyondTheProof I thinks that it will be better to add the line process_group: Optional[Any] = None,.

Thanks for the catch @KimDaeUng, you're totally right!

@BeyondTheProof
Copy link
Contributor Author

@SkafteNicki Submitted a PR here: #376

Thanks!

@BeyondTheProof
Copy link
Contributor Author

Hi @SkafteNicki, just following up on this. I have an approval from Borda, but still need one more :)

@maximsch2
Copy link
Contributor

The tests are failing in the PR though?

@BeyondTheProof
Copy link
Contributor Author

BeyondTheProof commented Jul 21, 2021

Ah, sorry, I didn't catch that. Will fix. Thank you!

@Borda Borda added this to the v0.5 milestone Aug 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working waiting on author
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants