diff --git a/parlai/tasks/md_gender/yelp.py b/parlai/tasks/md_gender/yelp.py index 33fe710a181..6c2ec8e9274 100644 --- a/parlai/tasks/md_gender/yelp.py +++ b/parlai/tasks/md_gender/yelp.py @@ -189,9 +189,7 @@ def _setup_data(self, opt): extra_data = [] if self.add_unknown_classes: # load about data (unknown but inferred) - extra_data = gend_utils.get_inferred_about_data( - self.opt['task'], self.opt['datatype'] - ) + extra_data = gend_utils.get_inferred_about_data(self.opt['task'], self.opt) # now create partner/TO data: true neutral for ex in data: diff --git a/projects/md_gender/bert_ranker_classifier/agents.py b/projects/md_gender/bert_ranker_classifier/agents.py index 3a1fe3b07b3..e00ce532b6e 100644 --- a/projects/md_gender/bert_ranker_classifier/agents.py +++ b/projects/md_gender/bert_ranker_classifier/agents.py @@ -17,28 +17,18 @@ class BertRankerClassifierAgent(ClassificationMixin, BiEncoderRankerAgent): Bert BiEncoder that computes classification metrics (F1, precision, recall) """ - def get_labels_field(self, observations): - if 'labels' in observations[0]: - labels_field = 'labels' - elif 'eval_labels' in observations[0]: - labels_field = 'eval_labels' - else: - labels_field = None - return labels_field - def train_step(self, batch): output = super().train_step(batch) preds = output['text'] - labels_field = self.get_labels_field(batch['observations']) - labels_lst = self._get_labels(batch['observations'], labels_field) - self._update_confusion_matrix(preds, labels_lst) + self._update_confusion_matrix(preds, batch.labels) return output def eval_step(self, batch): + if batch.text_vec is None: + return output = super().eval_step(batch) preds = output['text'] - labels_field = self.get_labels_field(batch['observations']) - if labels_field is not None: - labels_lst = self._get_labels(batch['observations'], labels_field) - self._update_confusion_matrix(preds, labels_lst) + labels = batch.labels + if preds is not None and labels is not None: + self._update_confusion_matrix(preds, labels) return output