diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py index a4337df81..393065e6c 100644 --- a/jiant/tasks/tasks.py +++ b/jiant/tasks/tasks.py @@ -1097,6 +1097,101 @@ def load_data(self): log.info("\tFinished loading SNLI data.") +@register_task("adversarial_nli_a1", rel_path="AdversarialNLI/", datasets=["R1"]) +@register_task("adversarial_nli_a2", rel_path="AdversarialNLI/", datasets=["R2"]) +@register_task("adversarial_nli_a3", rel_path="AdversarialNLI/", datasets=["R3"]) +@register_task("adversarial_nli", rel_path="AdversarialNLI/", datasets=["R1", "R2", "R3"]) +class AdversarialNLITask(PairClassificationTask): + """Task class for use with Adversarial Natural Language Inference dataset. + + Configures a 3-class PairClassificationTask using Adversarial NLI data. + Requires original ANLI dataset file structure under the relative path. + Data: https://dl.fbaipublicfiles.com/anli/anli_v0.1.zip + Paper: https://arxiv.org/abs/1910.14599 + + Attributes: + path (str): AdversarialNLI path relative to JIANT_DATA_DIR + max_seq_len (int): max tokens allowed in a sequence + train_data_text (list[list[str], list[str], list[int]]): + list of lists of context, hypothesis, and target training data + val_data_text (list[list[str], list[str], list[int]]): + list of lists of context, hypothesis, and target val data + test_data_text (list[list[str], list[str], list[int]]): + list of lists of context, hypothesis, and target test data + datasets (list[str]): list of sub-datasets used in task (e.g., R1) + sentences (list): list of all (tokenized) context and hypothesis + texts from train and val data. + """ + + def __init__(self, path, max_seq_len, name, datasets, **kw): + """Initialize an AdversarialNLITask task. + + Args: + path (str): AdversarialNLI path relative to the data dir + max_seq_len (int): max tokens allowed in a sequence + name (str): task name, specified in @register_task + datasets (list[str]): list of ANLI sub-datasets used in task + """ + super(AdversarialNLITask, self).__init__(name, n_classes=3, **kw) + self.path = path + self.max_seq_len = max_seq_len + self.train_data_text = None + self.val_data_text = None + self.test_data_text = None + self.datasets = datasets + + def _read_data(self, path: str) -> pd.core.frame.DataFrame: + """Read json, tokenize text, encode labels as int, return dataframe.""" + df = pd.read_json(path_or_buf=path, encoding="UTF-8", lines=True) + # for ANLI datasets n=neutral, e=entailment, c=contradiction + df["target"] = df["label"].map({"n": 0, "e": 1, "c": 2}) + tokenizer = get_tokenizer(self._tokenizer_name) + df["context"] = df["context"].apply(tokenizer.tokenize) + df["hypothesis"] = df["hypothesis"].apply(tokenizer.tokenize) + return df[["context", "hypothesis", "target"]] + + def load_data(self): + """Read, preprocess and load data into an AdversarialNLITask. + + Assumes original dataset file structure under `self.rel_path`. + Loads only the datasets (e.g., "R1") specified in the `datasets` attr. + Populates task train_, val_, test_data_text and `sentence` attr. + """ + train_dfs, val_dfs, test_dfs = [], [], [] + for dataset in self.datasets: + train_dfs.append(self._read_data(os.path.join(self.path, dataset, "train.jsonl"))) + val_dfs.append(self._read_data(os.path.join(self.path, dataset, "dev.jsonl"))) + test_dfs.append(self._read_data(os.path.join(self.path, dataset, "test.jsonl"))) + train_df = pd.concat(train_dfs, axis=0, ignore_index=True) + val_df = pd.concat(val_dfs, axis=0, ignore_index=True) + test_df = pd.concat(test_dfs, axis=0, ignore_index=True) + + self.train_data_text = [ + train_df["context"].tolist(), + train_df["hypothesis"].tolist(), + train_df["target"].tolist(), + ] + self.val_data_text = [ + val_df["context"].tolist(), + val_df["hypothesis"].tolist(), + val_df["target"].tolist(), + ] + self.test_data_text = [ + test_df["context"].tolist(), + test_df["hypothesis"].tolist(), + test_df["target"].tolist(), + ] + + self.sentences = ( + train_df["context"].tolist() + + train_df["hypothesis"].tolist() + + val_df["context"].tolist() + + val_df["hypothesis"].tolist() + ) + + log.info("\tFinished loading ANLI data: " + self.name) + + @register_task("mnli", rel_path="MNLI/") # second copy for different params @register_task("mnli-alt", rel_path="MNLI/")