diff --git a/jiant/models.py b/jiant/models.py index 60417ff22..fa54a4140 100644 --- a/jiant/models.py +++ b/jiant/models.py @@ -31,6 +31,7 @@ PairClassifier, NullPhraseLayer, TokenMultiProjectionEncoder, + SOPClassifier, ) from jiant.modules.attn_pair_encoder import AttnPairEncoder from jiant.modules.sentence_encoder import SentenceEncoder @@ -65,6 +66,7 @@ WiCTask, MRPCTask, QQPTask, + SentenceOrderTask, ) from jiant.utils import config from jiant.utils.utils import ( @@ -687,6 +689,34 @@ def build_single_sentence_module(task, d_inp: int, project_before_pooling: bool, return module +def build_sop(task, d_inp, model, params): + """ + Build and load the pretrained head for the sentence order prediction task. + Right now, there is only support for ALBERT. + Parameters + ---------- + task: Task, + d_inp: int, + model: MultiTaskModel, + params: Params + + Returns + ------- + module: SOPCLassifier, which is loaded with pretrained weights from ALBERT SOP + pretraining. + + """ + input_module = model.sent_encoder._text_field_embedder.input_module + assert ( + "albert" in input_module + ), "SOP is only supported for ALBERT, please set input_module to an ALBERT model" + module = SOPClassifier(d_inp, task.n_classes, params) + # The huggingface implementation exposes the pretrained projection layer for the SOP task, which + # we use. See: https://github.com/huggingface/transformers/issues/2671 for more details. + module.pooler.project = model.sent_encoder._text_field_embedder.model.pooler + return module + + def build_pair_sentence_module(task, d_inp, model, params): """ Build a pair classifier, shared if necessary """ @@ -745,6 +775,8 @@ def build_pair_attn(d_in, d_hid_attn): d_out = d_out + d_inp if isinstance(task, WiCTask) else d_out classifier = Classifier.from_params(4 * d_out, n_classes, params) module = PairClassifier(pooler, classifier, pair_attn) + if isinstance(task, SentenceOrderTask): + module = build_sop(task, d_inp, model, params) return module @@ -899,6 +931,8 @@ def forward(self, task, batch, predict=False): out = self._span_forward(batch, task, predict) elif isinstance(task, SpanPredictionTask): out = self._span_prediction_forward(batch, task, predict) + elif isinstance(task, SentenceOrderTask): + out = self._sop_forward(batch, task, predict) else: raise ValueError("Task-specific components not found!") return out diff --git a/jiant/modules/simple_modules.py b/jiant/modules/simple_modules.py index 42166af33..eba58ce4d 100644 --- a/jiant/modules/simple_modules.py +++ b/jiant/modules/simple_modules.py @@ -94,6 +94,30 @@ def from_params(cls, d_inp, n_classes, params): ) +class SOPClassifier(nn.Module): + """ + Task head for sentence order prediction task. We implement the pooled output from ALBERT + via a linear layer followed by Tanh activation layer, which is then fed into the + classification linear layer. + """ + + def __init__(self, d_inp, n_classes, params): + super(SOPClassifier, self).__init__() + self.activation = nn.Tanh() + self.pooler = Pooler(d_inp=d_inp, d_proj=d_inp, pool_type=params["pool_type"]) + assert params["cls_type"] == "log_reg", ( + "The ALBERT implementation of the SOP " + "task takes the final layer from the pooled" + "output. Please set cls_type = log_reg." + ) + self.classifier = Classifier.from_params(d_inp, n_classes, params) + + def forward(self, seq_emb, mask): + seq_emb = self.activation(self.pooler(seq_emb, mask)) + logits = self.classifier(seq_emb) + return logits + + class SingleClassifier(nn.Module): """ Thin wrapper around a set of modules. For single-sentence classification. """ diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py index c4203974a..d7a74e125 100644 --- a/jiant/tasks/tasks.py +++ b/jiant/tasks/tasks.py @@ -4,6 +4,7 @@ import logging as log import os from typing import Any, Dict, Iterable, List, Sequence, Type, Union, Generator +import random import numpy as np import pandas as pd @@ -3768,3 +3769,175 @@ def get_metrics(self, reset=False): """Get metrics specific to the task""" acc = self.scorer1.get_metric(reset) return {"accuracy": acc} + + +@register_task("wikipedia_corpus_sop", rel_path="wikipedia_sop_small") +class SentenceOrderTask(PairClassificationTask): + """ Task class for Sentence Order Prediction (SOP). See the ALBERT paper for details on SOP: + https://arxiv.org/abs/1909.11942. + We are currently using an preprocessed version of the Wikipedia corpus + (more specifically, the Wikidump version 2020-03-01 data) that consists of 5% of the data. You can generate + the data by following the instructions from jiant/scripts/sop. + One thing to note about our SOP ALBERT implementation is that we do not load the pretrained + weights for the SOP head because they are unavailable in Huggingface. We only use the + pretrained weights of the linear layer from ALBERT that creates the pooled output used in SOP. + """ + + def __init__(self, path, max_seq_len, name, **kw): + super(SentenceOrderTask, self).__init__(name, n_classes=2, **kw) + self.path = path + self.max_seq_len = max_seq_len + self.train_data_text = None + self.val_data_text = None + self.test_data_text = None + self.files_by_split = { + "train": os.path.join(path, "train.txt"), + "val": os.path.join(path, "valid.txt"), + "test": os.path.join(path, "test.txt"), + } + self._label_namespace = self.name + "_labels" + + def get_target_seq_length(self): + target_is_max = random.random() > 0.1 + max_seq_len = self.max_seq_len - 3 # exclude [CLS], [SEP], and [SEP] + if target_is_max: + target_seq_length = max_seq_len + else: + target_seq_length = random.randint(2, max_seq_len) + return target_seq_length + + def get_data_iter(self, path): + """Loading data file and tokenizing the text. We override the + this function and all functions that call this function because + the step of reading in the data for SOP is different than other + PairClassificationTasks. + + ALBERT does SOP classification by, for each document: + For each example, we first fetch as many sentences as possible that cumulatively have + target_seq_length number of tokens from the document: + -90% of the time, this target_seq_length is equal to max_seq_length, and + 10% of the time, it is set to a random number of tokens between 2 and max_seq_length. + -Given the sampled sentences, randomly sample N such that the first N sentences in the + sampled go to the first segment, and the rest go to the second. + -50% of the time, the first and second segments are switched. + + Args: + path: (str) data file path + """ + + def _tokenize(tokenizer_name, sent): + tokenizer = get_tokenizer(tokenizer_name) + return tokenizer.tokenize(sent) + + def is_end_document(seg): + tokenized_eod = _tokenize(self._tokenizer_name, "END OF ARTICLE") + return set(tokenized_eod).issubset(set(seg)) + + # The code below is adapted from the original ALBERT code. See: + # https://github.com/google-research/albert/blob/master/create_pretraining_data.py#L267. + f = open(path, "r") + # The dataset comes with one sentence per line, thus we split by + # line here. + current_chunk = [_tokenize(self._tokenizer_name, next(f))] + current_length = len(current_chunk[0]) + target_seq_length = self.get_target_seq_length() + while len(current_chunk) > 0: + segment = next(f) + segment = _tokenize(self._tokenizer_name, segment) + if is_end_document(segment) or current_length >= target_seq_length: + for_next_chunk = [] + if current_length > target_seq_length: + # Since the most current sentence added to the chunk exceeds the target + # length, we save it for the next chunk (next example). + for_next_chunk.append(current_chunk.pop()) + if not is_end_document(segment): + for_next_chunk.append(segment) + target_seq_length = self.get_target_seq_length() + if len(current_chunk) >= 2: + # Make sure we have at least 2 sentences to distribute between the two + # segments. + a_end = random.randint(1, len(current_chunk) - 1) + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + tokens_b = [] + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + in_order = random.random() < 0.5 + if in_order: + yield (tokens_a, tokens_b, in_order) + else: + yield (tokens_b, tokens_a, in_order) + # if len(current_chunk) >=2, we will yield and reinitialize + # if len(current_chunk) ==1, we will not yeild, and reinitialize + if len(for_next_chunk) > 0 and not is_end_document(segment): + # Make sure we only sample articles for each example that + # belong to the same document. + current_chunk = for_next_chunk + current_length = sum([len(chunk) for chunk in for_next_chunk]) + else: + # We find the next sentence for the next example. + try: # Might run into StopIterationError + current_chunk = [_tokenize(self._tokenizer_name, next(f))] + current_length = len(current_chunk[0]) + except: + print("Done loading data for SOP") + current_chunk = [] + current_length = 0 + pass + else: + current_chunk.append(segment) + current_length += len(segment) + + def load_data(self): + pass + + def process_split( + self, split, indexers, model_preprocessing_interface + ) -> Iterable[Type[Instance]]: + """Process a sentence order prediction split by indexing and creating fields. + We override the PairClassificationTask process_split because our data split + is different from the typical PairClassificationTask due to the more memory-efficient + generator way of loading data we employ for SOP due to the dataset size. + Args: + split: (list) a single list of sentences + indexers: (Indexer object) indexer to index input words + """ + + def _make_instance(sent_pairs_): + sent_a, sent_b, is_right_order = sent_pairs_ + inp = model_preprocessing_interface.boundary_token_fn(sent_a, sent_b) + input_sent = sentence_to_text_field(inp, indexers) + label = LabelField(is_right_order, label_namespace="labels", skip_indexing=True) + d = {"inputs": input_sent, "labels": label} + return Instance(d) + + for sent_pairs in split: + yield _make_instance(sent_pairs) + + def get_split_text(self, split: str): + """Get split text as iterable of records. + Args: + split: (str) should be one of 'train', 'val', or 'test'. + """ + return self.get_data_iter(self.files_by_split[split]) + + def get_sentences(self) -> Iterable[Sequence[str]]: + """Yield sentences, used to compute vocabulary. + """ + for split in self.files_by_split: + # Don't use test set for vocab building. + if split.startswith("test"): + continue + for sent in self.get_data_iter(self.files_by_split[split]): + # only counting sent[0] is enough for computing vocab + yield sent[0] + + def count_examples(self): + """Computes number of samples + Assuming every line is one example. + """ + example_counts = {} + for split, split_path in self.files_by_split.items(): + example_counts[split] = sum(1 for _ in self.get_data_iter(split_path)) + self.example_counts = example_counts diff --git a/scripts/sop/README.md b/scripts/sop/README.md new file mode 100644 index 000000000..482736e27 --- /dev/null +++ b/scripts/sop/README.md @@ -0,0 +1,33 @@ +# Downloading Wikipedia Corpus + +We use the preprocessing code from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT#getting-the-data +and the bash scripts provided here is used to help with streamlining the data generation in the NVIDIA repository. + +First, git clone https://github.com/NVIDIA/DeepLearningExamples.git. +Then, move scripts/sop/create_wiki_sop_data.sh and scripts/sop/get_small_english_wiki.sh into DeepLearningExamples/PyTorch/LanguageModeling/BERT/data. + +Then, follow the instructions below: + +NVIDIA script download the latest Wikipedia dump. We use the Wikipedia dump 2020-03-01. +To download the Wikipedia dump 2020-03-01, replace line 29 of `DeepLearningExamples/PyTorch/LanguageModeling/BERT/data/WikiDownloader.py`: +`'en' : 'https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2',` with `'en' : `https://dumps.wikimedia.org/enwiki/20200301/enwiki-20200301-pages-articles.xml.bz2`. + +The data creation for SOP is almost the same as MLM, except you need to edit the following. +In `DeepLearningExamples/PyTorch/LanguageModeling/BERT/data/TextSharding.py`, replace line 55: +`self.articles[global_article_count] = line.rstrip()` with `self.articles[global_article_count] = line.rstrip() + "\n ========THIS IS THE END OF ARTICLE.========"`. +This is because SOP requires a signal for the end of each Wikipedia article. + +Additionally, replace '/workspace/wikiextractor/WikiExtractor.py' in line 80 of +`DeepLearningExamples/PyTorch/LanguageModeling/BERT/data/bertPrep.py` with 'wikiextractor/WikiExtractor.py'. + +Run `bash create_wiki_sop_data.sh $lang $save_directory` +The NVIDIA code supports English (en) and Chinese (zh) wikipedia. + +For example, to download and process English Wikipedia and save it in `~/Download` directory, run +`bash create_wiki_sop_data.sh en ~/Download` + +The above command will download the entire English Wikipedia. + +In our experiments, we only use a small subset (around 5% of) the entire English Wikipedia, which has the same number of sentences as Wikitext103. +To get this subset, run `bash get_small_english_wiki.sh $path_to_wikicorpus_en`. where $path_to_wikicorpus_en is the directory where you saved the full processed `wikicorpus_en` corpus. + diff --git a/scripts/sop/create_wiki_sop_data.sh b/scripts/sop/create_wiki_sop_data.sh new file mode 100755 index 000000000..0f0722e24 --- /dev/null +++ b/scripts/sop/create_wiki_sop_data.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +lang=$1 #the language, 'en' for English wikipedia +export BERT_PREP_WORKING_DIR=$2 + +# clone wikiextractor if it doesn't exist +if [ ! -d "wikiextractor" ]; then + git clone https://github.com/attardi/wikiextractor.git +fi + +echo "Downloading $lang wikpedia in directory $save_dir" +# Download +python3 bertPrep.py --action download --dataset wikicorpus_$lang + + +# Properly format the text files +python3 bertPrep.py --action text_formatting --dataset wikicorpus_$lang + + +# Shard the text files (group wiki+books then shard) +python3 bertPrep.py --action sharding --dataset wikicorpus_$lang + + +# Combine sharded files into one +save_dir=$BERT_PREP_WORKING_DIR/sharded_training_shards_256_test_shards_256_fraction_0.2/wikicorpus_$lang +cat $save_dir/*training*.txt > $save_dir/train_$lang.txt +cat $save_dir/*test*.txt > $save_dir/test_$lang.txt +rm -rf $save_dir/wiki*training*.txt +rm -rf $save_dir/wiki*test*.txt + +# remove some remaining xml tags +sed -i 's/<[^>]*>//g' $save_dir/train_$lang.txt +sed -i 's/<[^>]*>//g' $save_dir/test_$lang.txt + +echo "Your corpus is saved in $save_dir" + diff --git a/scripts/sop/get_small_english_wiki.sh b/scripts/sop/get_small_english_wiki.sh new file mode 100644 index 000000000..e61311543 --- /dev/null +++ b/scripts/sop/get_small_english_wiki.sh @@ -0,0 +1,6 @@ +wiki_path=$1 + +mkdir -p $wiki_path/wikipedia_sop_small +head -3978309 $wiki_path/train_en.txt > $wiki_path/wikipedia_sop_small/train.txt +head -10001 $wiki_path/test_en.txt > $wiki_path/wikipedia_sop_small/test.txt +tail -8438 $wiki_path/train_en.txt > $wiki_path/wikipedia_sop_small/valid.txt diff --git a/tests/tasks/test_sop.py b/tests/tasks/test_sop.py new file mode 100644 index 000000000..99effc7e2 --- /dev/null +++ b/tests/tasks/test_sop.py @@ -0,0 +1,39 @@ +import unittest +import tempfile +import os +from jiant.tasks.registry import REGISTRY + + +class TestSOP(unittest.TestCase): + def setUp(self): + cls, _, kw = REGISTRY["wikipedia_corpus_sop"] + self.temp_dir = tempfile.mkdtemp() + self.max_seq_len = 24 + self.SOPTask = cls( + os.path.join("wikipedia_corpus_sop"), + max_seq_len=self.max_seq_len, + name="wikipedia_corpus_sop", + tokenizer_name="roberta-large", + **kw, + ) + os.mkdir(os.path.join(self.temp_dir, "wikipedia_corpus_sop")) + self.train_path = os.path.join(self.temp_dir, "wikipedia_corpus_sop", "train.txt") + with open(self.train_path, "w") as write_fn: + write_fn.write("1Let's see if SOP works. \n") + write_fn.write("1SOP is one of two pretraining objectives. \n") + write_fn.write("1The other one is MLM.") + write_fn.write("=========END OF ARTICLE======== \n") + write_fn.write("2NLP is pretty cool.\n") + write_fn.write("2An area of focus in the NYU lab is transfer learning.\n") + write_fn.write("2There's some pretty cool stuff.") + write_fn.write("=========END OF ARTICLE======== \n") + write_fn.close() + + def test_sop_preprocessing(self): + train_examples = list(self.SOPTask.get_data_iter(self.train_path)) + for example in train_examples: + # This should be same number since seg_A and seg_B are from same document. + assert example[0][0] == example[1][0] + # Make sure END OF ARTICLE is not included as an example. + assert "=" not in "".join(example[0] + example[1]) + assert len(example[0]) + len(example[1]) <= self.max_seq_len - 3