Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

AUC metrics #3751

Merged
merged 40 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3bf960f
added rates
liliarose Jun 21, 2021
89d977d
added AUC and updated rates slightly (not sure if we'll be using them…
liliarose Jun 21, 2021
84fc014
linting
liliarose Jun 21, 2021
e5cb9c7
fixed a bug
liliarose Jun 21, 2021
db9f716
Still needs to be updated, but will deal with this later....
liliarose Jun 22, 2021
1ea17d9
only dependency on sklearn.metrics should be auc, but still need to b…
liliarose Jun 22, 2021
da844d1
cleaned up/deleted unnecessary statements
liliarose Jun 22, 2021
5408789
might have messed sth up
liliarose Jun 22, 2021
6c6755d
last commit before adding tests & changing the structure so that it i…
liliarose Jun 22, 2021
70bbd34
found a bug
liliarose Jun 22, 2021
3668d76
added 3 tests + a couple small changes to torch classifier agent (hav…
liliarose Jun 22, 2021
6ab2fdf
fixed a small thing that would affect testing metrics
liliarose Jun 22, 2021
4cd745d
fixed things so the tests pass... but also don't know why it likes th…
liliarose Jun 22, 2021
7e27792
passed all current auc tests.... but I will add more later for checking
liliarose Jun 22, 2021
43483dd
added a hopefully faster merge way
liliarose Jun 23, 2021
69a56bd
finished testing inside test_metrics... now gonna work on the actual …
liliarose Jun 23, 2021
44e59ed
changed some stuff so it only is run on eval_model, but there are som…
liliarose Jun 24, 2021
e4fb497
testing snapshot
liliarose Jun 24, 2021
ab25b30
removed extra prints
liliarose Jun 24, 2021
0588976
fixed something
liliarose Jun 24, 2021
afc720f
removed an extra commit
liliarose Jun 24, 2021
4a7beef
fixed another thing....
liliarose Jun 24, 2021
560edf7
another bug
liliarose Jun 25, 2021
e60e028
removed slots and added better class comment
liliarose Jun 28, 2021
d400258
linting
liliarose Jun 28, 2021
c91bba5
why was linting not added before...
liliarose Jun 28, 2021
5d2763f
final commit hopefully
liliarose Jun 30, 2021
25e7961
oops forgot to update the tests
liliarose Jun 30, 2021
acbd0e4
fixed some max_bucket_dec_places problems
liliarose Jul 1, 2021
40abafd
updated sort and counter typing
liliarose Jul 1, 2021
2375cc0
fixed another typing
liliarose Jul 1, 2021
9645728
fixed an import thing
liliarose Jul 1, 2021
3672c6d
fixed a small typing thing
liliarose Jul 1, 2021
5c8c724
added comments + changed -auc in eval_model
liliarose Jul 6, 2021
6e7f1af
added comment about classifier agent
liliarose Jul 6, 2021
22b4c01
linted
liliarose Jul 6, 2021
065c9e5
fixed optional
liliarose Jul 6, 2021
a3bb11f
added to documentation
liliarose Jul 6, 2021
60b7953
fixed comments + counter import
liliarose Jul 7, 2021
40750f6
fixed wording :)
liliarose Jul 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 152 additions & 4 deletions parlai/core/torch_classifier_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
from parlai.core.torch_agent import TorchAgent, Output
from parlai.utils.misc import round_sigfigs, warn_once
from parlai.core.metrics import Metric, AverageMetric
from typing import List, Optional, Tuple, Dict
from typing import List, Optional, Tuple, Dict, Union
from parlai.utils.typing import TScalar
from parlai.utils.io import PathManager
import parlai.utils.logging as logging

from sklearn.metrics import auc

import torch
import parlai.utils.logging as logging
import torch.nn.functional as F
from collections import Counter
liliarose marked this conversation as resolved.
Show resolved Hide resolved
import torch
import math


class ConfusionMatrixMetric(Metric):
Expand Down Expand Up @@ -168,6 +170,127 @@ def value(self) -> float:
return numer / denom


class AUCMetrics(Metric):
"""
Class that calculates the area under the roc curve from list of labels and its true
liliarose marked this conversation as resolved.
Show resolved Hide resolved
probabilities; expecting values to be (false positives, true positives)
"""

__slots__ = (
'_pos_dict',
'_tot_pos',
'_neg_dict',
'_tot_neg',
'_class_name',
'_max_bucket_dec_places',
liliarose marked this conversation as resolved.
Show resolved Hide resolved
)

@property
def macro_average(self) -> bool:
"""
Indicates whether this metric should be macro-averaged when globally reported.
"""
return False

@classmethod
def raw_data_to_auc(
cls,
true_labels: List[int],
pos_probs: List[float],
class_name,
max_bucket_dec_places: float = 5,
liliarose marked this conversation as resolved.
Show resolved Hide resolved
):
auc_object = cls(class_name, max_bucket_dec_places=max_bucket_dec_places)
auc_object.update_raw(
true_labels=true_labels, pos_probs=pos_probs, class_name=class_name
)
return auc_object

def __init__(
self,
class_name: Union[int, str],
pos_dict: Dict[float, int] = None,
liliarose marked this conversation as resolved.
Show resolved Hide resolved
neg_dict: Dict[float, int] = None,
max_bucket_dec_places: float = 3,
liliarose marked this conversation as resolved.
Show resolved Hide resolved
):
self._pos_dict = pos_dict if pos_dict else Counter()
self._neg_dict = neg_dict if neg_dict else Counter()
self._class_name = class_name
self._max_bucket_dec_places = max_bucket_dec_places

def update_raw(self, true_labels: List[int], pos_probs: List[float], class_name):
assert self._class_name == class_name
EricMichaelSmith marked this conversation as resolved.
Show resolved Hide resolved
assert len(true_labels) == len(pos_probs)

TO_INT_FACTOR = 10 ** self._max_bucket_dec_places
# add the upper and lower bound of the values
for label, prob in zip(true_labels, pos_probs):
# calculate the upper and lower bound of the values
prob_down = math.floor(prob * TO_INT_FACTOR) / TO_INT_FACTOR
if label == self._class_name:
interested_dict = self._pos_dict
else:
interested_dict = self._neg_dict
if interested_dict.get(prob_down):
interested_dict[prob_down] += 1
else:
interested_dict[prob_down] = 1

def __add__(self, other: Optional['AUCMetrics']) -> 'AUCMetrics':
if other is None:
return self
assert isinstance(other, AUCMetrics)
assert other._class_name == self._class_name
all_pos_dict = self._pos_dict + other._pos_dict
all_neg_dict = self._neg_dict + other._neg_dict

return AUCMetrics(
self._class_name, pos_dict=all_pos_dict, neg_dict=all_neg_dict
)

def _calc_fp_tp(self) -> List[Tuple[int]]:
liliarose marked this conversation as resolved.
Show resolved Hide resolved
all_thresholds = sorted(
set(list(self._pos_dict.keys()) + list(self._neg_dict.keys()))
)
# sorted in ascending order,
# so adding a upper bound so that its tp, fp is (0, 0)
all_thresholds.append(all_thresholds[-1] + 1)
L = len(all_thresholds)
# false positives, true positives
fp_tp = [(0, 0)]

# the biggest one is always (0,0), so skip that one
for i in range(L - 2, -1, -1):
fp, tp = fp_tp[-1]
thres = all_thresholds[i]
# false positives
fp += self._neg_dict.get(thres, 0)
# true positives
tp += self._pos_dict.get(thres, 0)
fp_tp.append((fp, tp))
return fp_tp

def value(self) -> float:
_tot_pos = sum(self._pos_dict.values())
_tot_neg = sum(self._neg_dict.values())
if _tot_pos == 0 and _tot_neg == 0:
return 0
fp_tp = self._calc_fp_tp()
fp_tp.sort(key=lambda x: x[0])
liliarose marked this conversation as resolved.
Show resolved Hide resolved
fps, tps = list(zip(*fp_tp))
if _tot_neg == 0:
fpr = [0] * len(fps)
else:
fpr = [fp / _tot_neg for fp in fps]

if _tot_pos == 0:
tpr = [0] * len(tps)
else:
tpr = [tp / _tot_pos for tp in tps]

return auc(tpr, fpr)


class WeightedF1Metric(Metric):
"""
Class that represents the weighted f1 from ClassificationF1Metric.
Expand Down Expand Up @@ -344,6 +467,16 @@ def __init__(self, opt: Opt, shared=None):
else:
self.threshold = None

# set up calculating auc, only used in binary classification
if len(self.class_list) == 2:
self.calc_auc = opt.get('area_under_curve', False)
else:
self.calc_auc = False

if self.calc_auc:
self.auc_class_ind = 0
self.auc = AUCMetrics(class_name=self.class_list[self.auc_class_ind])

# set up model and optimizers
states = {}
if shared:
Expand Down Expand Up @@ -462,6 +595,13 @@ def _format_interactive_output(self, probs, prediction_id):
)
return preds

def _update_aucs(self, batch, probs):
probs_arr = probs.detach().cpu().numpy()
class_probs = probs_arr[:, self.auc_class_ind]
self.auc.update_raw(
batch.labels, class_probs, self.class_list[self.auc_class_ind]
)

def train_step(self, batch):
"""
Train on a single batch of examples.
Expand Down Expand Up @@ -497,6 +637,10 @@ def eval_step(self, batch):
self.model.eval()
scores = self.score(batch)
probs = F.softmax(scores, dim=1)

if self.calc_auc:
self._update_aucs(batch, probs)

if self.threshold is None:
_, prediction_id = torch.max(probs.cpu(), 1)
else:
Expand Down Expand Up @@ -531,3 +675,7 @@ def score(self, batch):
class.
"""
raise NotImplementedError('Abstract class: user must implement score()')

def reset_auc(self):
if self.calc_auc:
self.auc = AUCMetrics(class_name=self.class_list[self.auc_class_ind])
19 changes: 18 additions & 1 deletion parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
get_rank,
)

CLASSIFIER_AGENT = 1
liliarose marked this conversation as resolved.
Show resolved Hide resolved


def setup_args(parser=None):
if parser is None:
Expand Down Expand Up @@ -69,6 +71,14 @@ def setup_args(parser=None):
default='conversations',
choices=['conversations', 'parlai'],
)
parser.add_argument(
'--area-under-curve',
liliarose marked this conversation as resolved.
Show resolved Hide resolved
'-auc',
type='bool',
default=False,
help='whether to also calculate the area under the roc curve; '
'only for binary classification',
)
parser.add_argument('-ne', '--num-examples', type=int, default=-1)
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
Expand Down Expand Up @@ -168,8 +178,15 @@ def _eval_single_world(opt, agent, task):
world_logger.write(outfile, world, file_format=opt['save_format'])

report = aggregate_unnamed_reports(all_gather_list(world.report()))
world.reset()

if isinstance(world.agents, list) and len(world.agents) > 1:
classifier_agent = world.agents[CLASSIFIER_AGENT]
if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc:
report['AUC'] = classifier_agent.auc
classifier_agent.reset_auc()
# for safety measures
agent.reset_auc()
world.reset()
return report


Expand Down
Loading