diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index a46aa377e2b..f4413f733ac 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -52,6 +52,10 @@ class MetricDisplayData(NamedTuple): METRICS_DISPLAY_DATA = { "accuracy": MetricDisplayData("Accuracy", "Exact match text accuracy"), + 'auc': MetricDisplayData( + 'AUC', + "Area Under the Receiver Operating Characteristic Curve (true positive rate vs false positive rate curve)", + ), "bleu-4": MetricDisplayData( "BLEU-4", "BLEU-4 of the generation, under a standardized (model-independent) tokenizer", diff --git a/parlai/core/torch_classifier_agent.py b/parlai/core/torch_classifier_agent.py index 410b5e30eac..3c881df8781 100644 --- a/parlai/core/torch_classifier_agent.py +++ b/parlai/core/torch_classifier_agent.py @@ -15,14 +15,16 @@ from parlai.core.torch_agent import TorchAgent, Output from parlai.utils.misc import round_sigfigs, warn_once from parlai.core.metrics import Metric, AverageMetric -from typing import List, Optional, Tuple, Dict +from typing import List, Optional, Tuple, Dict, Union +from typing import Counter from parlai.utils.typing import TScalar from parlai.utils.io import PathManager -import parlai.utils.logging as logging - +from sklearn.metrics import auc -import torch +import parlai.utils.logging as logging import torch.nn.functional as F +import torch +import math class ConfusionMatrixMetric(Metric): @@ -168,6 +170,158 @@ def value(self) -> float: return numer / denom +class AUCMetrics(Metric): + """ + Computes Area Under ROC Curve (AUC) metrics. + + Does so by keeping track of positives' and negatives' probability score counts in + Counters or dictionaries. Note the introduction of `max_bucket_dec_places`; this + integer number determines the number of digits to save for the probability scores. A + higher `max_bucket_dec_places` will a more accurate estimate of the exact AUC + metric, but may also use more space. + + NOTE: currently only used for classifiers in the `eval_model` script; to use, + add the argument `-auc ` when calling `eval_model` script + """ + + @property + def macro_average(self) -> bool: + """ + Indicates whether this metric should be macro-averaged when globally reported. + """ + return False + + @classmethod + def raw_data_to_auc( + cls, + true_labels: List[Union[int, str]], + pos_probs: List[float], + class_name, + max_bucket_dec_places: int = 3, + ): + auc_object = cls(class_name, max_bucket_dec_places=max_bucket_dec_places) + auc_object.update_raw( + true_labels=true_labels, pos_probs=pos_probs, class_name=class_name + ) + return auc_object + + def __init__( + self, + class_name: Union[int, str], + max_bucket_dec_places: int = 3, + pos_dict: Optional[Counter[float]] = None, + neg_dict: Optional[Counter[float]] = None, + ): + # `_pos_dict` keeps track of the probabilities of the positive class + self._pos_dict = pos_dict if pos_dict else Counter() + # `_neg_dict` keeps track of the probabilities of the negative class + self._neg_dict = neg_dict if neg_dict else Counter() + self._class_name = class_name + self._max_bucket_dec_places = max_bucket_dec_places + + def update_raw( + self, true_labels: List[Union[int, str]], pos_probs: List[float], class_name + ): + """ + given the true/golden labels and the probabilities of the positive class, we + will update our bucket dictionaries of positive and negatives (based on the + class_name); `max_bucket_dec_places` is also used here to round the + probabilities and possibly. + """ + assert self._class_name == class_name + assert len(true_labels) == len(pos_probs) + + TO_INT_FACTOR = 10 ** self._max_bucket_dec_places + # add the upper and lower bound of the values + for label, prob in zip(true_labels, pos_probs): + # calculate the upper and lower bound of the values + prob_down = math.floor(prob * TO_INT_FACTOR) / TO_INT_FACTOR + if label == self._class_name: + interested_dict = self._pos_dict + else: + interested_dict = self._neg_dict + if interested_dict.get(prob_down): + interested_dict[prob_down] += 1 + else: + interested_dict[prob_down] = 1 + + def __add__(self, other: Optional['AUCMetrics']) -> 'AUCMetrics': + if other is None: + return self + assert isinstance(other, AUCMetrics) + assert other._class_name == self._class_name + all_pos_dict = self._pos_dict + other._pos_dict + all_neg_dict = self._neg_dict + other._neg_dict + + return AUCMetrics( + self._class_name, pos_dict=all_pos_dict, neg_dict=all_neg_dict + ) + + def _calc_fp_tp(self) -> List[Tuple[int]]: + """ + Calculates the False Positives and True positives; returned as a list of pairs: + + `[(fp, tp)]` + """ + all_thresholds = sorted( + set(list(self._pos_dict.keys()) + list(self._neg_dict.keys())) + ) + # sorted in ascending order, + # so adding a upper bound so that its tp, fp is (0, 0) + all_thresholds.append(all_thresholds[-1] + 1) + L = len(all_thresholds) + # false positives, true positives + fp_tp = [(0, 0)] + + # the biggest one is always (0,0), so skip that one + for i in range(L - 2, -1, -1): + fp, tp = fp_tp[-1] + thres = all_thresholds[i] + # false positives + fp += self._neg_dict.get(thres, 0) + # true positives + tp += self._pos_dict.get(thres, 0) + fp_tp.append((fp, tp)) + return fp_tp + + def _calc_fpr_tpr(self) -> Tuple[Union[List[int], int]]: + """ + Calculates the false positive rates and true positive rates Also returns the + total number of positives and negatives; returned as a list of pairs and two + integers: + + `([(fpr, tpr)], positives, negatives)`; note that if the total + negatives/positives is 0, then will return 0 for either fpr/tpr instead of + raising an error + """ + _tot_pos = sum(self._pos_dict.values()) + _tot_neg = sum(self._neg_dict.values()) + fp_tp = self._calc_fp_tp() + fps, tps = list(zip(*fp_tp)) + if _tot_neg == 0: + fpr = [0] * len(fps) + else: + fpr = [fp / _tot_neg for fp in fps] + + if _tot_pos == 0: + tpr = [0] * len(tps) + else: + tpr = [tp / _tot_pos for tp in tps] + + return (list(zip(fpr, tpr)), _tot_pos, _tot_neg) + + def value(self) -> float: + fpr_tpr, _tot_pos, _tot_neg = self._calc_fpr_tpr() + + if _tot_pos == 0 and _tot_neg == 0: + return 0 + + # auc needs x-axis to be sorted + fpr_tpr.sort() + fpr, tpr = list(zip(*fpr_tpr)) + return auc(fpr, tpr) + + class WeightedF1Metric(Metric): """ Class that represents the weighted f1 from ClassificationF1Metric. @@ -344,6 +498,26 @@ def __init__(self, opt: Opt, shared=None): else: self.threshold = None + # set up calculating auc + self.calc_auc = opt.get('area_under_curve_digits', -1) > 0 + + if self.calc_auc: + self.auc_bucket_decimal_size = opt.get('area_under_curve_digits') + if opt.get('area_under_curve_class') is None: + # self.auc_class_ind + interested_classes = self.class_list + else: + interested_classes = opt.get('area_under_curve_class') + try: + self.auc_class_indices = [ + self.class_dict[class_name] for class_name in interested_classes + ] + except Exception: + raise RuntimeError( + f'The inputted classes for auc were probably invalid.\n Current class names: {self.class_list} \n Names of AUC classes passed in: {interested_classes}' + ) + self.reset_auc() + # set up model and optimizers states = {} if shared: @@ -462,6 +636,12 @@ def _format_interactive_output(self, probs, prediction_id): ) return preds + def _update_aucs(self, batch, probs): + probs_arr = probs.detach().cpu().numpy() + for index, curr_auc in zip(self.auc_class_indices, self.aucs): + class_probs = probs_arr[:, index] + curr_auc.update_raw(batch.labels, class_probs, self.class_list[index]) + def train_step(self, batch): """ Train on a single batch of examples. @@ -497,6 +677,10 @@ def eval_step(self, batch): self.model.eval() scores = self.score(batch) probs = F.softmax(scores, dim=1) + + if self.calc_auc: + self._update_aucs(batch, probs) + if self.threshold is None: _, prediction_id = torch.max(probs.cpu(), 1) else: @@ -531,3 +715,13 @@ def score(self, batch): class. """ raise NotImplementedError('Abstract class: user must implement score()') + + def reset_auc(self): + if self.calc_auc: + self.aucs = [ + AUCMetrics( + class_name=self.class_list[index], + max_bucket_dec_places=self.auc_bucket_decimal_size, + ) + for index in self.auc_class_indices + ] diff --git a/parlai/scripts/eval_model.py b/parlai/scripts/eval_model.py index 289dc636372..c09def3263d 100644 --- a/parlai/scripts/eval_model.py +++ b/parlai/scripts/eval_model.py @@ -42,6 +42,9 @@ get_rank, ) +# the index to access classifier agent's output in the world +CLASSIFIER_AGENT = 1 + def setup_args(parser=None): if parser is None: @@ -69,6 +72,21 @@ def setup_args(parser=None): default='conversations', choices=['conversations', 'parlai'], ) + parser.add_argument( + '--area-under-curve-digits', + '-auc', + type=int, + default=-1, + help='a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric', + ) + parser.add_argument( + '--area-under-curve-class', + '-auclass', + type=str, + default=None, + nargs='*', + help='the name(s) of the class to calculate the auc for', + ) parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10) @@ -168,8 +186,18 @@ def _eval_single_world(opt, agent, task): world_logger.write(outfile, world, file_format=opt['save_format']) report = aggregate_unnamed_reports(all_gather_list(world.report())) - world.reset() + if isinstance(world.agents, list) and len(world.agents) > 1: + classifier_agent = world.agents[CLASSIFIER_AGENT] + if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc: + for class_indices, curr_auc in zip( + classifier_agent.auc_class_indices, classifier_agent.aucs + ): + report[f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc + classifier_agent.reset_auc() + # for safety measures + agent.reset_auc() + world.reset() return report diff --git a/tests/test_metrics.py b/tests/test_metrics.py index c0818c18b37..563cae706c8 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -23,7 +23,11 @@ IntraDistinctMetric, FairseqBleuMetric, ) -from parlai.core.torch_classifier_agent import ConfusionMatrixMetric, WeightedF1Metric +from parlai.core.torch_classifier_agent import ( + ConfusionMatrixMetric, + WeightedF1Metric, + AUCMetrics, +) import parlai.utils.testing as testing_utils @@ -314,6 +318,327 @@ def test_micro_aggregation(self): assert agg['b/fixed'] == 4 assert 'b/global_avg' not in agg + def test_auc_metrics(self): + class_name = 'class_notok' + class_to_int = {'class_notok': 1, 'class_ok': 0} + decimal_place = 3 + # task 1; borrowing example from scikit learn + task1_probabilities = [0.1, 0.4, 0.35, 0.8] + task1_gold_labels = ['class_ok', 'class_ok', 'class_notok', 'class_notok'] + task1_pos_buckets = {0.35: 1, 0.8: 1} + task1_neg_buckets = {0.1: 1, 0.4: 1} + task1_exp_fp_tp = { + # thres: (False positives, True positives) + 0.1: (2, 2), + 0.35: (1, 2), + 0.4: (1, 1), + 0.8: (0, 1), + '_': (0, 0), + } + + # task 2; checking with an odd number + task2_probabilities = [0.05, 0.2, 0.6] + task2_gold_labels = ['class_ok', 'class_ok', 'class_notok'] + task2_pos_buckets = {0.6: 1} + task2_neg_buckets = {0.05: 1, 0.2: 1} + task2_exp_fp_tp = {0.05: (2, 1), 0.2: (1, 1), 0.6: (0, 1), 1.5: (0, 0)} + + # task 3: combining task 1 and task 2 + task3_probabilities = task1_probabilities + task2_probabilities + task3_gold_labels = task1_gold_labels + task2_gold_labels + task3_pos_buckets = {0.35: 1, 0.8: 1, 0.6: 1} + task3_neg_buckets = {0.1: 1, 0.4: 1, 0.05: 1, 0.2: 1} + task3_exp_fp_tp = { + # threshold: FP, TP + 0.05: (4, 3), + 0.1: (3, 3), + 0.2: (2, 3), + 0.35: (1, 3), + 0.4: (1, 2), + 0.6: (0, 2), + 0.8: (0, 1), + '_': (0, 0), + } + + # task 4: testing when there's ones in the same bucket + task4_probabilities = [0.1, 0.400001, 0.4, 0.359, 0.35, 0.900001, 0.9] + task4_gold_labels = [ + 'class_ok', + 'class_ok', + 'class_ok', + 'class_notok', + 'class_notok', + 'class_notok', + 'class_notok', + ] + task4_neg_buckets = {0.1: 1, 0.4: 2} + task4_pos_buckets = {0.35: 1, 0.359: 1, 0.9: 2} + task4_exp_fp_tp = { + # thres: (False positives, True positives) + 0.1: (3, 4), + 0.35: (2, 4), + 0.359: (2, 3), + 0.4: (2, 2), + 0.9: (0, 2), + '_': (0, 0), + } + + # task 5: testing when there's more difference in the bucket (similar to task 4), + # but testing to make sure the rounding/flooring is correct, and the edge cases 0.0, 1.0 + task5_probabilities = [0, 0.8, 0.4009, 0.400, 0.359, 0.35, 0.9999, 0.999, 1] + # 4 okay, 5 not okay + task5_gold_labels = [ + 'class_ok', + 'class_ok', + 'class_ok', + 'class_ok', + 'class_notok', + 'class_notok', + 'class_notok', + 'class_notok', + 'class_notok', + ] + task5_neg_buckets = {0: 1, 0.8: 1, 0.4: 2} + task5_pos_buckets = {0.35: 1, 0.359: 1, 0.999: 2, 1: 1} + task5_exp_fp_tp = { + # thres: (False positives, True positives) + 0: (4, 5), + 0.35: (3, 5), + 0.359: (3, 4), + 0.4: (3, 3), + 0.8: (1, 3), + 0.9: (0, 3), + 1.0: (0, 1), + '_': (0, 0), + } + + # task 6: combining task 4 + task 5 (combining with same keys) + task6_probabilities = task4_probabilities + task5_probabilities + task6_gold_labels = task4_gold_labels + task5_gold_labels + task6_neg_buckets = {0: 1, 0.8: 1, 0.4: 4, 0.1: 1} + task6_pos_buckets = {0.35: 2, 0.359: 2, 0.9: 2, 0.999: 2, 1: 1} + task6_exp_fp_tp = { + # threshold: FP, TP + 0: (7, 9), + 0.1: (6, 9), + 0.35: (5, 9), + 0.359: (5, 7), + 0.4: (5, 5), + 0.8: (1, 5), + 0.9: (0, 5), + 0.999: (0, 3), + 1: (0, 1), + '_': (0, 0), + } + + # run and check the TPs and FPs for singles + task1_result = AUCMetrics.raw_data_to_auc( + task1_gold_labels, + task1_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task2_result = AUCMetrics.raw_data_to_auc( + task2_gold_labels, + task2_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task3_result = AUCMetrics.raw_data_to_auc( + task3_gold_labels, + task3_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task4_result = AUCMetrics.raw_data_to_auc( + task4_gold_labels, + task4_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task5_result = AUCMetrics.raw_data_to_auc( + task5_gold_labels, + task5_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task6_result = AUCMetrics.raw_data_to_auc( + task6_gold_labels, + task6_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + + # check the buckets first + self.assertEqual(task1_result._pos_dict, task1_pos_buckets) + self.assertEqual(task1_result._neg_dict, task1_neg_buckets) + self.assertEqual(task2_result._pos_dict, task2_pos_buckets) + self.assertEqual(task2_result._neg_dict, task2_neg_buckets) + self.assertEqual(task3_result._pos_dict, task3_pos_buckets) + self.assertEqual(task3_result._neg_dict, task3_neg_buckets) + self.assertEqual(task4_result._pos_dict, task4_pos_buckets) + self.assertEqual(task4_result._neg_dict, task4_neg_buckets) + self.assertEqual(task5_result._pos_dict, task5_pos_buckets) + self.assertEqual(task5_result._neg_dict, task5_neg_buckets) + self.assertEqual(task6_result._pos_dict, task6_pos_buckets) + self.assertEqual(task6_result._neg_dict, task6_neg_buckets) + + # then check fp, tp + self.assertEqual(set(task1_result._calc_fp_tp()), set(task1_exp_fp_tp.values())) + self.assertEqual(set(task2_result._calc_fp_tp()), set(task2_exp_fp_tp.values())) + self.assertEqual(set(task3_result._calc_fp_tp()), set(task3_exp_fp_tp.values())) + self.assertEqual(set(task4_result._calc_fp_tp()), set(task4_exp_fp_tp.values())) + self.assertEqual(set(task5_result._calc_fp_tp()), set(task5_exp_fp_tp.values())) + self.assertEqual(set(task6_result._calc_fp_tp()), set(task6_exp_fp_tp.values())) + + # check that merging also produces the same results + task3_result = task1_result + task2_result + self.assertEqual(task3_result._pos_dict, task3_pos_buckets) + self.assertEqual(task3_result._neg_dict, task3_neg_buckets) + self.assertEqual(set(task3_result._calc_fp_tp()), set(task3_exp_fp_tp.values())) + + task6_result = task4_result + task5_result + self.assertEqual(task6_result._pos_dict, task6_pos_buckets) + self.assertEqual(task6_result._neg_dict, task6_neg_buckets) + self.assertEqual(set(task6_result._calc_fp_tp()), set(task6_exp_fp_tp.values())) + + # now actually testing the area under curve + from sklearn.metrics import roc_auc_score + + task1_labels_int = [ + class_to_int[gold_label] for gold_label in task1_gold_labels + ] + task2_labels_int = [ + class_to_int[gold_label] for gold_label in task2_gold_labels + ] + task3_labels_int = [ + class_to_int[gold_label] for gold_label in task3_gold_labels + ] + task4_labels_int = [ + class_to_int[gold_label] for gold_label in task4_gold_labels + ] + task5_labels_int = [ + class_to_int[gold_label] for gold_label in task5_gold_labels + ] + task6_labels_int = [ + class_to_int[gold_label] for gold_label in task6_gold_labels + ] + + self.assertAlmostEqual( + roc_auc_score(task1_labels_int, task1_probabilities), task1_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task2_labels_int, task2_probabilities), task2_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task3_labels_int, task3_probabilities), task3_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task4_labels_int, task4_probabilities), task4_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task5_labels_int, task5_probabilities), task5_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task6_labels_int, task6_probabilities), task6_result.value() + ) + + # last task: adding everything together; uses task 3 & 6 + # gonna just check roc scores + task_all_gold_labels = task3_gold_labels + task6_gold_labels + task_all_labels_int = task3_labels_int + task6_labels_int + task_all_probabilities = task3_probabilities + task6_probabilities + + task_all_result = task3_result + task6_result + + self.assertAlmostEqual( + roc_auc_score(task_all_labels_int, task_all_probabilities), + task_all_result.value(), + ) + + task_all_result2 = AUCMetrics.raw_data_to_auc( + task_all_gold_labels, task_all_probabilities, class_name + ) + self.assertAlmostEqual( + roc_auc_score(task_all_labels_int, task_all_probabilities), + task_all_result2.value(), + ) + + ### now reusing the tests for the other class, just checking rocs + ## for binary classes, they should be the same? + class_name = 'class_ok' + task1_probabilities = [1 - curr_prob for curr_prob in task1_probabilities] + task2_probabilities = [1 - curr_prob for curr_prob in task2_probabilities] + task3_probabilities = [1 - curr_prob for curr_prob in task3_probabilities] + task4_probabilities = [1 - curr_prob for curr_prob in task4_probabilities] + task5_probabilities = [1 - curr_prob for curr_prob in task5_probabilities] + task6_probabilities = [1 - curr_prob for curr_prob in task6_probabilities] + + task1_labels_int = [1 - curr for curr in task1_labels_int] + task2_labels_int = [1 - curr for curr in task2_labels_int] + task3_labels_int = [1 - curr for curr in task3_labels_int] + task4_labels_int = [1 - curr for curr in task4_labels_int] + task5_labels_int = [1 - curr for curr in task5_labels_int] + task6_labels_int = [1 - curr for curr in task6_labels_int] + + # get the results + task1_result = AUCMetrics.raw_data_to_auc( + task1_gold_labels, + task1_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task2_result = AUCMetrics.raw_data_to_auc( + task2_gold_labels, + task2_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task3_result = AUCMetrics.raw_data_to_auc( + task3_gold_labels, + task3_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task4_result = AUCMetrics.raw_data_to_auc( + task4_gold_labels, + task4_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task5_result = AUCMetrics.raw_data_to_auc( + task5_gold_labels, + task5_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + task6_result = AUCMetrics.raw_data_to_auc( + task6_gold_labels, + task6_probabilities, + class_name, + max_bucket_dec_places=decimal_place, + ) + + # check against roc_auc_score + self.assertAlmostEqual( + roc_auc_score(task1_labels_int, task1_probabilities), task1_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task2_labels_int, task2_probabilities), task2_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task3_labels_int, task3_probabilities), task3_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task4_labels_int, task4_probabilities), task4_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task5_labels_int, task5_probabilities), task5_result.value() + ) + self.assertAlmostEqual( + roc_auc_score(task6_labels_int, task6_probabilities), task6_result.value() + ) + def test_classifier_metrics(self): # We assume a batch of 16 samples, binary classification case, from 2 tasks. # task 1