Skip to content

Commit

Permalink
Merge pull request #16 from koaning/margin-doubt
Browse files Browse the repository at this point in the history
Add Margin Reason
  • Loading branch information
koaning authored Nov 23, 2021
2 parents 7c008af + a7733b1 commit 9a2f667
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 1 deletion.
43 changes: 43 additions & 0 deletions doubtlab/reason.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,49 @@ def __call__(self, X, y):
return np.where(confidences > self.threshold, confidences, 0)


class MarginConfidenceReason:
"""
Assign doubt when a the difference between the top two most confident classes is too large.
Throws an error when there are only two classes.
Arguments:
model: scikit-learn classifier
threshold: confidence threshold for doubt assignment
Usage:
```python
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from doubtlab.ensemble import DoubtEnsemble
from doubtlab.reason import MarginConfidenceReason
X, y = load_iris(return_X_y=True)
model = LogisticRegression(max_iter=1_000)
model.fit(X, y)
doubt = DoubtEnsemble(reason = MarginConfidenceReason(model=model))
indices = doubt.get_indices(X, y)
```
"""

def __init__(self, model, threshold=0.2):
self.model = model
self.threshold = threshold

def _calc_margin(self, probas):
sorted = np.sort(probas, axis=1)
return sorted[:, -1] - sorted[:, -2]

def __call__(self, X, y):
probas = self.model.predict_proba(X)
margin = self._calc_margin(probas)
return np.where(margin > self.threshold, margin, 0)


class ShortConfidenceReason:
"""
Assign doubt when the correct class gains too little confidence.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

setup(
name="doubtlab",
version="0.1.1",
version="0.1.2",
author="Vincent D. Warmerdam",
packages=find_packages(exclude=["notebooks", "docs"]),
description="Don't Blindly Trust Your Labels",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DisagreeReason,
LongConfidenceReason,
ShortConfidenceReason,
MarginConfidenceReason,
WrongPredictionReason,
AbsoluteDifferenceReason,
RelativeDifferenceReason,
Expand All @@ -22,6 +23,7 @@
DisagreeReason,
LongConfidenceReason,
ShortConfidenceReason,
MarginConfidenceReason,
WrongPredictionReason,
AbsoluteDifferenceReason,
RelativeDifferenceReason,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_general_reason.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ProbaReason,
OutlierReason,
DisagreeReason,
MarginConfidenceReason,
LongConfidenceReason,
ShortConfidenceReason,
WrongPredictionReason,
Expand All @@ -22,6 +23,7 @@
ProbaReason,
LongConfidenceReason,
ShortConfidenceReason,
MarginConfidenceReason,
WrongPredictionReason,
CleanlabReason,
]
Expand Down
17 changes: 17 additions & 0 deletions tests/test_reason/test_margin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression

from doubtlab.reason import MarginConfidenceReason


def test_margin_confidence_margin():
"""Ensures margin is calculated correctly."""
X, y = load_iris(return_X_y=True)
model = LogisticRegression(max_iter=1_000)
model.fit(X, y)

reason = MarginConfidenceReason(model=model)
probas = np.eye(3)
margin = reason._calc_margin(probas=probas)
assert np.all(np.isclose(margin, np.ones(3)))

0 comments on commit 9a2f667

Please sign in to comment.