Skip to content

Commit

Permalink
Adds Fever NLI task and data downloader (#1215)
Browse files Browse the repository at this point in the history
* Implement task and add downloader for Fever NLI

Co-authored-by: jeswan <57466294+jeswan@users.noreply.github.com>
  • Loading branch information
angie-chen55 and jeswan authored Nov 5, 2020
1 parent 0cc8cbb commit e7eefc6
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 0 deletions.
1 change: 1 addition & 0 deletions guides/tasks/supported_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
| Cosmos QA | cosmosqa ||| cosmosqa | |
| EP-UD | dep || | dep | Edge-Probing |
| EP-DPR | dpr || | dpr | Edge-Probing |
| Fever NLI | fever_nli ||| fever_nli | |
| GLUE Diagnostic | glue_diagnostics ||| glue_diagnostics | GLUE |
| HellaSwag | hellaswag ||| hellaswag | |
| [MCScript2.0](https://arxiv.org/pdf/1905.09531.pdf) | mcscript || | mcscript | [data](https://my.hidrive.com/share/wdnind8pp5#$/) |
Expand Down
1 change: 1 addition & 0 deletions jiant/scripts/download_data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
}
OTHER_DOWNLOAD_TASKS = {
"abductive_nli",
"fever_nli",
"swag",
"qamr",
"qasrl",
Expand Down
65 changes: 65 additions & 0 deletions jiant/scripts/download_data/dl_datasets/files_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import os
import pandas as pd
import re
Expand Down Expand Up @@ -28,6 +29,10 @@ def download_task_data_and_write_config(task_name: str, task_data_path: str, tas
download_abductive_nli_data_and_write_config(
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
)
elif task_name == "fever_nli":
download_fever_nli_data_and_write_config(
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
)
elif task_name == "swag":
download_swag_data_and_write_config(
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
Expand Down Expand Up @@ -157,6 +162,66 @@ def download_abductive_nli_data_and_write_config(
)


def download_fever_nli_data_and_write_config(
task_name: str, task_data_path: str, task_config_path: str
):
os.makedirs(task_data_path, exist_ok=True)
download_utils.download_and_unzip(
("https://www.dropbox.com/s/hylbuaovqwo2zav/nli_fever.zip?dl=1"), task_data_path,
)
# Since the FEVER NLI dataset doesn't have labels for the dev set, we also download the original
# FEVER dev set and match example CIDs to obtain labels.
orig_dev_path = os.path.join(task_data_path, "fever-dev-temp.jsonl")
download_utils.download_file(
"https://s3-eu-west-1.amazonaws.com/fever.public/shared_task_dev.jsonl", orig_dev_path,
)
id_to_label = {}
for line in py_io.read_jsonl(orig_dev_path):
if "id" not in line:
logging.warning("FEVER dev dataset is missing ID.")
continue
if "label" not in line:
logging.warning("FEVER dev dataset is missing label.")
continue
id_to_label[line["id"]] = line["label"]
os.remove(orig_dev_path)

dev_path = os.path.join(task_data_path, "nli_fever", "dev_fitems.jsonl")
dev_examples = []
for line in py_io.read_jsonl(dev_path):
if "cid" not in line:
logging.warning("Data in {} is missing CID.".format(dev_path))
continue
if int(line["cid"]) not in id_to_label:
logging.warning("Could not match CID {} to dev data.".format(line["cid"]))
continue
dev_example = line
dev_example["label"] = id_to_label[int(line["cid"])]
dev_examples.append(dev_example)
py_io.write_jsonl(dev_examples, os.path.join(task_data_path, "val.jsonl"))
os.remove(dev_path)

for phase in ["train", "test"]:
os.rename(
os.path.join(task_data_path, "nli_fever", f"{phase}_fitems.jsonl"),
os.path.join(task_data_path, f"{phase}.jsonl"),
)
shutil.rmtree(os.path.join(task_data_path, "nli_fever"))

py_io.write_json(
data={
"task": task_name,
"paths": {
"train": os.path.join(task_data_path, "train.jsonl"),
"val": os.path.join(task_data_path, "val.jsonl"),
"test": os.path.join(task_data_path, "test.jsonl"),
},
"name": task_name,
},
path=task_config_path,
)


def download_swag_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str):
os.makedirs(task_data_path, exist_ok=True)
download_utils.download_and_unzip(
Expand Down
1 change: 1 addition & 0 deletions jiant/tasks/evaluate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme:
tasks.AcceptabilityDefinitenessTask,
tasks.BoolQTask,
tasks.CopaTask,
tasks.FeverNliTask,
tasks.MnliTask,
tasks.PawsXTask,
tasks.QnliTask,
Expand Down
104 changes: 104 additions & 0 deletions jiant/tasks/lib/fever_nli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import numpy as np
import torch
from dataclasses import dataclass
from typing import List

from jiant.tasks.core import (
BaseExample,
BaseTokenizedExample,
BaseDataRow,
BatchMixin,
Task,
TaskTypes,
)
from jiant.tasks.lib.templates.shared import double_sentence_featurize, labels_to_bimap
from jiant.utils.python.io import read_jsonl


@dataclass
class Example(BaseExample):
guid: str
premise: str
hypothesis: str
label: str

def tokenize(self, tokenizer):
return TokenizedExample(
guid=self.guid,
premise=tokenizer.tokenize(self.premise),
hypothesis=tokenizer.tokenize(self.hypothesis),
label_id=FeverNliTask.LABEL_TO_ID[self.label],
)


@dataclass
class TokenizedExample(BaseTokenizedExample):
guid: str
premise: List
hypothesis: List
label_id: int

def featurize(self, tokenizer, feat_spec):
return double_sentence_featurize(
guid=self.guid,
input_tokens_a=self.premise,
input_tokens_b=self.hypothesis,
label_id=self.label_id,
tokenizer=tokenizer,
feat_spec=feat_spec,
data_row_class=DataRow,
)


@dataclass
class DataRow(BaseDataRow):
guid: str
input_ids: np.ndarray
input_mask: np.ndarray
segment_ids: np.ndarray
label_id: int
tokens: list


@dataclass
class Batch(BatchMixin):
input_ids: torch.LongTensor
input_mask: torch.LongTensor
segment_ids: torch.LongTensor
label_id: torch.LongTensor
tokens: list


class FeverNliTask(Task):
Example = Example
TokenizedExample = Example
DataRow = DataRow
Batch = Batch

TASK_TYPE = TaskTypes.CLASSIFICATION
LABELS = ["REFUTES", "SUPPORTS", "NOT ENOUGH INFO"]
LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS)

def get_train_examples(self):
return self._create_examples(lines=read_jsonl(self.train_path), set_type="train")

def get_val_examples(self):
return self._create_examples(lines=read_jsonl(self.val_path), set_type="val")

def get_test_examples(self):
return self._create_examples(lines=read_jsonl(self.test_path), set_type="test")

@classmethod
def _create_examples(cls, lines, set_type):
# noinspection DuplicatedCode
examples = []
for (i, line) in enumerate(lines):
examples.append(
Example(
guid="%s-%s" % (set_type, i),
premise=line["context"],
hypothesis=line["query"],
label=line["label"] if set_type != "test" else cls.LABELS[-1],
)
)
return examples
2 changes: 2 additions & 0 deletions jiant/tasks/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from jiant.tasks.lib.cosmosqa import CosmosQATask
from jiant.tasks.lib.edge_probing.dep import DepTask
from jiant.tasks.lib.edge_probing.dpr import DprTask
from jiant.tasks.lib.fever_nli import FeverNliTask
from jiant.tasks.lib.glue_diagnostics import GlueDiagnosticsTask
from jiant.tasks.lib.hellaswag import HellaSwagTask
from jiant.tasks.lib.mctaco import MCTACOTask
Expand Down Expand Up @@ -95,6 +96,7 @@
"cosmosqa": CosmosQATask,
"dep": DepTask,
"dpr": DprTask,
"fever_nli": FeverNliTask,
"glue_diagnostics": GlueDiagnosticsTask,
"hellaswag": HellaSwagTask,
"mctaco": MCTACOTask,
Expand Down

0 comments on commit e7eefc6

Please sign in to comment.