From 4a9b058f4833587e1ebb54cd56726eb559c575ec Mon Sep 17 00:00:00 2001 From: Haokun Liu Date: Tue, 28 Jan 2020 09:51:31 -0500 Subject: [PATCH] Update to transformers 2.3.0 & Add ALBERT (#990) * fix roberta tokenization error * update transformers * update alignment func * trim input_module * update lm head * update albert special tokens * input_module_to_pretokenized -> transformer_input_module_to_tokenizer_id * update ccg alignment * fix wic retokenize * update wic docstring, remove unnecessary condition * refactor record task to avoid tokenization problem Co-authored-by: Sam Bowman --- README.md | 2 +- environment.yml | 7 +- gcp/config/jiant_paths.sh | 2 +- gcp/kubernetes/templates/jiant_env.libsonnet | 2 +- gcp/kubernetes/templates/run_batch.jsonnet | 4 +- gcp/set_up_workstation.sh | 4 +- jiant/config/defaults.conf | 55 +++--- jiant/config/examples/stilts_example.conf | 2 +- jiant/config/superglue_bert.conf | 2 +- .../__init__.py | 56 ++++++ .../modules.py | 172 ++++++++++++------ jiant/models.py | 36 ++-- jiant/modules/sentence_encoder.py | 3 +- jiant/preprocess.py | 45 +++-- .../__init__.py | 65 ------- jiant/tasks/qa.py | 63 ++----- jiant/tasks/tasks.py | 47 +++-- jiant/utils/retokenize.py | 37 +--- jiant/utils/tokenizers.py | 9 +- jiant/utils/utils.py | 1 + scripts/ccg/align_tags_to_bert.py | 30 ++- scripts/demo.with_docker.sh | 2 +- scripts/edgeprobing/exp_fns.sh | 14 +- setup.py | 4 +- ...est_huggingface_transformers_interface.py} | 20 +- tests/test_retokenize.py | 88 ++++----- tutorials/setup_tutorial.md | 2 +- 27 files changed, 395 insertions(+), 379 deletions(-) create mode 100644 jiant/huggingface_transformers_interface/__init__.py rename jiant/{pytorch_transformers_interface => huggingface_transformers_interface}/modules.py (79%) delete mode 100644 jiant/pytorch_transformers_interface/__init__.py rename tests/{test_pytorch_transformers_interface.py => test_huggingface_transformers_interface.py} (90%) diff --git a/README.md b/README.md index 8199a6f42..b231f5d38 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ A few things you might want to know about `jiant`: - `jiant` is configuration-driven. You can run an enormous variety of experiments by simply writing configuration files. Of course, if you need to add any major new features, you can also easily edit or extend the code. - `jiant` contains implementations of strong baselines for the [GLUE](https://gluebenchmark.com) and [SuperGLUE](https://super.gluebenchmark.com/) benchmarks, and it's the recommended starting point for work on these benchmarks. - `jiant` was developed at [the 2018 JSALT Workshop](https://www.clsp.jhu.edu/workshops/18-workshop/) by [the General-Purpose Sentence Representation Learning](https://jsalt18-sentence-repl.github.io/) team and is maintained by [the NYU Machine Learning for Language Lab](https://wp.nyu.edu/ml2/people/), with help from [many outside collaborators](https://github.com/nyu-mll/jiant/graphs/contributors) (especially Google AI Language's [Ian Tenney](https://ai.google/research/people/IanTenney)). -- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-transformers) of GPT, BERT, and XLNet. +- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace Transformers [implementations](https://github.com/huggingface/transformers) for GPT, BERT and other transformer models. - The name `jiant` doesn't mean much. The 'j' stands for JSALT. That's all the acronym we have. ## Getting Started diff --git a/environment.yml b/environment.yml index 3b4c2ca70..c6188e910 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: # for --remote_log functionality - google-cloud-logging==1.11.0 - # for some tokenizers in pytorch-transformers + # for some tokenizers in huggingface transformers - spacy==2.1 - ftfy @@ -39,9 +39,8 @@ dependencies: - sacremoses # Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_ - # pytorch_transformers > 1.0. These are the same package, though the name changed between + # transformers > 2.3.0. These are the same package, though the name changed between # these two versions. AllenNLP requires 0.6 to support the BertAdam optimizer, and jiant # directly requires 1.0 to support XLNet and WWM-BERT. # This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067 - - sacremoses - - pytorch-transformers==1.2.0 + - transformers==2.3.0 diff --git a/gcp/config/jiant_paths.sh b/gcp/config/jiant_paths.sh index 202c78855..6924a2a0d 100644 --- a/gcp/config/jiant_paths.sh +++ b/gcp/config/jiant_paths.sh @@ -13,7 +13,7 @@ export JIANT_PROJECT_PREFIX="$HOME/exp" # pre-downloaded ELMo models export ELMO_SRC_DIR="/nfs/jiant/share/elmo" # cache for BERT etc. models -export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/pytorch_transformers_cache" +export HUGGINGFACE_TRANSFORMERS_CACHE="/nfs/jiant/share/transformers_cache" # word embeddings export WORD_EMBS_FILE="/nfs/jiant/share/wiki-news-300d-1M.vec" diff --git a/gcp/kubernetes/templates/jiant_env.libsonnet b/gcp/kubernetes/templates/jiant_env.libsonnet index 5da1abbf5..f8388317f 100644 --- a/gcp/kubernetes/templates/jiant_env.libsonnet +++ b/gcp/kubernetes/templates/jiant_env.libsonnet @@ -21,7 +21,7 @@ # Path to ELMO cache. elmo_src_dir: "/nfs/jiant/share/elmo", # Path to BERT etc. model cache; should be writable by Kubernetes workers. - pytorch_transformers_cache_path: "/nfs/jiant/share/pytorch_transformers_cache", + transformers_cache_path: "/nfs/jiant/share/transformers_cache", # Path to default word embeddings file word_embs_file: "/nfs/jiant/share/wiki-news-300d-1M.vec", } diff --git a/gcp/kubernetes/templates/run_batch.jsonnet b/gcp/kubernetes/templates/run_batch.jsonnet index e32b6e399..d43daa04d 100644 --- a/gcp/kubernetes/templates/run_batch.jsonnet +++ b/gcp/kubernetes/templates/run_batch.jsonnet @@ -35,8 +35,8 @@ function(job_name, command, project_dir, uid, fsgroup, value: jiant_env.jiant_data_dir, }, { - name: "PYTORCH_PRETRAINED_BERT_CACHE", - value: jiant_env.pytorch_transformers_cache_path + name: "HUGGINGFACE_TRANSFORMERS_CACHE", + value: jiant_env.transformers_cache_path }, { name: "ELMO_SRC_DIR", diff --git a/gcp/set_up_workstation.sh b/gcp/set_up_workstation.sh index d8c82d4b5..d730d8bda 100755 --- a/gcp/set_up_workstation.sh +++ b/gcp/set_up_workstation.sh @@ -26,8 +26,8 @@ source /etc/profile.d/jiant_paths.sh if [ ! -d "${JIANT_PROJECT_PREFIX}" ]; then mkdir "${JIANT_PROJECT_PREFIX}" fi -if [ ! -d "${PYTORCH_PRETRAINED_BERT_CACHE}" ]; then - sudo mkdir -m 0777 "${PYTORCH_PRETRAINED_BERT_CACHE}" +if [ ! -d "${HUGGINGFACE_TRANSFORMERS_CACHE}" ]; then + sudo mkdir -m 0777 "${HUGGINGFACE_TRANSFORMERS_CACHE}" fi # Build the conda environment, and activate diff --git a/jiant/config/defaults.conf b/jiant/config/defaults.conf index 4dcf37f5f..1b6513ba0 100644 --- a/jiant/config/defaults.conf +++ b/jiant/config/defaults.conf @@ -244,20 +244,23 @@ input_module = "" // The word embedding or contextual word representation layer // - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's // ELMo, but not ELMo's LSTM layer hidden states. Use with // tokenizer = MosesTokenizer. - // - bert-base-uncased, etc.: Any BERT model from pytorch_transformers. + // - bert-base-uncased, etc.: Any BERT model from transformers. // - roberta-base / roberta-large / roberta-large-mnli: RoBERTa model from - // pytorch_transformers. + // transformers. + // - albert-base-v1 / albert-large-v1 / albert-xlarge-v1 / albert-xxlarge-v1 + // - albert-base-v2 / albert-large-v2 / albert-xlarge-v2 / albert-xxlarge-v2: + // ALBERT model from transformers. // - xlnet-base-cased / xlnet-large-cased: XLNet Model from - // pytorch_transformers. + // transformers. // - openai-gpt: The OpenAI GPT language model encoder from - // pytorch_transformers. - // - gpt2 / gpt2-medium / gpt2-large: The OpenAI GPT-2 language model encoder from - // pytorch_transformers. + // transformers. + // - gpt2 / gpt2-medium / gpt2-large/ gpt2-xl: The OpenAI GPT-2 language model + // encoder from transformers. // - transfo-xl-wt103: The Transformer-XL language model encoder from - // pytorch_transformers. + // transformers. // - xlm-mlm-en-2048: XLM english language model encoder from - // pytorch_transformers. - // Note: Any input_module from pytorch_transformers requires + // transformers. + // Note: Any input_module from transformers requires // tokenizer = ${input_module} or auto. tokenizer = auto // The name of the tokenizer, passed to the Task constructor for @@ -269,7 +272,7 @@ tokenizer = auto // The name of the tokenizer, passed to the Task constructor f // - MosesTokenizer: Our standard word tokenizer. (Support for // other NLTK tokenizers is pending.) // - bert-uncased-base, etc.: Use the tokenizer supplied with - // pytorch_transformers that corresponds the input_module. + // transformers that corresponds the input_module. // - SplitChars: Splits the input into individual characters. word_embs_file = ${WORD_EMBS_FILE} // Path to embeddings file, used with glove and fastText. @@ -284,21 +287,21 @@ d_char = 100 // Dimension of trained char embeddings. n_char_filters = 100 // Number of filters in trained char CNN. char_filter_sizes = "2,3,4,5" // Size of char CNN filters. -pytorch_transformers_output_mode = "none" // How to handle the embedding layer of the - // BERT/XLNet model: - // "none" or "top" returns only top-layer activation, - // "cat" returns top-layer concatenated with - // lexical layer, - // "only" returns only lexical layer, - // "mix" uses ELMo-style scalar mixing (with learned - // weights) across all layers. -pytorch_transformers_max_layer = -1 // Maximum layer to return from BERT etc. encoder. Layer 0 is - // wordpiece embeddings. pytorch_transformers_embeddings_mode - // will behave as if the is truncated at this layer, so 'top' - // will return this layer, and 'mix' will return a mix of all - // layers up to and including this layer. - // Set to -1 to use all layers. - // Used for probing experiments. +transformers_output_mode = "none" // How to handle the embedding layer of the + // BERT/XLNet model: + // "none" or "top" returns only top-layer activation, + // "cat" returns top-layer concatenated with + // lexical layer, + // "only" returns only lexical layer, + // "mix" uses ELMo-style scalar mixing (with learned + // weights) across all layers. +transformers_max_layer = -1 // Maximum layer to return from BERT etc. encoder. Layer 0 is + // wordpiece embeddings. transformers_embeddings_mode + // will behave as if the is truncated at this layer, so 'top' + // will return this layer, and 'mix' will return a mix of all + // layers up to and including this layer. + // Set to -1 to use all layers. + // Used for probing experiments. force_include_wsj_vocabulary = 0 // Set if using PTB parsing (grammar induction) task. Makes sure // to include WSJ vocabulary. @@ -365,7 +368,7 @@ pair_attn = 1 // If true, use attn in sentence-pair classification/regression t d_hid_attn = 512 // Post-attention LSTM state size. shared_pair_attn = 0 // If true, share pair_attn parameters across all tasks that use it. d_proj = 512 // Size of task-specific linear projection applied before before pooling. - // Disabled when fine-tuning pytorch_transformers models. + // Disabled when fine-tuning transformers models. pool_type = "auto" // Type of pooling to reduce sequences of vectors into a single vector. // Options: "auto", "max", "mean", "first", "final" // "auto" uses "first" for plain BERT (with no sent_enc), "final" for plain diff --git a/jiant/config/examples/stilts_example.conf b/jiant/config/examples/stilts_example.conf index dd2e6115e..e358f4f9b 100644 --- a/jiant/config/examples/stilts_example.conf +++ b/jiant/config/examples/stilts_example.conf @@ -18,7 +18,7 @@ batch_size = 24 write_preds = "val,test" //BERT-specific parameters -pytorch_transformers_output_mode = "top" +transformers_output_mode = "top" sep_embs_for_skip = 1 sent_enc = "none" classifier = log_reg // following BERT paper diff --git a/jiant/config/superglue_bert.conf b/jiant/config/superglue_bert.conf index ffbad2ace..0e8a6d553 100644 --- a/jiant/config/superglue_bert.conf +++ b/jiant/config/superglue_bert.conf @@ -10,7 +10,7 @@ max_seq_len = 256 // Mainly needed for MultiRC, to avoid over-truncating // Model settings input_module = "bert-large-cased" -pytorch_transformers_output_mode = "top" +transformers_output_mode = "top" pair_attn = 0 // shouldn't be needed but JIC s2s = { attention = none diff --git a/jiant/huggingface_transformers_interface/__init__.py b/jiant/huggingface_transformers_interface/__init__.py new file mode 100644 index 000000000..74f232161 --- /dev/null +++ b/jiant/huggingface_transformers_interface/__init__.py @@ -0,0 +1,56 @@ +""" +Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_ +transformers > 2.3 + +These are the same package, though the name changed between these two versions. AllenNLP requires +0.6 to support the BertAdam optimizer, and jiant directly requires 2.3. + +This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067 + +TODO: We do not support non-English versions of XLM, if you need them, add some code in XLMEmbedderModule +to prepare langs input to transformers.XLMModel +""" + +# All the supported input_module from huggingface transformers +# input_modules mapped to the same string share vocabulary +transformer_input_module_to_tokenizer_name = { + "bert-base-uncased": "bert_uncased", + "bert-large-uncased": "bert_uncased", + "bert-large-uncased-whole-word-masking": "bert_uncased", + "bert-large-uncased-whole-word-masking-finetuned-squad": "bert_uncased", + "bert-base-cased": "bert_cased", + "bert-large-cased": "bert_cased", + "bert-large-cased-whole-word-masking": "bert_cased", + "bert-large-cased-whole-word-masking-finetuned-squad": "bert_cased", + "bert-base-cased-finetuned-mrpc": "bert_cased", + "bert-base-multilingual-uncased": "bert_multilingual_uncased", + "bert-base-multilingual-cased": "bert_multilingual_cased", + "roberta-base": "roberta", + "roberta-large": "roberta", + "roberta-large-mnli": "roberta", + "xlnet-base-cased": "xlnet_cased", + "xlnet-large-cased": "xlnet_cased", + "openai-gpt": "openai_gpt", + "gpt2": "gpt2", + "gpt2-medium": "gpt2", + "gpt2-large": "gpt2", + "gpt2-xl": "gpt2", + "transfo-xl-wt103": "transfo_xl", + "xlm-mlm-en-2048": "xlm_en", + "albert-base-v1": "albert", + "albert-large-v1": "albert", + "albert-xlarge-v1": "albert", + "albert-xxlarge-v1": "albert", + "albert-base-v2": "albert", + "albert-large-v2": "albert", + "albert-xlarge-v2": "albert", + "albert-xxlarge-v2": "albert", +} + + +def input_module_uses_transformers(input_module): + return input_module in transformer_input_module_to_tokenizer_name + + +def input_module_tokenizer_name(input_module): + return transformer_input_module_to_tokenizer_name[input_module] diff --git a/jiant/pytorch_transformers_interface/modules.py b/jiant/huggingface_transformers_interface/modules.py similarity index 79% rename from jiant/pytorch_transformers_interface/modules.py rename to jiant/huggingface_transformers_interface/modules.py index 65d903602..d8957e0cd 100644 --- a/jiant/pytorch_transformers_interface/modules.py +++ b/jiant/huggingface_transformers_interface/modules.py @@ -7,30 +7,29 @@ import torch.nn as nn from allennlp.modules import scalar_mix -import pytorch_transformers +import transformers from jiant.utils.options import parse_task_list_arg from jiant.utils import utils -from jiant.pytorch_transformers_interface import input_module_tokenizer_name +from jiant.huggingface_transformers_interface import input_module_tokenizer_name -class PytorchTransformersEmbedderModule(nn.Module): - """ Shared code for pytorch_transformers wrappers. +class HuggingfaceTransformersEmbedderModule(nn.Module): + """ Shared code for transformers wrappers. Subclasses share a good deal of code, but have a number of subtle differences due to different - APIs from pytorch_transfromers. + APIs from transfromers. """ def __init__(self, args): - super(PytorchTransformersEmbedderModule, self).__init__() + super(HuggingfaceTransformersEmbedderModule, self).__init__() self.cache_dir = os.getenv( - "PYTORCH_PRETRAINED_BERT_CACHE", - os.path.join(args.exp_dir, "pytorch_transformers_cache"), + "HUGGINGFACE_TRANSFORMERS_CACHE", os.path.join(args.exp_dir, "transformers_cache") ) utils.maybe_make_dir(self.cache_dir) - self.output_mode = args.pytorch_transformers_output_mode + self.output_mode = args.transformers_output_mode self.input_module = args.input_module self.max_pos = None self.tokenizer_required = input_module_tokenizer_name(args.input_module) @@ -51,8 +50,8 @@ def parameter_setup(self, args): param.requires_grad = bool(args.transfer_paradigm == "finetune") self.num_layers = self.model.config.num_hidden_layers - if args.pytorch_transformers_max_layer >= 0: - self.max_layer = args.pytorch_transformers_max_layer + if args.transformers_max_layer >= 0: + self.max_layer = args.transformers_max_layer assert self.max_layer <= self.num_layers else: self.max_layer = self.num_layers @@ -70,7 +69,7 @@ def parameter_setup(self, args): if self.output_mode == "mix": if args.transfer_paradigm == "frozen": log.warning( - "NOTE: pytorch_transformers_output_mode='mix', so scalar " + "NOTE: transformers_output_mode='mix', so scalar " "mixing weights will be fine-tuned even if BERT " "model is frozen." ) @@ -78,7 +77,7 @@ def parameter_setup(self, args): # scalars. See the ELMo implementation here: # https://github.com/allenai/allennlp/blob/master/allennlp/modules/elmo.py#L115 assert len(parse_task_list_arg(args.target_tasks)) <= 1, ( - "pytorch_transformers_output_mode='mix' only supports a single set of " + "transformers_output_mode='mix' only supports a single set of " "scalars (but if you need this feature, see the TODO in " "the code!)" ) @@ -86,7 +85,7 @@ def parameter_setup(self, args): self.scalar_mix = scalar_mix.ScalarMix(self.max_layer + 1, do_layer_norm=False) def correct_sent_indexing(self, sent): - """ Correct id difference between pytorch_transformers and AllenNLP. + """ Correct id difference between transformers and AllenNLP. The AllenNLP indexer adds'@@UNKNOWN@@' token as index 1, and '@@PADDING@@' as index 0 args: @@ -99,14 +98,14 @@ def correct_sent_indexing(self, sent): """ assert ( self.tokenizer_required in sent - ), "pytorch_transformers cannot find correcpondingly tokenized input" + ), "transformers cannot find correcpondingly tokenized input" ids = sent[self.tokenizer_required] input_mask = (ids != 0).long() pad_mask = (ids == 0).long() - # map AllenNLP @@PADDING@@ to _pad_id in specific pytorch_transformer + # map AllenNLP @@PADDING@@ to _pad_id in specific transformer vocab unk_mask = (ids == 1).long() - # map AllenNLP @@UNKNOWN@@ to _unk_id in specific pytorch_transformer + # map AllenNLP @@UNKNOWN@@ to _unk_id in specific transformer vocab valid_mask = (ids > 1).long() # shift ordinary indexes by 2 to match pretrained token embedding indexes if self._unk_id is not None: @@ -115,7 +114,7 @@ def correct_sent_indexing(self, sent): ids = (ids - 2) * valid_mask + self._pad_id * pad_mask assert ( unk_mask == 0 - ).all(), "out-of-vocabulary token found in the input, but _unk_id of pytorch_transformers model is not specified" + ).all(), "out-of-vocabulary token found in the input, but _unk_id of transformers model is not specified" if self.max_pos is not None: assert ( ids.size()[-1] <= self.max_pos @@ -126,7 +125,7 @@ def correct_sent_indexing(self, sent): def prepare_output(self, lex_seq, hidden_states, input_mask): """ - Convert the output of the pytorch_transformers module to a vector sequence as expected by jiant. + Convert the output of the transformers module to a vector sequence as expected by jiant. args: lex_seq: The sequence of input word embeddings as a tensor (batch_size, sequence_length, hidden_size). @@ -226,7 +225,7 @@ def apply_lm_boundary_tokens(s1, get_offset=False): raise NotImplementedError def forward(self, sent, task_name): - """ Run pytorch_transformers model and return output representation + """ Run transformers model and return output representation This function should be implmented in subclasses. args: @@ -252,19 +251,19 @@ def get_pretrained_lm_head(self): raise NotImplementedError -class BertEmbedderModule(PytorchTransformersEmbedderModule): +class BertEmbedderModule(HuggingfaceTransformersEmbedderModule): """ Wrapper for BERT module to fit into jiant APIs. - Check PytorchTransformersEmbedderModule for function definitions """ + Check HuggingfaceTransformersEmbedderModule for function definitions """ def __init__(self, args): super(BertEmbedderModule, self).__init__(args) - self.model = pytorch_transformers.BertModel.from_pretrained( + self.model = transformers.BertModel.from_pretrained( args.input_module, cache_dir=self.cache_dir, output_hidden_states=True ) self.max_pos = self.model.config.max_position_embeddings - self.tokenizer = pytorch_transformers.BertTokenizer.from_pretrained( + self.tokenizer = transformers.BertTokenizer.from_pretrained( args.input_module, cache_dir=self.cache_dir, do_lower_case="uncased" in args.tokenizer ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. self._sep_id = self.tokenizer.convert_tokens_to_ids("[SEP]") @@ -301,7 +300,7 @@ def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> tor return self.prepare_output(lex_seq, hidden_states, input_mask) def get_pretrained_lm_head(self): - model_with_lm_head = pytorch_transformers.BertForMaskedLM.from_pretrained( + model_with_lm_head = transformers.BertForMaskedLM.from_pretrained( self.input_module, cache_dir=self.cache_dir ) lm_head = model_with_lm_head.cls @@ -309,19 +308,19 @@ def get_pretrained_lm_head(self): return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1)) -class RobertaEmbedderModule(PytorchTransformersEmbedderModule): +class RobertaEmbedderModule(HuggingfaceTransformersEmbedderModule): """ Wrapper for RoBERTa module to fit into jiant APIs. - Check PytorchTransformersEmbedderModule for function definitions """ + Check HuggingfaceTransformersEmbedderModule for function definitions """ def __init__(self, args): super(RobertaEmbedderModule, self).__init__(args) - self.model = pytorch_transformers.RobertaModel.from_pretrained( + self.model = transformers.RobertaModel.from_pretrained( args.input_module, cache_dir=self.cache_dir, output_hidden_states=True ) self.max_pos = self.model.config.max_position_embeddings - self.tokenizer = pytorch_transformers.RobertaTokenizer.from_pretrained( + self.tokenizer = transformers.RobertaTokenizer.from_pretrained( args.input_module, cache_dir=self.cache_dir ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. self._sep_id = self.tokenizer.convert_tokens_to_ids("") @@ -355,7 +354,7 @@ def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> tor return self.prepare_output(lex_seq, hidden_states, input_mask) def get_pretrained_lm_head(self): - model_with_lm_head = pytorch_transformers.RobertaForMaskedLM.from_pretrained( + model_with_lm_head = transformers.RobertaForMaskedLM.from_pretrained( self.input_module, cache_dir=self.cache_dir ) lm_head = model_with_lm_head.lm_head @@ -363,18 +362,75 @@ def get_pretrained_lm_head(self): return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1)) -class XLNetEmbedderModule(PytorchTransformersEmbedderModule): +class AlbertEmbedderModule(HuggingfaceTransformersEmbedderModule): + """ Wrapper for ALBERT module to fit into jiant APIs. + Check HuggingfaceTransformersEmbedderModule for function definitions """ + + def __init__(self, args): + super(AlbertEmbedderModule, self).__init__(args) + + self.model = transformers.AlbertModel.from_pretrained( + args.input_module, cache_dir=self.cache_dir, output_hidden_states=True + ) + self.max_pos = self.model.config.max_position_embeddings + + self.tokenizer = transformers.AlbertTokenizer.from_pretrained( + args.input_module, cache_dir=self.cache_dir + ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. + self._sep_id = self.tokenizer.convert_tokens_to_ids("[SEP]") + self._cls_id = self.tokenizer.convert_tokens_to_ids("[CLS]") + self._pad_id = self.tokenizer.convert_tokens_to_ids("") + self._unk_id = self.tokenizer.convert_tokens_to_ids("") + + self.parameter_setup(args) + + @staticmethod + def apply_boundary_tokens(s1, s2=None, get_offset=False): + # ALBERT-style boundary token padding on string token sequences + if s2: + s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"] + if get_offset: + return s, 1, len(s1) + 2 + else: + s = ["[CLS]"] + s1 + ["[SEP]"] + if get_offset: + return s, 1 + return s + + def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> torch.FloatTensor: + ids, input_mask = self.correct_sent_indexing(sent) + hidden_states, lex_seq = [], None + if self.output_mode not in ["none", "top"]: + lex_seq = self.model.embeddings.word_embeddings(ids) + lex_seq = self.model.embeddings.LayerNorm(lex_seq) + if self.output_mode != "only": + token_types = self.get_seg_ids(ids, input_mask) + _, output_pooled_vec, hidden_states = self.model( + ids, token_type_ids=token_types, attention_mask=input_mask + ) + return self.prepare_output(lex_seq, hidden_states, input_mask) + + def get_pretrained_lm_head(self): + model_with_lm_head = transformers.AlbertForMaskedLM.from_pretrained( + self.input_module, cache_dir=self.cache_dir + ) + lm_head = model_with_lm_head.predictions + lm_head.decoder.weight = self.model.embeddings.word_embeddings.weight + return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1)) + + +class XLNetEmbedderModule(HuggingfaceTransformersEmbedderModule): """ Wrapper for XLNet module to fit into jiant APIs. - Check PytorchTransformersEmbedderModule for function definitions """ + Check HuggingfaceTransformersEmbedderModule for function definitions """ def __init__(self, args): super(XLNetEmbedderModule, self).__init__(args) - self.model = pytorch_transformers.XLNetModel.from_pretrained( + self.model = transformers.XLNetModel.from_pretrained( args.input_module, cache_dir=self.cache_dir, output_hidden_states=True ) - self.tokenizer = pytorch_transformers.XLNetTokenizer.from_pretrained( + self.tokenizer = transformers.XLNetTokenizer.from_pretrained( args.input_module, cache_dir=self.cache_dir, do_lower_case="uncased" in args.tokenizer ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. self._sep_id = self.tokenizer.convert_tokens_to_ids("") @@ -385,8 +441,8 @@ def __init__(self, args): self.parameter_setup(args) # Segment IDs for CLS and SEP tokens. Unlike in BERT, these aren't part of the usual 0/1 - # input segments. Standard constants reused from pytorch_transformers. They aren't actually - # used within the pytorch_transformers code, so we're reproducing them here in case they're + # input segments. Standard constants reused from transformers. They aren't actually + # used within the transformers code, so we're reproducing them here in case they're # removed in a later cleanup. self._SEG_ID_CLS = 2 self._SEG_ID_SEP = 3 @@ -417,7 +473,7 @@ def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> tor return self.prepare_output(lex_seq, hidden_states, input_mask) def get_pretrained_lm_head(self, args): - model_with_lm_head = pytorch_transformers.XLNetLMHeadModel.from_pretrained( + model_with_lm_head = transformers.XLNetLMHeadModel.from_pretrained( self.input_module, cache_dir=self.cache_dir ) lm_head = model_with_lm_head.lm_loss @@ -425,19 +481,19 @@ def get_pretrained_lm_head(self, args): return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1)) -class OpenAIGPTEmbedderModule(PytorchTransformersEmbedderModule): +class OpenAIGPTEmbedderModule(HuggingfaceTransformersEmbedderModule): """ Wrapper for OpenAI GPT module to fit into jiant APIs. - Check PytorchTransformersEmbedderModule for function definitions """ + Check HuggingfaceTransformersEmbedderModule for function definitions """ def __init__(self, args): super(OpenAIGPTEmbedderModule, self).__init__(args) - self.model = pytorch_transformers.OpenAIGPTModel.from_pretrained( + self.model = transformers.OpenAIGPTModel.from_pretrained( args.input_module, cache_dir=self.cache_dir, output_hidden_states=True ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. self.max_pos = self.model.config.n_positions - self.tokenizer = pytorch_transformers.OpenAIGPTTokenizer.from_pretrained( + self.tokenizer = transformers.OpenAIGPTTokenizer.from_pretrained( args.input_module, cache_dir=self.cache_dir ) self._pad_id = self.tokenizer.convert_tokens_to_ids("\n") @@ -480,7 +536,7 @@ def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> tor return self.prepare_output(lex_seq, hidden_states, input_mask) def get_pretrained_lm_head(self, args): - model_with_lm_head = pytorch_transformers.OpenAIGPTLMHeadModel.from_pretrained( + model_with_lm_head = transformers.OpenAIGPTLMHeadModel.from_pretrained( self.input_module, cache_dir=self.cache_dir ) lm_head = model_with_lm_head.lm_head @@ -488,19 +544,19 @@ def get_pretrained_lm_head(self, args): return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1)) -class GPT2EmbedderModule(PytorchTransformersEmbedderModule): +class GPT2EmbedderModule(HuggingfaceTransformersEmbedderModule): """ Wrapper for GPT-2 module to fit into jiant APIs. - Check PytorchTransformersEmbedderModule for function definitions """ + Check HuggingfaceTransformersEmbedderModule for function definitions """ def __init__(self, args): super(GPT2EmbedderModule, self).__init__(args) - self.model = pytorch_transformers.GPT2Model.from_pretrained( + self.model = transformers.GPT2Model.from_pretrained( args.input_module, cache_dir=self.cache_dir, output_hidden_states=True ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. self.max_pos = self.model.config.n_positions - self.tokenizer = pytorch_transformers.GPT2Tokenizer.from_pretrained( + self.tokenizer = transformers.GPT2Tokenizer.from_pretrained( args.input_module, cache_dir=self.cache_dir ) self._pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>") @@ -542,7 +598,7 @@ def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> tor return self.prepare_output(lex_seq, hidden_states, input_mask) def get_pretrained_lm_head(self): - model_with_lm_head = pytorch_transformers.GPT2LMHeadModel.from_pretrained( + model_with_lm_head = transformers.GPT2LMHeadModel.from_pretrained( self.input_module, cache_dir=self.cache_dir ) lm_head = model_with_lm_head.lm_head @@ -550,18 +606,18 @@ def get_pretrained_lm_head(self): return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1)) -class TransfoXLEmbedderModule(PytorchTransformersEmbedderModule): +class TransfoXLEmbedderModule(HuggingfaceTransformersEmbedderModule): """ Wrapper for Transformer-XL module to fit into jiant APIs. - Check PytorchTransformersEmbedderModule for function definitions """ + Check HuggingfaceTransformersEmbedderModule for function definitions """ def __init__(self, args): super(TransfoXLEmbedderModule, self).__init__(args) - self.model = pytorch_transformers.TransfoXLModel.from_pretrained( + self.model = transformers.TransfoXLModel.from_pretrained( args.input_module, cache_dir=self.cache_dir, output_hidden_states=True ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. - self.tokenizer = pytorch_transformers.TransfoXLTokenizer.from_pretrained( + self.tokenizer = transformers.TransfoXLTokenizer.from_pretrained( args.input_module, cache_dir=self.cache_dir ) self._pad_id = self.tokenizer.convert_tokens_to_ids("") @@ -604,8 +660,8 @@ def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> tor return self.prepare_output(lex_seq, hidden_states, input_mask) def get_pretrained_lm_head(self): - # Note: pytorch_transformers didn't implement TransfoXLLMHeadModel, use this in eval only - model_with_lm_head = pytorch_transformers.TransfoXLLMHeadModel.from_pretrained( + # Note: transformers didn't implement TransfoXLLMHeadModel, use this in eval only + model_with_lm_head = transformers.TransfoXLLMHeadModel.from_pretrained( self.input_module, cache_dir=self.cache_dir ) lm_head = model_with_lm_head.crit @@ -620,19 +676,19 @@ def get_pretrained_lm_head(self): return lm_head -class XLMEmbedderModule(PytorchTransformersEmbedderModule): +class XLMEmbedderModule(HuggingfaceTransformersEmbedderModule): """ Wrapper for XLM module to fit into jiant APIs. - Check PytorchTransformersEmbedderModule for function definitions """ + Check HuggingfaceTransformersEmbedderModule for function definitions """ def __init__(self, args): super(XLMEmbedderModule, self).__init__(args) - self.model = pytorch_transformers.XLMModel.from_pretrained( + self.model = transformers.XLMModel.from_pretrained( args.input_module, cache_dir=self.cache_dir, output_hidden_states=True ) # TODO: Speed things up slightly by reusing the previously-loaded tokenizer. self.max_pos = self.model.config.max_position_embeddings - self.tokenizer = pytorch_transformers.XLMTokenizer.from_pretrained( + self.tokenizer = transformers.XLMTokenizer.from_pretrained( args.input_module, cache_dir=self.cache_dir ) self._unk_id = self.tokenizer.convert_tokens_to_ids("") @@ -663,7 +719,7 @@ def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> tor return self.prepare_output(lex_seq, hidden_states, input_mask) def get_pretrained_lm_head(self): - model_with_lm_head = pytorch_transformers.XLMWithLMHeadModel.from_pretrained( + model_with_lm_head = transformers.XLMWithLMHeadModel.from_pretrained( self.input_module, cache_dir=self.cache_dir ) lm_head = model_with_lm_head.pred_layer diff --git a/jiant/models.py b/jiant/models.py index 78355a39a..41295a22d 100644 --- a/jiant/models.py +++ b/jiant/models.py @@ -43,7 +43,7 @@ from jiant.modules.prpn.PRPN import PRPN from jiant.modules.seq2seq_decoder import Seq2SeqDecoder from jiant.modules.span_modules import SpanClassifierModule -from jiant.pytorch_transformers_interface import input_module_uses_pytorch_transformers +from jiant.huggingface_transformers_interface import input_module_uses_transformers from jiant.tasks.edge_probing import EdgeProbingTask from jiant.tasks.lm import LanguageModelingTask from jiant.tasks.lm_parsing import LanguageModelingParsingTask @@ -237,43 +237,49 @@ def build_model(args, vocab, pretrained_embs, tasks, cuda_devices): # Build embeddings. cove_layer = None if args.input_module.startswith("bert-"): - from jiant.pytorch_transformers_interface.modules import BertEmbedderModule + from jiant.huggingface_transformers_interface.modules import BertEmbedderModule log.info(f"Using BERT model ({args.input_module}).") embedder = BertEmbedderModule(args) d_emb = embedder.get_output_dim() elif args.input_module.startswith("roberta-"): - from jiant.pytorch_transformers_interface.modules import RobertaEmbedderModule + from jiant.huggingface_transformers_interface.modules import RobertaEmbedderModule log.info(f"Using RoBERTa model ({args.input_module}).") embedder = RobertaEmbedderModule(args) d_emb = embedder.get_output_dim() + elif args.input_module.startswith("albert-"): + from jiant.huggingface_transformers_interface.modules import AlbertEmbedderModule + + log.info(f"Using ALBERT model ({args.input_module}).") + embedder = AlbertEmbedderModule(args) + d_emb = embedder.get_output_dim() elif args.input_module.startswith("xlnet-"): - from jiant.pytorch_transformers_interface.modules import XLNetEmbedderModule + from jiant.huggingface_transformers_interface.modules import XLNetEmbedderModule log.info(f"Using XLNet model ({args.input_module}).") embedder = XLNetEmbedderModule(args) d_emb = embedder.get_output_dim() elif args.input_module.startswith("openai-gpt"): - from jiant.pytorch_transformers_interface.modules import OpenAIGPTEmbedderModule + from jiant.huggingface_transformers_interface.modules import OpenAIGPTEmbedderModule log.info(f"Using OpenAI GPT model ({args.input_module}).") embedder = OpenAIGPTEmbedderModule(args) d_emb = embedder.get_output_dim() elif args.input_module.startswith("gpt2"): - from jiant.pytorch_transformers_interface.modules import GPT2EmbedderModule + from jiant.huggingface_transformers_interface.modules import GPT2EmbedderModule log.info(f"Using GPT-2 model ({args.input_module}).") embedder = GPT2EmbedderModule(args) d_emb = embedder.get_output_dim() elif args.input_module.startswith("transfo-xl-"): - from jiant.pytorch_transformers_interface.modules import TransfoXLEmbedderModule + from jiant.huggingface_transformers_interface.modules import TransfoXLEmbedderModule log.info(f"Using Transformer-XL model ({args.input_module}).") embedder = TransfoXLEmbedderModule(args) d_emb = embedder.get_output_dim() elif args.input_module.startswith("xlm-"): - from jiant.pytorch_transformers_interface.modules import XLMEmbedderModule + from jiant.huggingface_transformers_interface.modules import XLMEmbedderModule log.info(f"Using XLM model ({args.input_module}).") embedder = XLMEmbedderModule(args) @@ -343,7 +349,7 @@ def build_embeddings(args, vocab, tasks, pretrained_embs=None): d_word = args.d_word word_embs = nn.Embedding(n_token_vocab, d_word).weight else: - assert input_module_uses_pytorch_transformers(args.input_module) or args.input_module in [ + assert input_module_uses_transformers(args.input_module) or args.input_module in [ "elmo", "elmo-chars-only", ], f"'{args.input_module}' is not a valid value for input_module." @@ -544,10 +550,10 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg setattr(model, "%s_hid2voc" % task.name, hid2voc) setattr(model, "%s_mdl" % task.name, hid2voc) elif isinstance(task, LanguageModelingTask): - assert not input_module_uses_pytorch_transformers(args.input_module), ( - "our LM Task does not support pytorch_transformers, if you need them, try to update", + assert not input_module_uses_transformers(args.input_module), ( + "our LM Task does not support transformers, if you need them, try to update", "corresponding parts of the code. You may find get_pretrained_lm_head and", - "apply_lm_boundary_tokens from pytorch_transformer_interface.module useful,", + "apply_lm_boundary_tokens from huggingface_transformers_interface.module useful,", "do check if they are working correctly though.", ) d_sent = args.d_hid + (args.skip_embs * d_emb) @@ -817,7 +823,7 @@ def __init__(self, args, sent_encoder, vocab, cuda_devices): self.uses_pair_embedding = input_module_uses_pair_embedding(args.input_module) self.uses_mirrored_pair = input_module_uses_mirrored_pair(args.input_module) self.project_before_pooling = not ( - input_module_uses_pytorch_transformers(args.input_module) + input_module_uses_transformers(args.input_module) and args.transfer_paradigm == "finetune" ) # Rough heuristic. TODO: Make this directly user-controllable. self.sep_embs_for_skip = args.sep_embs_for_skip @@ -1344,9 +1350,9 @@ def input_module_uses_pair_embedding(input_module): running on pair tasks, like what GPT / BERT do on MNLI. It seems redundant now, but it allows us to load similar models from other sources later on """ - from jiant.pytorch_transformers_interface import input_module_uses_pytorch_transformers + from jiant.huggingface_transformers_interface import input_module_uses_transformers - return input_module_uses_pytorch_transformers(input_module) + return input_module_uses_transformers(input_module) def input_module_uses_mirrored_pair(input_module): diff --git a/jiant/modules/sentence_encoder.py b/jiant/modules/sentence_encoder.py index 1b6ad750c..e8d99d0e7 100644 --- a/jiant/modules/sentence_encoder.py +++ b/jiant/modules/sentence_encoder.py @@ -11,7 +11,6 @@ from allennlp.nn import InitializerApplicator, util from allennlp.modules import Highway, TimeDistributed -from jiant.pytorch_transformers_interface.modules import PytorchTransformersEmbedderModule from jiant.tasks.tasks import PairClassificationTask, PairRegressionTask from jiant.utils import utils from jiant.modules.simple_modules import NullPhraseLayer @@ -84,7 +83,7 @@ def forward(self, sent, task, reset=True): self.reset_states() # General sentence embeddings (for sentence encoder). - # Make sent_mask first, pytorch_transformers text_field_embedder will change the token index + # Make sent_mask first, transformers text_field_embedder will change the token index sent_mask = util.get_text_field_mask(sent).float() # Skip this for probing runs that don't need it. if not isinstance(self._phrase_layer, NullPhraseLayer): diff --git a/jiant/preprocess.py b/jiant/preprocess.py index dcefb98f8..072f35253 100644 --- a/jiant/preprocess.py +++ b/jiant/preprocess.py @@ -28,13 +28,14 @@ TokenCharactersIndexer, ) -from jiant.pytorch_transformers_interface import ( - input_module_uses_pytorch_transformers, +from jiant.huggingface_transformers_interface import ( + input_module_uses_transformers, input_module_tokenizer_name, ) -from pytorch_transformers import ( +from transformers import ( BertTokenizer, RobertaTokenizer, + AlbertTokenizer, XLNetTokenizer, OpenAIGPTTokenizer, GPT2Tokenizer, @@ -263,9 +264,9 @@ def _build_vocab(args: config.Params, tasks: List[Task], vocab_path: str): if args.force_include_wsj_vocabulary: # Add WSJ full vocabulary for PTB F1 parsing tasks. add_wsj_vocab(vocab, args.data_dir) - if input_module_uses_pytorch_transformers(args.input_module): - # Add pre-computed vocabulary of corresponding tokenizer for pytorch_transformers models. - add_pytorch_transformers_vocab(vocab, args.tokenizer) + if input_module_uses_transformers(args.input_module): + # Add pre-computed vocabulary of corresponding tokenizer for transformers models. + add_transformers_vocab(vocab, args.tokenizer) vocab.save_to_files(vocab_path) log.info("\tSaved vocab to %s", vocab_path) @@ -288,13 +289,13 @@ def build_indexers(args): " you are using args.tokenizer = {args.tokenizer}" ) - if input_module_uses_pytorch_transformers(args.input_module): + if input_module_uses_transformers(args.input_module): assert ( not indexers - ), "pytorch_transformers modules like BERT/XLNet are not supported alongside other " + ), "transformers modules like BERT/XLNet are not supported alongside other " "indexers due to tokenization." assert args.tokenizer == args.input_module, ( - "pytorch_transformers models use custom tokenization for each model, so tokenizer " + "transformers models use custom tokenization for each model, so tokenizer " "must match the specified model." ) tokenizer_name = input_module_tokenizer_name(args.input_module) @@ -671,8 +672,8 @@ def add_task_label_vocab(vocab, task): vocab.add_token_to_namespace(label, namespace) -def add_pytorch_transformers_vocab(vocab, tokenizer_name): - """Add vocabulary from tokenizers in pytorch_transformers for use with pre-tokenized data. +def add_transformers_vocab(vocab, tokenizer_name): + """Add vocabulary from tokenizers in transformers for use with pre-tokenized data. These tokenizers have a convert_tokens_to_ids method, but this doesn't do anything special, so we can just use the standard indexers. @@ -683,6 +684,8 @@ def add_pytorch_transformers_vocab(vocab, tokenizer_name): tokenizer = BertTokenizer.from_pretrained(tokenizer_name, do_lower_case=do_lower_case) elif tokenizer_name.startswith("roberta-"): tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name) + elif tokenizer_name.startswith("albert-"): + tokenizer = AlbertTokenizer.from_pretrained(tokenizer_name) elif tokenizer_name.startswith("xlnet-"): tokenizer = XLNetTokenizer.from_pretrained(tokenizer_name, do_lower_case=do_lower_case) elif tokenizer_name.startswith("openai-gpt"): @@ -709,7 +712,7 @@ def add_pytorch_transformers_vocab(vocab, tokenizer_name): # do not use tokenizer.vocab_size, it does not include newly added token ordered_vocab = tokenizer.convert_ids_to_tokens(range(vocab_size)) - log.info("Added pytorch_transformers vocab (%s): %d tokens", tokenizer_name, len(ordered_vocab)) + log.info("Added transformers vocab (%s): %d tokens", tokenizer_name, len(ordered_vocab)) for word in ordered_vocab: vocab.add_token_to_namespace(word, input_module_tokenizer_name(tokenizer_name)) @@ -745,34 +748,38 @@ def __init__(self, args): lm_boundary_token_fn = None if args.input_module.startswith("bert-"): - from jiant.pytorch_transformers_interface.modules import BertEmbedderModule + from jiant.huggingface_transformers_interface.modules import BertEmbedderModule boundary_token_fn = BertEmbedderModule.apply_boundary_tokens elif args.input_module.startswith("roberta-"): - from jiant.pytorch_transformers_interface.modules import RobertaEmbedderModule + from jiant.huggingface_transformers_interface.modules import RobertaEmbedderModule boundary_token_fn = RobertaEmbedderModule.apply_boundary_tokens + elif args.input_module.startswith("albert-"): + from jiant.huggingface_transformers_interface.modules import AlbertEmbedderModule + + boundary_token_fn = AlbertEmbedderModule.apply_boundary_tokens elif args.input_module.startswith("xlnet-"): - from jiant.pytorch_transformers_interface.modules import XLNetEmbedderModule + from jiant.huggingface_transformers_interface.modules import XLNetEmbedderModule boundary_token_fn = XLNetEmbedderModule.apply_boundary_tokens elif args.input_module.startswith("openai-gpt"): - from jiant.pytorch_transformers_interface.modules import OpenAIGPTEmbedderModule + from jiant.huggingface_transformers_interface.modules import OpenAIGPTEmbedderModule boundary_token_fn = OpenAIGPTEmbedderModule.apply_boundary_tokens lm_boundary_token_fn = OpenAIGPTEmbedderModule.apply_lm_boundary_tokens elif args.input_module.startswith("gpt2"): - from jiant.pytorch_transformers_interface.modules import GPT2EmbedderModule + from jiant.huggingface_transformers_interface.modules import GPT2EmbedderModule boundary_token_fn = GPT2EmbedderModule.apply_boundary_tokens lm_boundary_token_fn = GPT2EmbedderModule.apply_lm_boundary_tokens elif args.input_module.startswith("transfo-xl-"): - from jiant.pytorch_transformers_interface.modules import TransfoXLEmbedderModule + from jiant.huggingface_transformers_interface.modules import TransfoXLEmbedderModule boundary_token_fn = TransfoXLEmbedderModule.apply_boundary_tokens lm_boundary_token_fn = TransfoXLEmbedderModule.apply_lm_boundary_tokens elif args.input_module.startswith("xlm-"): - from jiant.pytorch_transformers_interface.modules import XLMEmbedderModule + from jiant.huggingface_transformers_interface.modules import XLMEmbedderModule boundary_token_fn = XLMEmbedderModule.apply_boundary_tokens else: diff --git a/jiant/pytorch_transformers_interface/__init__.py b/jiant/pytorch_transformers_interface/__init__.py deleted file mode 100644 index e198f9c8e..000000000 --- a/jiant/pytorch_transformers_interface/__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_ -pytorch_transformers > 1.0 - -These are the same package, though the name changed between these two versions. AllenNLP requires -0.6 to support the BertAdam optimizer, and jiant directly requires 1.0 to support XLNet and -WWM-BERT. - -This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067 - -Note: huggingface forgot to upload bert-large-uncased-whole-word-masking-finetuned-squad -When they fix it, remove this note -https://github.com/huggingface/pytorch-transformers/issues/763 - -TODO: We do not support non-English versions of XLM, if you need them, add some code in XLMEmbedderModule -to prepare langs input to pytorch_transformers.XLMModel -""" - - -def input_module_uses_pytorch_transformers(input_module): - return ( - input_module.startswith("bert-") - or input_module.startswith("roberta-") - or input_module.startswith("xlnet-") - or input_module.startswith("gpt2") - or input_module.startswith("openai-gpt") - or input_module.startswith("transfo-xl-") - or input_module.startswith("xlm-") - ) - - -def input_module_tokenizer_name(input_module): - input_module_to_pretokenized = { - "bert-base-uncased": "bert_uncased", - "bert-large-uncased": "bert_uncased", - "bert-large-uncased-whole-word-masking": "bert_uncased", - "bert-large-uncased-whole-word-masking-finetuned-squad": "bert_uncased", - "bert-base-cased": "bert_cased", - "bert-large-cased": "bert_cased", - "bert-large-cased-whole-word-masking": "bert_cased", - "bert-large-cased-whole-word-masking-finetuned-squad": "bert_cased", - "bert-base-cased-finetuned-mrpc": "bert_cased", - "bert-base-multilingual-uncased": "bert_multilingual_uncased", - "bert-base-multilingual-cased": "bert_multilingual_cased", - "bert-base-chinese": "bert_chinese", - "bert-base-german-cased": "bert_german_cased", - "roberta-base": "roberta", - "roberta-large": "roberta", - "roberta-large-mnli": "roberta", - "xlnet-base-cased": "xlnet_cased", - "xlnet-large-cased": "xlnet_cased", - "openai-gpt": "openai_gpt", - "gpt2": "gpt2", - "gpt2-medium": "gpt2", - "gpt2-large": "gpt2", - "transfo-xl-wt103": "transfo_xl", - "xlm-mlm-en-2048": "xlm_en", - "xlm-mlm-ende-1024": "xlm_ende", - "xlm-mlm-enfr-1024": "xlm_enfr", - "xlm-clm-enfr-1024": "xlm_enfr", - "xlm-mlm-enro-1024": "xlm_enro", - "xlm-mlm-tlm-xnli15-1024": "xlm_xnli", - "xlm-mlm-xnli15-1024": "xlm_xnli", - } - return input_module_to_pretokenized[input_module] diff --git a/jiant/tasks/qa.py b/jiant/tasks/qa.py index c0efb345d..c94ad0bf7 100644 --- a/jiant/tasks/qa.py +++ b/jiant/tasks/qa.py @@ -27,11 +27,7 @@ from jiant.tasks.tasks import Task, SpanPredictionTask, MultipleChoiceTask from jiant.tasks.tasks import sentence_to_text_field from jiant.tasks.registry import register_task -from ..utils.retokenize import ( - space_tokenize_with_spans, - find_space_token_span, - create_tokenization_alignment, -) +from ..utils.retokenize import space_tokenize_with_spans, find_space_token_span, get_aligner_fn @register_task("multirc", rel_path="MultiRC/") @@ -232,20 +228,6 @@ def get_split_text(self, split: str): def load_data_for_path(self, path, split): """ Load data """ - def tokenize_preserve_placeholder(sent, max_ent_length): - """ Tokenize questions while preserving @placeholder token """ - sent_parts = sent.split("@placeholder") - assert len(sent_parts) == 2 - placeholder_loc = len( - tokenize_and_truncate( - self.tokenizer_name, sent_parts[0], self.max_seq_len - max_ent_length - ) - ) - sent_tok = tokenize_and_truncate( - self.tokenizer_name, sent, self.max_seq_len - max_ent_length - ) - return sent_tok[:placeholder_loc] + ["@placeholder"] + sent_tok[placeholder_loc:] - examples = [] data = [json.loads(d) for d in open(path, encoding="utf-8")] for item in data: @@ -255,10 +237,9 @@ def tokenize_preserve_placeholder(sent, max_ent_length): ) ent_idxs = item["passage"]["entities"] ents = [item["passage"]["text"][idx["start"] : idx["end"] + 1] for idx in ent_idxs] - max_ent_length = max([idx["end"] - idx["start"] + 1 for idx in ent_idxs]) qas = item["qas"] for qa in qas: - qst = tokenize_preserve_placeholder(qa["query"], max_ent_length) + qst = qa["query"] qst_id = qa["idx"] if "answers" in qa: anss = [a["text"] for a in qa["answers"]] @@ -311,9 +292,8 @@ def is_answer(x, ys): def insert_ent(ent, template): """ Replace ent into template (query with @placeholder) """ - assert "@placeholder" in template, "No placeholder detected!" - split_idx = template.index("@placeholder") - return template[:split_idx] + ent + template[split_idx + 1 :] + len(template.split("@placeholder")) == 2, "No placeholder detected!" + return template.replace("@placeholder", ent) def _make_instance(psg, qst, ans_str, label, psg_idx, qst_idx, ans_idx): """ pq_id: passage-question ID """ @@ -343,19 +323,17 @@ def _make_instance(psg, qst, ans_str, label, psg_idx, qst_idx, ans_idx): psg = example["passage"] qst_template = example["query"] - ent_strs = example["ents"] - ents = [ - tokenize_and_truncate(self._tokenizer_name, ent, self.max_seq_len) - for ent in ent_strs - ] + ents = example["ents"] anss = example["answers"] par_idx = example["psg_id"] qst_idx = example["qst_id"] - for ent_idx, (ent, ent_str) in enumerate(zip(ents, ent_strs)): - label = is_answer(ent_str, anss) - qst = insert_ent(ent, qst_template) - yield _make_instance(psg, qst, ent_str, label, par_idx, qst_idx, ent_idx) + for ent_idx, ent in enumerate(ents): + label = is_answer(ent, anss) + qst = tokenize_and_truncate( + self.tokenizer_name, insert_ent(ent, qst_template), self.max_seq_len + ) + yield _make_instance(psg, qst, ent, label, par_idx, qst_idx, ent_idx) def count_examples(self): """ Compute here b/c we're streaming the sentences. """ @@ -819,21 +797,18 @@ def remap_ptb_passage_and_answer_spans(ptb_tokens, answer_span, moses, tokenizer ) # We project the space-tokenized answer to processed-tokens (e.g. BERT). # The latter is used for training/predicting. - space_to_actual_token_map = create_tokenization_alignment( - tokens=detok_sent.split(), tokenizer_name=tokenizer_name - ) + aligner_fn = get_aligner_fn(tokenizer_name) + token_aligner, actual_tokens = aligner_fn(detok_sent) # space_processed_token_map is a list of tuples # (space_token, processed_token (e.g. BERT), space_token_index) # We will need this to map from token predictions to str spans - space_processed_token_map = [] - for i, (space_token, actual_token_ls) in enumerate(space_to_actual_token_map): - for actual_token in actual_token_ls: - space_processed_token_map.append((actual_token, space_token, i)) - ans_actual_token_span = ( - sum(len(_[1]) for _ in space_to_actual_token_map[: ans_space_token_span[0]]), - sum(len(_[1]) for _ in space_to_actual_token_map[: ans_space_token_span[1]]), - ) + space_processed_token_map = [ + (actual_tokens[actual_idx], space_token, space_idx) + for space_idx, (space_token, _, _) in enumerate(space_tokens_with_spans) + for actual_idx in token_aligner.project_tokens(space_idx) + ] + ans_actual_token_span = token_aligner.project_span(*ans_space_token_span) return { "detok_sent": detok_sent, diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py index 50eedfedb..2050ea7f6 100644 --- a/jiant/tasks/tasks.py +++ b/jiant/tasks/tasks.py @@ -36,6 +36,7 @@ load_pair_nli_jsonl, ) from jiant.utils.tokenizers import get_tokenizer +from jiant.utils.retokenize import get_aligner_fn from jiant.tasks.registry import register_task # global task registry from jiant.metrics.winogender_metrics import GenderParity from jiant.metrics.nli_metrics import NLITwoClassAccuracy @@ -2365,9 +2366,9 @@ class CCGTaggingTask(TaggingTask): def __init__(self, path, max_seq_len, name, tokenizer_name, **kw): """ There are 1363 supertags in CCGBank without introduced token. """ - from jiant.pytorch_transformers_interface import input_module_uses_pytorch_transformers + from jiant.huggingface_transformers_interface import input_module_uses_transformers - subword_tokenization = input_module_uses_pytorch_transformers(tokenizer_name) + subword_tokenization = input_module_uses_transformers(tokenizer_name) super().__init__( name, 1363 + int(subword_tokenization), tokenizer_name=tokenizer_name, **kw ) @@ -2721,16 +2722,18 @@ def load_data(self): trg_map = {"true": 1, "false": 0, True: 1, False: 0} + aligner_fn = get_aligner_fn(self._tokenizer_name) + def _process_preserving_word(sent, word): - """ Tokenize the subsequence before the [first] instance of the word and after, - then concatenate everything together. This allows us to track where in the tokenized - sequence the marked word is located. """ - sent_parts = sent.split(word) - sent_tok1 = tokenize_and_truncate(self._tokenizer_name, sent_parts[0], self.max_seq_len) - sent_mid = tokenize_and_truncate(self._tokenizer_name, word, self.max_seq_len) - sent_tok = tokenize_and_truncate(self._tokenizer_name, sent, self.max_seq_len) - start_idx = len(sent_tok1) - end_idx = start_idx + len(sent_mid) + """ Find out the index of the [first] instance of the word in the original sentence, + and project the span containing marked word to the span containing tokens created from + the marked word. """ + token_aligner, sent_tok = aligner_fn(sent) + raw_start_idx = len(sent.split(word)[0].split(" ")) - 1 + # after spliting, there could be three cases, 1. a tailing space, 2. characters in front + # of the keyword, 3. the sentence starts with the keyword + raw_end_idx = len(word.split()) + raw_start_idx + start_idx, end_idx = token_aligner.project_span(raw_start_idx, raw_end_idx) assert end_idx > start_idx, "Invalid marked word indices. Something is wrong." return sent_tok, start_idx, end_idx @@ -2786,7 +2789,7 @@ def _make_instance(input1, input2, idxs1, idxs2, labels, idx): inp, offset1, offset2 = model_preprocessing_interface.boundary_token_fn( input1, input2, get_offset=True ) - d["inputs"] = sentence_to_text_field(inp, indexers) + d["inputs"] = sentence_to_text_field(inp[: self.max_seq_len], indexers) else: inp1, offset1 = model_preprocessing_interface.boundary_token_fn( input1, get_offset=True @@ -2794,13 +2797,25 @@ def _make_instance(input1, input2, idxs1, idxs2, labels, idx): inp2, offset2 = model_preprocessing_interface.boundary_token_fn( input2, get_offset=True ) - d["input1"] = sentence_to_text_field(inp1, indexers) - d["input2"] = sentence_to_text_field(inp2, indexers) + d["input1"] = sentence_to_text_field(inp1[: self.max_seq_len], indexers) + d["input2"] = sentence_to_text_field(inp2[: self.max_seq_len], indexers) d["idx1"] = ListField( - [NumericField(i) for i in range(idxs1[0] + offset1, idxs1[1] + offset1)] + [ + NumericField(i) + for i in range( + min(idxs1[0] + offset1, self.max_seq_len - 1), + min(idxs1[1] + offset1, self.max_seq_len), + ) + ] ) d["idx2"] = ListField( - [NumericField(i) for i in range(idxs2[0] + offset2, idxs2[1] + offset2)] + [ + NumericField(i) + for i in range( + min(idxs2[0] + offset2, self.max_seq_len - 1), + min(idxs2[1] + offset2, self.max_seq_len), + ) + ] ) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) d["idx"] = LabelField(idx, label_namespace="idxs_tags", skip_indexing=True) diff --git a/jiant/utils/retokenize.py b/jiant/utils/retokenize.py index 6cde76df5..518d26cc1 100644 --- a/jiant/utils/retokenize.py +++ b/jiant/utils/retokenize.py @@ -95,35 +95,6 @@ def _mat_from_spans_sparse(spans: Sequence[Tuple[int, int]], n_chars: int) -> Ma return sparse.csr_matrix((data, (ridxs, cidxs)), shape=(len(spans), n_chars)) -def create_tokenization_alignment( - tokens: Sequence[str], tokenizer_name: str -) -> Sequence[Tuple[str, str]]: - """ - Builds alignment mapping between space tokenization and tokenization of - choice. - - Example: - Input: ['Larger', 'than', 'life.'] - Output: [('Larger', ['ĠL', 'arger']), ('than', ['Ġthan']), ('life.', ['Ġlife', '.'])] - - Parameters - ----------------------- - tokens: list[(str)]. list of tokens, - tokenizer_name: str - - Returns - ----------------------- - tokenization_mapping: list[(str, str)], list of tuples with (orig_token, tokenized_token). - - """ - tokenizer = get_tokenizer(tokenizer_name) - tokenization_mapping = [] - for tok in tokens: - aligned_tok = tokenizer.tokenize(tok) - tokenization_mapping.append((tok, aligned_tok)) - return tokenization_mapping - - def realign_spans(record, tokenizer_name): """ Builds the indices alignment while also tokenizing the input @@ -375,7 +346,7 @@ def align_wpm( def align_sentencepiece( text: Text, sentencepiece_tokenizer: Tokenizer ) -> Tuple[TokenAligner, List[Text]]: - """Alignment fn for SentencePiece Tokenizer, used in XLNET + """Alignment fn for SentencePiece Tokenizer, used in XLNET and ALBERT """ bow_tokens = space_tokenize_with_bow(text) sentencepiece_tokens = sentencepiece_tokenizer.tokenize(text) @@ -402,8 +373,12 @@ def align_bytebpe(text: Text, bytebpe_tokenizer: Tokenizer) -> Tuple[TokenAligne bow_tokens = space_tokenize_with_bow(text) bytebpe_tokens = bytebpe_tokenizer.tokenize(text) + if len(bytebpe_tokens) > 0: + bytebpe_tokens[0] = "Ġ" + bytebpe_tokens[0] modified_bytebpe_tokens = list(map(process_bytebpe_for_alignment, bytebpe_tokens)) ta = TokenAligner(bow_tokens, modified_bytebpe_tokens) + if len(bytebpe_tokens) > 0: + bytebpe_tokens[0] = re.sub(r"^Ġ", "", bytebpe_tokens[0]) return ta, bytebpe_tokens @@ -425,7 +400,7 @@ def get_aligner_fn(tokenizer_name: Text): elif tokenizer_name.startswith("openai-gpt") or tokenizer_name.startswith("xlm-mlm-en-"): bpe_tokenizer = get_tokenizer(tokenizer_name) return functools.partial(align_bpe, bpe_tokenizer=bpe_tokenizer) - elif tokenizer_name.startswith("xlnet-"): + elif tokenizer_name.startswith("xlnet-") or tokenizer_name.startswith("albert-"): sentencepiece_tokenizer = get_tokenizer(tokenizer_name) return functools.partial( align_sentencepiece, sentencepiece_tokenizer=sentencepiece_tokenizer diff --git a/jiant/utils/tokenizers.py b/jiant/utils/tokenizers.py index 1c6ba8497..4a071d0a6 100644 --- a/jiant/utils/tokenizers.py +++ b/jiant/utils/tokenizers.py @@ -10,10 +10,11 @@ from sacremoses import MosesDetokenizer from sacremoses import MosesTokenizer as SacreMosesTokenizer from nltk.tokenize.simple import SpaceTokenizer -from jiant.pytorch_transformers_interface import input_module_uses_pytorch_transformers -from pytorch_transformers import ( +from jiant.huggingface_transformers_interface import input_module_uses_transformers +from transformers import ( BertTokenizer, RobertaTokenizer, + AlbertTokenizer, XLNetTokenizer, OpenAIGPTTokenizer, GPT2Tokenizer, @@ -32,7 +33,7 @@ def select_tokenizer(args): Select a sane default tokenizer. """ if args.tokenizer == "auto": - if input_module_uses_pytorch_transformers(args.input_module): + if input_module_uses_transformers(args.input_module): tokenizer_name = args.input_module else: tokenizer_name = "MosesTokenizer" @@ -97,6 +98,8 @@ def get_tokenizer(tokenizer_name): tokenizer = BertTokenizer.from_pretrained(tokenizer_name, do_lower_case=do_lower_case) elif tokenizer_name.startswith("roberta-"): tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name) + elif tokenizer_name.startswith("albert-"): + tokenizer = AlbertTokenizer.from_pretrained(tokenizer_name) elif tokenizer_name.startswith("xlnet-"): do_lower_case = tokenizer_name.endswith("uncased") tokenizer = XLNetTokenizer.from_pretrained(tokenizer_name, do_lower_case=do_lower_case) diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index dc6f615ef..37e6afb32 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -82,6 +82,7 @@ def select_pool_type(args): if ( args.input_module.startswith("bert-") or args.input_module.startswith("roberta-") + or args.input_module.startswith("albert-") or args.input_module.startswith("xlm-") ): pool_type = "first" diff --git a/scripts/ccg/align_tags_to_bert.py b/scripts/ccg/align_tags_to_bert.py index 3abf8aedc..a28805005 100644 --- a/scripts/ccg/align_tags_to_bert.py +++ b/scripts/ccg/align_tags_to_bert.py @@ -5,8 +5,7 @@ import pandas as pd -from jiant import utils -from jiant.utils import retokenize +from jiant.utils.retokenize import get_aligner_fn """ @@ -30,22 +29,15 @@ def get_tags(text, current_tags, tokenizer_name, tag_dict): - aligner_fn = retokenize.get_aligner_fn(tokenizer_name) + aligner_fn = get_aligner_fn(tokenizer_name) assert len(text) == len(current_tags) - res_tags = [] introduced_tokenizer_tag = len(tag_dict) - for i in range(len(text)): - token = text[i] - _, new_toks = aligner_fn(token) - res_tags.append(tag_dict[current_tags[i]]) - if len(new_toks) > 1: - for tok in new_toks[1:]: - res_tags.append(introduced_tokenizer_tag) - # based on BERT-paper for wordpiece, we only keep the tag - # for the first part of the word. - _, aligned_text = aligner_fn(" ".join(text)) - assert len(aligned_text) == len(res_tags) - str_tags = [str(s) for s in res_tags] + token_aligner, aligned_text = aligner_fn(" ".join(text)) + aligned_tags = [introduced_tokenizer_tag for token in aligned_text] + for text_idx, text_tag in enumerate(current_tags): + aligned_idx = token_aligner.project_tokens(text_idx)[0] + aligned_tags[aligned_idx] = tag_dict[text_tag] + str_tags = [str(s) for s in aligned_tags] return " ".join(str_tags) @@ -80,9 +72,11 @@ def align_ccg(split, tokenizer_name, data_dir): None, saves tag alligned files to same directory as the original file. """ tags_to_id = json.load(open(data_dir + "tags_to_id.json", "r")) - ccg_text = pd.read_csv(data_dir + "ccg." + split, names=["text", "tags"], delimiter="\t") + ccg_text = pd.read_csv( + os.path.join(data_dir, "ccg." + split), names=["text", "tags"], delimiter="\t" + ) result = align_tags_BERT(ccg_text, tokenizer_name, tags_to_id) - result.to_csv(data_dir + "ccg." + split + "." + tokenizer_name, sep="\t") + result.to_csv(os.path.join(data_dir, "ccg." + split + "." + tokenizer_name), sep="\t") def main(arguments): diff --git a/scripts/demo.with_docker.sh b/scripts/demo.with_docker.sh index aadbc703e..cc4bbe3e4 100755 --- a/scripts/demo.with_docker.sh +++ b/scripts/demo.with_docker.sh @@ -42,6 +42,6 @@ sudo docker run --runtime=nvidia --rm \ -e "ELMO_SRC_DIR=$ELMO_SRC_DIR" \ -e "WORD_EMBS_FILE=$WORD_EMBS_FILE" \ -e "JIANT_PROJECT_PREFIX=$NFS_PATH/exp/$USER" \ - -e "PYTORCH_PRETRAINED_BERT_CACHE=$PYTORCH_PRETRAINED_BERT_CACHE" \ + -e "HUGGINGFACE_TRANSFORMERS_CACHE=$HUGGINGFACE_TRANSFORMERS_CACHE" \ --user $(id -u):$(id -g) \ -i ${IMAGE_NAME} "${COMMAND[@]}" diff --git a/scripts/edgeprobing/exp_fns.sh b/scripts/edgeprobing/exp_fns.sh index ac5f5e221..25d73fda4 100644 --- a/scripts/edgeprobing/exp_fns.sh +++ b/scripts/edgeprobing/exp_fns.sh @@ -162,7 +162,7 @@ function bert_cat_exp() { OVERRIDES="exp_name=bert-${2}-cat-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", pytorch_transformers_output_mode=cat" + OVERRIDES+=", transformers_output_mode=cat" run_exp "jiant/config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } @@ -172,7 +172,7 @@ function bert_lex_exp() { OVERRIDES="exp_name=bert-${2}-lex-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", pytorch_transformers_output_mode=only" + OVERRIDES+=", transformers_output_mode=only" run_exp "jiant/config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } @@ -182,7 +182,7 @@ function bert_mix_exp() { OVERRIDES="exp_name=bert-${2}-mix-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", pytorch_transformers_output_mode=mix" + OVERRIDES+=", transformers_output_mode=mix" run_exp "jiant/config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } @@ -194,8 +194,8 @@ function bert_mix_k_exp() { OVERRIDES="exp_name=bert-${2}-mix_${3}-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", pytorch_transformers_output_mode=mix" - OVERRIDES+=", pytorch_transformers_max_layer=${3}" + OVERRIDES+=", transformers_output_mode=mix" + OVERRIDES+=", transformers_max_layer=${3}" run_exp "jiant/config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } @@ -205,7 +205,7 @@ function bert_at_k_exp() { OVERRIDES="exp_name=bert-${2}-at_${3}-${1}, run_name=run" OVERRIDES+=", target_tasks=$1" OVERRIDES+=", input_module=bert-$2" - OVERRIDES+=", pytorch_transformers_output_mode=top" - OVERRIDES+=", pytorch_transformers_max_layer=${3}" + OVERRIDES+=", transformers_output_mode=top" + OVERRIDES+=", transformers_max_layer=${3}" run_exp "jiant/config/edgeprobe/edgeprobe_bert.conf" "${OVERRIDES}" } diff --git a/setup.py b/setup.py index 4929eb732..1c2c55c0e 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ "jiant.modules", "jiant.modules.onlstm", "jiant.modules.prpn", - "jiant.pytorch_transformers_interface", + "jiant.huggingface_transformers_interface", "jiant.tasks", "jiant.utils", ], @@ -65,7 +65,7 @@ "pyhocon==0.3.35", "python-Levenshtein==0.12.0", "sacremoses", - "pytorch-transformers==1.2.0", + "transformers==2.3.0", "ftfy", "spacy", ], diff --git a/tests/test_pytorch_transformers_interface.py b/tests/test_huggingface_transformers_interface.py similarity index 90% rename from tests/test_pytorch_transformers_interface.py rename to tests/test_huggingface_transformers_interface.py index 4f422b4ff..de5e0dc7e 100644 --- a/tests/test_pytorch_transformers_interface.py +++ b/tests/test_huggingface_transformers_interface.py @@ -2,10 +2,11 @@ from unittest import mock import torch import copy -from jiant.pytorch_transformers_interface.modules import ( - PytorchTransformersEmbedderModule, +from jiant.huggingface_transformers_interface.modules import ( + HuggingfaceTransformersEmbedderModule, BertEmbedderModule, RobertaEmbedderModule, + AlbertEmbedderModule, XLNetEmbedderModule, OpenAIGPTEmbedderModule, GPT2EmbedderModule, @@ -14,7 +15,7 @@ ) -class TestPytorchTransformersInterface(unittest.TestCase): +class TestHuggingfaceTransformersInterface(unittest.TestCase): def test_bert_apply_boundary_tokens(self): s1 = ["A", "B", "C"] s2 = ["D", "E"] @@ -37,6 +38,17 @@ def test_roberta_apply_boundary_tokens(self): ["", "A", "B", "C", "", "", "D", "E", ""], ) + def test_albert_apply_boundary_tokens(self): + s1 = ["A", "B", "C"] + s2 = ["D", "E"] + self.assertListEqual( + AlbertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"] + ) + self.assertListEqual( + AlbertEmbedderModule.apply_boundary_tokens(s1, s2), + ["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"], + ) + def test_xlnet_apply_boundary_tokens(self): s1 = ["A", "B", "C"] s2 = ["D", "E"] @@ -77,7 +89,7 @@ def test_correct_sent_indexing(self): model._unk_id = 10 model.max_pos = None model.tokenizer_required = "correct_tokenizer" - model.correct_sent_indexing = PytorchTransformersEmbedderModule.correct_sent_indexing + model.correct_sent_indexing = HuggingfaceTransformersEmbedderModule.correct_sent_indexing allenNLP_indexed = torch.LongTensor([[7, 10, 5, 11, 1, 13, 5], [7, 10, 11, 5, 1, 5, 0]]) diff --git a/tests/test_retokenize.py b/tests/test_retokenize.py index daffb92e6..b512651a9 100644 --- a/tests/test_retokenize.py +++ b/tests/test_retokenize.py @@ -57,15 +57,15 @@ def test_moses(self): ] aligner_fn = retokenize.get_aligner_fn("transfo-xl-wt103") - tas, tokens = zip(*(aligner_fn(sent) for sent in self.text)) - tas, tokens = list(tas), list(tokens) + token_aligners, tokens = zip(*(aligner_fn(sent) for sent in self.text)) + token_aligners, tokens = list(token_aligners), list(tokens) token_index_tgt = [ - [ta.project_tokens(idxs).tolist() for idxs in token_idxs] - for ta, token_idxs in zip(tas, self.token_index_src) + [token_aligner.project_tokens(idxs).tolist() for idxs in token_idxs] + for token_aligner, token_idxs in zip(token_aligners, self.token_index_src) ] span_index_tgt = [ - [ta.project_span(start, end) for (start, end) in span_idxs] - for ta, span_idxs in zip(tas, self.span_index_src) + [token_aligner.project_span(start, end) for (start, end) in span_idxs] + for token_aligner, span_idxs in zip(token_aligners, self.span_index_src) ] assert self.tokens == tokens assert self.token_index_tgt == token_index_tgt @@ -126,21 +126,16 @@ def test_wpm(self): ] aligner_fn = retokenize.get_aligner_fn("bert-base-cased") - tas, tokens = zip(*(aligner_fn(sent) for sent in self.text)) - tas, tokens = list(tas), list(tokens) + token_aligners, tokens = zip(*(aligner_fn(sent) for sent in self.text)) + token_aligners, tokens = list(token_aligners), list(tokens) token_index_tgt = [ - [ta.project_tokens(idxs).tolist() for idxs in token_idxs] - for ta, token_idxs in zip(tas, self.token_index_src) + [token_aligner.project_tokens(idxs).tolist() for idxs in token_idxs] + for token_aligner, token_idxs in zip(token_aligners, self.token_index_src) ] span_index_tgt = [ - [ta.project_span(start, end) for (start, end) in span_idxs] - for ta, span_idxs in zip(tas, self.span_index_src) + [token_aligner.project_span(start, end) for (start, end) in span_idxs] + for token_aligner, span_idxs in zip(token_aligners, self.span_index_src) ] - orig_tokens = self.text[0].split() - alignment_map = retokenize.create_tokenization_alignment(orig_tokens, "bert-base-cased") - wpm_tokens = self.tokens[0] - for i, v in enumerate(alignment_map): - assert v[0] == orig_tokens[i] and ",".join(v[1]) == wpm_tokens[i] assert self.tokens == tokens assert self.token_index_tgt == token_index_tgt assert self.span_index_tgt == span_index_tgt @@ -204,21 +199,16 @@ def test_bpe(self): ] aligner_fn = retokenize.get_aligner_fn("openai-gpt") - tas, tokens = zip(*(aligner_fn(sent) for sent in self.text)) - tas, tokens = list(tas), list(tokens) + token_aligners, tokens = zip(*(aligner_fn(sent) for sent in self.text)) + token_aligners, tokens = list(token_aligners), list(tokens) token_index_tgt = [ - [ta.project_tokens(idxs).tolist() for idxs in token_idxs] - for ta, token_idxs in zip(tas, self.token_index_src) + [token_aligner.project_tokens(idxs).tolist() for idxs in token_idxs] + for token_aligner, token_idxs in zip(token_aligners, self.token_index_src) ] span_index_tgt = [ - [ta.project_span(start, end) for (start, end) in span_idxs] - for ta, span_idxs in zip(tas, self.span_index_src) + [token_aligner.project_span(start, end) for (start, end) in span_idxs] + for token_aligner, span_idxs in zip(token_aligners, self.span_index_src) ] - orig_tokens = self.text[0].split() - alignment_map = retokenize.create_tokenization_alignment(orig_tokens, "openai-gpt") - bpe_tokens = self.tokens[0] - for i, v in enumerate(alignment_map): - assert v[0] == orig_tokens[i] and ",".join(v[1]) == bpe_tokens[i] assert self.tokens == tokens assert self.token_index_tgt == token_index_tgt assert self.span_index_tgt == span_index_tgt @@ -292,31 +282,26 @@ def test_sentencepiece(self): ] aligner_fn = retokenize.get_aligner_fn("xlnet-base-cased") - tas, tokens = zip(*(aligner_fn(sent) for sent in self.text)) - tas, tokens = list(tas), list(tokens) + token_aligners, tokens = zip(*(aligner_fn(sent) for sent in self.text)) + token_aligners, tokens = list(token_aligners), list(tokens) token_index_tgt = [ - [ta.project_tokens(idxs).tolist() for idxs in token_idxs] - for ta, token_idxs in zip(tas, self.token_index_src) + [token_aligner.project_tokens(idxs).tolist() for idxs in token_idxs] + for token_aligner, token_idxs in zip(token_aligners, self.token_index_src) ] span_index_tgt = [ - [ta.project_span(start, end) for (start, end) in span_idxs] - for ta, span_idxs in zip(tas, self.span_index_src) + [token_aligner.project_span(start, end) for (start, end) in span_idxs] + for token_aligner, span_idxs in zip(token_aligners, self.span_index_src) ] - orig_tokens = self.text[0].split() - alignment_map = retokenize.create_tokenization_alignment(orig_tokens, "xlnet-base-cased") - se_tokens = self.tokens[0] - for i, v in enumerate(alignment_map): - assert v[0] == orig_tokens[i] and ",".join(v[1]) == se_tokens[i] assert self.tokens == tokens assert self.token_index_tgt == token_index_tgt assert self.span_index_tgt == span_index_tgt def test_bytebpe(self): self.tokens = [ - ["ĠMembers", "Ġof", "Ġthe", "ĠHouse", "Ġcl", "apped", "Ġtheir", "Ġhands"], - ["ĠI", "Ġlook", "Ġat", "ĠSarah", "'s", "Ġdog", ".", "ĠIt", "Ġwas", "Ġcute", ".", "!"], + ["Members", "Ġof", "Ġthe", "ĠHouse", "Ġcl", "apped", "Ġtheir", "Ġhands"], + ["I", "Ġlook", "Ġat", "ĠSarah", "'s", "Ġdog", ".", "ĠIt", "Ġwas", "Ġcute", ".", "!"], [ - "ĠMr", + "Mr", ".", "ĠImm", "elt", @@ -333,7 +318,7 @@ def test_bytebpe(self): "Ġrules", ".", ], - ["ĠWhat", "?"], + ["What", "?"], ] self.token_index_tgt = [ [[0], [1], [2], [3], [4, 5], [6], [7]], @@ -349,21 +334,16 @@ def test_bytebpe(self): ] aligner_fn = retokenize.get_aligner_fn("roberta-base") - tas, tokens = zip(*(aligner_fn(sent) for sent in self.text)) - tas, tokens = list(tas), list(tokens) + token_aligners, tokens = zip(*(aligner_fn(sent) for sent in self.text)) + token_aligners, tokens = list(token_aligners), list(tokens) token_index_tgt = [ - [ta.project_tokens(idxs).tolist() for idxs in token_idxs] - for ta, token_idxs in zip(tas, self.token_index_src) + [token_aligner.project_tokens(idxs).tolist() for idxs in token_idxs] + for token_aligner, token_idxs in zip(token_aligners, self.token_index_src) ] span_index_tgt = [ - [ta.project_span(start, end) for (start, end) in span_idxs] - for ta, span_idxs in zip(tas, self.span_index_src) + [token_aligner.project_span(start, end) for (start, end) in span_idxs] + for token_aligner, span_idxs in zip(token_aligners, self.span_index_src) ] - orig_tokens = self.text[0].split() - alignment_map = retokenize.create_tokenization_alignment(orig_tokens, "roberta-base") - bytebpe_tokens = ["ĠMembers", "Ġof", "Ġthe", "ĠHouse", "Ġcl,apped", "Ġtheir", "Ġhands"] - for i, v in enumerate(alignment_map): - assert v[0] == orig_tokens[i] and ",".join(v[1]) == bytebpe_tokens[i] assert self.tokens == tokens assert self.token_index_tgt == token_index_tgt assert self.span_index_tgt == span_index_tgt diff --git a/tutorials/setup_tutorial.md b/tutorials/setup_tutorial.md index 6333e42df..b1abcb0da 100644 --- a/tutorials/setup_tutorial.md +++ b/tutorials/setup_tutorial.md @@ -61,7 +61,7 @@ And the next time you start a notebook server, you should see `jiant` as an opti ### Optional -If you'll be using GPT, BERT, or other models supplied by `pytorch-transformers`, then you may see speed gains from installing NVIDIA apex, following the instructions here: +If you'll be using GPT, BERT, or other models supplied by `transformers`, then you may see speed gains from installing NVIDIA apex, following the instructions here: https://github.com/NVIDIA/apex#linux