Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Make MD Gender classifier compatible with Batch refactor #3533

Merged
merged 3 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions parlai/tasks/md_gender/yelp.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,7 @@ def _setup_data(self, opt):
extra_data = []
if self.add_unknown_classes:
# load about data (unknown but inferred)
extra_data = gend_utils.get_inferred_about_data(
self.opt['task'], self.opt['datatype']
)
extra_data = gend_utils.get_inferred_about_data(self.opt['task'], self.opt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this a bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I imagine that the function signature must have changed somewhere along the way?


# now create partner/TO data: true neutral
for ex in data:
Expand Down
22 changes: 6 additions & 16 deletions projects/md_gender/bert_ranker_classifier/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,18 @@ class BertRankerClassifierAgent(ClassificationMixin, BiEncoderRankerAgent):
Bert BiEncoder that computes classification metrics (F1, precision, recall)
"""

def get_labels_field(self, observations):
if 'labels' in observations[0]:
labels_field = 'labels'
elif 'eval_labels' in observations[0]:
labels_field = 'eval_labels'
else:
labels_field = None
return labels_field

def train_step(self, batch):
output = super().train_step(batch)
preds = output['text']
labels_field = self.get_labels_field(batch['observations'])
labels_lst = self._get_labels(batch['observations'], labels_field)
self._update_confusion_matrix(preds, labels_lst)
self._update_confusion_matrix(preds, batch.labels)
return output

def eval_step(self, batch):
if batch.text_vec is None:
return
output = super().eval_step(batch)
preds = output['text']
labels_field = self.get_labels_field(batch['observations'])
if labels_field is not None:
labels_lst = self._get_labels(batch['observations'], labels_field)
self._update_confusion_matrix(preds, labels_lst)
labels = batch.labels
if preds is not None and labels is not None:
self._update_confusion_matrix(preds, labels)
return output