diff --git a/.gitignore b/.gitignore index 8bda64660..d477eb436 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ user_config.sh .idea .ipynb_checkpoints/ perluniprops/ +.DS_Store diff --git a/README.md b/README.md index ad7311d05..01eb6b4b1 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ A few things you might want to know about `jiant`: - `jiant` is configuration-driven. You can run an enormous variety of experiments by simply writing configuration files. Of course, if you need to add any major new features, you can also easily edit or extend the code. - `jiant` contains implementations of strong baselines for the [GLUE](https://gluebenchmark.com) and [SuperGLUE](https://super.gluebenchmark.com/) benchmarks, and it's the recommended starting point for work on these benchmarks. - `jiant` was developed at [the 2018 JSALT Workshop](https://www.clsp.jhu.edu/workshops/18-workshop/) by [the General-Purpose Sentence Representation Learning](https://jsalt18-sentence-repl.github.io/) team and is maintained by [the NYU Machine Learning for Language Lab](https://wp.nyu.edu/ml2/people/), with help from [many outside collaborators](https://github.com/nyu-mll/jiant/graphs/contributors) (especially Google AI Language's [Ian Tenney](https://ai.google/research/people/IanTenney)). -- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-pretrained-BERT) of BERT and GPT. +- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-transformers) of GPT, BERT, and XLNet. - The name `jiant` doesn't mean much. The 'j' stands for JSALT. That's all the acronym we have. ## Getting Started @@ -84,10 +84,10 @@ This package is released under the [MIT License](LICENSE.md). The material in th ## Acknowledgments -- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories. +- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories. - This work was made possible in part by a donation to NYU from Eric and Wendy Schmidt made by recommendation of the Schmidt Futures program. -- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work. +- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work. - Developer Alex Wang is supported by the National Science Foundation Graduate Research Fellowship Program under Grant No. DGE 1342536. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science diff --git a/cola_inference.py b/cola_inference.py index 04ad33e28..13331b7ee 100644 --- a/cola_inference.py +++ b/cola_inference.py @@ -47,10 +47,10 @@ from jiant.models import build_model from jiant.preprocess import build_indexers, build_tasks -from jiant.tasks.tasks import process_sentence, sentence_to_text_field +from jiant.tasks.tasks import tokenize_and_truncate, sentence_to_text_field from jiant.utils import config from jiant.utils.data_loaders import load_tsv -from jiant.utils.utils import check_arg_name, load_model_state +from jiant.utils.utils import check_arg_name, load_model_state, select_pool_type log.basicConfig(format="%(asctime)s: %(message)s", datefmt="%m/%d %I:%M:%S %p", level=log.INFO) @@ -121,6 +121,7 @@ def main(cl_arguments): cl_args = handle_arguments(cl_arguments) args = config.params_from_file(cl_args.config_file, cl_args.overrides) check_arg_name(args) + assert args.target_tasks == "cola", "Currently only supporting CoLA. ({})".format( args.target_tasks ) @@ -138,6 +139,11 @@ def main(cl_arguments): ) args.cuda = -1 + if args.tokenizer == "auto": + args.tokenizer = tokenizers.select_tokenizer(args) + if args.pool_type == "auto": + args.pool_type = select_pool_type(args) + # Prepare data # _, target_tasks, vocab, word_embs = build_tasks(args) tasks = sorted(set(target_tasks), key=lambda x: x.name) @@ -185,7 +191,7 @@ def run_repl(model, vocab, indexers, task, args): if input_string == "QUIT": break - tokens = process_sentence( + tokens = tokenize_and_truncate( tokenizer_name=task.tokenizer_name, sent=input_string, max_seq_len=args.max_seq_len ) print("TOKENS:", " ".join("[{}]".format(tok) for tok in tokens)) @@ -282,7 +288,7 @@ def load_cola_data(input_path, task, input_format, max_seq_len): with open(input_path, "r") as f_in: sentences = f_in.readlines() tokens = [ - process_sentence( + tokenize_and_truncate( tokenizer_name=task.tokenizer_name, sent=sentence, max_seq_len=max_seq_len ) for sentence in sentences diff --git a/config/ccg_bert.conf b/config/ccg_bert.conf index 037009d26..7f780fe7d 100644 --- a/config/ccg_bert.conf +++ b/config/ccg_bert.conf @@ -4,7 +4,6 @@ include "defaults.conf" pretrain_tasks = ccg target_tasks = ccg input_module = bert-base-uncased -tokenizer = ${input_module} do_target_task_training = 0 transfer_paradigm = finetune @@ -16,7 +15,6 @@ skip_embs = 1 // BERT-specific setup classifier = log_reg // following BERT paper -pool_type = first dropout = 0.1 // following BERT paper optimizer = bert_adam diff --git a/config/copa_bert.conf b/config/copa_bert.conf index 8b6c73dd6..f35cd3c3f 100644 --- a/config/copa_bert.conf +++ b/config/copa_bert.conf @@ -19,10 +19,8 @@ do_full_eval = 1 // Typical BERT base setup input_module = bert-base-uncased -tokenizer = bert-base-uncased transfer_paradigm = finetune classifier = log_reg -pool_type = first optimizer = bert_adam lr = 0.00001 sent_enc = none diff --git a/config/defaults.conf b/config/defaults.conf index 13562b2d3..91d97d816 100644 --- a/config/defaults.conf +++ b/config/defaults.conf @@ -90,7 +90,7 @@ transfer_paradigm = "frozen" // How to use pretrained model parameters during ta // "frozen" will train the downstream models on fixed // representations from the encoder model. // "finetune" will update the parameters of the encoders models as - // well as the downstream models. + // well as the downstream models. (This disables d_proj.) load_target_train_checkpoint = none // If not "none", load the specified model_state checkpoint // file when starting do_target_task_training. // Supports * wildcards. @@ -140,6 +140,9 @@ batch_size = 32 // Training batch size. optimizer = adam // Optimizer. All valid AllenNLP options are available, including 'sgd'. // Use 'bert_adam' for reproducing BERT experiments. // 'adam' uses the newer AMSGrad variant. + // Warning: bert_adam is designed for cases where the number of epochs is known + // in advance, so it may not behave reasonably unless max_epochs is set to a + // reasonable positive value. lr = 0.0001 // Initial learning rate. min_lr = 0.000001 // Minimum learning rate. Training will stop when our explicit LR decay lowers // the LR below this point or if any other stopping criterion applies. @@ -221,42 +224,41 @@ max_targ_word_v_size = 20000 // Maximum target word vocab size for seq2seq task // Input Handling // -input_module = "" // The word embedding or contextual word representation layer. - // Currently supported options: - // - scratch: Word embeddings trained from scratch. - // - glove: Leaded GloVe word embeddings. Typically used with - // tokenizer = MosesTokenizer. Note that this is not quite identical to the - // Stanford tokenizer used to train GloVe. - // - fastText: Leaded GloVe word embeddings. Use with - // tokenizer = MosesTokenizer. - // - elmo: AllenNLP's ELMo contextualized word vector model hidden states. Use - // with tokenizer = MosesTokenizer. - // - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's - // ELMo, but not ELMo's LSTM layer hidden states. Use with - // tokenizer = MosesTokenizer. - // - bert-base-uncased, etc.: Any BERT model specifier that is valid for - // pytorch-pretrained-bert may be specified here. Use with - // tokenizer = ${input_module} - // We support the newer bert-large-uncased-whole-word-masking and - // bert-large-cased-whole-word-masking cased models, but they require - // the git development version of pytorch-pretrained-bert. To use these - // models, follow the instructions under 'From source' here: - // https://github.com/huggingface/pytorch-pretrained-BERT - // Most of these options use MosesTokenizer tokenization, but - // BERT and GPT need more specific tokenization (tokenizer config - // parameter should be equal to input_module for BERT, and should be - // equal to 'OpenAI.BPE' if input_module = gpt). - // For ELMo, BERT, and GPT, there are additional config parameters below. - -tokenizer = "MosesTokenizer" // The name of the tokenizer, passed to the Task constructor for - // appropriate handling during data loading. Currently supported - // options: - // - "": Split the input data on whitespace. - // - MosesTokenizer: Our standard word tokenizer. (Support for - // other NLTK tokenizers is pending.) - // - bert-uncased-base, etc.: Use the tokenizer supplied with - // pytorch-pretrained-bert that corresponds to that BERT model. - // - OpenAI.BPE: The tokenizer supplied with OpenAI GPT. +input_module = "" // The word embedding or contextual word representation layer. + // Currently supported options: + // - scratch: Word embeddings trained from scratch. + // - glove: Loaded GloVe word embeddings. Typically used with + // tokenizer = MosesTokenizer. Note that this is not quite identical to + // the Stanford tokenizer used to train GloVe. + // - fastText: Loaded fastText word embeddings. Use with + // tokenizer = MosesTokenizer. + // - elmo: AllenNLP's ELMo contextualized word vector model hidden states. Use + // with tokenizer = MosesTokenizer. + // - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's + // ELMo, but not ELMo's LSTM layer hidden states. Use with + // tokenizer = MosesTokenizer. + // - gpt: The OpenAI GPT language model encoder. + // Use with tokenizer = OpenAI.BPE. + // - bert-base-uncased, etc.: Any BERT model specifier that is valid for + // pytorch-pretrained-bert may be specified here. Use with + // tokenizer = ${input_module} + // We support the newer bert-large-uncased-whole-word-masking and + // bert-large-cased-whole-word-masking cased models, but they require + // the git development version of pytorch-pretrained-bert. To use these + // models, follow the instructions under 'From source' here: + // https://github.com/huggingface/pytorch-pretrained-BERT + +tokenizer = auto // The name of the tokenizer, passed to the Task constructor for + // appropriate handling during data loading. Currently supported + // options: + // - auto: Select the tokenizer that matches the model specified in + // input_module above. Usually a safe default. + // - "": Split the input data on whitespace. + // - MosesTokenizer: Our standard word tokenizer. (Support for + // other NLTK tokenizers is pending.) + // - bert-uncased-base, etc.: Use the tokenizer supplied with + // pytorch-pretrained-bert that corresponds to that BERT model. + // - OpenAI.BPE: The tokenizer supplied with OpenAI GPT. word_embs_file = ${WORD_EMBS_FILE} // Path to embeddings file, used with glove and fastText. d_word = 300 // Dimension of word embeddings, used with scratch, glove, or fastText. @@ -282,22 +284,21 @@ openai_embeddings_mode = "none" // How to handle the embedding layer of the Ope // "mix" uses ELMo-style scalar mixing (with // learned weights) across all layers. -bert_embeddings_mode = "none" // How to handle the embedding layer of the - // BERT model: - // "none" or "top" returns only top-layer activation, - // "cat" returns top-layer concatenated with - // lexical layer, - // "only" returns only lexical layer, - // "mix" uses ELMo-style scalar mixing (with - // learned weights) across all layers. -bert_max_layer = -1 // Maximum layer to return from BERT encoder. Layer 0 is - // wordpiece embeddings. - // bert_embeddings_mode will behave as if the BERT encoder - // is truncated at this layer, so 'top' will return this - // layer, and 'mix' will return a mix of all layers up to - // and including this layer. - // Set to -1 to use all layers. - // Used for probing experiments. +pytorch_transformers_output_mode = "none" // How to handle the embedding layer of the + // BERT/XLNet model: + // "none" or "top" returns only top-layer activation, + // "cat" returns top-layer concatenated with + // lexical layer, + // "only" returns only lexical layer, + // "mix" uses ELMo-style scalar mixing (with learned + // weights) across all layers. +pytorch_transformers_max_layer = -1 // Maximum layer to return from BERT etc. encoder. Layer 0 is + // wordpiece embeddings. pytorch_transformers_embeddings_mode + // will behave as if the is truncated at this layer, so 'top' + // will return this layer, and 'mix' will return a mix of all + // layers up to and including this layer. + // Set to -1 to use all layers. + // Used for probing experiments. force_include_wsj_vocabulary = 0 // Set if using PTB parsing (grammar induction) task. Makes sure // to include WSJ vocabulary. @@ -320,7 +321,7 @@ n_layers_enc = 2 // Number of layers for a 'rnn' sent_enc. skip_embs = 1 // If true, concatenate the sent_enc's input (ELMo/GPT/BERT output or // embeddings) with the sent_enc's output. sep_embs_for_skip = 0 // Whether the skip embedding uses the same embedder object as the original - //embedding (before skip). + // embedding (before skip). // Only makes a difference if we are using ELMo weights, where it allows // the four tuned ELMo scalars to vary separately for each target task. n_layers_highway = 0 // Number of highway layers between the embedding layer and the sent_enc layer. [Deprecated.] @@ -364,8 +365,11 @@ pair_attn = 1 // If true, use attn in sentence-pair classification/regression t d_hid_attn = 512 // Post-attention LSTM state size. shared_pair_attn = 0 // If true, share pair_attn parameters across all tasks that use it. d_proj = 512 // Size of task-specific linear projection applied before before pooling. -pool_type = "max" // Type of pooling to reduce sequences of vectors into a single vector. - // Options: "max", "mean", "first", "final" + // Disabled when fine-tuning pytorch_transformers models. +pool_type = "auto" // Type of pooling to reduce sequences of vectors into a single vector. + // Options: "auto", "max", "mean", "first", "final" + // "auto" uses "first" for plain BERT (with no sent_enc), "final" for plain + // XLNet and GPT, and "max" in all other settings. span_classifier_loss_fn = "softmax" // Classifier loss function. Used only in some tasks (notably // span-related tasks), not mlp/fancy_mlp. Currently supports // sigmoid and softmax. diff --git a/config/examples/copa_bert.conf b/config/examples/copa_bert.conf index b7066ad0f..295c1c931 100644 --- a/config/examples/copa_bert.conf +++ b/config/examples/copa_bert.conf @@ -19,10 +19,8 @@ do_full_eval = 1 // Typical BERT base setup input_module = bert-base-uncased -tokenizer = bert-base-uncased transfer_paradigm = finetune classifier = log_reg -pool_type = first optimizer = bert_adam lr = 0.00001 sent_enc = none diff --git a/config/examples/stilts_example.conf b/config/examples/stilts_example.conf index 31166f54e..c1d3e3d68 100644 --- a/config/examples/stilts_example.conf +++ b/config/examples/stilts_example.conf @@ -18,8 +18,7 @@ batch_size = 24 write_preds = "val,test" //BERT-specific parameters -bert_embeddings_mode = "top" -pool_type = "first" +pytorch_transformers_output_mode = "top" sep_embs_for_skip = 1 sent_enc = "none" classifier = log_reg // following BERT paper @@ -34,6 +33,5 @@ patience = 20 max_vals = 10000 transfer_paradigm = "finetune" -tokenizer = "bert-base-uncased" input_module = "bert-base-uncased" diff --git a/config/superglue-bert.conf b/config/superglue-bert.conf index 2ec6acd2e..ffbad2ace 100644 --- a/config/superglue-bert.conf +++ b/config/superglue-bert.conf @@ -7,11 +7,10 @@ exp_name = "bert-large-cased" // Data and preprocessing settings max_seq_len = 256 // Mainly needed for MultiRC, to avoid over-truncating // But not 512 as that is really hard to fit in memory. -tokenizer = "bert-large-cased" + // Model settings input_module = "bert-large-cased" -bert_embeddings_mode = "top" -pool_type = "first" +pytorch_transformers_output_mode = "top" pair_attn = 0 // shouldn't be needed but JIC s2s = { attention = none diff --git a/environment.yml b/environment.yml index 240937df6..5c022258d 100644 --- a/environment.yml +++ b/environment.yml @@ -32,3 +32,9 @@ dependencies: - ftfy==5.4.1 - spacy==2.0.11 + # Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_ + # pytorch_transformers > 1.0. These are the same package, though the name changed between + # these two versions. AllenNLP requires 0.6 to support the BertAdam optimizer, and jiant + # directly requires 1.0 to support XLNet and WWM-BERT. + # This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067 + - pytorch-transformers==1.0.0 diff --git a/gcp/config/jiant_paths.sh b/gcp/config/jiant_paths.sh index f81f5fd92..202c78855 100644 --- a/gcp/config/jiant_paths.sh +++ b/gcp/config/jiant_paths.sh @@ -12,8 +12,8 @@ export JIANT_PROJECT_PREFIX="$HOME/exp" # pre-downloaded ELMo models export ELMO_SRC_DIR="/nfs/jiant/share/elmo" -# cache for BERT models -export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/bert_cache" +# cache for BERT etc. models +export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/pytorch_transformers_cache" # word embeddings export WORD_EMBS_FILE="/nfs/jiant/share/wiki-news-300d-1M.vec" diff --git a/gcp/kubernetes/run_batch.sh b/gcp/kubernetes/run_batch.sh index 78eef4c02..3c72faa7d 100755 --- a/gcp/kubernetes/run_batch.sh +++ b/gcp/kubernetes/run_batch.sh @@ -99,4 +99,3 @@ jsonnet -S -o "${YAML_FILE}" \ ## # Create the Kubernetes pod; this will actually launch the job. kubectl ${KUBECTL_MODE} -f "${YAML_FILE}" - diff --git a/gcp/kubernetes/templates/jiant_env.libsonnet b/gcp/kubernetes/templates/jiant_env.libsonnet index 03aaef32d..5da1abbf5 100644 --- a/gcp/kubernetes/templates/jiant_env.libsonnet +++ b/gcp/kubernetes/templates/jiant_env.libsonnet @@ -14,14 +14,14 @@ nfs_exp_dir: "/nfs/jiant/exp", # Name of pre-built Docker image, accessible from Kubernetes. - gcr_image: "gcr.io/google.com/jiant-stilts/jiant-conda:v0", + gcr_image: "gcr.io/google.com/jiant-stilts/jiant-conda:v2", # Default location for glue_data jiant_data_dir: "/nfs/jiant/share/glue_data", # Path to ELMO cache. elmo_src_dir: "/nfs/jiant/share/elmo", - # Path to BERT model cache; should be writable by Kubernetes workers. - bert_cache_path: "/nfs/jiant/share/bert_cache", + # Path to BERT etc. model cache; should be writable by Kubernetes workers. + pytorch_transformers_cache_path: "/nfs/jiant/share/pytorch_transformers_cache", # Path to default word embeddings file word_embs_file: "/nfs/jiant/share/wiki-news-300d-1M.vec", } diff --git a/gcp/kubernetes/templates/run_batch.jsonnet b/gcp/kubernetes/templates/run_batch.jsonnet index 3a61795e8..e32b6e399 100644 --- a/gcp/kubernetes/templates/run_batch.jsonnet +++ b/gcp/kubernetes/templates/run_batch.jsonnet @@ -36,7 +36,7 @@ function(job_name, command, project_dir, uid, fsgroup, }, { name: "PYTORCH_PRETRAINED_BERT_CACHE", - value: jiant_env.bert_cache_path, + value: jiant_env.pytorch_transformers_cache_path }, { name: "ELMO_SRC_DIR", diff --git a/jiant/bert/__init__.py b/jiant/bert/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/jiant/bert/utils.py b/jiant/bert/utils.py deleted file mode 100644 index 478ae0125..000000000 --- a/jiant/bert/utils.py +++ /dev/null @@ -1,163 +0,0 @@ -import logging as log -from typing import Dict - -import torch -import torch.nn as nn -from allennlp.modules import scalar_mix - -# huggingface implementation of BERT -import pytorch_pretrained_bert - -from jiant.preprocess import parse_task_list_arg - - -def _get_seg_ids(ids, sep_id): - """ Dynamically build the segment IDs for a concatenated pair of sentences - Searches for index SEP_ID in the tensor - - args: - ids (torch.LongTensor): batch of token IDs - - returns: - seg_ids (torch.LongTensor): batch of segment IDs - - example: - > sents = ["[CLS]", "I", "am", "a", "cat", ".", "[SEP]", "You", "like", "cats", "?", "[SEP]"] - > token_tensor = torch.Tensor([[vocab[w] for w in sent]]) # a tensor of token indices - > seg_ids = _get_seg_ids(token_tensor, sep_id=102) # BERT [SEP] ID - > assert seg_ids == torch.LongTensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) - """ - sep_idxs = (ids == sep_id).nonzero()[:, 1] - seg_ids = torch.ones_like(ids) - for row, idx in zip(seg_ids, sep_idxs[::2]): - row[: idx + 1].fill_(0) - return seg_ids - - -class BertEmbedderModule(nn.Module): - """ Wrapper for BERT module to fit into jiant APIs. """ - - def __init__(self, args, cache_dir=None): - super(BertEmbedderModule, self).__init__() - - self.model = pytorch_pretrained_bert.BertModel.from_pretrained( - args.input_module, cache_dir=cache_dir - ) - self.embeddings_mode = args.bert_embeddings_mode - self.num_layers = self.model.config.num_hidden_layers - if args.bert_max_layer >= 0: - self.max_layer = args.bert_max_layer - else: - self.max_layer = self.num_layers - assert self.max_layer <= self.num_layers - - tokenizer = pytorch_pretrained_bert.BertTokenizer.from_pretrained( - args.input_module, cache_dir=cache_dir - ) - self._sep_id = tokenizer.vocab["[SEP]"] - self._pad_id = tokenizer.vocab["[PAD]"] - - # Set trainability of this module. - for param in self.model.parameters(): - param.requires_grad = bool(args.transfer_paradigm == "finetune") - - # Configure scalar mixing, ELMo-style. - if self.embeddings_mode == "mix": - if args.transfer_paradigm == "frozen": - log.warning( - "NOTE: bert_embeddings_mode='mix', so scalar " - "mixing weights will be fine-tuned even if BERT " - "model is frozen." - ) - # TODO: if doing multiple target tasks, allow for multiple sets of - # scalars. See the ELMo implementation here: - # https://github.com/allenai/allennlp/blob/master/allennlp/modules/elmo.py#L115 - assert len(parse_task_list_arg(args.target_tasks)) <= 1, ( - "bert_embeddings_mode='mix' only supports a single set of " - "scalars (but if you need this feature, see the TODO in " - "the code!)" - ) - # Always have one more mixing weight, for lexical layer. - self.scalar_mix = scalar_mix.ScalarMix(self.max_layer + 1, do_layer_norm=False) - - def forward( - self, sent: Dict[str, torch.LongTensor], unused_task_name: str = "", is_pair_task=False - ) -> torch.FloatTensor: - """ Run BERT to get hidden states. - - This forward method does preprocessing on the go, - changing token IDs from preprocessed bert to - what AllenNLP indexes. - - Args: - sent: batch dictionary - is_pair_task (bool): true if input is a batch from a pair task - - Returns: - h: [batch_size, seq_len, d_emb] - """ - assert "bert_wpm_pretokenized" in sent - # [batch_size, var_seq_len] - ids = sent["bert_wpm_pretokenized"] - # BERT supports up to 512 tokens; see section 3.2 of https://arxiv.org/pdf/1810.04805.pdf - assert ids.size()[1] <= 512 - - mask = ids != 0 - # "Correct" ids to account for different indexing between BERT and - # AllenNLP. - # The AllenNLP indexer adds a '@@UNKNOWN@@' token to the - # beginning of the vocabulary, *and* treats that as index 1 (index 0 is - # reserved for padding). - ids[ids == 0] = self._pad_id + 2 # Shift the indices that were at 0 to become 2. - # Index 1 should never be used since the BERT WPM uses its own - # unk token, and handles this at the string level before indexing. - assert (ids > 1).all() - ids -= 2 # shift indices to match BERT wordpiece embeddings - - if self.embeddings_mode not in ["none", "top"]: - # This is redundant with the lookup inside BertModel, - # but doing so this way avoids the need to modify the BertModel - # code. - # Extract lexical embeddings; see - # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L186 # noqa - h_lex = self.model.embeddings.word_embeddings(ids) - h_lex = self.model.embeddings.LayerNorm(h_lex) - # following our use of the OpenAI model, don't use dropout for - # probing. If you would like to use dropout, consider applying - # later on in the SentenceEncoder (see models.py). - # h_lex = self.model.embeddings.dropout(embeddings) - else: - h_lex = None # dummy; should not be accessed. - - if self.embeddings_mode != "only": - # encoded_layers is a list of layer activations, each of which is - # [batch_size, seq_len, output_dim] - token_types = _get_seg_ids(ids, self._sep_id) if is_pair_task else torch.zeros_like(ids) - encoded_layers, _ = self.model( - ids, token_type_ids=token_types, attention_mask=mask, output_all_encoded_layers=True - ) - else: - encoded_layers = [] # 'only' mode is embeddings only - - all_layers = [h_lex] + encoded_layers - all_layers = all_layers[: self.max_layer + 1] - - if self.embeddings_mode in ["none", "top"]: - h = all_layers[-1] - elif self.embeddings_mode == "only": - h = all_layers[0] - elif self.embeddings_mode == "cat": - h = torch.cat([all_layers[-1], all_layers[0]], dim=2) - elif self.embeddings_mode == "mix": - h = self.scalar_mix(all_layers, mask=mask) - else: - raise NotImplementedError(f"embeddings_mode={self.embeddings_mode}" " not supported.") - - # [batch_size, var_seq_len, output_dim] - return h - - def get_output_dim(self): - if self.embeddings_mode == "cat": - return 2 * self.model.config.hidden_size - else: - return self.model.config.hidden_size diff --git a/jiant/models.py b/jiant/models.py index ff2de552c..782750ac2 100644 --- a/jiant/models.py +++ b/jiant/models.py @@ -41,6 +41,7 @@ from jiant.modules.prpn.PRPN import PRPN from jiant.modules.seq2seq_decoder import Seq2SeqDecoder from jiant.modules.span_modules import SpanClassifierModule +from jiant.pytorch_transformers_interface import input_module_uses_pytorch_transformers from jiant.tasks.edge_probing import EdgeProbingTask from jiant.tasks.lm import LanguageModelingTask from jiant.tasks.lm_parsing import LanguageModelingParsingTask @@ -152,8 +153,14 @@ def build_sent_encoder(args, vocab, d_emb, tasks, embedder, cove_layer): elif any(isinstance(task, LanguageModelingTask) for task in tasks) or args.sent_enc == "bilm": assert_for_log(args.sent_enc in ["rnn", "bilm"], "Only RNNLM supported!") assert_for_log( - args.input_module != "elmo" and not args.input_module.startswith("bert"), - "LM with full ELMo and BERT not supported", + not ( + args.input_module == "elmo" + or args.input_module.startswith("bert") + or args.input_module.startswith("xlnet") + ), + f"Using input_module = {args.input_module} for language modeling is probably not a " + "good idea, since it allows the language model to use information from the right-hand " + "context.", ) bilm = BiLMEncoder(d_emb, args.d_hid, args.d_hid, args.n_layers_enc) sent_encoder = SentenceEncoder( @@ -219,31 +226,27 @@ def build_model(args, vocab, pretrained_embs, tasks): """ # Build embeddings. + cove_layer = None if args.input_module == "gpt": # Note: incompatible with other embedders, but logic in preprocess.py # should prevent these from being enabled anyway. from .openai_transformer_lm.utils import OpenAIEmbedderModule log.info("Using OpenAI transformer model.") - cove_layer = None # Here, this uses openAIEmbedder. embedder = OpenAIEmbedderModule(args) d_emb = embedder.get_output_dim() elif args.input_module.startswith("bert"): - # Note: incompatible with other embedders, but logic in preprocess.py - # should prevent these from being enabled anyway. - from .bert.utils import BertEmbedderModule + from jiant.pytorch_transformers_interface.modules import BertEmbedderModule log.info(f"Using BERT model ({args.input_module}).") - cove_layer = None - # Set PYTORCH_PRETRAINED_BERT_CACHE environment variable to an existing - # cache; see - # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/file_utils.py # noqa - bert_cache_dir = os.getenv( - "PYTORCH_PRETRAINED_BERT_CACHE", os.path.join(args.exp_dir, "bert_cache") - ) - maybe_make_dir(bert_cache_dir) - embedder = BertEmbedderModule(args, cache_dir=bert_cache_dir) + embedder = BertEmbedderModule(args) + d_emb = embedder.get_output_dim() + elif args.input_module.startswith("xlnet"): + from jiant.pytorch_transformers_interface.modules import XLNetEmbedderModule + + log.info(f"Using XLNet model ({args.input_module}).") + embedder = XLNetEmbedderModule(args) d_emb = embedder.get_output_dim() else: # Default case, used for ELMo, CoVe, word embeddings, etc. @@ -308,11 +311,11 @@ def build_embeddings(args, vocab, tasks, pretrained_embs=None): d_word = args.d_word word_embs = nn.Embedding(n_token_vocab, d_word).weight else: - assert args.input_module.startswith("bert") or args.input_module in [ + assert input_module_uses_pytorch_transformers(args.input_module) or args.input_module in [ "gpt", "elmo", "elmo-chars-only", - ], "You do not have a valid value for input_module." + ], f"'{args.input_module}' is not a valid value for input_module." embeddings = None word_embs = None @@ -490,7 +493,10 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg task_params = model._get_task_params(task.name) if isinstance(task, SingleClassificationTask): module = build_single_sentence_module( - task=task, d_inp=d_sent, use_bert=model.use_bert, params=task_params + task=task, + d_inp=d_sent, + project_before_pooling=model.project_before_pooling, + params=task_params, ) setattr(model, "%s_mdl" % task.name, module) elif isinstance(task, (PairClassificationTask, PairRegressionTask, PairOrdinalRegressionTask)): @@ -513,7 +519,7 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg setattr(model, "%s_mdl" % task.name, hid2tag) elif isinstance(task, MultipleChoiceTask): module = build_multiple_choice_module( - task, d_sent, use_bert=model.use_bert, params=task_params + task, d_sent, project_before_pooling=model.project_before_pooling, params=task_params ) setattr(model, "%s_mdl" % task.name, module) elif isinstance(task, EdgeProbingTask): @@ -524,7 +530,7 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg setattr(model, "%s_decoder" % task.name, decoder) setattr(model, "%s_hid2voc" % task.name, hid2voc) elif isinstance(task, (MultiRCTask, ReCoRDTask)): - module = build_qa_module(task, d_sent, model.use_bert, task_params) + module = build_qa_module(task, d_sent, model.project_before_pooling, task_params) setattr(model, "%s_mdl" % task.name, module) else: raise ValueError("Module not found for %s" % task.name) @@ -580,13 +586,13 @@ def build_image_sent_module(task, d_inp, params): return pooler -def build_single_sentence_module(task, d_inp: int, use_bert: bool, params: Params): +def build_single_sentence_module(task, d_inp: int, project_before_pooling: bool, params: Params): """ Build a single sentence classifier args: - task (Task): task object, used to get the number of output classes - d_inp (int): input dimension to the module, needed for optional linear projection - - use_bert (bool): if using BERT, skip projection before pooling. + - project_before_pooling (bool): apply a projection layer before pooling. - params (Params): Params object with task-specific parameters returns: @@ -594,9 +600,12 @@ def build_single_sentence_module(task, d_inp: int, use_bert: bool, params: Param (optional) a linear projection, pooling, and an MLP classifier """ pooler = Pooler( - project=not use_bert, d_inp=d_inp, d_proj=params["d_proj"], pool_type=params["pool_type"] + project=project_before_pooling, + d_inp=d_inp, + d_proj=params["d_proj"], + pool_type=params["pool_type"], ) - d_out = d_inp if use_bert else params["d_proj"] + d_out = params["d_proj"] if project_before_pooling else d_inp classifier = Classifier.from_params(d_out, task.n_classes, params) module = SingleClassifier(pooler, classifier) return module @@ -623,34 +632,34 @@ def build_pair_attn(d_in, d_hid_attn): # Build the "pooler", which does pools a variable length sequence # possibly with a projection layer beforehand - if params["attn"] and not model.use_bert: + if params["attn"] and model.project_before_pooling: pooler = Pooler(project=False, d_inp=params["d_hid_attn"], d_proj=params["d_hid_attn"]) d_out = params["d_hid_attn"] * 2 else: pooler = Pooler( - project=not model.use_bert, + project=model.project_before_pooling, d_inp=d_inp, d_proj=params["d_proj"], pool_type=params["pool_type"], ) - d_out = d_inp if model.use_bert else params["d_proj"] + d_out = params["d_proj"] if model.project_before_pooling else d_inp # Build an attention module if necessary - if params["shared_pair_attn"] and params["attn"] and not model.use_bert: # shared attn + if params["shared_pair_attn"] and params["attn"]: # shared attn if not hasattr(model, "pair_attn"): pair_attn = build_pair_attn(d_inp, params["d_hid_attn"]) model.pair_attn = pair_attn else: pair_attn = model.pair_attn - elif params["attn"] and not model.use_bert: # non-shared attn + elif params["attn"]: # non-shared attn pair_attn = build_pair_attn(d_inp, params["d_hid_attn"]) else: # no attn pair_attn = None # Build the classifier n_classes = task.n_classes if hasattr(task, "n_classes") else 1 - if model.use_bert: - # BERT handles pair tasks by concatenating the inputs and classifying the joined + if model.use_pytorch_transformers: + # BERT/XLNet handle pair tasks by concatenating the inputs and classifying the joined # sequence, so we use a single sentence classifier if isinstance(task, WiCTask): d_out *= 3 # also pass the two contextual word representations @@ -680,12 +689,15 @@ def build_tagger(task, d_inp, out_dim): return hid2tag -def build_multiple_choice_module(task, d_sent, use_bert, params): +def build_multiple_choice_module(task, d_sent, project_before_pooling, params): """ Basic parts for MC task: reduce a vector representation for each model into a scalar. """ pooler = Pooler( - project=not use_bert, d_inp=d_sent, d_proj=params["d_proj"], pool_type=params["pool_type"] + project=project_before_pooling, + d_inp=d_sent, + d_proj=params["d_proj"], + pool_type=params["pool_type"], ) - d_out = d_sent if use_bert else params["d_proj"] + d_out = params["d_proj"] if project_before_pooling else d_sent choice2scalar = Classifier(d_out, n_classes=1, cls_type=params["cls_type"]) return SingleClassifier(pooler, choice2scalar) @@ -707,7 +719,7 @@ def build_decoder(task, d_inp, vocab, embedder, args): return decoder, hid2voc -def build_qa_module(task, d_inp, use_bert, params): +def build_qa_module(task, d_inp, project_before_pooling, params): """ Build a simple QA module that 1) pools representations (either of the joint (context, question, answer) or individually 2) projects down to two logits @@ -715,9 +727,12 @@ def build_qa_module(task, d_inp, use_bert, params): This module models each question-answer pair _individually_ """ pooler = Pooler( - project=not use_bert, d_inp=d_inp, d_proj=params["d_proj"], pool_type=params["pool_type"] + project=project_before_pooling, + d_inp=d_inp, + d_proj=params["d_proj"], + pool_type=params["pool_type"], ) - d_out = d_inp if use_bert else params["d_proj"] + d_out = params["d_proj"] if project_before_pooling else d_inp classifier = Classifier.from_params(d_out, 2, params) return SingleClassifier(pooler, classifier) @@ -736,7 +751,10 @@ def __init__(self, args, sent_encoder, vocab): self.vocab = vocab self.utilization = Average() if args.track_batch_utilization else None self.elmo = args.input_module == "elmo" - self.use_bert = bool(args.input_module.startswith("bert")) + self.use_pytorch_transformers = input_module_uses_pytorch_transformers(args.input_module) + self.project_before_pooling = not ( + self.use_pytorch_transformers and args.transfer_paradigm == "finetune" + ) # Rough heuristic. TODO: Make this directly user-controllable. self.sep_embs_for_skip = args.sep_embs_for_skip def forward(self, task, batch, predict=False): @@ -842,7 +860,7 @@ def _nli_diagnostic_forward(self, batch, task, predict): # embed the sentence classifier = self._get_classifier(task) - if self.use_bert: + if self.use_pytorch_transformers: sent, mask = self.sent_encoder(batch["inputs"], task) logits = classifier(sent, mask) else: @@ -880,7 +898,7 @@ def _pair_sentence_forward(self, batch, task, predict): # embed the sentence classifier = self._get_classifier(task) - if self.use_bert: + if self.use_pytorch_transformers: sent, mask = self.sent_encoder(batch["inputs"], task) # special case for WiC b/c we want to add representations of particular tokens if isinstance(task, WiCTask): @@ -1039,12 +1057,12 @@ def _mc_forward(self, batch, task, predict): logits = [] module = self._get_classifier(task) - if self.use_bert: + if self.use_pytorch_transformers: for choice_idx in range(task.n_choices): sent, mask = self.sent_encoder(batch["choice%d" % choice_idx], task) logit = module(sent, mask) logits.append(logit) - out["n_exs"] = batch["choice0"]["bert_wpm_pretokenized"].size(0) + out["n_exs"] = batch["choice0"]["pytorch_transformers_wpm_pretokenized"].size(0) else: ctx, ctx_mask = self.sent_encoder(batch["question"], task) for choice_idx in range(task.n_choices): @@ -1110,12 +1128,12 @@ def _multiple_choice_reading_comprehension_forward(self, batch, task, predict): """ out = {} classifier = self._get_classifier(task) - if self.use_bert: - # if using BERT, we concatenate the passage, question, and answer + if self.use_pytorch_transformers: + # if using BERT/XLNet, we concatenate the passage, question, and answer inp = batch["psg_qst_ans"] ex_embs, ex_mask = self.sent_encoder(inp, task) logits = classifier(ex_embs, ex_mask) - out["n_exs"] = inp["bert_wpm_pretokenized"].size(0) + out["n_exs"] = inp["pytorch_transformers_wpm_pretokenized"].size(0) else: # else, we embed each independently and concat them psg_emb, psg_mask = self.sent_encoder(batch["psg"], task) diff --git a/jiant/modules/sentence_encoder.py b/jiant/modules/sentence_encoder.py index 24e8c4aed..7c536e832 100644 --- a/jiant/modules/sentence_encoder.py +++ b/jiant/modules/sentence_encoder.py @@ -11,13 +11,13 @@ from allennlp.nn import InitializerApplicator, util from allennlp.modules import Highway, TimeDistributed -from ..bert.utils import BertEmbedderModule -from ..tasks.tasks import PairClassificationTask, PairRegressionTask -from ..utils import utils -from .simple_modules import NullPhraseLayer -from .bilm_encoder import BiLMEncoder -from .onlstm.ON_LSTM import ONLSTMStack -from .prpn.PRPN import PRPN +from jiant.pytorch_transformers_interface.modules import PytorchTransformersEmbedderModule +from jiant.tasks.tasks import PairClassificationTask, PairRegressionTask +from jiant.utils import utils +from jiant.modules.simple_modules import NullPhraseLayer +from jiant.modules.bilm_encoder import BiLMEncoder +from jiant.modules.onlstm.ON_LSTM import ONLSTMStack +from jiant.modules.prpn.PRPN import PRPN class SentenceEncoder(Model): @@ -83,34 +83,19 @@ def forward(self, sent, task, reset=True): if reset: self.reset_states() - # Embeddings - # Note: These highway modules are actually identity functions by - # default. - is_pair_task = isinstance(task, (PairClassificationTask, PairRegressionTask)) - # General sentence embeddings (for sentence encoder). # Skip this for probing runs that don't need it. if not isinstance(self._phrase_layer, NullPhraseLayer): - if isinstance(self._text_field_embedder, BertEmbedderModule): - word_embs_in_context = self._text_field_embedder(sent, is_pair_task=is_pair_task) - - else: - word_embs_in_context = self._text_field_embedder(sent) - word_embs_in_context = self._highway_layer(word_embs_in_context) + word_embs_in_context = self._highway_layer(self._text_field_embedder(sent)) else: word_embs_in_context = None # Task-specific sentence embeddings (e.g. custom ELMo weights). # Skip computing this if it won't be used. if self.sep_embs_for_skip: - if isinstance(self._text_field_embedder, BertEmbedderModule): - task_word_embs_in_context = self._text_field_embedder( - sent, task._classifier_name, is_pair_task=is_pair_task - ) - - else: - task_word_embs_in_context = self._text_field_embedder(sent, task._classifier_name) - task_word_embs_in_context = self._highway_layer(task_word_embs_in_context) + task_word_embs_in_context = self._highway_layer( + self._text_field_embedder(sent, task._classifier_name) + ) else: task_word_embs_in_context = None diff --git a/jiant/preprocess.py b/jiant/preprocess.py index 35d5a0b38..c989cb524 100644 --- a/jiant/preprocess.py +++ b/jiant/preprocess.py @@ -27,6 +27,7 @@ TokenCharactersIndexer, ) +from jiant.pytorch_transformers_interface import input_module_uses_pytorch_transformers from jiant.tasks import ( ALL_DIAGNOSTICS, ALL_COLA_NPI_TASKS, @@ -106,7 +107,7 @@ def del_field_tokens(instance): del field.tokens -def _index_split(task, split, indexers, vocab, record_file): +def _index_split(task, split, indexers, vocab, record_file, boundary_token_fn): """Index instances and stream to disk. Args: task: Task instance @@ -118,7 +119,7 @@ def _index_split(task, split, indexers, vocab, record_file): log_prefix = "\tTask %s (%s)" % (task.name, split) log.info("%s: Indexing from scratch.", log_prefix) split_text = task.get_split_text(split) - instance_iter = task.process_split(split_text, indexers) + instance_iter = task.process_split(split_text, indexers, boundary_token_fn) if hasattr(instance_iter, "__len__"): # if non-lazy log.warn( "%s: non-lazy Instance generation. You'll want to refactor " @@ -224,9 +225,9 @@ def _build_vocab(args, tasks, vocab_path: str): if args.input_module == "gpt": # Add pre-computed BPE vocabulary for OpenAI transformer model. add_openai_bpe_vocab(vocab, "openai_bpe") - if args.input_module.startswith("bert"): - # Add pre-computed BPE vocabulary for BERT model. - add_bert_wpm_vocab(vocab, args.input_module) + elif input_module_uses_pytorch_transformers(args.input_module): + # Add pre-computed BPE vocabulary for BERT/XLNet model. + add_pytorch_transformers_wpm_vocab(vocab, args.tokenizer) vocab.save_to_files(vocab_path) log.info("\tSaved vocab to %s", vocab_path) @@ -235,11 +236,12 @@ def _build_vocab(args, tasks, vocab_path: str): def build_indexers(args): indexers = {} - if not args.input_module.startswith("bert") and args.input_module not in ["elmo", "gpt"]: + if args.input_module in ["scratch", "glove", "fastText"]: indexers["words"] = SingleIdTokenIndexer() - if args.input_module == "elmo": + elif args.input_module in ["elmo", "elmo-chars-only"]: indexers["elmo"] = ELMoTokenCharactersIndexer("elmo") assert args.tokenizer in {"", "MosesTokenizer"} + if args.char_embs: indexers["chars"] = TokenCharactersIndexer("chars") if args.cove: @@ -247,6 +249,7 @@ def build_indexers(args): f"CoVe model expects Moses tokenization (MosesTokenizer);" " you are using args.tokenizer = {args.tokenizer}" ) + if args.input_module == "gpt": assert ( not indexers @@ -255,14 +258,16 @@ def build_indexers(args): args.tokenizer == "OpenAI.BPE" ), "OpenAI transformer uses custom BPE tokenization. Set tokenizer=OpenAI.BPE." indexers["openai_bpe_pretokenized"] = SingleIdTokenIndexer("openai_bpe") - if args.input_module.startswith("bert"): - assert not indexers, "BERT is not supported alongside other indexers due to tokenization." + elif input_module_uses_pytorch_transformers(args.input_module): + assert ( + not indexers + ), "pytorch_transformers modules like BERT/XLNet are not supported alongside other " + "indexers due to tokenization." assert args.tokenizer == args.input_module, ( - "BERT models use custom WPM tokenization for " - "each model, so tokenizer must match the " - "specified BERT model." + "BERT/XLNet models use custom WPM tokenization for each model, so tokenizer " + "must match the specified model." ) - indexers["bert_wpm_pretokenized"] = SingleIdTokenIndexer(args.input_module) + indexers["pytorch_transformers_wpm_pretokenized"] = SingleIdTokenIndexer(args.input_module) return indexers @@ -305,9 +310,7 @@ def build_tasks(args): # 3) build / load word vectors word_embs = None - if args.input_module not in ["elmo", "gpt", "scratch"] and not args.input_module.startswith( - "bert" - ): + if args.input_module in ["glove", "fastText"]: emb_file = os.path.join(args.exp_dir, "embs.pkl") if args.reload_vocab or not os.path.exists(emb_file): word_embs = _build_embeddings(args, vocab, emb_file) @@ -324,6 +327,19 @@ def build_tasks(args): 'Flag reload_indexing was set, but no tasks are set to reindex (use -o "args.reindex_tasks' ' = "task1,task2,..."")', ) + + # Set up boundary_token_fn, which applies SOS/EOS/SEP/CLS delimiters + if args.input_module.startswith("bert"): + from jiant.pytorch_transformers_interface.modules import BertEmbedderModule + + boundary_token_fn = BertEmbedderModule.apply_boundary_tokens + elif args.input_module.startswith("xlnet"): + from jiant.pytorch_transformers_interface.modules import XLNetEmbedderModule + + boundary_token_fn = XLNetEmbedderModule.apply_boundary_tokens + else: + boundary_token_fn = utils.apply_standard_boundary_tokens + for task in tasks: force_reindex = args.reload_indexing and task.name in reindex_tasks for split in ALL_SPLITS: @@ -338,7 +354,7 @@ def build_tasks(args): if os.path.exists(record_file) and os.path.islink(record_file): os.remove(record_file) - _index_split(task, split, indexers, vocab, record_file) + _index_split(task, split, indexers, vocab, record_file, boundary_token_fn) # Delete in-memory data - we'll lazy-load from disk later. # TODO: delete task.{split}_data_text as well? @@ -552,20 +568,27 @@ def add_task_label_vocab(vocab, task): vocab.add_token_to_namespace(label, namespace) -def add_bert_wpm_vocab(vocab, bert_model_name): - """Add BERT WPM vocabulary for use with pre-tokenized data. +def add_pytorch_transformers_wpm_vocab(vocab, tokenizer_name): + """Add BERT/XLNet WPM vocabulary for use with pre-tokenized data. - BertTokenizer has a convert_tokens_to_ids method, but this doesn't do - anything special so we can just use the standard indexers. + These tokenizers have a convert_tokens_to_ids method, but this doesn't do + anything special, so we can just use the standard indexers. """ - from pytorch_pretrained_bert import BertTokenizer + do_lower_case = "uncased" in tokenizer_name + + if tokenizer_name.startswith("bert"): + from pytorch_transformers import BertTokenizer + + tokenizer = BertTokenizer.from_pretrained(tokenizer_name, do_lower_case=do_lower_case) + else: + from pytorch_transformers import XLNetTokenizer + + tokenizer = XLNetTokenizer.from_pretrained(tokenizer_name, do_lower_case=do_lower_case) - do_lower_case = "uncased" in bert_model_name - tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=do_lower_case) - ordered_vocab = tokenizer.convert_ids_to_tokens(range(len(tokenizer.vocab))) - log.info("BERT WPM vocab (model=%s): %d tokens", bert_model_name, len(ordered_vocab)) + ordered_vocab = tokenizer.convert_ids_to_tokens(range(tokenizer.vocab_size)) + log.info("WPM vocab (%s): %d tokens", tokenizer_name, len(ordered_vocab)) for word in ordered_vocab: - vocab.add_token_to_namespace(word, bert_model_name) + vocab.add_token_to_namespace(word, tokenizer_name) def add_openai_bpe_vocab(vocab, namespace="openai_bpe"): diff --git a/jiant/pytorch_transformers_interface/__init__.py b/jiant/pytorch_transformers_interface/__init__.py new file mode 100644 index 000000000..8dcf98d83 --- /dev/null +++ b/jiant/pytorch_transformers_interface/__init__.py @@ -0,0 +1,14 @@ +""" +Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_ +pytorch_transformers > 1.0 + +These are the same package, though the name changed between these two versions. AllenNLP requires +0.6 to support the BertAdam optimizer, and jiant directly requires 1.0 to support XLNet and +WWM-BERT. + +This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067 +""" + + +def input_module_uses_pytorch_transformers(module_name): + return module_name.startswith("bert-") or module_name.startswith("xlnet-") diff --git a/jiant/pytorch_transformers_interface/modules.py b/jiant/pytorch_transformers_interface/modules.py new file mode 100644 index 000000000..67a9b6f07 --- /dev/null +++ b/jiant/pytorch_transformers_interface/modules.py @@ -0,0 +1,320 @@ +import copy +import logging as log +import os +from typing import Dict + +import torch +import torch.nn as nn +from allennlp.modules import scalar_mix + +import pytorch_transformers + +from jiant.preprocess import parse_task_list_arg +from jiant.utils import utils + + +class PytorchTransformersEmbedderModule(nn.Module): + """ Shared code for pytorch_transformers wrappers. + + Subclasses share a good deal of code, but have a number of subtle differences due to different + APIs from pytorch_transfromers. + """ + + def __init__(self, args): + super(PytorchTransformersEmbedderModule, self).__init__() + + self.cache_dir = os.getenv( + "PYTORCH_PRETRAINED_BERT_CACHE", + os.path.join(args.exp_dir, "pytorch_transformers_cache"), + ) + utils.maybe_make_dir(self.cache_dir) + + self.embeddings_mode = args.pytorch_transformers_output_mode + + # Integer token indices for special symbols. + self._sep_id = None + self._cls_id = None + self._pad_id = None + + # If set, treat these special tokens as part of input segments other than A/B. + self._SEG_ID_CLS = None + self._SEG_ID_SEP = None + + def parameter_setup(self, args): + # Set trainability of this module. + for param in self.model.parameters(): + param.requires_grad = bool(args.transfer_paradigm == "finetune") + + self.num_layers = self.model.config.num_hidden_layers + if args.pytorch_transformers_max_layer >= 0: + self.max_layer = args.pytorch_transformers_max_layer + assert self.max_layer <= self.num_layers + else: + self.max_layer = self.num_layers + + # Configure scalar mixing, ELMo-style. + if self.embeddings_mode == "mix": + if args.transfer_paradigm == "frozen": + log.warning( + "NOTE: pytorch_transformers_output_mode='mix', so scalar " + "mixing weights will be fine-tuned even if BERT " + "model is frozen." + ) + # TODO: if doing multiple target tasks, allow for multiple sets of + # scalars. See the ELMo implementation here: + # https://github.com/allenai/allennlp/blob/master/allennlp/modules/elmo.py#L115 + assert len(parse_task_list_arg(args.target_tasks)) <= 1, ( + "pytorch_transformers_output_mode='mix' only supports a single set of " + "scalars (but if you need this feature, see the TODO in " + "the code!)" + ) + # Always have one more mixing weight, for lexical layer. + self.scalar_mix = scalar_mix.ScalarMix(self.max_layer + 1, do_layer_norm=False) + + def prepare_output(self, lex_seq, hidden_states, mask): + """ + Convert the output of the pytorch_transformers module to a vector sequence as expected by jiant. + + args: + lex_seq: The sequence of input word embeddings as a tensor (batch_size, sequence_length, hidden_size). + Used only if embeddings_mode = "only". + hidden_states: A list of sequences of model hidden states as tensors (batch_size, sequence_length, hidden_size). + mask: A tensor with 1s in positions corresponding to non-padding tokens (batch_size, sequence_length). + + """ + available_layers = hidden_states[: self.max_layer + 1] + + if self.embeddings_mode in ["none", "top"]: + h = available_layers[-1] + elif self.embeddings_mode == "only": + h = lex_seq + elif self.embeddings_mode == "cat": + h = torch.cat([available_layers[-1], lex_seq], dim=2) + elif self.embeddings_mode == "mix": + h = self.scalar_mix(available_layers, mask=mask) + else: + raise NotImplementedError(f"embeddings_mode={self.embeddings_mode}" " not supported.") + + # [batch_size, var_seq_len, output_dim] + return h + + def get_output_dim(self): + if self.embeddings_mode == "cat": + return 2 * self.model.config.hidden_size + else: + return self.model.config.hidden_size + + def get_seg_ids(self, token_ids): + """ Dynamically build the segment IDs for a concatenated pair of sentences + Searches for index _sep_id in the tensor. Supports BERT or XLNet-style padding. + Sets padding tokens to segment zero. + + args: + token_ids (torch.LongTensor): batch of token IDs + + returns: + seg_ids (torch.LongTensor): batch of segment IDs + + example: + > sents = ["[CLS]", "I", "am", "a", "cat", ".", "[SEP]", "You", "like", "cats", "?", "[SEP]", "[PAD]"] + > token_tensor = torch.Tensor([[vocab[w] for w in sent]]) # a tensor of token indices + > seg_ids = get_seg_ids(token_tensor) + > assert seg_ids == torch.LongTensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]) + """ + + sep_idxs = (token_ids == self._sep_id).nonzero()[:, 1] + seg_ids = torch.zeros_like(token_ids) + for row_idx, row in enumerate(token_ids): + sep_idxs = (row == self._sep_id).nonzero() + seg = 0 + prev_sep_idx = -1 + for sep_idx in sep_idxs: + seg_ids[row_idx, prev_sep_idx + 1 : sep_idx + 1].fill_(seg) + seg = 1 - seg # Alternate. + prev_sep_idx = sep_idx + + if self._SEG_ID_CLS is not None: + seg_ids[token_ids == self._cls_id] = self._SEG_ID_CLS + + if self._SEG_ID_SEP is not None: + seg_ids[token_ids == self._sep_id] = self._SEG_ID_SEP + + return seg_ids + + +class BertEmbedderModule(PytorchTransformersEmbedderModule): + """ Wrapper for BERT module to fit into jiant APIs. """ + + def __init__(self, args): + super(BertEmbedderModule, self).__init__(args) + + self.model = pytorch_transformers.BertModel.from_pretrained( + args.input_module, cache_dir=self.cache_dir, output_hidden_states=True + ) + + tokenizer = pytorch_transformers.BertTokenizer.from_pretrained( + args.input_module, cache_dir=self.cache_dir, do_lower_case="uncased" in args.tokenizer + ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. + self._sep_id = tokenizer.convert_tokens_to_ids("[SEP]") + self._cls_id = tokenizer.convert_tokens_to_ids("[CLS]") + self._pad_id = tokenizer.convert_tokens_to_ids("[PAD]") + + self.parameter_setup(args) + + @staticmethod + def apply_boundary_tokens(s1, s2=None): + # BERT-style boundary token padding on string token sequences + if s2: + return ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"] + else: + return ["[CLS]"] + s1 + ["[SEP]"] + + def forward( + self, sent: Dict[str, torch.LongTensor], unused_task_name: str = "" + ) -> torch.FloatTensor: + """ Run BERT to get hidden states. + + This forward method does preprocessing on the go, + changing token IDs from preprocessed bert to + what AllenNLP indexes. + + Args: + sent: batch dictionary + + Returns: + h: [batch_size, seq_len, d_emb] + """ + assert "pytorch_transformers_wpm_pretokenized" in sent + # [batch_size, var_seq_len] + ids = sent["pytorch_transformers_wpm_pretokenized"] + # BERT supports up to 512 tokens; see section 3.2 of https://arxiv.org/pdf/1810.04805.pdf + assert ids.size()[1] <= 512 + + mask = ids != 0 + # "Correct" ids to account for different indexing between BERT and + # AllenNLP. + # The AllenNLP indexer adds a '@@UNKNOWN@@' token to the + # beginning of the vocabulary, *and* treats that as index 1 (index 0 is + # reserved for padding). + ids[ids == 0] = self._pad_id + 2 # Shift the indices that were at 0 to become 2. + # Index 1 should never be used since the BERT WPM uses its own + # unk token, and handles this at the string level before indexing. + assert (ids > 1).all() + ids -= 2 # shift indices to match BERT wordpiece embeddings + + if self.embeddings_mode not in ["none", "top"]: + # This is redundant with the lookup inside BertModel, + # but doing so this way avoids the need to modify the BertModel + # code. + # Extract lexical embeddings + lex_seq = self.model.embeddings.word_embeddings(ids) + lex_seq = self.model.embeddings.LayerNorm(lex_seq) + hidden_states = [] # dummy; should not be accessed. + # following our use of the OpenAI model, don't use dropout for + # probing. If you would like to use dropout, consider applying + # later on in the SentenceEncoder (see models.py). + # h_lex = self.model.embeddings.dropout(embeddings) + else: + lex_seq = None # dummy; should not be accessed. + + if self.embeddings_mode != "only": + # encoded_layers is a list of layer activations, each of which is + # [batch_size, seq_len, output_dim] + token_types = self.get_seg_ids(ids) + _, output_pooled_vec, hidden_states = self.model( + ids, token_type_ids=token_types, attention_mask=mask + ) + + # [batch_size, var_seq_len, output_dim] + return self.prepare_output(lex_seq, hidden_states, mask) + + +class XLNetEmbedderModule(PytorchTransformersEmbedderModule): + """ Wrapper for XLNet module to fit into jiant APIs. """ + + def __init__(self, args): + + super(XLNetEmbedderModule, self).__init__(args) + + self.model = pytorch_transformers.XLNetModel.from_pretrained( + args.input_module, cache_dir=self.cache_dir, output_hidden_states=True + ) + + tokenizer = pytorch_transformers.XLNetTokenizer.from_pretrained( + args.input_module, cache_dir=self.cache_dir, do_lower_case="uncased" in args.tokenizer + ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. + self._sep_id = tokenizer.convert_tokens_to_ids("") + self._cls_id = tokenizer.convert_tokens_to_ids("") + self._pad_id = tokenizer.convert_tokens_to_ids("") + self._unk_id = tokenizer.convert_tokens_to_ids("") + + self.parameter_setup(args) + + # Segment IDs for CLS and SEP tokens. Unlike in BERT, these aren't part of the usual 0/1 input segments. + # Standard constants reused from pytorch_transformers. They aren't actually used within the pytorch_transformers code, so we're reproducing them here in case they're removed in a later cleanup. + self._SEG_ID_CLS = 2 + self._SEG_ID_SEP = 3 + + @staticmethod + def apply_boundary_tokens(s1, s2=None): + # XLNet-style boundary token marking on string token sequences + if s2: + return s1 + [""] + s2 + ["", ""] + else: + return s1 + ["", ""] + + def forward( + self, sent: Dict[str, torch.LongTensor], unused_task_name: str = "" + ) -> torch.FloatTensor: + """ Run XLNet to get hidden states. + + This forward method does preprocessing on the go, + changing token IDs from preprocessed word pieces to + what AllenNLP indexes. + + Args: + sent: batch dictionary + + Returns: + h: [batch_size, seq_len, d_emb] + """ + assert "pytorch_transformers_wpm_pretokenized" in sent + + # [batch_size, var_seq_len] + # Make a copy so our padding modifications below don't impact masking decisions elsewhere. + ids = copy.deepcopy(sent["pytorch_transformers_wpm_pretokenized"]) + + mask = ids != 0 + + # "Correct" ids to account for different indexing between XLNet and + # AllenNLP. + # The AllenNLP indexer adds a '@@UNKNOWN@@' token to the + # beginning of the vocabulary, *and* treats that as index 1 (index 0 is + # reserved for native padding). + ids[ids == 0] = self._pad_id + 2 # Rewrite padding indices. + ids[ids == 1] = self._unk_id + 2 # Rewrite UNK indices. + ids -= 2 # shift indices to match XLNet wordpiece embeddings + + if self.embeddings_mode not in ["none", "top"]: + # This is redundant with the lookup inside XLNetModel, + # but doing so this way avoids the need to modify the XLNetModel + # code. + lex_seq = self.model.word_embedding(ids) + hidden_states = [] # dummy; should not be accessed. + # following our use of the OpenAI model, don't use dropout for + # probing. If you would like to use dropout, consider applying + # later on in the SentenceEncoder (see models.py). + # h_lex = self.model.embeddings.dropout(embeddings) + else: + lex_seq = None # dummy; should not be accessed. + + if self.embeddings_mode != "only": + # encoded_layers is a list of layer activations, each of which is + # [batch_size, seq_len, output_dim] + token_types = self.get_seg_ids(ids) + _, output_mems, hidden_states = self.model( + ids, token_type_ids=token_types, attention_mask=mask + ) + + # [batch_size, var_seq_len, output_dim] + return self.prepare_output(lex_seq, hidden_states, mask) diff --git a/jiant/tasks/edge_probing.py b/jiant/tasks/edge_probing.py index 4090e27c7..0b6f0816a 100644 --- a/jiant/tasks/edge_probing.py +++ b/jiant/tasks/edge_probing.py @@ -178,19 +178,10 @@ def get_num_examples(cls, split_text): def _make_span_field(cls, s, text_field, offset=1): return SpanField(s[0] + offset, s[1] - 1 + offset, text_field) - def _pad_tokens(self, tokens): - """Pad tokens according to the current tokenization style.""" - if self.tokenizer_name.startswith("bert-"): - # standard padding for BERT; see - # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/extract_features.py#L85 # noqa - return ["[CLS]"] + tokens + ["[SEP]"] - else: - return [utils.SOS_TOK] + tokens + [utils.EOS_TOK] - - def make_instance(self, record, idx, indexers) -> Type[Instance]: + def make_instance(self, record, idx, indexers, boundary_token_fn) -> Type[Instance]: """Convert a single record to an AllenNLP Instance.""" tokens = record["text"].split() # already space-tokenized by Moses - tokens = self._pad_tokens(tokens) + tokens = boundary_token_fn(tokens) # apply model-appropriate variants of [cls] and [sep]. text_field = sentence_to_text_field(tokens, indexers) d = {} @@ -218,11 +209,11 @@ def make_instance(self, record, idx, indexers) -> Type[Instance]: ) return Instance(d) - def process_split(self, records, indexers) -> Iterable[Type[Instance]]: + def process_split(self, records, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ def _map_fn(r, idx): - return self.make_instance(r, idx, indexers) + return self.make_instance(r, idx, indexers, boundary_token_fn) return map(_map_fn, records, itertools.count()) diff --git a/jiant/tasks/lm.py b/jiant/tasks/lm.py index 6c548edb6..e0bdd6962 100644 --- a/jiant/tasks/lm.py +++ b/jiant/tasks/lm.py @@ -8,7 +8,7 @@ from allennlp.data.token_indexers import SingleIdTokenIndexer from allennlp.training.metrics import Average -from jiant.utils.data_loaders import process_sentence +from jiant.utils.data_loaders import tokenize_and_truncate from jiant.tasks.registry import register_task from jiant.tasks.tasks import ( UNK_TOK_ALLENNLP, @@ -81,9 +81,9 @@ def get_data_iter(self, path): toks = row.strip() if not toks: continue - yield process_sentence(self._tokenizer_name, toks, self.max_seq_len) + yield tokenize_and_truncate(self._tokenizer_name, toks, self.max_seq_len) - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """Process a language modeling split by indexing and creating fields. Args: split: (list) a single list of sentences @@ -95,6 +95,7 @@ def _make_instance(sent_): and bwd targs adds as a target for input to avoid issues with needing to strip extra tokens in the input for each direction """ + sent_ = boundary_token_fn(sent_) # Add and d = { "input": sentence_to_text_field(sent_, indexers), "targs": sentence_to_text_field(sent_[1:] + [sent_[0]], self.target_indexer), diff --git a/jiant/tasks/lm_parsing.py b/jiant/tasks/lm_parsing.py index ab6db3a16..b1b746e2b 100644 --- a/jiant/tasks/lm_parsing.py +++ b/jiant/tasks/lm_parsing.py @@ -32,12 +32,16 @@ def count_examples(self): example_counts[split] = int(math.ceil(allf / self.max_seq_len)) self.example_counts = example_counts - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """Process a language modeling split by indexing and creating fields. Args: split: (list) a single list of sentences indexers: (Indexer object) indexer to index input words + boundary_token_fn: Inserts start and end symbols for classification tasks. + Not used here. This may be a problem for GPT-2 or future LMs that use non-standard + boundary tokens. """ + del boundary_token_fn # Enforcing that this won't be used. def _make_instance(sent): """ Forward targs adds as a target for input diff --git a/jiant/tasks/qa.py b/jiant/tasks/qa.py index f40da6f1a..875e34839 100644 --- a/jiant/tasks/qa.py +++ b/jiant/tasks/qa.py @@ -11,7 +11,7 @@ from allennlp.data.fields import LabelField, MetadataField from allennlp.data import Instance -from jiant.utils.data_loaders import process_sentence +from jiant.utils.data_loaders import tokenize_and_truncate from jiant.tasks.tasks import Task from jiant.tasks.tasks import sentence_to_text_field @@ -111,20 +111,21 @@ def load_data_for_path(self, path): assert ( "version" in ex and ex["version"] == 1.1 - ), "MultiRC version is invalid! Example indices are likely incorrect. Please re-download the data from super.gluebenchmark.com ." + ), "MultiRC version is invalid! Example indices are likely incorrect. " + "Please re-download the data from super.gluebenchmark.com ." # each example has a passage field -> (text, questions) # text is the passage, which requires some preprocessing # questions is a list of questions, has fields (question, sentences_used, answers) - ex["passage"]["text"] = process_sentence( + ex["passage"]["text"] = tokenize_and_truncate( self.tokenizer_name, ex["passage"]["text"], self.max_seq_len ) for question in ex["passage"]["questions"]: - question["question"] = process_sentence( + question["question"] = tokenize_and_truncate( self.tokenizer_name, question["question"], self.max_seq_len ) for answer in question["answers"]: - answer["text"] = process_sentence( + answer["text"] = tokenize_and_truncate( self.tokenizer_name, answer["text"], self.max_seq_len ) examples.append(ex) @@ -143,9 +144,9 @@ def get_sentences(self) -> Iterable[Sequence[str]]: for answer in question["answers"]: yield answer["text"] - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def _make_instance(passage, question, answer, label, par_idx, qst_idx, ans_idx): """ pq_id: passage-question ID """ @@ -157,13 +158,13 @@ def _make_instance(passage, question, answer, label, par_idx, qst_idx, ans_idx): d["qst_idx"] = MetadataField(qst_idx) d["ans_idx"] = MetadataField(ans_idx) d["idx"] = MetadataField(ans_idx) # required by evaluate() - if is_using_bert: - inp = para + question[1:-1] + answer[1:] + if is_using_pytorch_transformers: + inp = boundary_token_fn(para, question + answer) d["psg_qst_ans"] = sentence_to_text_field(inp, indexers) else: - d["psg"] = sentence_to_text_field(passage, indexers) - d["qst"] = sentence_to_text_field(question, indexers) - d["ans"] = sentence_to_text_field(answer, indexers) + d["psg"] = sentence_to_text_field(boundary_token_fn(passage), indexers) + d["qst"] = sentence_to_text_field(boundary_token_fn(question), indexers) + d["ans"] = sentence_to_text_field(boundary_token_fn(answer), indexers) d["label"] = LabelField(label, label_namespace="labels", skip_indexing=True) return Instance(d) @@ -266,15 +267,17 @@ def tokenize_preserve_placeholder(sent): sent_parts = sent.split("@placeholder") assert len(sent_parts) == 2 sent_parts = [ - process_sentence(self.tokenizer_name, s, self.max_seq_len) for s in sent_parts + tokenize_and_truncate(self.tokenizer_name, s, self.max_seq_len) for s in sent_parts ] - return sent_parts[0][:-1] + ["@placeholder"] + sent_parts[1][1:] + return sent_parts[0] + ["@placeholder"] + sent_parts[1] examples = [] data = [json.loads(d) for d in open(path, encoding="utf-8")] for item in data: psg_id = item["idx"] - psg = process_sentence(self.tokenizer_name, item["passage"]["text"], self.max_seq_len) + psg = tokenize_and_truncate( + self.tokenizer_name, item["passage"]["text"], self.max_seq_len + ) ent_idxs = item["passage"]["entities"] ents = [item["passage"]["text"][idx["start"] : idx["end"] + 1] for idx in ent_idxs] qas = item["qas"] @@ -298,7 +301,6 @@ def tokenize_preserve_placeholder(sent): return examples def _load_answers(self) -> None: - """ """ answers = {} for split, split_path in self.files_by_split.items(): data = [json.loads(d) for d in open(split_path, encoding="utf-8")] @@ -322,9 +324,9 @@ def get_sentences(self) -> Iterable[Sequence[str]]: yield example["passage"] yield example["query"] - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def is_answer(x, ys): """ Given a list of answers, determine if x is an answer """ @@ -346,12 +348,12 @@ def _make_instance(psg, qst, ans_str, label, psg_idx, qst_idx, ans_idx): d["qst_idx"] = MetadataField(qst_idx) d["ans_idx"] = MetadataField(ans_idx) d["idx"] = MetadataField(ans_idx) # required by evaluate() - if is_using_bert: - inp = psg + qst[1:] + if is_using_pytorch_transformers: + inp = boundary_token_fn(psg, qst) d["psg_qst_ans"] = sentence_to_text_field(inp, indexers) else: - d["psg"] = sentence_to_text_field(psg, indexers) - d["qst"] = sentence_to_text_field(qst, indexers) + d["psg"] = sentence_to_text_field(boundary_token_fn(psg), indexers) + d["qst"] = sentence_to_text_field(boundary_token_fn(qst), indexers) d["label"] = LabelField(label, label_namespace="labels", skip_indexing=True) return Instance(d) @@ -362,7 +364,7 @@ def _make_instance(psg, qst, ans_str, label, psg_idx, qst_idx, ans_idx): ent_strs = example["ents"] ents = [ - process_sentence(self._tokenizer_name, ent, self.max_seq_len)[1:-1] + tokenize_and_truncate(self._tokenizer_name, ent, self.max_seq_len) for ent in ent_strs ] diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py index 7c771c5bd..71308023d 100644 --- a/jiant/tasks/tasks.py +++ b/jiant/tasks/tasks.py @@ -32,7 +32,7 @@ load_diagnostic_tsv, load_span_data, load_tsv, - process_sentence, + tokenize_and_truncate, load_pair_nli_jsonl, ) from jiant.utils.tokenizers import get_tokenizer @@ -68,12 +68,14 @@ def atomic_tokenize( with the *first* nonatomic token in the list. """ for nonatomic_tok in nonatomic_toks: sent = sent.replace(nonatomic_tok, atomic_tok) - sent = process_sentence(tokenizer_name, sent, max_seq_len) + sent = tokenize_and_truncate(tokenizer_name, sent, max_seq_len) sent = [nonatomic_toks[0] if t == atomic_tok else t for t in sent] return sent -def process_single_pair_task_split(split, indexers, is_pair=True, classification=True): +def process_single_pair_task_split( + split, indexers, boundary_token_fn, is_pair=True, classification=True +): """ Convert a dataset of sentences into padded sequences of indices. Shared across several classes. @@ -81,6 +83,8 @@ def process_single_pair_task_split(split, indexers, is_pair=True, classification Args: - split (list[list[str]]): list of inputs (possibly pair) and outputs - indexers () + - boundary_token_fn (list[str], list[str] (optional) -> list[str]): + A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to a token sequence. - is_pair (Bool) - classification (Bool) @@ -88,20 +92,20 @@ def process_single_pair_task_split(split, indexers, is_pair=True, classification - instances (Iterable[Instance]): an iterable of AllenNLP Instances with fields """ # check here if using bert to avoid passing model info to tasks - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def _make_instance(input1, input2, labels, idx): d = {} - d["sent1_str"] = MetadataField(" ".join(input1[1:-1])) - if is_using_bert and is_pair: - inp = input1 + input2[1:] # throw away input2 leading [CLS] + d["sent1_str"] = MetadataField(" ".join(input1)) + if is_using_pytorch_transformers and is_pair: + inp = boundary_token_fn(input1, input2) d["inputs"] = sentence_to_text_field(inp, indexers) - d["sent2_str"] = MetadataField(" ".join(input2[1:-1])) + d["sent2_str"] = MetadataField(" ".join(input2)) else: - d["input1"] = sentence_to_text_field(input1, indexers) + d["input1"] = sentence_to_text_field(boundary_token_fn(input1), indexers) if input2: - d["input2"] = sentence_to_text_field(input2, indexers) - d["sent2_str"] = MetadataField(" ".join(input2[1:-1])) + d["input2"] = sentence_to_text_field(boundary_token_fn(input2), indexers) + d["sent2_str"] = MetadataField(" ".join(input2)) if classification: d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) else: @@ -249,7 +253,7 @@ def get_num_examples(self, split_text): """ return len(split_text[0]) - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ raise NotImplementedError @@ -294,9 +298,9 @@ def get_metrics(self, reset=False): acc = self.scorer1.get_metric(reset) return {"accuracy": acc} - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ - return process_single_pair_task_split(split, indexers, is_pair=False) + return process_single_pair_task_split(split, indexers, boundary_token_fn, is_pair=False) class PairClassificationTask(ClassificationTask): @@ -316,9 +320,9 @@ def get_metrics(self, reset=False): acc = self.scorer1.get_metric(reset) return {"accuracy": acc} - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ - return process_single_pair_task_split(split, indexers, is_pair=True) + return process_single_pair_task_split(split, indexers, boundary_token_fn, is_pair=True) class PairRegressionTask(RegressionTask): @@ -337,9 +341,11 @@ def get_metrics(self, reset=False): mse = self.scorer1.get_metric(reset) return {"mse": mse} - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ - return process_single_pair_task_split(split, indexers, is_pair=True, classification=False) + return process_single_pair_task_split( + split, indexers, boundary_token_fn, is_pair=True, classification=False + ) class PairOrdinalRegressionTask(RegressionTask): @@ -361,9 +367,11 @@ def get_metrics(self, reset=False): spearmanr = self.scorer2.get_metric(reset) return {"1-mse": 1 - mse, "mse": mse, "spearmanr": spearmanr} - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ - return process_single_pair_task_split(split, indexers, is_pair=True, classification=False) + return process_single_pair_task_split( + split, indexers, boundary_token_fn, is_pair=True, classification=False + ) def update_metrics(self, logits, labels, tagmask=None): self.scorer1(mean_squared_error(logits, labels)) # update average MSE @@ -715,12 +723,12 @@ def load_data(self): ) log.info("\tFinished loading CoLA sperate domain.") - def process_split(self, split, indexers): + def process_split(self, split, indexers, boundary_token_fn): def _make_instance(input1, labels, tagids): """ from multiple types in one column create multiple fields """ d = {} - d["input1"] = sentence_to_text_field(input1, indexers) - d["sent1_str"] = MetadataField(" ".join(input1[1:-1])) + d["input1"] = sentence_to_text_field(boundary_token_fn(input1), indexers) + d["sent1_str"] = MetadataField(" ".join(input1)) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) d["tagmask"] = MultiLabelField( tagids, label_namespace="tags", skip_indexing=True, num_labels=len(self.tag_list) @@ -1330,9 +1338,9 @@ def update_scores_for_tag_group(ix_to_tags_dic, tag_group): self._scorer_all_mcc(preds, labels) self._scorer_all_acc(logits, labels) - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def create_labels_from_tags(fields_dict, ix_to_tag_dict, tag_arr, tag_group): # If there is something in this row then tag_group should be set to @@ -1355,16 +1363,16 @@ def create_labels_from_tags(fields_dict, ix_to_tag_dict, tag_arr, tag_group): def _make_instance(input1, input2, label, idx, lex_sem, pr_ar_str, logic, knowledge): """ from multiple types in one column create multiple fields """ d = {} - if is_using_bert: - inp = input1 + input2[1:] # drop the leading [CLS] token + if is_using_pytorch_transformers: + inp = boundary_token_fn(input1, input2) d["inputs"] = sentence_to_text_field(inp, indexers) else: - d["input1"] = sentence_to_text_field(input1, indexers) - d["input2"] = sentence_to_text_field(input2, indexers) + d["input1"] = sentence_to_text_field(boundary_token_fn(input1), indexers) + d["input2"] = sentence_to_text_field(boundary_token_fn(input2), indexers) d["labels"] = LabelField(label, label_namespace="labels", skip_indexing=True) d["idx"] = LabelField(idx, label_namespace="idx_tags", skip_indexing=True) - d["sent1_str"] = MetadataField(" ".join(input1[1:-1])) - d["sent2_str"] = MetadataField(" ".join(input2[1:-1])) + d["sent1_str"] = MetadataField(" ".join(input1)) + d["sent2_str"] = MetadataField(" ".join(input2)) # adds keys to dict "d" for every possible type in the column create_labels_from_tags(d, self.ix_to_lex_sem_dic, lex_sem, "lex_sem") @@ -1407,7 +1415,8 @@ def collect_metrics(ix_to_tag_dict, tag_group): # SuperGLUE diagnostic (2-class NLI), expects JSONL @register_task("broadcoverage-diagnostic", rel_path="RTE/diagnostics") class BroadCoverageDiagnosticTask(GLUEDiagnosticTask): - """ Class for SuperGLUE broad coverage (linguistics, commonsense, world knowledge) diagnostic task """ + """ Class for SuperGLUE broad coverage (linguistics, commonsense, world knowledge) + diagnostic task """ def __init__(self, path, max_seq_len, name, **kw): super().__init__(path, max_seq_len, name, n_classes=2, **kw) @@ -1456,10 +1465,12 @@ def create_score_function(scorer, arg_to_scorer, tags_dict, tag_group): targ_map = {"entailment": 1, "not_entailment": 0} data = [json.loads(d) for d in open(os.path.join(self.path, "AX-b.jsonl"))] sent1s = [ - process_sentence(self._tokenizer_name, d["sentence1"], self.max_seq_len) for d in data + tokenize_and_truncate(self._tokenizer_name, d["sentence1"], self.max_seq_len) + for d in data ] sent2s = [ - process_sentence(self._tokenizer_name, d["sentence2"], self.max_seq_len) for d in data + tokenize_and_truncate(self._tokenizer_name, d["sentence2"], self.max_seq_len) + for d in data ] labels = [targ_map[d["label"]] for d in data] idxs = [int(d["idx"]) for d in data] @@ -1552,21 +1563,21 @@ def load_data(self): ) log.info("\tFinished loading winogender (from SuperGLUE formatted data).") - def process_split(self, split, indexers): - is_using_bert = "bert_wpm_pretokenized" in indexers + def process_split(self, split, indexers, boundary_token_fn): + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def _make_instance(input1, input2, labels, idx, pair_id): d = {} - d["sent1_str"] = MetadataField(" ".join(input1[1:-1])) - if is_using_bert: - inp = input1 + input2[1:] # throw away input2 leading [CLS] + d["sent1_str"] = MetadataField(" ".join(input1)) + if is_using_pytorch_transformers: + inp = boundary_token_fn(input1, input2) d["inputs"] = sentence_to_text_field(inp, indexers) - d["sent2_str"] = MetadataField(" ".join(input2[1:-1])) + d["sent2_str"] = MetadataField(" ".join(input2)) else: - d["input1"] = sentence_to_text_field(input1, indexers) + d["input1"] = sentence_to_text_field(boundary_token_fn(input1), indexers) if input2: - d["input2"] = sentence_to_text_field(input2, indexers) - d["sent2_str"] = MetadataField(" ".join(input2[1:-1])) + d["input2"] = sentence_to_text_field(boundary_token_fn(input2), indexers) + d["sent2_str"] = MetadataField(" ".join(input2)) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) d["idx"] = LabelField(idx, label_namespace="idxs_tags", skip_indexing=True) d["pair_id"] = LabelField(pair_id, label_namespace="pair_id_tags", skip_indexing=True) @@ -1667,10 +1678,14 @@ def _load_jsonl(data_file): sent1s, sent2s, trgs, idxs = [], [], [], [] for example in data: sent1s.append( - process_sentence(self._tokenizer_name, example["premise"], self.max_seq_len) + tokenize_and_truncate( + self._tokenizer_name, example["premise"], self.max_seq_len + ) ) sent2s.append( - process_sentence(self._tokenizer_name, example["hypothesis"], self.max_seq_len) + tokenize_and_truncate( + self._tokenizer_name, example["hypothesis"], self.max_seq_len + ) ) trg = targ_map[example["label"]] if "label" in example else 0 trgs.append(trg) @@ -1906,13 +1921,13 @@ def get_sentences(self) -> Iterable[Sequence[str]]: for sent in self.load_data_for_path(path): yield sent - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process a language modeling split. Split is a single list of sentences here. """ def _make_instance(input1, input2, labels): d = {} - d["input1"] = sentence_to_text_field(input1, indexers) - d["input2"] = sentence_to_text_field(input2, indexers) + d["input1"] = sentence_to_text_field(boundary_token_fn(input1), indexers) + d["input2"] = sentence_to_text_field(boundary_token_fn(input2), indexers) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) return Instance(d) @@ -1976,8 +1991,8 @@ def load_data_for_path(self, path): row = row.strip().split("\t") if len(row) != 3 or not (row[0] and row[1] and row[2]): continue - sent1 = process_sentence(self._tokenizer_name, row[0], self.max_seq_len) - sent2 = process_sentence(self._tokenizer_name, row[1], self.max_seq_len) + sent1 = tokenize_and_truncate(self._tokenizer_name, row[0], self.max_seq_len) + sent2 = tokenize_and_truncate(self._tokenizer_name, row[1], self.max_seq_len) targ = int(row[2]) yield (sent1, sent2, targ) @@ -1999,18 +2014,18 @@ def count_examples(self): example_counts[split] = sum(1 for line in open(split_path)) self.example_counts = example_counts - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def _make_instance(input1, input2, labels): d = {} - if is_using_bert: - inp = input1 + input2[1:] # drop leading [CLS] token + if is_using_pytorch_transformers: + inp = boundary_token_fn(input1, input2) # drop leading [CLS] token d["inputs"] = sentence_to_text_field(inp, indexers) else: - d["input1"] = sentence_to_text_field(input1, indexers) - d["input2"] = sentence_to_text_field(input2, indexers) + d["input1"] = sentence_to_text_field(boundary_token_fn(input1), indexers) + d["input2"] = sentence_to_text_field(boundary_token_fn(input2), indexers) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) return Instance(d) @@ -2123,7 +2138,7 @@ def __init__(self, path, max_seq_len, name="ccg", **kw): self.val_data_text = None self.test_data_text = None - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process a tagging task """ inputs = [TextField(list(map(Token, sent)), token_indexers=indexers) for sent in split[0]] targs = [ @@ -2312,19 +2327,10 @@ def _make_span_field(self, s, text_field, offset=1): # so minus 1 at the end index. return SpanField(s[0] + offset, s[1] - 1 + offset, text_field) - def _pad_tokens(self, tokens): - """Pad tokens according to the current tokenization style.""" - if self.tokenizer_name.startswith("bert-"): - # standard padding for BERT; see - # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/extract_features.py#L85 - return ["[CLS]"] + tokens + ["[SEP]"] - else: - return [utils.SOS_TOK] + tokens + [utils.EOS_TOK] - - def make_instance(self, record, idx, indexers) -> Type[Instance]: + def make_instance(self, record, idx, indexers, boundary_token_fn) -> Type[Instance]: """Convert a single record to an AllenNLP Instance.""" tokens = record["text"].split() - tokens = self._pad_tokens(tokens) + tokens = boundary_token_fn(tokens) text_field = sentence_to_text_field(tokens, indexers) example = {} @@ -2341,11 +2347,11 @@ def make_instance(self, record, idx, indexers) -> Type[Instance]: ) return Instance(example) - def process_split(self, records, indexers) -> Iterable[Type[Instance]]: + def process_split(self, records, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AllenNLP Instances. """ def _map_fn(r, idx): - return self.make_instance(r, idx, indexers) + return self.make_instance(r, idx, indexers, boundary_token_fn) return map(_map_fn, records, itertools.count()) @@ -2402,10 +2408,14 @@ def _load_data(data_file): sent1s, sent2s, targs, idxs = [], [], [], [] for example in data: sent1s.append( - process_sentence(self._tokenizer_name, example["premise"], self.max_seq_len) + tokenize_and_truncate( + self._tokenizer_name, example["premise"], self.max_seq_len + ) ) sent2s.append( - process_sentence(self._tokenizer_name, example["hypothesis"], self.max_seq_len) + tokenize_and_truncate( + self._tokenizer_name, example["hypothesis"], self.max_seq_len + ) ) trg = targ_map[example["label"]] if "label" in example else 0 targs.append(trg) @@ -2469,12 +2479,12 @@ def _process_preserving_word(sent, word): then concatenate everything together. This allows us to track where in the tokenized sequence the marked word is located. """ sent_parts = sent.split(word) - sent_tok1 = process_sentence(self._tokenizer_name, sent_parts[0], self.max_seq_len) - sent_tok2 = process_sentence(self._tokenizer_name, sent_parts[1], self.max_seq_len) - sent_mid = process_sentence(self._tokenizer_name, word, self.max_seq_len) - sent_tok = sent_tok1[:-1] + sent_mid[1:-1] + sent_tok2[1:] - start_idx = len(sent_tok1[:-1]) - end_idx = start_idx + len(sent_mid[1:-1]) + sent_tok1 = tokenize_and_truncate(self._tokenizer_name, sent_parts[0], self.max_seq_len) + sent_tok2 = tokenize_and_truncate(self._tokenizer_name, sent_parts[1], self.max_seq_len) + sent_mid = tokenize_and_truncate(self._tokenizer_name, word, self.max_seq_len) + sent_tok = sent_tok1 + sent_mid + sent_tok2 + start_idx = len(sent_tok1) + end_idx = start_idx + len(sent_mid) assert end_idx > start_idx, "Invalid marked word indices. Something is wrong." return sent_tok, start_idx, end_idx @@ -2498,7 +2508,9 @@ def _load_split(data_file): idxs.append(row["idx"]) assert ( "version" in row and row["version"] == 1.1 - ), "WiC version is not v1.1; examples indices are likely incorrect and data is likely pre-tokenized. Please re-download the data from super.gluebenchmark for the correct data." + ), "WiC version is not v1.1; examples indices are likely incorrect and data " + "is likely pre-tokenized. Please re-download the data from " + "super.gluebenchmark.com." return [sents1, sents2, idxs1, idxs2, trgs, idxs] self.train_data_text = _load_split(os.path.join(self.path, "train.jsonl")) @@ -2512,26 +2524,26 @@ def _load_split(data_file): ) log.info("\tFinished loading WiC data.") - def process_split(self, split, indexers): + def process_split(self, split, indexers, boundary_token_fn): """ Convert a dataset of sentences into padded sequences of indices. Shared across several classes. """ # check here if using bert to avoid passing model info to tasks - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def _make_instance(input1, input2, idxs1, idxs2, labels, idx): d = {} - d["sent1_str"] = MetadataField(" ".join(input1[1:-1])) - d["sent2_str"] = MetadataField(" ".join(input2[1:-1])) - if is_using_bert: - inp = input1 + input2[1:] # throw away input2 leading [CLS] + d["sent1_str"] = MetadataField(" ".join(input1)) + d["sent2_str"] = MetadataField(" ".join(input2)) + if is_using_pytorch_transformers: + inp = boundary_token_fn(input1, input2) d["inputs"] = sentence_to_text_field(inp, indexers) idxs2 = (idxs2[0] + len(input1), idxs2[1] + len(input1)) else: - d["input1"] = sentence_to_text_field(input1, indexers) - d["input2"] = sentence_to_text_field(input2, indexers) + d["input1"] = sentence_to_text_field(boundary_token_fn(input1), indexers) + d["input2"] = sentence_to_text_field(boundary_token_fn(input2), indexers) d["idx1"] = ListField([NumericField(i) for i in range(idxs1[0], idxs1[1])]) d["idx2"] = ListField([NumericField(i) for i in range(idxs2[0], idxs2[1])]) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) @@ -2595,13 +2607,17 @@ def _load_split(data_file): else "What happened as a result?" ) choices = [ - process_sentence(self._tokenizer_name, choice, self.max_seq_len) + tokenize_and_truncate(self._tokenizer_name, choice, self.max_seq_len) for choice in [choice1, choice2] ] targ = example["label"] if "label" in example else 0 - contexts.append(process_sentence(self._tokenizer_name, context, self.max_seq_len)) + contexts.append( + tokenize_and_truncate(self._tokenizer_name, context, self.max_seq_len) + ) choicess.append(choices) - questions.append(process_sentence(self._tokenizer_name, question, self.max_seq_len)) + questions.append( + tokenize_and_truncate(self._tokenizer_name, question, self.max_seq_len) + ) targs.append(targ) return [contexts, choicess, questions, targs] @@ -2616,19 +2632,23 @@ def _load_split(data_file): ) log.info("\tFinished loading COPA (as QA) data.") - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AlleNNLP Instances. """ - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def _make_instance(context, choices, question, label, idx): d = {} - d["question_str"] = MetadataField(" ".join(context[1:-1])) - if not is_using_bert: - d["question"] = sentence_to_text_field(context, indexers) + d["question_str"] = MetadataField(" ".join(context)) + if not is_using_pytorch_transformers: + d["question"] = sentence_to_text_field(boundary_token_fn(context), indexers) for choice_idx, choice in enumerate(choices): - inp = context + question[1:] + choice[1:] if is_using_bert else choice + inp = ( + boundary_token_fn(context, question + choice) + if is_using_pytorch_transformers + else boundary_token_fn(choice) + ) d["choice%d" % choice_idx] = sentence_to_text_field(inp, indexers) - d["choice%d_str" % choice_idx] = MetadataField(" ".join(choice[1:-1])) + d["choice%d_str" % choice_idx] = MetadataField(" ".join(choice)) d["label"] = LabelField(label, label_namespace="labels", skip_indexing=True) d["idx"] = LabelField(idx, label_namespace="idxs_tags", skip_indexing=True) return Instance(d) @@ -2671,13 +2691,13 @@ def _load_split(data_file): questions, choicess, targs = [], [], [] data = pd.read_csv(data_file) for ex_idx, ex in data.iterrows(): - sent1 = process_sentence(self._tokenizer_name, ex["sent1"], self.max_seq_len) + sent1 = tokenize_and_truncate(self._tokenizer_name, ex["sent1"], self.max_seq_len) questions.append(sent1) sent2_prefix = ex["sent2"] choices = [] for i in range(4): choice = sent2_prefix + " " + ex["ending%d" % i] - choice = process_sentence(self._tokenizer_name, choice, self.max_seq_len) + choice = tokenize_and_truncate(self._tokenizer_name, choice, self.max_seq_len) choices.append(choice) choicess.append(choices) targ = ex["label"] if "label" in ex else 0 @@ -2695,19 +2715,23 @@ def _load_split(data_file): ) log.info("\tFinished loading SWAG data.") - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AlleNNLP Instances. """ - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def _make_instance(question, choices, label, idx): d = {} - d["question_str"] = MetadataField(" ".join(question[1:-1])) - if not is_using_bert: - d["question"] = sentence_to_text_field(question, indexers) + d["question_str"] = MetadataField(" ".join(question)) + if not is_using_pytorch_transformers: + d["question"] = sentence_to_text_field(boundary_token_fn(question), indexers) for choice_idx, choice in enumerate(choices): - inp = question + choice[1:] if is_using_bert else choice + inp = ( + boundary_token_fn(question, choice) + if is_using_pytorch_transformers + else boundary_token_fn(choice) + ) d["choice%d" % choice_idx] = sentence_to_text_field(inp, indexers) - d["choice%d_str" % choice_idx] = MetadataField(" ".join(choice[1:-1])) + d["choice%d_str" % choice_idx] = MetadataField(" ".join(choice)) d["label"] = LabelField(label, label_namespace="labels", skip_indexing=True) d["idx"] = LabelField(idx, label_namespace="idxs_tags", skip_indexing=True) return Instance(d) @@ -2789,8 +2813,12 @@ def _load_jsonl(data_file): raw_data = [json.loads(d) for d in open(data_file, encoding="utf-8")] data = [] for d in raw_data: - question = process_sentence(self._tokenizer_name, d["question"], self.max_seq_len) - passage = process_sentence(self._tokenizer_name, d["passage"], self.max_seq_len) + question = tokenize_and_truncate( + self._tokenizer_name, d["question"], self.max_seq_len + ) + passage = tokenize_and_truncate( + self._tokenizer_name, d["passage"], self.max_seq_len + ) new_datum = {"question": question, "passage": passage} answer = d["label"] if "label" in d else False new_datum["label"] = answer @@ -2806,19 +2834,19 @@ def _load_jsonl(data_file): ] log.info("\tFinished loading BoolQ data.") - def process_split(self, split, indexers) -> Iterable[Type[Instance]]: + def process_split(self, split, indexers, boundary_token_fn) -> Iterable[Type[Instance]]: """ Process split text into a list of AlleNNLP Instances. """ - is_using_bert = "bert_wpm_pretokenized" in indexers + is_using_pytorch_transformers = "pytorch_transformers_wpm_pretokenized" in indexers def _make_instance(d, idx): new_d = {} - new_d["question_str"] = MetadataField(" ".join(d["question"][1:-1])) - new_d["passage_str"] = MetadataField(" ".join(d["passage"][1:-1])) - if not is_using_bert: - new_d["input1"] = sentence_to_text_field(d["passage"], indexers) - new_d["input2"] = sentence_to_text_field(d["question"], indexers) - else: # BERT - psg_qst = d["passage"] + d["question"][1:] + new_d["question_str"] = MetadataField(" ".join(d["question"])) + new_d["passage_str"] = MetadataField(" ".join(d["passage"])) + if not is_using_pytorch_transformers: + new_d["input1"] = sentence_to_text_field(boundary_token_fn(d["passage"]), indexers) + new_d["input2"] = sentence_to_text_field(boundary_token_fn(d["question"]), indexers) + else: # BERT/XLNet + psg_qst = boundary_token_fn(d["passage"], d["question"]) new_d["inputs"] = sentence_to_text_field(psg_qst, indexers) new_d["labels"] = LabelField(d["label"], label_namespace="labels", skip_indexing=True) new_d["idx"] = LabelField(idx, label_namespace="idxs_tags", skip_indexing=True) diff --git a/jiant/utils/data_loaders.py b/jiant/utils/data_loaders.py index 5337bbb92..0a30ab27f 100644 --- a/jiant/utils/data_loaders.py +++ b/jiant/utils/data_loaders.py @@ -13,18 +13,16 @@ from jiant.utils.tokenizers import get_tokenizer from jiant.utils.retokenize import realign_spans -BERT_CLS_TOK, BERT_SEP_TOK = "[CLS]", "[SEP]" -SOS_TOK, EOS_TOK = "", "" - def load_span_data(tokenizer_name, file_name, label_fn=None, has_labels=True): """ - Load a span-related task file in .jsonl format, does re-alignment of spans, and tokenizes the text. + Load a span-related task file in .jsonl format, does re-alignment of spans, and tokenizes + the text. Re-alignment of spans involves transforming the spans so that it matches the text after tokenization. - For example, given the original text: [Mr., Porter, is, nice] and bert-base-cased tokenization, we get - [Mr, ., Por, ter, is, nice ]. If the original span indices was [0,2], under the new tokenization, - it becomes [0, 3]. + For example, given the original text: [Mr., Porter, is, nice] and bert-base-cased + tokenization, we get [Mr, ., Por, ter, is, nice ]. If the original span indices was [0,2], + under the new tokenization, it becomes [0, 3]. The task file should of be of the following form: text: str, label: bool @@ -32,7 +30,8 @@ def load_span_data(tokenizer_name, file_name, label_fn=None, has_labels=True): Args: tokenizer_name: str, file_name: str, - label_fn: function that expects a row and outputs a transformed row with labels tarnsformed. + label_fn: function that expects a row and outputs a transformed row with labels + transformed. Returns: List of dictionaries of the aligned spans and tokenized text. """ @@ -48,27 +47,27 @@ def load_span_data(tokenizer_name, file_name, label_fn=None, has_labels=True): def load_pair_nli_jsonl(data_file, tokenizer_name, max_seq_len, targ_map): """ - Loads a pair NLI task. + Loads a pair NLI task. Parameters ----------------- data_file: path to data file, - tokenizer_name: str, - max_seq_len: int, - targ_map: a dictionary that maps labels to ints + tokenizer_name: str, + max_seq_len: int, + targ_map: a dictionary that maps labels to ints Returns ----------------- - sent1s: list of strings of tokenized first sentences, - sent2s: list of strings of tokenized second sentences, + sent1s: list of strings of tokenized first sentences, + sent2s: list of strings of tokenized second sentences, trgs: list of ints of labels, idxs: list of ints """ data = [json.loads(d) for d in open(data_file, encoding="utf-8")] sent1s, sent2s, trgs, idxs, pair_ids = [], [], [], [], [] for example in data: - sent1s.append(process_sentence(tokenizer_name, example["premise"], max_seq_len)) - sent2s.append(process_sentence(tokenizer_name, example["hypothesis"], max_seq_len)) + sent1s.append(tokenize_and_truncate(tokenizer_name, example["premise"], max_seq_len)) + sent2s.append(tokenize_and_truncate(tokenizer_name, example["hypothesis"], max_seq_len)) trg = targ_map[example["label"]] if "label" in example else 0 trgs.append(trg) idxs.append(example["idx"]) @@ -143,11 +142,11 @@ def load_tsv( if has_labels: mask = mask & rows[label_idx].notnull() rows = rows.loc[mask] - sent1s = rows[s1_idx].apply(lambda x: process_sentence(tokenizer_name, x, max_seq_len)) + sent1s = rows[s1_idx].apply(lambda x: tokenize_and_truncate(tokenizer_name, x, max_seq_len)) if s2_idx is None: sent2s = pd.Series() else: - sent2s = rows[s2_idx].apply(lambda x: process_sentence(tokenizer_name, x, max_seq_len)) + sent2s = rows[s2_idx].apply(lambda x: tokenize_and_truncate(tokenizer_name, x, max_seq_len)) label_fn = label_fn if label_fn is not None else (lambda x: x) if has_labels: @@ -236,8 +235,8 @@ def targs_to_idx(col_name): rows[col_name] = rows[col_name].apply(lambda x: [word_to_idx[x]] if x != "" else []) return word_to_idx, idx_to_word, rows[col_name] - sent1s = rows[s1_col].apply(lambda x: process_sentence(tokenizer_name, x, max_seq_len)) - sent2s = rows[s2_col].apply(lambda x: process_sentence(tokenizer_name, x, max_seq_len)) + sent1s = rows[s1_col].apply(lambda x: tokenize_and_truncate(tokenizer_name, x, max_seq_len)) + sent2s = rows[s2_col].apply(lambda x: tokenize_and_truncate(tokenizer_name, x, max_seq_len)) labels = rows[label_col].apply(lambda x: label_fn(x)) # Build indices for field attributes lex_sem_to_ix_dic, ix_to_lex_sem_dic, lex_sem = targs_to_idx("Lexical Semantics") @@ -286,17 +285,13 @@ def get_tag_list(tag_vocab): return tag_list -def process_sentence(tokenizer_name, sent, max_seq_len): - """process a sentence """ - max_seq_len -= 2 - assert max_seq_len > 0, "Max sequence length should be at least 2!" +def tokenize_and_truncate(tokenizer_name, sent, max_seq_len): + """Truncate and tokenize a sentence or paragraph.""" + max_seq_len -= 2 # For boundary tokens. tokenizer = get_tokenizer(tokenizer_name) - if tokenizer_name.startswith("bert-"): - sos_tok, eos_tok = BERT_CLS_TOK, BERT_SEP_TOK - else: - sos_tok, eos_tok = SOS_TOK, EOS_TOK + if isinstance(sent, str): - return [sos_tok] + tokenizer.tokenize(sent)[:max_seq_len] + [eos_tok] + return tokenizer.tokenize(sent)[:max_seq_len] elif isinstance(sent, list): assert isinstance(sent[0], str), "Invalid sentence found!" - return [sos_tok] + sent[:max_seq_len] + [eos_tok] + return sent[:max_seq_len] diff --git a/jiant/utils/retokenize.py b/jiant/utils/retokenize.py index 835617c59..b1819a379 100644 --- a/jiant/utils/retokenize.py +++ b/jiant/utils/retokenize.py @@ -98,7 +98,7 @@ def realign_spans(record, tokenizer_name): """ Builds the indices alignment while also tokenizing the input piece by piece. - Only BERT and Moses tokenization is supported currently. + Only BERT/XLNet and Moses tokenization is supported currently. Parameters ----------------------- @@ -286,7 +286,7 @@ def space_tokenize_with_eow(sentence): return [t + "" for t in sentence.split()] -def process_bert_wordpiece_for_alignment(t): +def process_wordpiece_for_alignment(t): """Add markers to ensure word-boundary alignment.""" if t.startswith("##"): return re.sub(r"^##", "", t) @@ -315,15 +315,15 @@ def align_openai(text: Text) -> Tuple[TokenAligner, List[Text]]: return ta, bpe_tokens -def align_bert(text: Text, model_name: str) -> Tuple[TokenAligner, List[Text]]: +def align_wpm(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]: # If using lowercase, do this for the source tokens for better matching. - do_lower_case = model_name.endswith("uncased") + do_lower_case = tokenizer_name.endswith("uncased") bow_tokens = space_tokenize_with_bow(text.lower() if do_lower_case else text) - bert_tokenizer = get_tokenizer(model_name) - wpm_tokens = bert_tokenizer.tokenize(text) + wpm_tokenizer = get_tokenizer(tokenizer_name) + wpm_tokens = wpm_tokenizer.tokenize(text) # Align using markers for stability w.r.t. word boundaries. - modified_wpm_tokens = list(map(process_bert_wordpiece_for_alignment, wpm_tokens)) + modified_wpm_tokens = list(map(process_wordpiece_for_alignment, wpm_tokens)) ta = TokenAligner(bow_tokens, modified_wpm_tokens) return ta, wpm_tokens @@ -333,7 +333,7 @@ def get_aligner_fn(tokenizer_name: Text): return align_moses elif tokenizer_name == "OpenAI.BPE": return align_openai - elif tokenizer_name.startswith("bert-"): - return functools.partial(align_bert, model_name=tokenizer_name) + elif tokenizer_name.startswith("bert-") or tokenizer_name.startswith("xlnet-"): + return functools.partial(align_wpm, tokenizer_name=tokenizer_name) else: raise ValueError(f"Unsupported tokenizer '{tokenizer_name}'") diff --git a/jiant/utils/tokenizers.py b/jiant/utils/tokenizers.py index 6d072466a..45ec87133 100644 --- a/jiant/utils/tokenizers.py +++ b/jiant/utils/tokenizers.py @@ -17,6 +17,22 @@ def tokenize(self, sentence): raise NotImplementedError +def select_tokenizer(args): + """ + Select a sane default tokenizer. + """ + if args.tokenizer == "auto": + if args.input_module.startswith("bert-") or args.input_module.startswith("xlnet-"): + tokenizer_name = args.input_module + elif args.input_module == "gpt": + tokenizer_name = "OpenAI.BPE" + else: + tokenizer_name = "MosesTokenizer" + else: + tokenizer_name = args.tokenizer + return tokenizer_name + + class OpenAIBPETokenizer(Tokenizer): # TODO: Add detokenize method to OpenAIBPE class def __init__(self): @@ -64,10 +80,15 @@ def detokenize(self, tokens): def get_tokenizer(tokenizer_name): log.info(f"\tLoading Tokenizer {tokenizer_name}") if tokenizer_name.startswith("bert-"): - from pytorch_pretrained_bert import BertTokenizer + from pytorch_transformers import BertTokenizer do_lower_case = tokenizer_name.endswith("uncased") tokenizer = BertTokenizer.from_pretrained(tokenizer_name, do_lower_case=do_lower_case) + elif tokenizer_name.startswith("xlnet-"): + from pytorch_transformers import XLNetTokenizer + + do_lower_case = tokenizer_name.endswith("uncased") + tokenizer = XLNetTokenizer.from_pretrained(tokenizer_name, do_lower_case=do_lower_case) elif tokenizer_name == "OpenAI.BPE": tokenizer = OpenAIBPETokenizer() elif tokenizer_name == "MosesTokenizer": diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index e1c446d7f..31c9b2562 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -33,6 +33,32 @@ _MOSES_DETOKENIZER = MosesDetokenizer() +def select_pool_type(args): + """ + Select a sane default sequence pooling type. + """ + if args.pool_type == "auto": + if args.sent_enc == "none" and args.input_module.startswith("bert-"): + pool_type = "first" + elif args.sent_enc == "none" and args.input_module.startswith("xlnet-"): + pool_type = "final" + elif args.sent_enc == "none" and args.input_module == "gpt": + pool_type = "final" + else: + pool_type = "max" + else: + pool_type = args.pool_type + return pool_type + + +def apply_standard_boundary_tokens(s1, s2=None): + """Apply and to sequences of string-valued tokens. + Corresponds to more complex functions used with models like XLNet and BERT. + """ + assert not s2, "apply_standard_boundary_tokens only supports single sequences" + return [SOS_TOK] + s1 + [EOS_TOK] + + def check_for_previous_checkpoints(serialization_dir, tasks, phase, load_model): """ Check if there are previous checkpoints. @@ -154,10 +180,12 @@ def parse_json_diff(diff): actual value of the replaced or inserted item, whereas for jsondiff.delete, we do not want to show deletions in our parameters. For example, for jsondiff.replace, the output of jsondiff may be the below: - {'mrpc': {replace: ConfigTree([('classifier_dropout', 0.1), ('classifier_hid_dim', 256), ('max_vals', 8), ('val_interval', 1)])}} + {'mrpc': {replace: ConfigTree([('classifier_dropout', 0.1), ('classifier_hid_dim', 256), + ('max_vals', 8), ('val_interval', 1)])}} since 'mrpc' was overriden in demo.conf. Thus, we only want to show the update and delete the replace. The output of this function will be: - {'mrpc': ConfigTree([('classifier_dropout', 0.1), ('classifier_hid_dim', 256), ('max_vals', 8), ('val_interval', 1)])} + {'mrpc': ConfigTree([('classifier_dropout', 0.1), ('classifier_hid_dim', 256), + ('max_vals', 8), ('val_interval', 1)])} See for more information on jsondiff. """ new_diff = {} diff --git a/main.py b/main.py index 4316a395e..d13c19ce0 100644 --- a/main.py +++ b/main.py @@ -26,7 +26,7 @@ from jiant.preprocess import build_tasks from jiant import tasks as task_modules from jiant.trainer import build_trainer -from jiant.utils import config +from jiant.utils import config, tokenizers from jiant.utils.utils import ( assert_for_log, load_model_state, @@ -35,6 +35,7 @@ sort_param_recursive, select_relevant_print_args, check_for_previous_checkpoints, + select_pool_type, delete_all_checkpoints, ) @@ -189,9 +190,13 @@ def check_configurations(args, pretrain_tasks, target_tasks): args.load_model or args.load_target_train_checkpoint not in ["none", ""] or args.allow_untrained_encoder_parameters - ), "Evaluating a model without training it on this run or loading a checkpoint. Set `allow_untrained_encoder_parameters` if you really want to use an untrained task model." + ), "Evaluating a model without training it on this run or loading a checkpoint. " + "Set `allow_untrained_encoder_parameters` if you really want to use an untrained " + "task model." log.warning( - "Evauluating a target task model without training it in this run. It's up to you to ensure that you are loading parameters that were sufficiently trained for this task." + "Evauluating a target task model without training it in this run. It's up to " + "you to ensure that you are loading parameters that were sufficiently trained " + "for this task." ) steps_log.write("Evaluating model on tasks: %s \n" % args.target_tasks) @@ -365,6 +370,11 @@ def initial_setup(args, cl_args): ) args.cuda = -1 + if args.tokenizer == "auto": + args.tokenizer = tokenizers.select_tokenizer(args) + if args.pool_type == "auto": + args.pool_type = select_pool_type(args) + return args, seed @@ -386,13 +396,14 @@ def check_arg_name(args): for task in task_modules.ALL_GLUE_TASKS + task_modules.ALL_SUPERGLUE_TASKS: assert_for_log( not args.regex_contains("^{}_".format(task)), - "Error: Attempting to load old task-specific args for task %s, please refer to the master branch's default configs for the most recent task specific argument structures." - % task, + "Error: Attempting to load old task-specific args for task %s, please refer to the " + "master branch's default configs for the most recent task specific argument " + "structures." % task, ) for old_name, new_name in name_dict.items(): assert_for_log( old_name not in args, - "Error: Attempting to load old arg name [%s], please update to new name [%s]" + "Error: Attempting to load old arg name %s, please update to new name %s." % (old_name, name_dict[old_name]), ) old_input_module_vals = [ @@ -405,7 +416,8 @@ def check_arg_name(args): for input_type in old_input_module_vals: assert_for_log( input_type not in args, - "Error: Attempting to load old arg name [%s], please use input_module config parameter and refer to master branch's default configs for current way to specify [%s]" + "Error: Attempting to load old arg name %s, please use input_module config " + "parameter and refer to master branch's default configs for current way to specify %s." % (input_type, input_type), ) diff --git a/scripts/edgeprobing/exp_fns.sh b/scripts/edgeprobing/exp_fns.sh index 763b9d74a..82113a787 100644 --- a/scripts/edgeprobing/exp_fns.sh +++ b/scripts/edgeprobing/exp_fns.sh @@ -119,7 +119,7 @@ function openai_cat_exp() { # Usage: openai_cat_exp OVERRIDES="exp_name=openai-cat-$1, run_name=run" OVERRIDES+=", target_tasks=$1" - OVERRIDES+=", openai_embeddings_mode=cat" + OVERRIDES+=", openai_output_mode=cat" run_exp "config/edgeprobe/edgeprobe_openai.conf" "${OVERRIDES}" } @@ -128,7 +128,7 @@ function openai_lex_exp() { # Usage: openai_lex_exp OVERRIDES="exp_name=openai-lex-$1, run_name=run" OVERRIDES+=", target_tasks=$1" - OVERRIDES+=", openai_embeddings_mode=only" + OVERRIDES+=", openai_output_mode=only" run_exp "config/edgeprobe/edgeprobe_openai.conf" "${OVERRIDES}" } @@ -137,7 +137,7 @@ function openai_mix_exp() { # Usage: openai_mix_exp OVERRIDES="exp_name=openai-mix-$1, run_name=run" OVERRIDES+=", target_tasks=$1" - OVERRIDES+=", openai_embeddings_mode=mix" + OVERRIDES+=", openai_output_mode=mix" run_exp "config/edgeprobe/edgeprobe_openai.conf" "${OVERRIDES}" } @@ -148,7 +148,7 @@ function openai_bwb_exp() { OVERRIDES="exp_name=openai-bwb-$1, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", openai_transformer_ckpt=${CKPT_PATH}" - OVERRIDES+=", openai_embeddings_mode=cat" + OVERRIDES+=", openai_output_mode=cat" run_exp "config/edgeprobe/edgeprobe_openai.conf" "${OVERRIDES}" } @@ -162,7 +162,7 @@ function bert_cat_exp() { OVERRIDES="exp_name=bert-${2}-cat-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", bert_embeddings_mode=cat" + OVERRIDES+=", pytorch_transformers_output_mode=cat" run_exp "config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } @@ -172,7 +172,7 @@ function bert_lex_exp() { OVERRIDES="exp_name=bert-${2}-lex-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", bert_embeddings_mode=only" + OVERRIDES+=", pytorch_transformers_output_mode=only" run_exp "config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } @@ -182,7 +182,7 @@ function bert_mix_exp() { OVERRIDES="exp_name=bert-${2}-mix-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", bert_embeddings_mode=mix" + OVERRIDES+=", pytorch_transformers_output_mode=mix" run_exp "config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } @@ -194,8 +194,8 @@ function bert_mix_k_exp() { OVERRIDES="exp_name=bert-${2}-mix_${3}-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", bert_embeddings_mode=mix" - OVERRIDES+=", bert_max_layer=${3}" + OVERRIDES+=", pytorch_transformers_output_mode=mix" + OVERRIDES+=", pytorch_transformers_max_layer=${3}" run_exp "config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } @@ -205,7 +205,7 @@ function bert_at_k_exp() { OVERRIDES="exp_name=bert-${2}-at_${3}-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", bert_embeddings_mode=top" - OVERRIDES+=", bert_max_layer=${3}" + OVERRIDES+=", pytorch_transformers_output_mode=top" + OVERRIDES+=", pytorch_transformers_max_layer=${3}" run_exp "config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py index f0c033cbd..7e8faa1f6 100644 --- a/tests/test_checkpointing.py +++ b/tests/test_checkpointing.py @@ -64,7 +64,7 @@ def setUp(self): self.temp_dir = tempfile.mkdtemp() self.path = os.path.join(self.temp_dir, "temp_dataset.tsv") self.wic = tasks.WiCTask(self.temp_dir, 100, "wic", tokenizer_name="MosesTokenizer") - indexers = {"bert_wpm_pretokenized": SingleIdTokenIndexer("bert-xe-cased")} + indexers = {"pytorch_transformers_wpm_pretokenized": SingleIdTokenIndexer("bert-xe-cased")} self.wic.val_data = [ Instance( { diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 3874eff99..815b3175b 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -50,7 +50,7 @@ def test_build_indexers(self): self.params3 = params_from_file(self.DEFAULTS_PATH, self.HOCON3) self.params4 = params_from_file(self.DEFAULTS_PATH, self.HOCON4) indexer = build_indexers(self.params1) - len(indexer) == 1 and list(indexer.keys())[0] == "bert_wpm_pretokenized" + len(indexer) == 1 and list(indexer.keys())[0] == "pytorch_transformers_wpm_pretokenized" indexer = build_indexers(self.params2) len(indexer) == 1 and list(indexer.keys())[0] == "words" indexer = build_indexers(self.params3) diff --git a/tests/test_pytorch_transformers_interface.py b/tests/test_pytorch_transformers_interface.py new file mode 100644 index 000000000..68420eb3d --- /dev/null +++ b/tests/test_pytorch_transformers_interface.py @@ -0,0 +1,74 @@ +import unittest +from unittest import mock +import torch +from jiant.pytorch_transformers_interface.modules import BertEmbedderModule, XLNetEmbedderModule + + +class TestPytorchTransformersInterface(unittest.TestCase): + def test_bert_apply_boundary_tokens(self): + s1 = ["A", "B", "C"] + s2 = ["D", "E"] + self.assertListEqual( + BertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"] + ) + self.assertListEqual( + BertEmbedderModule.apply_boundary_tokens(s1, s2), + ["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"], + ) + + def test_xlnet_apply_boundary_tokens(self): + s1 = ["A", "B", "C"] + s2 = ["D", "E"] + self.assertListEqual( + XLNetEmbedderModule.apply_boundary_tokens(s1), ["A", "B", "C", "", ""] + ) + self.assertListEqual( + XLNetEmbedderModule.apply_boundary_tokens(s1, s2), + ["A", "B", "C", "", "D", "E", "", ""], + ) + + def test_bert_seg_ids(self): + bert_model = mock.Mock() + bert_model._sep_id = 3 + bert_model._cls_id = 5 + bert_model._pad_id = 7 + bert_model._SEG_ID_CLS = None + bert_model._SEG_ID_SEP = None + bert_model.get_seg_ids = BertEmbedderModule.get_seg_ids + + # [CLS] 8 [SEP] 9 10 11 [SEP] + # [CLS] 8 9 [SEP] 10 [SEP] [PAD] + inp = torch.Tensor([[5, 8, 3, 9, 10, 11, 3], [5, 8, 9, 3, 10, 3, 7]]) + output = bert_model.get_seg_ids(bert_model, inp) + assert torch.all( + torch.eq(output, torch.Tensor([[0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 0, 1, 1, 0]])) + ) + + # [CLS] 8 9 [SEP] + # [CLS] 8 [SEP] [PAD] + inp = torch.Tensor([[5, 8, 9, 3], [5, 9, 3, 7]]) + output = bert_model.get_seg_ids(bert_model, inp) + assert torch.all(torch.eq(output, torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0]]))) + + def test_xlnet_seg_ids(self): + xlnet_model = mock.Mock() + xlnet_model._sep_id = 3 + xlnet_model._cls_id = 5 + xlnet_model._pad_id = 7 + xlnet_model._SEG_ID_CLS = 2 + xlnet_model._SEG_ID_SEP = 3 + xlnet_model.get_seg_ids = XLNetEmbedderModule.get_seg_ids + + # 8 [SEP] 9 10 11 [SEP] [CLS] + # 8 [SEP] 9 10 [SEP] [CLS] [PAD] + inp = torch.Tensor([[8, 3, 9, 10, 11, 3, 5], [8, 3, 9, 10, 3, 5, 7]]) + output = xlnet_model.get_seg_ids(xlnet_model, inp) + assert torch.all( + torch.eq(output, torch.Tensor([[0, 3, 1, 1, 1, 3, 2], [0, 3, 1, 1, 3, 2, 0]])) + ) + + # 8 9 10 [SEP] [CLS] + # 8 9 [SEP] [CLS] [PAD] + inp = torch.Tensor([[8, 9, 10, 3, 5], [8, 9, 3, 5, 7]]) + output = xlnet_model.get_seg_ids(xlnet_model, inp) + assert torch.all(torch.eq(output, torch.Tensor([[0, 0, 0, 3, 2], [0, 0, 3, 2, 0]]))) diff --git a/tests/test_write_preds.py b/tests/test_write_preds.py index 18932117d..82bcfcb6a 100644 --- a/tests/test_write_preds.py +++ b/tests/test_write_preds.py @@ -89,7 +89,7 @@ def setUp(self): }, ] ) - indexers = {"bert_wpm_pretokenized": SingleIdTokenIndexer("bert-xe-cased")} + indexers = {"pytorch_transformers_wpm_pretokenized": SingleIdTokenIndexer("bert-xe-cased")} self.wic.val_data = [ Instance( { diff --git a/tutorials/adding_tasks.md b/tutorials/adding_tasks.md index b6b32981e..8c3bd5bc0 100644 --- a/tutorials/adding_tasks.md +++ b/tutorials/adding_tasks.md @@ -41,7 +41,7 @@ A lot of the following functions may already be written for your task type (espe -1. The `load_data` (inheritable) function is for loading your data. This function loads your TSV to a format that can be made into AllenNLP iterators. In this function, you will want to call `load_tsv` from `jiant/utils/data_loaders.py`, which loads and tokenizes the data. Currently, only English tokenization is supported. You can specify the fields to load from as parameters to `load_tsv` (which right now is based on number-based indexing). See [here](https://github.com/jsalt18-sentence-repl/jiant/blob/master/jiant/utils/data_loaders.py) for more documentation on `load_tsv`. An example is below +1. The `load_data` (inheritable) function is for loading your data. This function loads your TSV/JSONL/... to a format that can be made into AllenNLP iterators. In this function, you will want to call `load_tsv` from `jiant/utils/data_loaders.py`, which loads and tokenizes the data. Currently, only English tokenization is supported. You can specify the fields to load from as parameters to `load_tsv` (which right now is based on number-based indexing). See [here](https://github.com/jsalt18-sentence-repl/jiant/blob/master/jiant/utils/data_loaders.py) for more documentation on `load_tsv`. An example is below ```python def load_data(self, path, max_seq_len): @@ -70,9 +70,9 @@ A lot of the following functions may already be written for your task type (espe d['sent2_str'] = MetadataField(" ".join(input2[1:-1])) return Instance(d) ``` -3. `update_metrics` (inheritable) function is a function to update scorers, which are configerable scorers (mostly from AllenNLP) such as F1Measure or BooleanAccuracy that keeps track of task-specific scores. Let us say that we want to only update F1 and ignore accuracy. In that case, you can set self.scorers = [self.f1_scorer], and this will automatically set the inherited update_metrics function to only update the F1 scorer. +3. `update_metrics` (inheritable) is a function to update scorers, which are configerable scorers (mostly from AllenNLP) such as F1Measure or BooleanAccuracy that keeps track of task-specific scores. Let us say that we want to only update F1 and ignore accuracy. In that case, you can set self.scorers = [self.f1_scorer], and this will automatically set the inherited update_metrics function to only update the F1 scorer. -4. `get_metrics` is a function (inheritable) that returns the metrics from the updated scorers in dictionary form. Since we're only getting F1, we should set the get_metrics function to be: +4. `get_metrics` (inheritable) is a function that returns the metrics from the updated scorers in dictionary form. Since we're only getting F1, we should set the get_metrics function to be: ```python def get_metrics(self, reset=False): '''Get metrics specific to the task''' @@ -83,13 +83,13 @@ A lot of the following functions may already be written for your task type (espe ```python self.sentences = self.train_data_text[0] + self.val_data_text[0] ``` -6. `process_split` (inheritable) takes in a split of your data and produces an iterable of AllenNLP Instances. An Instance is a wrapper around a dictionary of (field_name, Field) pairs. Fields are objects to help with data processing (indexing, padding, etc.). See [here](https://github.com/nyu-mll/jiant/blob/1ee392e8dfba4a5fc5d0aca06a9f780a6a1b1e1e/jiant/tasks/tasks.py#L338) for an example. +6. `process_split` (inheritable) takes in a split of your data and produces an iterable of AllenNLP Instances. An Instance is a wrapper around a dictionary of (field_name, Field) pairs. Fields are objects to help with data processing (indexing, padding, etc.). This is handled for us here, since we inherit from `PairClassificationTask`, but if you're writing a task that inherits directly from `Task`, you should look to `PairClassificationTask` for an example of how to implement this method yourself. 7. `count_examples` (inheritable) sets `task.example_counts` (Dict[str:int]): the number of examples per split (train, val, test). 8. `val_metric` (inheritable) is a string variable that is the name of task-specific metric to track during training, e.g. F1 score. -9. `val_metric_decreases`(inheritable) is a boolean for whether or not the objective function should be minimized or maximized. The default is set to False. +9. `val_metric_decreases` (inheritable) is a boolean for whether or not the objective function should be minimized or maximized. The default is set to False. Your finished task class may look something like this: @@ -133,7 +133,7 @@ class SomeDataClassificationTask(PairClassificationTask): return {'f1': f1} ``` -Phew! Now, you also have to add the models you're going to use for your task, which lives in [`jiant/models/py`](https://github.com/zphang/jiant/blob/repretrain/jiant/models.py). +Phew! Now, you also have to add the models you're going to use for your task, which lives in [`jiant/models/py`](https://github.com/nyu-mll/jiant/blob/master/jiant/models.py). Since our task type is PairClassificationTask, a well supported type of task, we can skip this step. However, if your task type is not well supported (or you want to try a different sort of model), in `jiant/models/models.py`, you will need to change the `build_task_specific_module` function to include a branch for your logic. ```python @@ -168,5 +168,8 @@ Of course, don't forget to define your task-specific module building function! ``` Finally, all you have to do is add the task to either the `pretrain_tasks` or `target_tasks` parameter in the config file, and viola! Your task is added. -If you have any additions or suggested changes to this tutorial, please open an issue on [GitHub](https://github.com/nyu-mll/jiant)! +# Notes +## `boundary_token_fn` + +This method applies boundary tokens (like SOS/EOS) to the edges of your text. It also, for BERT and XLNet, applies tokens like [SEP] that delimit the two halves of a two-part input sequence. So, if you'd like to feeding a two-part input into a BERT/XLNet model as a single sequence, create two token sequences, and feed them to `boundary_token_fn` as two arguments. diff --git a/tutorials/setup_tutorial.md b/tutorials/setup_tutorial.md index 27bc34cae..caf7c9667 100644 --- a/tutorials/setup_tutorial.md +++ b/tutorials/setup_tutorial.md @@ -61,21 +61,21 @@ And the next time you start a notebook server, you should see `jiant` as an opti ### Optional -If you'll be using GPT, BERT, or other models supplied by `pytorch-pretrained-BERT`, then you may see speed gains from installing NVIDIA apex, following the instructions here: +If you'll be using GPT, BERT, or other models supplied by `pytorch-transformers`, then you may see speed gains from installing NVIDIA apex, following the instructions here: https://github.com/NVIDIA/apex#linux ## 2. Getting data and setting up our environment In this tutorial, we will be working with GLUE data. -The repo contains a convenience Python script for downloading all [GLUE](https://gluebenchmark.com/tasks) data: - +The repo contains a convenience Python script for downloading all [GLUE](https://gluebenchmark.com/) and [SuperGLUE](https://super.gluebenchmark.com/) tasks: ``` python scripts/download_glue_data.py --data_dir data --tasks all +python scripts/download_superglue_data.py --data_dir data --tasks all ``` -We also support quite a few other data sources (check [here](https://jiant.info/documentation#/?id=data-sources) for a list). +We also support quite a few other data sources (check [here](https://jiant.info/documentation#/?id=data-sources) for a list). Finally, you'll need to set a few environment variables in [user_config_template.sh](https://github.com/nyu-mll/jiant/blob/master/user_config_template.sh), which include: @@ -131,7 +131,7 @@ Some important options include: * `sent_enc`: If you want to train a new sentence encoder (rather than using a loaded one like BERT), specify it here. This is the only part of the `config/demo.conf` that we should change for our experiment since we want to train a biLSTM encoder. Thus, in your `config/tutorial.conf`, set `sent_enc=rnn`. * `pretrain_tasks`: This is a comma-delimited string of tasks. In `config/demo.conf`, this is set to "sst,mrpc", which is what we want. Note that we have `pretrain_tasks` as a separate field from `target_tasks` because our training loop handles the two phases differently (for example, multitask training is only supported in pretraining stage). Note that there should not be a space in-between tasks. * `target_tasks`: This is a comma-delimited string of tasks you want to fine-tune and evaluate on (in this case "sts-b,wnli"). -* `input_module`: This is a string specifying the type of contextualized word embedding you want to use. In `config/demo.conf`, this is already set to `scratch`. +* `input_module`: This is a string specifying the type of (contextualized) word embedding you want to use. In `config/demo.conf`, this is already set to `scratch`. * `val_interval`: This is the interval (in steps) at which you want to evaluate your model on the validation set during pretraining. A step is a batch update. * `exp_name`, which expects a string of your experiment name. * `run_name`, which expects a string of your run name. @@ -168,11 +168,12 @@ reload_indexing = 0 reload_vocab = 0 pretrain_tasks = "sst,mrpc" -target_tasks = "sts-b,wnli" +target_tasks = "sts-b,commitbank" classifier = mlp classifier_hid_dim = 32 max_seq_len = 10 max_word_v_size = 1000 +pair_attn = 0 input_module = scratch d_word = 50 @@ -194,7 +195,6 @@ sts-b += { max_vals = 16 val_interval = 10 } - ``` Now we get on to the actual experiment running!