Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
AUC metrics (#3751)
Browse files Browse the repository at this point in the history
* added rates

* added AUC and updated rates slightly (not sure if we'll be using them at the end, but need to add tests

* linting

* fixed a bug

* Still needs to be updated, but will deal with this later....

* only dependency on sklearn.metrics should be auc, but still need to be tested for accuracy

* cleaned up/deleted unnecessary statements

* might have messed sth up

* last commit before adding tests & changing the structure so that it is recorded at the very end

* found a bug

* added 3 tests + a couple small changes to torch classifier agent (have not updated to only calculate at the end yet

* fixed a  small thing that would affect testing metrics

* fixed things so the tests pass... but also don't know why it likes the fpr and tpr reversed???

* passed all current auc tests.... but I will add more later for checking

* added a hopefully faster merge way

* finished testing inside test_metrics... now gonna work on the actual side

* changed some stuff so it only is run on eval_model, but there are some weirdness going on with auc

* testing snapshot

* removed extra prints

* fixed something

* removed an extra commit

* fixed another thing....

* another bug

* removed slots and added better class comment

* linting

* why was linting not added before...

* final commit hopefully

* oops forgot to update the tests

* fixed some max_bucket_dec_places problems

* updated sort and counter typing

* fixed another typing

* fixed an import thing

* fixed a small typing thing

* added comments + changed -auc in eval_model

* added comment about classifier agent

* linted

* fixed optional

* added to documentation

* fixed comments + counter import

* fixed wording :)
  • Loading branch information
liliarose authored Jul 8, 2021
1 parent bd9ac8f commit 36004c9
Show file tree
Hide file tree
Showing 4 changed files with 557 additions and 6 deletions.
4 changes: 4 additions & 0 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class MetricDisplayData(NamedTuple):

METRICS_DISPLAY_DATA = {
"accuracy": MetricDisplayData("Accuracy", "Exact match text accuracy"),
'auc': MetricDisplayData(
'AUC',
"Area Under the Receiver Operating Characteristic Curve (true positive rate vs false positive rate curve)",
),
"bleu-4": MetricDisplayData(
"BLEU-4",
"BLEU-4 of the generation, under a standardized (model-independent) tokenizer",
Expand Down
202 changes: 198 additions & 4 deletions parlai/core/torch_classifier_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
from parlai.core.torch_agent import TorchAgent, Output
from parlai.utils.misc import round_sigfigs, warn_once
from parlai.core.metrics import Metric, AverageMetric
from typing import List, Optional, Tuple, Dict
from typing import List, Optional, Tuple, Dict, Union
from typing import Counter
from parlai.utils.typing import TScalar
from parlai.utils.io import PathManager
import parlai.utils.logging as logging

from sklearn.metrics import auc

import torch
import parlai.utils.logging as logging
import torch.nn.functional as F
import torch
import math


class ConfusionMatrixMetric(Metric):
Expand Down Expand Up @@ -168,6 +170,158 @@ def value(self) -> float:
return numer / denom


class AUCMetrics(Metric):
"""
Computes Area Under ROC Curve (AUC) metrics.
Does so by keeping track of positives' and negatives' probability score counts in
Counters or dictionaries. Note the introduction of `max_bucket_dec_places`; this
integer number determines the number of digits to save for the probability scores. A
higher `max_bucket_dec_places` will a more accurate estimate of the exact AUC
metric, but may also use more space.
NOTE: currently only used for classifiers in the `eval_model` script; to use,
add the argument `-auc <max_bucket_dec_places>` when calling `eval_model` script
"""

@property
def macro_average(self) -> bool:
"""
Indicates whether this metric should be macro-averaged when globally reported.
"""
return False

@classmethod
def raw_data_to_auc(
cls,
true_labels: List[Union[int, str]],
pos_probs: List[float],
class_name,
max_bucket_dec_places: int = 3,
):
auc_object = cls(class_name, max_bucket_dec_places=max_bucket_dec_places)
auc_object.update_raw(
true_labels=true_labels, pos_probs=pos_probs, class_name=class_name
)
return auc_object

def __init__(
self,
class_name: Union[int, str],
max_bucket_dec_places: int = 3,
pos_dict: Optional[Counter[float]] = None,
neg_dict: Optional[Counter[float]] = None,
):
# `_pos_dict` keeps track of the probabilities of the positive class
self._pos_dict = pos_dict if pos_dict else Counter()
# `_neg_dict` keeps track of the probabilities of the negative class
self._neg_dict = neg_dict if neg_dict else Counter()
self._class_name = class_name
self._max_bucket_dec_places = max_bucket_dec_places

def update_raw(
self, true_labels: List[Union[int, str]], pos_probs: List[float], class_name
):
"""
given the true/golden labels and the probabilities of the positive class, we
will update our bucket dictionaries of positive and negatives (based on the
class_name); `max_bucket_dec_places` is also used here to round the
probabilities and possibly.
"""
assert self._class_name == class_name
assert len(true_labels) == len(pos_probs)

TO_INT_FACTOR = 10 ** self._max_bucket_dec_places
# add the upper and lower bound of the values
for label, prob in zip(true_labels, pos_probs):
# calculate the upper and lower bound of the values
prob_down = math.floor(prob * TO_INT_FACTOR) / TO_INT_FACTOR
if label == self._class_name:
interested_dict = self._pos_dict
else:
interested_dict = self._neg_dict
if interested_dict.get(prob_down):
interested_dict[prob_down] += 1
else:
interested_dict[prob_down] = 1

def __add__(self, other: Optional['AUCMetrics']) -> 'AUCMetrics':
if other is None:
return self
assert isinstance(other, AUCMetrics)
assert other._class_name == self._class_name
all_pos_dict = self._pos_dict + other._pos_dict
all_neg_dict = self._neg_dict + other._neg_dict

return AUCMetrics(
self._class_name, pos_dict=all_pos_dict, neg_dict=all_neg_dict
)

def _calc_fp_tp(self) -> List[Tuple[int]]:
"""
Calculates the False Positives and True positives; returned as a list of pairs:
`[(fp, tp)]`
"""
all_thresholds = sorted(
set(list(self._pos_dict.keys()) + list(self._neg_dict.keys()))
)
# sorted in ascending order,
# so adding a upper bound so that its tp, fp is (0, 0)
all_thresholds.append(all_thresholds[-1] + 1)
L = len(all_thresholds)
# false positives, true positives
fp_tp = [(0, 0)]

# the biggest one is always (0,0), so skip that one
for i in range(L - 2, -1, -1):
fp, tp = fp_tp[-1]
thres = all_thresholds[i]
# false positives
fp += self._neg_dict.get(thres, 0)
# true positives
tp += self._pos_dict.get(thres, 0)
fp_tp.append((fp, tp))
return fp_tp

def _calc_fpr_tpr(self) -> Tuple[Union[List[int], int]]:
"""
Calculates the false positive rates and true positive rates Also returns the
total number of positives and negatives; returned as a list of pairs and two
integers:
`([(fpr, tpr)], positives, negatives)`; note that if the total
negatives/positives is 0, then will return 0 for either fpr/tpr instead of
raising an error
"""
_tot_pos = sum(self._pos_dict.values())
_tot_neg = sum(self._neg_dict.values())
fp_tp = self._calc_fp_tp()
fps, tps = list(zip(*fp_tp))
if _tot_neg == 0:
fpr = [0] * len(fps)
else:
fpr = [fp / _tot_neg for fp in fps]

if _tot_pos == 0:
tpr = [0] * len(tps)
else:
tpr = [tp / _tot_pos for tp in tps]

return (list(zip(fpr, tpr)), _tot_pos, _tot_neg)

def value(self) -> float:
fpr_tpr, _tot_pos, _tot_neg = self._calc_fpr_tpr()

if _tot_pos == 0 and _tot_neg == 0:
return 0

# auc needs x-axis to be sorted
fpr_tpr.sort()
fpr, tpr = list(zip(*fpr_tpr))
return auc(fpr, tpr)


class WeightedF1Metric(Metric):
"""
Class that represents the weighted f1 from ClassificationF1Metric.
Expand Down Expand Up @@ -341,6 +495,26 @@ def __init__(self, opt: Opt, shared=None):
else:
self.threshold = None

# set up calculating auc
self.calc_auc = opt.get('area_under_curve_digits', -1) > 0

if self.calc_auc:
self.auc_bucket_decimal_size = opt.get('area_under_curve_digits')
if opt.get('area_under_curve_class') is None:
# self.auc_class_ind
interested_classes = self.class_list
else:
interested_classes = opt.get('area_under_curve_class')
try:
self.auc_class_indices = [
self.class_dict[class_name] for class_name in interested_classes
]
except Exception:
raise RuntimeError(
f'The inputted classes for auc were probably invalid.\n Current class names: {self.class_list} \n Names of AUC classes passed in: {interested_classes}'
)
self.reset_auc()

# set up model and optimizers
states = {}
if shared:
Expand Down Expand Up @@ -459,6 +633,12 @@ def _format_interactive_output(self, probs, prediction_id):
)
return preds

def _update_aucs(self, batch, probs):
probs_arr = probs.detach().cpu().numpy()
for index, curr_auc in zip(self.auc_class_indices, self.aucs):
class_probs = probs_arr[:, index]
curr_auc.update_raw(batch.labels, class_probs, self.class_list[index])

def train_step(self, batch):
"""
Train on a single batch of examples.
Expand Down Expand Up @@ -494,6 +674,10 @@ def eval_step(self, batch):
self.model.eval()
scores = self.score(batch)
probs = F.softmax(scores, dim=1)

if self.calc_auc:
self._update_aucs(batch, probs)

if self.threshold is None:
_, prediction_id = torch.max(probs.cpu(), 1)
else:
Expand Down Expand Up @@ -528,3 +712,13 @@ def score(self, batch):
class.
"""
raise NotImplementedError('Abstract class: user must implement score()')

def reset_auc(self):
if self.calc_auc:
self.aucs = [
AUCMetrics(
class_name=self.class_list[index],
max_bucket_dec_places=self.auc_bucket_decimal_size,
)
for index in self.auc_class_indices
]
30 changes: 29 additions & 1 deletion parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
get_rank,
)

# the index to access classifier agent's output in the world
CLASSIFIER_AGENT = 1


def setup_args(parser=None):
if parser is None:
Expand Down Expand Up @@ -69,6 +72,21 @@ def setup_args(parser=None):
default='conversations',
choices=['conversations', 'parlai'],
)
parser.add_argument(
'--area-under-curve-digits',
'-auc',
type=int,
default=-1,
help='a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric',
)
parser.add_argument(
'--area-under-curve-class',
'-auclass',
type=str,
default=None,
nargs='*',
help='the name(s) of the class to calculate the auc for',
)
parser.add_argument('-ne', '--num-examples', type=int, default=-1)
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
Expand Down Expand Up @@ -182,8 +200,18 @@ def _eval_single_world(opt, agent, task):
world_logger.write(outfile, world, file_format=opt['save_format'])

report = aggregate_unnamed_reports(all_gather_list(world.report()))
world.reset()

if isinstance(world.agents, list) and len(world.agents) > 1:
classifier_agent = world.agents[CLASSIFIER_AGENT]
if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc:
for class_indices, curr_auc in zip(
classifier_agent.auc_class_indices, classifier_agent.aucs
):
report[f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc
classifier_agent.reset_auc()
# for safety measures
agent.reset_auc()
world.reset()
return report


Expand Down
Loading

0 comments on commit 36004c9

Please sign in to comment.