Skip to content

Commit

Permalink
CommonsenseQA+hellaswag (#942)
Browse files Browse the repository at this point in the history
* add commonsenseqa task

* add hellaswag task

* dabug

* from #928

* add special tokens to CommensenseQA input

* format

* revert irrelevant change

* Typo fix

* delete

* rename stuff

* Update qa.py

* black
  • Loading branch information
HaokunLiu authored and Yada Pruksachatkun committed Oct 26, 2019
1 parent 7390fae commit 0c57c96
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 25 deletions.
1 change: 0 additions & 1 deletion jiant/modules/attn_pair_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
initializer(self)

def forward(self, s1, s2, s1_mask, s2_mask): # pylint: disable=arguments-differ
""" """
# Similarity matrix
# Shape: (batch_size, s2_length, s1_length)
similarity_mat = self._matrix_attention(s2, s1)
Expand Down
105 changes: 99 additions & 6 deletions jiant/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
from typing import Iterable, Sequence, Type

import torch
from allennlp.training.metrics import Average, F1Measure
import logging as log
from allennlp.training.metrics import Average, F1Measure, CategoricalAccuracy
from allennlp.data.fields import LabelField, MetadataField
from allennlp.data import Instance
from jiant.allennlp_mods.numeric_field import NumericField
from jiant.allennlp_mods.span_metrics import SpanF1Measure

from jiant.utils.data_loaders import tokenize_and_truncate

from jiant.tasks.tasks import Task, SpanPredictionTask
from jiant.tasks.tasks import Task, SpanPredictionTask, MultipleChoiceTask
from jiant.tasks.tasks import sentence_to_text_field
from jiant.tasks.registry import register_task
from ..utils.retokenize import get_aligner_fn
Expand Down Expand Up @@ -80,7 +81,6 @@ class MultiRCTask(Task):
See paper at https://cogcomp.org/multirc/ """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super().__init__(name, **kw)
self.scorer1 = F1Measure(positive_label=1)
self.scorer2 = Average() # to delete
Expand Down Expand Up @@ -247,7 +247,6 @@ class ReCoRDTask(Task):
See paper at https://sheng-z.github.io/ReCoRD-explorer """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super().__init__(name, **kw)
self.val_metric = "%s_avg" % self.name
self.val_metric_decreases = False
Expand Down Expand Up @@ -500,8 +499,6 @@ def get_sentences(self) -> Iterable[Sequence[str]]:
def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers

def _make_instance(sentence_tokens, question_tokens, answer_span, idx):
d = dict()

Expand Down Expand Up @@ -608,3 +605,99 @@ def preprocess_qasrl_datum(cls, datum):
for verb_idx, verb_entry in datum["verbEntries"].items()
],
}


@register_task("commonsenseqa", rel_path="CommonsenseQA/")
@register_task("commonsenseqa-easy", rel_path="CommonsenseQA/", easy=True)
class CommonsenseQATask(MultipleChoiceTask):
""" Task class for CommonsenseQA Task. """

def __init__(self, path, max_seq_len, name, easy=False, **kw):
super().__init__(name, **kw)
self.path = path
self.max_seq_len = max_seq_len

self.easy = easy
self.train_data_text = None
self.val_data_text = None
self.test_data_text = None

self.scorer1 = CategoricalAccuracy()
self.scorers = [self.scorer1]
self.val_metric = "%s_accuracy" % name
self.val_metric_decreases = False
self.n_choices = 5
self.label2choice_idx = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4}
self.choice_idx2label = ["A", "B", "C", "D", "E"]

def load_data(self):
""" Process the dataset located at path. """

def _load_split(data_file):
questions, choices, targs, id_str = [], [], [], []
data = [json.loads(l) for l in open(data_file, encoding="utf-8")]
for example in data:
question = tokenize_and_truncate(
self._tokenizer_name, "Q:" + example["question"]["stem"], self.max_seq_len
)
choices_dict = {
a_choice["label"]: tokenize_and_truncate(
self._tokenizer_name, "A:" + a_choice["text"], self.max_seq_len
)
for a_choice in example["question"]["choices"]
}
multiple_choices = [choices_dict[label] for label in self.choice_idx2label]
targ = self.label2choice_idx[example["answerKey"]] if "answerKey" in example else 0
id_str = example["id"]
questions.append(question)
choices.append(multiple_choices)
targs.append(targ)
id_str.append(id_str)
return [questions, choices, targs, id_str]

train_file = "train_rand_split_EASY.jsonl" if self.easy else "train_rand_split.jsonl"
val_file = "dev_rand_split_EASY.jsonl" if self.easy else "dev_rand_split.jsonl"
test_file = "test_rand_split_no_answers.jsonl"
self.train_data_text = _load_split(os.path.join(self.path, train_file))
self.val_data_text = _load_split(os.path.join(self.path, val_file))
self.test_data_text = _load_split(os.path.join(self.path, test_file))
self.sentences = (
self.train_data_text[0]
+ self.val_data_text[0]
+ [choice for choices in self.train_data_text[1] for choice in choices]
+ [choice for choices in self.val_data_text[1] for choice in choices]
)
log.info("\tFinished loading CommonsenseQA data.")

def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
""" Process split text into a list of AllenNLP Instances. """

def _make_instance(question, choices, label, id_str):
d = {}
d["question_str"] = MetadataField(" ".join(question))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(question), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
model_preprocessing_interface.boundary_token_fn(question, choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
d["choice%d" % choice_idx] = sentence_to_text_field(inp, indexers)
d["choice%d_str" % choice_idx] = MetadataField(" ".join(choice))
d["label"] = LabelField(label, label_namespace="labels", skip_indexing=True)
d["id_str"] = MetadataField(id_str)
return Instance(d)

split = list(split)
instances = map(_make_instance, *split)
return instances

def get_metrics(self, reset=False):
"""Get metrics specific to the task"""
acc = self.scorer1.get_metric(reset)
return {"accuracy": acc}
7 changes: 0 additions & 7 deletions jiant/tasks/senteval_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class SEProbingSentenceLengthTask(SingleClassificationTask):
""" Sentence length task """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(SEProbingSentenceLengthTask, self).__init__(name, n_classes=7, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -66,7 +65,6 @@ class SEProbingBigramShiftTask(SingleClassificationTask):
""" Bigram shift task """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(SEProbingBigramShiftTask, self).__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -117,7 +115,6 @@ class SEProbingPastPresentTask(SingleClassificationTask):
""" Past Present Task """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(SEProbingPastPresentTask, self).__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -168,7 +165,6 @@ class SEProbingOddManOutTask(SingleClassificationTask):
""" Odd man out task """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(SEProbingOddManOutTask, self).__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -219,7 +215,6 @@ class SEProbingCoordinationInversionTask(SingleClassificationTask):
""" Coordination Inversion task. """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(SEProbingCoordinationInversionTask, self).__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -311,7 +306,6 @@ class SEProbingTreeDepthTask(SingleClassificationTask):
""" Tree Depth Task """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(SEProbingTreeDepthTask, self).__init__(name, n_classes=8, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -354,7 +348,6 @@ class SEProbingTopConstituentsTask(SingleClassificationTask):
""" Top Constituents task """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(SEProbingTopConstituentsTask, self).__init__(name, n_classes=20, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down
1 change: 0 additions & 1 deletion jiant/tasks/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class Seq2SeqTask(SequenceGenerationTask):
"""Sequence-to-sequence Task"""

def __init__(self, path, max_seq_len, max_targ_v_size, name, **kw):
""" """
super().__init__(name, **kw)
self.scorer2 = BooleanAccuracy()
self.scorers.append(self.scorer2)
Expand Down
103 changes: 93 additions & 10 deletions jiant/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@ class SSTTask(SingleClassificationTask):
""" Task class for Stanford Sentiment Treebank. """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(SSTTask, self).__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -549,7 +548,6 @@ class CoLANPITask(SingleClassificationTask):
Note: Used for an NYU seminar, data not yet public"""

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(CoLANPITask, self).__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -610,7 +608,6 @@ class CoLATask(SingleClassificationTask):
"""Class for Warstdadt acceptability task"""

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(CoLATask, self).__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -964,7 +961,6 @@ class STSBTask(PairRegressionTask):
""" Task class for Sentence Textual Similarity Benchmark. """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super(STSBTask, self).__init__(name, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -1724,7 +1720,6 @@ class RTETask(PairClassificationTask):
""" Task class for Recognizing Textual Entailment 1, 2, 3, 5 """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super().__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -2771,7 +2766,7 @@ def _load_split(data_file):
def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
""" Process split text into a list of AlleNNLP Instances. """
""" Process split text into a list of AllenNLP Instances. """

def _make_instance(context, choices, question, label, idx):
d = {}
Expand Down Expand Up @@ -2815,7 +2810,6 @@ class COPATask(MultipleChoiceTask):
""" Task class for Choice of Plausible Alternatives Task. """

def __init__(self, path, max_seq_len, name, **kw):
""" """
super().__init__(name, **kw)
self.path = path
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -2875,7 +2869,7 @@ def _load_split(data_file):
def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
""" Process split text into a list of AlleNNLP Instances. """
""" Process split text into a list of AllenNLP Instances. """

def _make_instance(context, choices, question, label, idx):
d = {}
Expand Down Expand Up @@ -2961,7 +2955,7 @@ def _load_split(data_file):
def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
""" Process split text into a list of AlleNNLP Instances. """
""" Process split text into a list of AllenNLP Instances. """

def _make_instance(question, choices, label, idx):
d = {}
Expand Down Expand Up @@ -2994,6 +2988,95 @@ def get_metrics(self, reset=False):
return {"accuracy": acc}


@register_task("hellaswag", rel_path="HellaSwag/")
class HellaSwagTask(MultipleChoiceTask):
""" Task class for HellaSwag. """

def __init__(self, path, max_seq_len, name, **kw):
super().__init__(name, **kw)
self.path = path
self.max_seq_len = max_seq_len

self.train_data_text = None
self.val_data_text = None
self.test_data_text = None

self.scorer1 = CategoricalAccuracy()
self.scorers = [self.scorer1]
self.val_metric = "%s_accuracy" % name
self.val_metric_decreases = False
self.n_choices = 4

def load_data(self):
""" Process the dataset located at path. """

def _load_split(data_file):
questions, choicess, targs, idxs = [], [], [], []
data = [json.loads(l) for l in open(data_file, encoding="utf-8")]
for example in data:
sent1 = tokenize_and_truncate(
self._tokenizer_name, example["ctx_a"], self.max_seq_len
)
questions.append(sent1)
sent2_prefix = example["ctx_b"]
choices = [
tokenize_and_truncate(
self._tokenizer_name, sent2_prefix + " " + ending, self.max_seq_len
)
for ending in example["endings"]
]
choicess.append(choices)
targ = example["label"] if "label" in example else 0
idx = example["ind"]
targs.append(targ)
idxs.append(idx)
return [questions, choicess, targs, idxs]

self.train_data_text = _load_split(os.path.join(self.path, "hellaswag_train.jsonl"))
self.val_data_text = _load_split(os.path.join(self.path, "hellaswag_val.jsonl"))
self.test_data_text = _load_split(os.path.join(self.path, "hellaswag_test.jsonl"))
self.sentences = (
self.train_data_text[0]
+ self.val_data_text[0]
+ [choice for choices in self.train_data_text[1] for choice in choices]
+ [choice for choices in self.val_data_text[1] for choice in choices]
)
log.info("\tFinished loading HellaSwag data.")

def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
""" Process split text into a list of AllenNLP Instances. """

def _make_instance(question, choices, label, idx):
d = {}
d["question_str"] = MetadataField(" ".join(question))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(question), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
model_preprocessing_interface.boundary_token_fn(question, choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
d["choice%d" % choice_idx] = sentence_to_text_field(inp, indexers)
d["choice%d_str" % choice_idx] = MetadataField(" ".join(choice))
d["label"] = LabelField(label, label_namespace="labels", skip_indexing=True)
d["idx"] = LabelField(idx, label_namespace="idxs_tags", skip_indexing=True)
return Instance(d)

split = list(split)
instances = map(_make_instance, *split)
return instances

def get_metrics(self, reset=False):
"""Get metrics specific to the task"""
acc = self.scorer1.get_metric(reset)
return {"accuracy": acc}


@register_task("winograd-coreference", rel_path="WSC")
class WinogradCoreferenceTask(SpanClassificationTask):
def __init__(self, path, **kw):
Expand Down Expand Up @@ -3083,7 +3166,7 @@ def _load_jsonl(data_file):
def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
""" Process split text into a list of AlleNNLP Instances. """
""" Process split text into a list of AllenNLP Instances. """

def _make_instance(d, idx):
new_d = {}
Expand Down

0 comments on commit 0c57c96

Please sign in to comment.