From 175bc03fecc7290685f9449e421873852a416025 Mon Sep 17 00:00:00 2001 From: arendu Date: Wed, 16 Jun 2021 13:21:34 -0700 Subject: [PATCH] updated weighted_f1 to not assume binary classification --- parlai/core/torch_classifier_agent.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/parlai/core/torch_classifier_agent.py b/parlai/core/torch_classifier_agent.py index 410b5e30eac..4c08dfc0ea9 100644 --- a/parlai/core/torch_classifier_agent.py +++ b/parlai/core/torch_classifier_agent.py @@ -199,11 +199,8 @@ def value(self) -> float: values = list(self._values.values()) if len(values) == 0: return weighted_f1 - total_examples = ( - values[0]._true_positives - + values[0]._true_negatives - + values[0]._false_positives - + values[0]._false_negatives + total_examples = sum( + [each._true_positives + each._false_negatives for each in values] ) for each in values: actual_positive = each._true_positives + each._false_negatives