Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiclass Classification: assert num_classes >=2 #2205

Open
robmarkcole opened this issue Aug 1, 2024 · 4 comments · May be fixed by #2219
Open

Multiclass Classification: assert num_classes >=2 #2205

robmarkcole opened this issue Aug 1, 2024 · 4 comments · May be fixed by #2219
Labels
trainers PyTorch Lightning trainers

Comments

@robmarkcole
Copy link
Contributor

Summary

Both segmentation and object detection require that the background be included and there is currently a note on these args: num_classes: Number of prediction classes (including the background). Considering every dataaset must have at least 1 class, the min value of num_classes is 2. I propose adding an assertion, to prevent people (like myself!) from forgetting this and setting num_classes=1 for datasets with a single class.

Rationale

This config error has happened to me several times, and can pass silently

Implementation

I suppose we add validation to the BaseTask init

Alternatives

No response

Additional information

No response

@adamjstewart adamjstewart added the trainers PyTorch Lightning trainers label Aug 4, 2024
@adamjstewart
Copy link
Collaborator

Not to completely derail what should otherwise be a simple fix, but...

This brings up the question of how we want to handle different forms of classification/semantic segmentation:

  • Binary
  • Multiclass
  • Multilabel

Torchmetrics originally had a single class for Accuracy. In Lightning-AI/torchmetrics#1001, they proposed and implemented separate classes for each of the 3 above types of classification (BinaryAccuracy, etc.). The original plan was to deprecate and remove the old single class, but it seems that plan was aborted at some point.

We should decide whether we want BinaryClassificationTask, etc. or whether we want to add a task='binary', etc. parameter to ClassificationTask.

We could definitely still add such an assertion for now and change it to assert num_classes > 1 if task != 'binary' later if needed.

@robmarkcole
Copy link
Contributor Author

As you point out, binary etc are args torchmetrics accepts, so I think it makes sense to have this functionality with the existing task

@adamjstewart
Copy link
Collaborator

Just waiting for clarity on whether torchmetrics is planning on supporting the old metrics forever before deciding, but I was leaning towards that too.

@adamjstewart
Copy link
Collaborator

Looks like I misinterpreted, both are supported.

Is there anything special we need to do in our trainers to support binary and multilabel, or do we literally just need to pass different task values to torchmetrics? If the former, we may want to split, but if the latter, I agree we should just keep the current classes and add a task parameter.

@adamjstewart adamjstewart changed the title Assert num_classes >=2 Multiclass Classification: assert num_classes >=2 Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants