Skip to content

Commit

Permalink
CosmosQA (#952)
Browse files Browse the repository at this point in the history
* misc run scripts

* cosmosqa

* cosmosqa

* cosmosqa

* cosmosqa run

* cleaned up repo

* cleaned up repo

* reformatted
  • Loading branch information
phu-pmh authored and Yada Pruksachatkun committed Nov 9, 2019
1 parent d769338 commit 8af068d
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions jiant/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,101 @@ def get_metrics(self, reset=False):
"""Get metrics specific to the task"""
acc = self.scorer1.get_metric(reset)
return {"accuracy": acc}


@register_task("cosmosqa", rel_path="cosmosqa/")
class CosmosQATask(MultipleChoiceTask):
""" Task class for CosmosQA Task.
adaptation of preprocessing from
https://github.com/wilburOne/cosmosqa """

def __init__(self, path, max_seq_len, name, **kw):
super().__init__(name, **kw)
self.path = path
self.max_seq_len = max_seq_len

self.train_data_text = None
self.val_data_text = None
self.test_data_text = None

self.scorer1 = CategoricalAccuracy()
self.scorers = [self.scorer1]
self.val_metric = "%s_accuracy" % name
self.val_metric_decreases = False
self.n_choices = 4

def load_data(self):
""" Process the dataset located at path. """
self.train_data_text = self._load_csv(os.path.join(self.path, "train.csv"))
self.val_data_text = self._load_csv(os.path.join(self.path, "valid.csv"))
self.test_data_text = self._load_csv(os.path.join(self.path, "test_no_label.csv"))
self.sentences = (
self.train_data_text[0]
+ self.val_data_text[0]
+ [choice for choices in self.train_data_text[1] for choice in choices]
+ [choice for choices in self.val_data_text[1] for choice in choices]
)
log.info("\tFinished loading CosmosQA data.")

def _load_csv(self, input_file):
import csv

with open(input_file, "r") as csv_file:
reader = csv.DictReader(csv_file)
records = [record for record in reader]

contexts, choices, targs, id_str = [], [], [], []
for record in records:
question = record["question"]

ans_choices = [record["answer" + str(i)] for i in range(self.n_choices)]
qa_tok_choices = [
tokenize_and_truncate(
self._tokenizer_name, question + " " + ans_choices[i], self.max_seq_len
)
for i in range(len(ans_choices))
]
max_ans_len = max([len(tok) for tok in qa_tok_choices])
context = tokenize_and_truncate(
self._tokenizer_name, record["context"], self.max_seq_len - max_ans_len
)
targ = int(record["label"]) if "label" in record else 0
idx = record["id"]
contexts.append(context)
choices.append(qa_tok_choices)
targs.append(targ)
id_str.append(idx)
return [contexts, choices, targs, id_str]

def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
""" Process split text into a list of AllenNLP Instances. """

def _make_instance(context, choices, label, id_str):
d = {}
d["context_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["context"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(context), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
model_preprocessing_interface.boundary_token_fn(context, choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
d["choice%d" % choice_idx] = sentence_to_text_field(inp, indexers)
d["choice%d_str" % choice_idx] = MetadataField(" ".join(choice))
d["label"] = LabelField(label, label_namespace="labels", skip_indexing=True)
d["id_str"] = MetadataField(id_str)
return Instance(d)

split = list(split)
instances = map(_make_instance, *split)
return instances

def get_metrics(self, reset=False):
"""Get metrics specific to the task"""
acc = self.scorer1.get_metric(reset)
return {"accuracy": acc}

0 comments on commit 8af068d

Please sign in to comment.