Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible improvements for Accuracy #1089

Open
Yura52 opened this issue May 31, 2020 · 10 comments
Open

Possible improvements for Accuracy #1089

Yura52 opened this issue May 31, 2020 · 10 comments
Assignees

Comments

@Yura52
Copy link

Yura52 commented May 31, 2020

In full detail the feature request is described here, below is a quick recap.

There are two inconveniences I experience with the current interface of Accuracy.

1. Inconsistent input format for binary classification and multiclass problems

In the first case, Accuracy expects labels as input, whilst in the second case it expects probabilities/logits. I am a somewhat experienced Ignite user and I still get confused by this behavior.

2. No shortcuts for saying "I want to pass logits/probabilities as input"

In practice, I have never used Accuracy in the following manner for binary classification:

accuracy = Accuracy()

Instead, I always do one of the following:

accuracy = Accuracy(transform=lambda x: torch.round(torch.sigmoid(x)))
# either
accuracy = Accuracy(transform=lambda x: torch.round(x))

Suggested solution for both problems: let the user explicitly say in which form input will be passed:

import enum
class Accuracy(...):
    class Mode(enum.Enum):
        LABELS = enum.auto()
        PROBABILITIES = enum.auto()
        LOGITS = enum.auto()

    def __init__(self, mode=Mode.LABELS, ...):
        ...

The suggested interface can be also extended to support custom thresholds by adding the __call__ method to the Mode class.

@sdesrozis
Copy link
Contributor

sdesrozis commented Jun 4, 2020

@WeirdKeksButtonK I really appreciate this API ! Thank you very much 👍🏻

@vcarpani
Copy link
Contributor

Hello, I believe I could be assigned to this issue, since I have a PR for it

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 12, 2020

@vcarpani sure ! On Github we can not assign any user to the issue but only those from the project or who participated in the conversation here.

@vfdev-5 vfdev-5 added PyDataGlobal PyData Global 2020 Sprint and removed Hacktoberfest PyDataGlobal PyData Global 2020 Sprint labels Oct 31, 2020
@vfdev-5 vfdev-5 added the PyDataGlobal PyData Global 2020 Sprint label Oct 21, 2021
@sallycaoyu
Copy link
Contributor

sallycaoyu commented Feb 25, 2023

Hi everyone, I would like to try improving this issue.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 25, 2023

Sure @sallycaoyu , please check also all related PRs and mentions.

@sallycaoyu
Copy link
Contributor

sallycaoyu commented Feb 27, 2023

For now, I am trying to finish implementing a binary_mode for binary and multilabel types to transform probabilities and logits into 0s and 1s as this PR has done. And if that works well, then I can consider how to add more flexibility to multiclass like issue #822 suggests.

Does that sound like a good plan? Or would your like Ignite to have a mode similar to what this issue suggests, i.e., mode in one of [binary, multiclass, multilabel] instead of one of [unchanged, probabilities, logits]? The former way will lead to more modifications to what we have right now, like removing is_multilabel and replacing it with mode for Accuracy, Precision, Recall, ClassificationReport, because now multilabel will be one option of mode.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 27, 2023

@sallycaoyu thanks for the update! I think we can continue with mode as [unchanged, probabilities, logits, labels?].
Can you please sketch up with code snippets new API usage, emphasizing on "before" and "after". For example:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits 

### after
acc = Accuracy(mode=Accuracy.LOGITS)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C), (N, )

acc = Accuracy(mode=Accuracy.LABELS)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass labels  : (N, ), (N, )

etc

@sallycaoyu
Copy link
Contributor

Sure! Suppose we have:

class Accuracy
       def __init__(
            self,
            output_transform: Callable = lambda x: x,
            is_multilabel: bool = False,
            device: Union[str, torch.device] = torch.device("cpu"),
            mode: str = 'unchanged',
            threshold: Union[float, int] = 0.5
        )
          .....

Then, for binary and multilabel data:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (0s and 1s) : (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (0s and 1s) : (N, C, ...), (N, C, ...)



### after
# LOGITS MODE
acc = Accuracy(mode='logits', threshold = 3.25)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary logits (float in [-inf, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='logits', threshold = 3.25, is_multilabel = True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel logits (float in [-inf, inf]): (N, C, ...), (N, C, ...) 

# in this case, Accuracy will transform any value < 3.25 to be 0, value >= 3.25 to be 1
# if not passing a threshold, Accuracy will softmax the logits, and then transform any value < 0.5 to be 0, >= 0.5 to be 1



# PROBABILITIES MODE
acc = Accuracy(mode='probabilities', threshold = 0.6)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary probabilities (float in [0, 1]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='probabilities', threshold = 0.6, is_multilabel = True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel probabilities (float in [0, 1]): (N, C, ...), (N, C, ...)

# in this case, Accuracy will transform any value < 0.6 to be 0, value >= 0.6 to be 1
# if not passing a threshold, Accuracy will transform any value < 0.5 to be 0, >= 0.5 to be 1



# LABELS MODE
acc = Accuracy(mode='labels', threshold = 5)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (int in [0, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='labels', threshold = 5, is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (int in [0, inf]): (N, C, ...), (N, C, ...)

# in the case, Accuracy will transform any value < 5 to be 0, >= 5 to be 1
# must specify a threshold for labels mode




# UNCHANGED MODE
acc = Accuracy(mode='unchanged')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (0s and 1s): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='unchanged’, is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (0s and 1s): (N, C, ...), (N, ...)

# will work like before : raise an error when any value is not 0 or 1
# should not specify a threshold for unchanged mode 

For multiclass data:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)



### after: should not apply threshold to multiclass data
# LABELS MODE
acc = Accuracy(mode='labels')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass labels : (N, ...), (N, ...)
# conflict with _check_type(), since we use y.ndimension() + 1 == y_pred.ndimension() to check for multiclass data


# For now, the following multiclass modes will work like before (argmax):
# PROBABILITIES MODE
acc = Accuracy(mode='probabilities')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass probabilities : (N, C, ...), (N, ...)


# LOGITS MODE
acc = Accuracy(mode='logits')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...) 


# UNCHANGED MODE
acc = Accuracy(mode='unchanged')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 1, 2023

Thanks a lot for the snippet @sallycaoyu !

I have few thoughts about that:

  • would it make sense to introduce mode=multilabels or someother approprate name to hint about multiclass data ?

  • I'm not sure about usefulness of threshold arg. If we want to threshold logits/probas to labels we can use output_transform with any rule/threshold we want:

# binary data:
acc = Accuracy(mode='logits', output_transform=lambda x: (x > 0).to(dtype=torch.long))
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary logits (float in [-inf, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)
  • Maybe we can also drop unchanged mode ?

What do you think ?

@sallycaoyu
Copy link
Contributor

sallycaoyu commented Mar 1, 2023

@vfdev-5 Thank you very much for the comments!

I agree that we can drop unchanged mode. And I also agree that output_transform can give users more flexibility than threshold, so threshold is not very necessary. Then by default, for:

  • probabilities mode:
    • for binary and multilabel, we can round data to 0 or 1 and compare with y_true
    • for multiclass, we can take argmax for now
  • logits mode:
    • for binary and multilabel, we can do sigmoid then round to 0 or 1 and compare with y_true
    • for multiclass, we can take argmax for now
  • labels mode: I am actually not so sure about how to handle this situation.
    • for multilabel, we can one-hot y_pred and y_true, then compare them
    • for multiclass, we can directly compare y_pred and y_true if they are of the same shape (N, ...), (N, ...) or (N, 1, ...), (N, ...)
    • for binary, how do we map class 0 - class N to 0 or 1?
    • Maybe we should not enable this mode for binary data, but I don't think calling this mode multilabels is a good name for multiclass data because it should be differentiated from the results of multilabel classification. Would nonbinary_labels be a better name for this mode?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants