Skip to content

Commit

Permalink
Abductive NLI (aNLI) (#922)
Browse files Browse the repository at this point in the history
* anli

* anli fix

* Adding aNLI link, additional test/dev warning
  • Loading branch information
zphang authored Oct 7, 2019
1 parent 65a3d19 commit efd946d
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions jiant/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3023,3 +3023,117 @@ def count_examples(self, splits=["train", "val", "test"]):
for split in splits:
st = self.get_split_text(split)
self.example_counts[split] = len(st)


@register_task("anli", rel_path="aNLI")
class AlphaNLITask(MultipleChoiceTask):
"""
Task class for Abductive Natural Language Inference.
Paper: https://arxiv.org/abs/1908.05739
Website: https://leaderboard.allenai.org/anli/submissions/get-started
"""

def __init__(self, path, max_seq_len, name, **kw):
super().__init__(name, **kw)
self.path = path
self.max_seq_len = max_seq_len

self.train_data_text = None
self.val_data_text = None
self.test_data_text = None

self.scorer1 = CategoricalAccuracy()
self.scorers = [self.scorer1]
self.val_metric = "%s_accuracy" % name
self.val_metric_decreases = False
self.n_choices = 2

def load_data(self):
""" Process the dataset located at path. """

def _load_split(inputs_file, labels_file):
obs1, hyp1, hyp2, obs2 = [], [], [], []
with open(inputs_file, encoding="utf-8") as f:
for line in f:
row = json.loads(line)
obs1.append(
tokenize_and_truncate(self._tokenizer_name, row["obs1"], self.max_seq_len)
)
hyp1.append(
tokenize_and_truncate(self._tokenizer_name, row["hyp1"], self.max_seq_len)
)
hyp2.append(
tokenize_and_truncate(self._tokenizer_name, row["hyp2"], self.max_seq_len)
)
obs2.append(
tokenize_and_truncate(self._tokenizer_name, row["obs2"], self.max_seq_len)
)
with open(labels_file) as f:
labels = [int(i) - 1 for i in f.read().split()] # -1 to get {0, 1} labels
return [obs1, hyp1, hyp2, obs2, labels]

self.train_data_text = _load_split(
inputs_file=os.path.join(self.path, "train.jsonl"),
labels_file=os.path.join(self.path, "train-labels.lst"),
)
self.val_data_text = _load_split(
inputs_file=os.path.join(self.path, "dev.jsonl"),
labels_file=os.path.join(self.path, "dev-labels.lst"),
)

log.warning("aNLI has no public test set, so we reuse the dev set as a stand-in")
self.test_data_text = self.val_data_text
self.sentences = (
self.train_data_text[0]
+ self.train_data_text[1]
+ self.train_data_text[2]
+ self.train_data_text[3]
+ self.val_data_text[0]
+ self.val_data_text[1]
+ self.val_data_text[2]
+ self.val_data_text[3]
)
log.info("\tFinished loading aNLI data.")

def process_split(
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
""" Process split text into a list of AllenNLP Instances. """

def _make_instance(obs1, hyp1, hyp2, obs2, label, idx):
d = {}
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
# We're combining obs1 and obs2 in a potentially suboptimal way here
d["question"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(obs1 + obs2), indexers
)
d["question_str"] = MetadataField(" ".join(obs1 + obs2))
for hyp_idx, hyp in enumerate([hyp1, hyp2]):
d["choice%d" % hyp_idx] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(hyp), indexers
)
d["choice%d_str" % hyp_idx] = MetadataField(" ".join(hyp))
else:
for hyp_idx, hyp in enumerate([hyp1, hyp2]):
inp = (
model_preprocessing_interface.boundary_token_fn(obs1 + hyp, obs2)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(hyp)
)
d["choice%d" % hyp_idx] = sentence_to_text_field(inp, indexers)
d["choice%d_str" % hyp_idx] = MetadataField(" ".join(inp))
d["label"] = LabelField(label, label_namespace="labels", skip_indexing=True)
d["idx"] = LabelField(idx, label_namespace="idxs_tags", skip_indexing=True)
return Instance(d)

split = list(split)
if len(split) < 6:
split.append(itertools.count())
instances = map(_make_instance, *split)
return instances

def get_metrics(self, reset=False):
"""Get metrics specific to the task"""
acc = self.scorer1.get_metric(reset)
return {"accuracy": acc}

0 comments on commit efd946d

Please sign in to comment.