Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification refactor #1001

Closed
SkafteNicki opened this issue May 3, 2022 · 11 comments · Fixed by #1195
Closed

Classification refactor #1001

SkafteNicki opened this issue May 3, 2022 · 11 comments · Fixed by #1195
Assignees
Labels
Important milestonish refactoring refactoring and code health
Milestone

Comments

@SkafteNicki
Copy link
Member

The classification package is long overdue for a refactor as we are seeing a rising number of issues that either request new features that are hard to implement in the current codebase, a disagreement between what users expect the metrics are doing and what they are actually doing.

A full list of issues marked that should be taken care of with the refactor can be found here

The refactor hope to adress the following problems:

  • Maintainability: The core part of the classification metrics was written back when torchmetrics was pl.metrics by an contributor. We should maybe have been more thoroughly in the review phase, because the code has been hard to maintain. This refactor should hopefully help adress this by lowering code complexity.
  • Consistency: The classification package have a number of consistency issues. This issue gives an overview of some of the consistency issues, but essentially num_classes=2 sometimes means doing binary classification and sometimes means multiclass classification (which differs in their definition) depending on what metric you are using.
  • Expectations: There are a number of cases where the default choice of arguments currently does not match users expection, because we differ a bit from how sklearn are handeling some cases. This refactor will adress these differences.
  • Performance: we have received feedback that our implementations are very slow. While these comparisons are not always fair (comparing a one line implementation of accuracy vs. our modular implementation which is much more general), it is fair to say that some improvements can be made.

Proposed solution

The proposed solution is to split each metric into three seperate metric instances

  • BinaryMetricName
  • MultiClassMetricName
  • MultiLabelMetricName

For example Accuracy will be split into BinaryAccuracy, MultiClassAccuracy, MultiLabelAccuracy. This solution directly solves a number of problems:

  1. While most metrics indeed support all three tasks, some only support a single or 2. Ranaming and splitting all metrics will make it much more clear what metric support what mode.
  2. Some performance is lost on during a lot of input validation. By forcing the user beforehand to specify what task we are evaluating we can reduce the amount of input validation needed.
  3. Code complexity: instead of all metrics essentially having if-else statements based on what task we are trying to solve, each metric will have a much more clear computational path.
  4. Input arguments: take threshold and topk as examples, which are current arguments to the Accuracy, F1 ect. metrics. threshold should only be set for binary and multilabel and topk should only be specified for multiclass. Dividing into seperate metrics helps communicate what arguments have a influence on the computations going on.

Alternatives

  1. We keep everything in one class but introduce an new (required) argument:

    class Accuracy(Metric):
    	def __init__(
    		self, 
    		mode: Literal['binary', 'multiclass', 'multilabel']
    		...
    	):
    		...

    This alternative is not directly in opposition to the main proposed solution. If requested by users we could still provide a single class that just wraps the three individual metric classes into 1.

  2. We keep the outer API the exact same and try to clean up the internals. This will most likely only address some of the current problems.

Integration

  • PL: should really not be a problem as core PL does not depend on the classification package, only on the base torchmetrics.Metric class, which should not need to be touched doing the refactor. Some examples may need to be updated.
  • Flash: some changes would need to be made. We still expect that it will be minimal as initialization of a flash Task should already contain all information necessary to determine what class should be used. cc: @ethanwharris

Deprecation

The goal is to have the hole classification package refactored/cleaned up as the major work in 0.10. All current classification metrics will be given a deprecation warning and users will have until 0.11 to refactor their code to use the new classes.

While we are developing the new package we will have a freeze on new metrics in the classification package. We will still happily accept new metrics for other domains.

Documentation impact

Up until v0.8, this change would have made our documentation very annoying to scroll through as everything was in one central page. This change would essentially make the documentation for classification 3 times harder to navigate.

However, from v0.8 we changed it to have one page per metric. For this refactor we would keep one page per core metric e.g. Accuracy, Precision, Recall etc. and each page would then list every version of the metric.

Development

The development can essentially be divided into 3 phases:

  1. Development of a generalized StatScore and ConfusionMatrix class for all three modes. Many classification metrics can be calculated from these statistics.
  2. Subclass the generalized classes into specific metrics.
  3. Deal with metrics that do not fall into the one of the two generalized classes (about 1/3 of the classification metrics currently)

Main part of the refactor will be done by @SkafteNicki and @justusschock, with support from the rest of the core metrics team. We may be open for contributions for step 2 as it should be fairly simple sub classing and copy-paste work. Development should start within 2 weeks time.

Any feedback is appreciated :)

@SkafteNicki SkafteNicki added Important milestonish refactoring refactoring and code health labels May 3, 2022
@Borda Borda added this to the v0.10 milestone May 3, 2022
@awaelchli
Copy link
Contributor

I like that we are thinking of simplifying this for the users. The classification module has grown organically - many more options have been added over time. I agree a refactor is appropriate here.

I like the alternative approach 1 where we keep the current basic classes as light wrappers. These are easy to remember and frequently used (Accuracy, F1, Precision ...) by a large user base. If we go with the wrapper class and do validation inside of them, perhaps we can recommend the specific classes BinaryMetricName, MultiClassMetricName, MultiLabelMetricName for the depending on the error and use case.

@ethanwharris
Copy link
Member

Sounds good, I agree with @awaelchli that it would be nice to keep the base Accuracy etc. classes with e.g. the mode argument seems reasonable. Flash can be updated to work with either API 😃

@Yura52
Copy link

Yura52 commented Jun 1, 2022

While it is not directly related to the discussion, in theory, this issue may also be relevant, since it highlights one more aspect of accuracy-like metrics that may be taken into account during the refactoring.

@adamjstewart
Copy link
Contributor

All current classification metrics will be given a deprecation warning and users will have until 0.11 to refactor their code to use the new classes.

It seems like this plan was aborted at some point? Can someone point me to the discussion where this happened. Want to clarify if the old metrics are discouraged or if the plan is to support both old and new for the indefinite future.

@awaelchli
Copy link
Contributor

@adamjstewart The PR that deprecated things was #1195. This was 2 years ago and the deprecation messages are in there and since then, the classes marked for removal have been removed already.

I guess maybe you are referring to the classes flavors like Accuracy, MultiClassAccuracy, BinaryAccuracy etc? I am not aware of any plans to deprecate there. For example, BinaryAccuracy is simply equivalent to Accuracy(task="binary"), and MultiClassAccuracy is equivalent to Accuracy(task="multiclass"). To do it this way is what the discussion here converged to and I am not aware of any change in plans.

@adamjstewart
Copy link
Contributor

Ah, gotcha. I interpreted this as deprecating Accuracy and replacing it with {Binary,Multi{class,label}}Accuracy. Good to know that both will be supported!

@Borda
Copy link
Member

Borda commented Aug 4, 2024

Ah, gotcha. I interpreted this as deprecating Accuracy and replacing it with {Binary,Multi{class,label}}Accuracy. Good to know that both will be supported!

This correct as we had in past many confused users which type was used, expecially with some edge cases within first batch...
The we had to keep it in the codebase due to compatibility 🐰

@Borda
Copy link
Member

Borda commented Aug 5, 2024

@lantiga @SkafteNicki are we considering dropping the old metrics in the future 2.0?

@adamjstewart
Copy link
Contributor

I also see references like "Legacy Example" in https://lightning.ai/docs/torchmetrics/stable/classification/accuracy.html that suggest that the previous style is discouraged. I personally think the old style is very convenient, especially for generic LightningModules like torchgeo.trainers.SemanticSegmentationTask, which could support binary/multiclass/multilabel with only minor changes.

@Borda
Copy link
Member

Borda commented Aug 6, 2024

I personally think the old style is very convenient, especially for generic LightningModules like

I totally agree but also with the luxury or not needed thinking about the task came frustration that results are not correct... for example you have classification and the shuffle is bad in a few first batches come with labels 0 so the metrics are set up as binary but the task is multi-label in real and other labels comes later...
as you can see in the PR it resolved handful issues, but for compatibility we kept old API 🐰

@adamjstewart
Copy link
Contributor

few first batches come with labels 0 so the metrics are set up as binary

Oh I mean auto-detection is bad, but asking the user to specify a task for a single class is nice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Important milestonish refactoring refactoring and code health
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants