Skip to content

Commit

Permalink
Add MCTest and MCTACO (#1197)
Browse files Browse the repository at this point in the history
* add mctest and mctaco

* Update mctaco.py

* add task to suppported tasks
  • Loading branch information
HaokunLiu committed Oct 29, 2020
1 parent 76e2826 commit c00360f
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 1 deletion.
2 changes: 2 additions & 0 deletions guides/tasks/supported_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

| Name | `task_name` | `jiant` | Downloader | `jiant_task_name` | Misc |
|---|---|:---:|:---:|---|---|
| MCTACO | mctaco || | mctaco | |
| MCTest | mctest160 or mctest500 || | mctest | |
| [Argument Reasoning Comprehension](https://arxiv.org/abs/1708.01425) | arct || | arct | [Github](https://github.com/UKPLab/argument-reasoning-comprehension-task) |
| Abductive NLI | abductive_nli ||| abductive_nli | |
| SuperGLUE Winogender Diagnostic | superglue_axg ||| superglue_axg | SuperGLUE |
Expand Down
47 changes: 46 additions & 1 deletion jiant/tasks/evaluate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def update(self, batch_logits, batch_loss, batch, batch_metadata):
self.logits_list.append(batch_logits)
batch_guid = batch_metadata.get("guid")
if batch_guid is not None:
self.guid_list.append(batch_guid)
self.guid_list.extend(batch_guid)

def get_guids(self):
if self.guid_list:
Expand Down Expand Up @@ -261,6 +261,48 @@ def compute_metrics_from_preds_and_labels(cls, preds, labels):
return Metrics(major=acc, minor={"acc": acc})


class MCTACOEvaluationScheme(BaseLogitsEvaluationScheme):
@classmethod
def get_preds_from_accumulator(self, task, accumulator):
logits = accumulator.get_accumulated()
pred = np.argmax(logits, axis=1)
guid = accumulator.guid_list
return guid, pred

@classmethod
def compute_metrics_from_accumulator(self, task, accumulator, tokenizer, labels) -> Metrics:
guid, pred = self.get_preds_from_accumulator(task=task, accumulator=accumulator)
em_ls = []
f1_ls = []
label_pred_by_question = {}

for one_guid, one_pred, one_label in zip(guid, pred, labels):
split, question_id, example_id = one_guid.split("-")
if question_id not in label_pred_by_question:
label_pred_by_question[question_id] = [], []
label_pred_by_question[question_id][0].append(one_label)
label_pred_by_question[question_id][1].append(one_pred)

em_ls = [
float(group_label == group_pred)
for group_label, group_pred in label_pred_by_question.values()
]
f1_ls = [
f1_score(y_true=group_label, y_pred=group_pred)
for group_label, group_pred in label_pred_by_question.values()
]

em = sum(em_ls) / len(em_ls)
f1 = sum(f1_ls) / len(f1_ls)
minor = {
"em": em,
"f1": f1,
"f1_em": (f1 + em) / 2,
}
metrics = Metrics(major=minor["f1_em"], minor=minor,)
return metrics


class MultiLabelAccAndF1EvaluationScheme(BaseLogitsEvaluationScheme):
def get_labels_from_cache_and_examples(self, task, cache, examples):
return get_multi_label_ids_from_cache(cache=cache)
Expand Down Expand Up @@ -935,6 +977,8 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme:
),
):
return SimpleAccuracyEvaluationScheme()
elif isinstance(task, tasks.MCTACOTask):
return MCTACOEvaluationScheme()
elif isinstance(task, tasks.CCGTask):
return CCGEvaluationScheme()
elif isinstance(task, tasks.CommitmentBankTask):
Expand All @@ -953,6 +997,7 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme:
tasks.MutualTask,
tasks.MutualPlusTask,
tasks.SocialIQATask,
tasks.MCTestTask,
),
):
return MultipleChoiceAccuracyEvaluationScheme()
Expand Down
116 changes: 116 additions & 0 deletions jiant/tasks/lib/mctaco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np
import torch
from dataclasses import dataclass
from typing import List

from jiant.tasks.core import (
BaseExample,
BaseTokenizedExample,
BaseDataRow,
BatchMixin,
Task,
TaskTypes,
)
from jiant.tasks.lib.templates.shared import double_sentence_featurize, labels_to_bimap
from jiant.utils.python.io import read_file_lines


@dataclass
class Example(BaseExample):
guid: str
sentence_question: str
answer: str
label: str

def tokenize(self, tokenizer):
return TokenizedExample(
guid=self.guid,
sentence_question=tokenizer.tokenize(self.sentence_question),
answer=tokenizer.tokenize(self.answer),
label_id=MCTACOTask.LABEL_TO_ID[self.label],
)


@dataclass
class TokenizedExample(BaseTokenizedExample):
guid: str
sentence_question: List
answer: List
label_id: int

def featurize(self, tokenizer, feat_spec):
return double_sentence_featurize(
guid=self.guid,
input_tokens_a=self.sentence_question,
input_tokens_b=self.answer,
label_id=self.label_id,
tokenizer=tokenizer,
feat_spec=feat_spec,
data_row_class=DataRow,
)


@dataclass
class DataRow(BaseDataRow):
guid: str
input_ids: np.ndarray
input_mask: np.ndarray
segment_ids: np.ndarray
label_id: int
tokens: list


@dataclass
class Batch(BatchMixin):
input_ids: torch.LongTensor
input_mask: torch.LongTensor
segment_ids: torch.LongTensor
label_id: torch.LongTensor
tokens: list


class MCTACOTask(Task):
Example = Example
TokenizedExample = TokenizedExample
DataRow = DataRow
Batch = Batch

TASK_TYPE = TaskTypes.CLASSIFICATION
LABELS = ["yes", "no"]
LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS)

def get_train_examples(self):
return self._create_examples(
lines=read_file_lines(self.train_path, strip_lines=True), set_type="train"
)

def get_val_examples(self):
return self._create_examples(
lines=read_file_lines(self.val_path, strip_lines=True), set_type="val"
)

def get_test_examples(self):
return self._create_examples(
lines=read_file_lines(self.test_path, strip_lines=True), set_type="test"
)

@classmethod
def _create_examples(cls, lines, set_type):
# noinspection DuplicatedCode
examples = []
last_question = ""
question_count = -1
for (i, line) in enumerate(lines):
sentence, question, answer, label, category = line.split("\t")
if last_question != question:
question_count += 1
last_question = question
examples.append(
Example(
guid="%s-q%s-%s" % (set_type, question_count, i),
sentence_question=sentence + question,
answer=answer,
label=label if set_type != "test" else cls.LABELS[-1],
)
)
return examples
78 changes: 78 additions & 0 deletions jiant/tasks/lib/mctest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from dataclasses import dataclass

from jiant.tasks.lib.templates.shared import labels_to_bimap
from jiant.tasks.lib.templates import multiple_choice as mc_template
from jiant.utils.python.io import read_file_lines


@dataclass
class Example(mc_template.Example):
@property
def task(self):
return MCTestTask


@dataclass
class TokenizedExample(mc_template.TokenizedExample):
pass


@dataclass
class DataRow(mc_template.DataRow):
pass


@dataclass
class Batch(mc_template.Batch):
pass


class MCTestTask(mc_template.AbstractMultipleChoiceTask):
Example = Example
TokenizedExample = TokenizedExample
DataRow = DataRow
Batch = Batch

CHOICE_KEYS = ["A", "B", "C", "D"]
CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS)
NUM_CHOICES = len(CHOICE_KEYS)

def get_train_examples(self):
return self._create_examples(
lines=read_file_lines(self.train_path, strip_lines=True),
ans_lines=read_file_lines(self.path_dict["train_ans"], strip_lines=True),
set_type="train",
)

def get_val_examples(self):
return self._create_examples(
lines=read_file_lines(self.val_path, strip_lines=True),
ans_lines=read_file_lines(self.path_dict["val_ans"], strip_lines=True),
set_type="val",
)

def get_test_examples(self):
return self._create_examples(
lines=read_file_lines(self.test_path, strip_lines=True),
ans_lines=None,
set_type="test",
)

@classmethod
def _create_examples(cls, lines, ans_lines, set_type):
examples = []
if ans_lines is None:
ans_lines = ["\t".join([cls.CHOICE_KEYS[-1]] * 4) for line in lines]
for i, (line, ans) in enumerate(zip(lines, ans_lines)):
line = line.split("\t")
ans = ans.split("\t")
for j in range(4):
examples.append(
Example(
guid="%s-%s" % (set_type, i * 4 + j),
prompt=line[2].replace("\\newline", " ") + " " + line[3 + j * 5],
choice_list=line[4 + j * 5 : 8 + j * 5],
label=ans[j],
)
)
return examples
4 changes: 4 additions & 0 deletions jiant/tasks/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from jiant.tasks.lib.edge_probing.dpr import DprTask
from jiant.tasks.lib.glue_diagnostics import GlueDiagnosticsTask
from jiant.tasks.lib.hellaswag import HellaSwagTask
from jiant.tasks.lib.mctaco import MCTACOTask
from jiant.tasks.lib.mctest import MCTestTask
from jiant.tasks.lib.mlm_simple import MLMSimpleTask
from jiant.tasks.lib.mlm_premasked import MLMPremaskedTask
from jiant.tasks.lib.mlm_pretokenized import MLMPretokenizedTask
Expand Down Expand Up @@ -94,6 +96,8 @@
"dpr": DprTask,
"glue_diagnostics": GlueDiagnosticsTask,
"hellaswag": HellaSwagTask,
"mctaco": MCTACOTask,
"mctest": MCTestTask,
"mlm_simple": MLMSimpleTask,
"mlm_premasked": MLMPremaskedTask,
"mlm_pretokenized": MLMPretokenizedTask,
Expand Down

0 comments on commit c00360f

Please sign in to comment.