From 20913e437c3d87e347ef219efc07a2043bee5774 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Thu, 14 Jan 2021 10:48:30 -0800 Subject: [PATCH 1/5] update to transformers v4.2.1 --- jiant/proj/main/modeling/heads.py | 14 +++++++++----- jiant/proj/main/modeling/taskmodels.py | 16 +++++++++++++--- jiant/tasks/lib/ropes.py | 4 +++- jiant/tasks/lib/templates/squad_style/core.py | 10 ++++++---- jiant/tasks/lib/templates/squad_style/utils.py | 2 +- requirements-no-torch.txt | 4 ++-- setup.py | 4 ++-- 7 files changed, 36 insertions(+), 18 deletions(-) diff --git a/jiant/proj/main/modeling/heads.py b/jiant/proj/main/modeling/heads.py index 38b637c99..b68588282 100644 --- a/jiant/proj/main/modeling/heads.py +++ b/jiant/proj/main/modeling/heads.py @@ -108,8 +108,10 @@ class BertMLMHead(BaseMLMHead): def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12, hidden_act="gelu"): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) - self.transform_act_fn = transformers.modeling_bert.ACT2FN[hidden_act] - self.LayerNorm = transformers.modeling_bert.BertLayerNorm(hidden_size, eps=layer_norm_eps) + self.transform_act_fn = transformers.models.bert.modeling_bert.ACT2FN[hidden_act] + self.LayerNorm = transformers.models.bert.modeling_bert.BertLayerNorm( + hidden_size, eps=layer_norm_eps + ) self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True) @@ -132,7 +134,9 @@ class RobertaMLMHead(BaseMLMHead): def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) - self.layer_norm = transformers.modeling_bert.BertLayerNorm(hidden_size, eps=layer_norm_eps) + self.layer_norm = transformers.models.bert.modeling_bert.BertLayerNorm( + hidden_size, eps=layer_norm_eps + ) self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True) @@ -143,7 +147,7 @@ def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12): def forward(self, unpooled): x = self.dense(unpooled) - x = transformers.modeling_bert.gelu(x) + x = transformers.models.bert.modeling_bert.gelu(x) x = self.layer_norm(x) # project back to size of vocabulary with bias @@ -161,7 +165,7 @@ def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"): self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True) self.dense = nn.Linear(hidden_size, embedding_size) self.decoder = nn.Linear(embedding_size, vocab_size) - self.activation = transformers.modeling_bert.ACT2FN[hidden_act] + self.activation = transformers.models.bert.modeling_bert.ACT2FN[hidden_act] # Need a link between the two variables so that the bias is correctly resized with # `resize_token_embeddings` diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index f656229d9..67bd06145 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -334,7 +334,12 @@ def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> Enco def get_output_from_standard_transformer_models(encoder, input_ids, segment_ids, input_mask): - output = encoder(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask) + output = encoder( + input_ids=input_ids, + token_type_ids=segment_ids, + attention_mask=input_mask, + return_dict=False, + ) pooled, unpooled, other = output[1], output[0], output[2:] return pooled, unpooled, other @@ -347,7 +352,7 @@ def get_output_from_bart_models(encoder, input_ids, input_mask): # sentence representation is the final decoder state. # That's what we use for `unpooled` here. dec_last, dec_all, enc_last, enc_all = encoder( - input_ids=input_ids, attention_mask=input_mask, output_hidden_states=True, + input_ids=input_ids, attention_mask=input_mask, output_hidden_states=True, return_dict=False ) unpooled = dec_last @@ -361,7 +366,12 @@ def get_output_from_bart_models(encoder, input_ids, input_mask): def get_output_from_electra(encoder, input_ids, segment_ids, input_mask): - output = encoder(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask) + output = encoder( + input_ids=input_ids, + token_type_ids=segment_ids, + attention_mask=input_mask, + return_dict=False, + ) unpooled = output[0] pooled = unpooled[:, 0, :] return pooled, unpooled, output diff --git a/jiant/tasks/lib/ropes.py b/jiant/tasks/lib/ropes.py index 74c823ad7..8c022bf0e 100644 --- a/jiant/tasks/lib/ropes.py +++ b/jiant/tasks/lib/ropes.py @@ -90,7 +90,9 @@ def to_feature_list( # (This may not apply for future added models that don't start with a CLS token, # such as XLNet/GPT-2) sequence_added_tokens = 1 - sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair + sequence_pair_added_tokens = ( + tokenizer.model_max_length - tokenizer.model_max_length_sentences_pair + ) span_doc_tokens = all_doc_tokens while len(spans) * doc_stride < len(all_doc_tokens): diff --git a/jiant/tasks/lib/templates/squad_style/core.py b/jiant/tasks/lib/templates/squad_style/core.py index b19e216f1..b57167e2b 100644 --- a/jiant/tasks/lib/templates/squad_style/core.py +++ b/jiant/tasks/lib/templates/squad_style/core.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Union, List, Dict, Optional -from transformers.tokenization_bert import whitespace_tokenize +from transformers.models.bert.tokenization_bert import whitespace_tokenize from jiant.tasks.lib.templates.squad_style import utils as squad_utils from jiant.shared.constants import PHASE @@ -144,11 +144,13 @@ def to_feature_list( # in the way they compute mask of added tokens. tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower() sequence_added_tokens = ( - tokenizer.max_len - tokenizer.max_len_single_sentence + 1 + tokenizer.model_max_length - tokenizer.model_max_length_single_sentence + 1 if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET - else tokenizer.max_len - tokenizer.max_len_single_sentence + else tokenizer.model_max_length - tokenizer.model_max_length_single_sentence + ) + sequence_pair_added_tokens = ( + tokenizer.model_max_length - tokenizer.model_max_length_sentences_pair ) - sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair span_doc_tokens = all_doc_tokens while len(spans) * doc_stride < len(all_doc_tokens): diff --git a/jiant/tasks/lib/templates/squad_style/utils.py b/jiant/tasks/lib/templates/squad_style/utils.py index bc84fc9eb..9cd35daa0 100644 --- a/jiant/tasks/lib/templates/squad_style/utils.py +++ b/jiant/tasks/lib/templates/squad_style/utils.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import List, Dict -from transformers.tokenization_bert import BasicTokenizer +from transformers.models.bert.tokenization_bert import BasicTokenizer from jiant.utils.display import maybe_tqdm diff --git a/requirements-no-torch.txt b/requirements-no-torch.txt index 43b9b2bd2..65aa90de2 100644 --- a/requirements-no-torch.txt +++ b/requirements-no-torch.txt @@ -13,6 +13,6 @@ seqeval==0.0.12 scikit-learn==0.22.2.post1 scipy==1.4.1 sentencepiece==0.1.86 -tokenizers==0.8.1.rc2 +tokenizers==0.9.4 tqdm==4.46.0 -transformers==3.1.0 +transformers==4.2.1 diff --git a/setup.py b/setup.py index ed46c0070..01c4cb51b 100644 --- a/setup.py +++ b/setup.py @@ -72,10 +72,10 @@ "scikit-learn == 0.22.2.post1", "scipy == 1.4.1", "sentencepiece == 0.1.86", - "tokenizers == 0.8.1.rc2", + "tokenizers == 0.9.4", "torch >= 1.5.0", "tqdm == 4.46.0", - "transformers == 3.1.0", + "transformers == 4.2.1", "torchvision == 0.6.0", ], extras_require=extras, From 816496b483cf9c29e30163693354f5bbf3559cf7 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Mon, 18 Jan 2021 20:36:52 -0800 Subject: [PATCH 2/5] use default return_dict in taskmodels and remove hidden state context manager in models. --- jiant/proj/main/modeling/taskmodels.py | 55 +++++++++++++++++--------- jiant/utils/transformer_utils.py | 32 --------------- 2 files changed, 37 insertions(+), 50 deletions(-) delete mode 100644 jiant/utils/transformer_utils.py diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index 67bd06145..68ad48b42 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -6,7 +6,6 @@ import torch.nn as nn import jiant.proj.main.modeling.heads as heads -import jiant.utils.transformer_utils as transformer_utils from jiant.proj.main.components.outputs import LogitsOutput, LogitsAndLossOutput from jiant.utils.python.datastructures import take_one from jiant.shared.model_setup import ModelArchitectures @@ -234,8 +233,8 @@ def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer): self.layer = layer def forward(self, batch, task, tokenizer, compute_loss: bool = False): - with transformer_utils.output_hidden_states_context(self.encoder): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) + encoder_output = get_output_from_encoder_and_batch( + encoder=self.encoder, batch=batch, output_hidden_states=True) # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] @@ -267,7 +266,7 @@ class EncoderOutput: # Extend later with attention, hidden_acts, etc -def get_output_from_encoder_and_batch(encoder, batch) -> EncoderOutput: +def get_output_from_encoder_and_batch(encoder, batch, output_hidden_states=False) -> EncoderOutput: """Pass batch to encoder, return encoder model output. Args: @@ -283,10 +282,13 @@ def get_output_from_encoder_and_batch(encoder, batch) -> EncoderOutput: input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + output_hidden_states=output_hidden_states, ) -def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> EncoderOutput: +def get_output_from_encoder( + encoder, input_ids, segment_ids, input_mask, output_hidden_states=False +) -> EncoderOutput: """Pass inputs to encoder, return encoder output. Args: @@ -310,18 +312,29 @@ def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> Enco ModelArchitectures.XLM_ROBERTA, ]: pooled, unpooled, other = get_output_from_standard_transformer_models( - encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, + encoder=encoder, + input_ids=input_ids, + segment_ids=segment_ids, + input_mask=input_mask, + output_hidden_states=output_hidden_states, ) elif model_arch == ModelArchitectures.ELECTRA: pooled, unpooled, other = get_output_from_electra( - encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, + encoder=encoder, + input_ids=input_ids, + segment_ids=segment_ids, + input_mask=input_mask, + output_hidden_states=output_hidden_states, ) elif model_arch in [ ModelArchitectures.BART, ModelArchitectures.MBART, ]: pooled, unpooled, other = get_output_from_bart_models( - encoder=encoder, input_ids=input_ids, input_mask=input_mask, + encoder=encoder, + input_ids=input_ids, + input_mask=input_mask, + output_hidden_states=output_hidden_states, ) else: raise KeyError(model_arch) @@ -333,28 +346,34 @@ def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> Enco return EncoderOutput(pooled=pooled, unpooled=unpooled) -def get_output_from_standard_transformer_models(encoder, input_ids, segment_ids, input_mask): +def get_output_from_standard_transformer_models( + encoder, input_ids, segment_ids, input_mask, output_hidden_states=False +): output = encoder( input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, - return_dict=False, + output_hidden_states=output_hidden_states, ) - pooled, unpooled, other = output[1], output[0], output[2:] - return pooled, unpooled, other + return output.pooler_output, output.last_hidden_state, output -def get_output_from_bart_models(encoder, input_ids, input_mask): +def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_states=False): # BART and mBART and encoder-decoder architectures. # As described in the BART paper and implemented in Transformers, # for single input tasks, the encoder input is the sequence, # the decode input is 1-shifted sequence, and the resulting # sentence representation is the final decoder state. # That's what we use for `unpooled` here. - dec_last, dec_all, enc_last, enc_all = encoder( - input_ids=input_ids, attention_mask=input_mask, output_hidden_states=True, return_dict=False + output = encoder( + input_ids=input_ids, attention_mask=input_mask, output_hidden_states=output_hidden_states, ) - unpooled = dec_last + dec_last = output.last_hidden_state + dec_all = output.decoder_hidden_states + enc_last = output.encoder_last_hidden_state + enc_all = output.encoder_hidden_states + + unpooled = output other = (enc_all + dec_all,) @@ -370,9 +389,9 @@ def get_output_from_electra(encoder, input_ids, segment_ids, input_mask): input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, - return_dict=False, + output_hidden_states=output_hidden_states, ) - unpooled = output[0] + unpooled = output.last_hidden_state pooled = unpooled[:, 0, :] return pooled, unpooled, output diff --git a/jiant/utils/transformer_utils.py b/jiant/utils/transformer_utils.py deleted file mode 100644 index 844b6c368..000000000 --- a/jiant/utils/transformer_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -import contextlib - -from jiant.shared.model_resolution import ModelArchitectures - - -@contextlib.contextmanager -def output_hidden_states_context(encoder): - model_arch = ModelArchitectures.from_encoder(encoder) - if model_arch in ( - ModelArchitectures.BERT, - ModelArchitectures.ROBERTA, - ModelArchitectures.ALBERT, - ModelArchitectures.XLM_ROBERTA, - ModelArchitectures.ELECTRA, - ): - if hasattr(encoder.encoder, "output_hidden_states"): - # Transformers < v2 - modified_obj = encoder.encoder - elif hasattr(encoder.encoder.config, "output_hidden_states"): - # Transformers >= v3 - modified_obj = encoder.encoder.config - else: - raise RuntimeError(f"Failed to convert model {type(encoder)} to output hidden states") - old_value = modified_obj.output_hidden_states - modified_obj.output_hidden_states = True - yield - modified_obj.output_hidden_states = old_value - elif model_arch in (ModelArchitectures.BART, ModelArchitectures.MBART): - yield - return - else: - raise KeyError(model_arch) From d6d989cbfe03fcbf2621c238951c0ef8ec282930 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Mon, 18 Jan 2021 21:31:43 -0800 Subject: [PATCH 3/5] return hidden states in output of model wrapper --- jiant/proj/main/modeling/taskmodels.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index 68ad48b42..7bc1d3bb4 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -234,7 +234,8 @@ def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer): def forward(self, batch, task, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch( - encoder=self.encoder, batch=batch, output_hidden_states=True) + encoder=self.encoder, batch=batch, output_hidden_states=True + ) # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] @@ -355,7 +356,7 @@ def get_output_from_standard_transformer_models( attention_mask=input_mask, output_hidden_states=output_hidden_states, ) - return output.pooler_output, output.last_hidden_state, output + return output.pooler_output, output.last_hidden_state, output.hidden_states def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_states=False): @@ -368,32 +369,30 @@ def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_st output = encoder( input_ids=input_ids, attention_mask=input_mask, output_hidden_states=output_hidden_states, ) - dec_last = output.last_hidden_state dec_all = output.decoder_hidden_states - enc_last = output.encoder_last_hidden_state enc_all = output.encoder_hidden_states unpooled = output - other = (enc_all + dec_all,) + hidden_states = (enc_all + dec_all,) bsize, slen = input_ids.shape batch_idx = torch.arange(bsize).to(input_ids.device) # Get last non-pad index pooled = unpooled[batch_idx, slen - input_ids.eq(encoder.config.pad_token_id).sum(1) - 1] - return pooled, unpooled, other + return pooled, unpooled, hidden_states -def get_output_from_electra(encoder, input_ids, segment_ids, input_mask): +def get_output_from_electra(encoder, input_ids, segment_ids, input_mask, output_hidden_states=False): output = encoder( input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, output_hidden_states=output_hidden_states, ) - unpooled = output.last_hidden_state + unpooled = output.hidden_states pooled = unpooled[:, 0, :] - return pooled, unpooled, output + return pooled, unpooled, output.hidden_states def compute_mlm_loss(logits, masked_lm_labels): From f912c231056607c150dd9f736a27165c08aba13f Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Thu, 25 Feb 2021 09:48:57 -0500 Subject: [PATCH 4/5] update to transformers 4.3.3 --- requirements-no-torch.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-no-torch.txt b/requirements-no-torch.txt index 65aa90de2..c9257de9f 100644 --- a/requirements-no-torch.txt +++ b/requirements-no-torch.txt @@ -15,4 +15,4 @@ scipy==1.4.1 sentencepiece==0.1.86 tokenizers==0.9.4 tqdm==4.46.0 -transformers==4.2.1 +transformers==4.3.3 diff --git a/setup.py b/setup.py index 01c4cb51b..31f16f366 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ "tokenizers == 0.9.4", "torch >= 1.5.0", "tqdm == 4.46.0", - "transformers == 4.2.1", + "transformers == 4.3.3", "torchvision == 0.6.0", ], extras_require=extras, From 6404f1365acb819a5e708a49d434327f22e50986 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Thu, 25 Feb 2021 09:55:53 -0500 Subject: [PATCH 5/5] black --- jiant/proj/main/modeling/taskmodels.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index 7bc1d3bb4..bdf2c3bc8 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -383,7 +383,9 @@ def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_st return pooled, unpooled, hidden_states -def get_output_from_electra(encoder, input_ids, segment_ids, input_mask, output_hidden_states=False): +def get_output_from_electra( + encoder, input_ids, segment_ids, input_mask, output_hidden_states=False +): output = encoder( input_ids=input_ids, token_type_ids=segment_ids,