Skip to content

Commit

Permalink
Update to Transformers v4.3.3 (#1266)
Browse files Browse the repository at this point in the history
* use default return_dict in taskmodels and remove hidden state context manager in models.

* return hidden states in output of model wrapper

* update to transformers 4.3.3

* black
  • Loading branch information
jeswan authored Feb 25, 2021
1 parent 84f2f5a commit b2cfb2a
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 69 deletions.
14 changes: 9 additions & 5 deletions jiant/proj/main/modeling/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ class BertMLMHead(BaseMLMHead):
def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12, hidden_act="gelu"):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.transform_act_fn = transformers.modeling_bert.ACT2FN[hidden_act]
self.LayerNorm = transformers.modeling_bert.BertLayerNorm(hidden_size, eps=layer_norm_eps)
self.transform_act_fn = transformers.models.bert.modeling_bert.ACT2FN[hidden_act]
self.LayerNorm = transformers.models.bert.modeling_bert.BertLayerNorm(
hidden_size, eps=layer_norm_eps
)

self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True)
Expand All @@ -132,7 +134,9 @@ class RobertaMLMHead(BaseMLMHead):
def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.layer_norm = transformers.modeling_bert.BertLayerNorm(hidden_size, eps=layer_norm_eps)
self.layer_norm = transformers.models.bert.modeling_bert.BertLayerNorm(
hidden_size, eps=layer_norm_eps
)

self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True)
Expand All @@ -143,7 +147,7 @@ def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12):

def forward(self, unpooled):
x = self.dense(unpooled)
x = transformers.modeling_bert.gelu(x)
x = transformers.models.bert.modeling_bert.gelu(x)
x = self.layer_norm(x)

# project back to size of vocabulary with bias
Expand All @@ -161,7 +165,7 @@ def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"):
self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True)
self.dense = nn.Linear(hidden_size, embedding_size)
self.decoder = nn.Linear(embedding_size, vocab_size)
self.activation = transformers.modeling_bert.ACT2FN[hidden_act]
self.activation = transformers.models.bert.modeling_bert.ACT2FN[hidden_act]

# Need a link between the two variables so that the bias is correctly resized with
# `resize_token_embeddings`
Expand Down
74 changes: 52 additions & 22 deletions jiant/proj/main/modeling/taskmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.nn as nn

import jiant.proj.main.modeling.heads as heads
import jiant.utils.transformer_utils as transformer_utils
from jiant.proj.main.components.outputs import LogitsOutput, LogitsAndLossOutput
from jiant.utils.python.datastructures import take_one
from jiant.shared.model_resolution import ModelArchitectures
Expand Down Expand Up @@ -234,8 +233,9 @@ def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer):
self.layer = layer

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
with transformer_utils.output_hidden_states_context(self.encoder):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
encoder_output = get_output_from_encoder_and_batch(
encoder=self.encoder, batch=batch, output_hidden_states=True
)
# A tuple of layers of hidden states
hidden_states = take_one(encoder_output.other)
layer_hidden_states = hidden_states[self.layer]
Expand Down Expand Up @@ -267,7 +267,7 @@ class EncoderOutput:
# Extend later with attention, hidden_acts, etc


def get_output_from_encoder_and_batch(encoder, batch) -> EncoderOutput:
def get_output_from_encoder_and_batch(encoder, batch, output_hidden_states=False) -> EncoderOutput:
"""Pass batch to encoder, return encoder model output.
Args:
Expand All @@ -283,10 +283,13 @@ def get_output_from_encoder_and_batch(encoder, batch) -> EncoderOutput:
input_ids=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask,
output_hidden_states=output_hidden_states,
)


def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> EncoderOutput:
def get_output_from_encoder(
encoder, input_ids, segment_ids, input_mask, output_hidden_states=False
) -> EncoderOutput:
"""Pass inputs to encoder, return encoder output.
Args:
Expand All @@ -310,18 +313,29 @@ def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> Enco
ModelArchitectures.XLM_ROBERTA,
]:
pooled, unpooled, other = get_output_from_standard_transformer_models(
encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask,
encoder=encoder,
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask,
output_hidden_states=output_hidden_states,
)
elif model_arch == ModelArchitectures.ELECTRA:
pooled, unpooled, other = get_output_from_electra(
encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask,
encoder=encoder,
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask,
output_hidden_states=output_hidden_states,
)
elif model_arch in [
ModelArchitectures.BART,
ModelArchitectures.MBART,
]:
pooled, unpooled, other = get_output_from_bart_models(
encoder=encoder, input_ids=input_ids, input_mask=input_mask,
encoder=encoder,
input_ids=input_ids,
input_mask=input_mask,
output_hidden_states=output_hidden_states,
)
else:
raise KeyError(model_arch)
Expand All @@ -333,38 +347,54 @@ def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> Enco
return EncoderOutput(pooled=pooled, unpooled=unpooled)


def get_output_from_standard_transformer_models(encoder, input_ids, segment_ids, input_mask):
output = encoder(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
pooled, unpooled, other = output[1], output[0], output[2:]
return pooled, unpooled, other
def get_output_from_standard_transformer_models(
encoder, input_ids, segment_ids, input_mask, output_hidden_states=False
):
output = encoder(
input_ids=input_ids,
token_type_ids=segment_ids,
attention_mask=input_mask,
output_hidden_states=output_hidden_states,
)
return output.pooler_output, output.last_hidden_state, output.hidden_states


def get_output_from_bart_models(encoder, input_ids, input_mask):
def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_states=False):
# BART and mBART and encoder-decoder architectures.
# As described in the BART paper and implemented in Transformers,
# for single input tasks, the encoder input is the sequence,
# the decode input is 1-shifted sequence, and the resulting
# sentence representation is the final decoder state.
# That's what we use for `unpooled` here.
dec_last, dec_all, enc_last, enc_all = encoder(
input_ids=input_ids, attention_mask=input_mask, output_hidden_states=True,
output = encoder(
input_ids=input_ids, attention_mask=input_mask, output_hidden_states=output_hidden_states,
)
unpooled = dec_last
dec_all = output.decoder_hidden_states
enc_all = output.encoder_hidden_states

other = (enc_all + dec_all,)
unpooled = output

hidden_states = (enc_all + dec_all,)

bsize, slen = input_ids.shape
batch_idx = torch.arange(bsize).to(input_ids.device)
# Get last non-pad index
pooled = unpooled[batch_idx, slen - input_ids.eq(encoder.config.pad_token_id).sum(1) - 1]
return pooled, unpooled, other
return pooled, unpooled, hidden_states


def get_output_from_electra(encoder, input_ids, segment_ids, input_mask):
output = encoder(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
unpooled = output[0]
def get_output_from_electra(
encoder, input_ids, segment_ids, input_mask, output_hidden_states=False
):
output = encoder(
input_ids=input_ids,
token_type_ids=segment_ids,
attention_mask=input_mask,
output_hidden_states=output_hidden_states,
)
unpooled = output.hidden_states
pooled = unpooled[:, 0, :]
return pooled, unpooled, output
return pooled, unpooled, output.hidden_states


def compute_mlm_loss(logits, masked_lm_labels):
Expand Down
4 changes: 3 additions & 1 deletion jiant/tasks/lib/ropes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def to_feature_list(
# (This may not apply for future added models that don't start with a CLS token,
# such as XLNet/GPT-2)
sequence_added_tokens = 1
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
sequence_pair_added_tokens = (
tokenizer.model_max_length - tokenizer.model_max_length_sentences_pair
)

span_doc_tokens = all_doc_tokens
while len(spans) * doc_stride < len(all_doc_tokens):
Expand Down
10 changes: 6 additions & 4 deletions jiant/tasks/lib/templates/squad_style/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import Union, List, Dict, Optional

from transformers.tokenization_bert import whitespace_tokenize
from transformers.models.bert.tokenization_bert import whitespace_tokenize

from jiant.tasks.lib.templates.squad_style import utils as squad_utils
from jiant.shared.constants import PHASE
Expand Down Expand Up @@ -144,11 +144,13 @@ def to_feature_list(
# in the way they compute mask of added tokens.
tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
sequence_added_tokens = (
tokenizer.max_len - tokenizer.max_len_single_sentence + 1
tokenizer.model_max_length - tokenizer.model_max_length_single_sentence + 1
if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
else tokenizer.max_len - tokenizer.max_len_single_sentence
else tokenizer.model_max_length - tokenizer.model_max_length_single_sentence
)
sequence_pair_added_tokens = (
tokenizer.model_max_length - tokenizer.model_max_length_sentences_pair
)
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair

span_doc_tokens = all_doc_tokens
while len(spans) * doc_stride < len(all_doc_tokens):
Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/templates/squad_style/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from typing import List, Dict

from transformers.tokenization_bert import BasicTokenizer
from transformers.models.bert.tokenization_bert import BasicTokenizer
from jiant.utils.display import maybe_tqdm


Expand Down
32 changes: 0 additions & 32 deletions jiant/utils/transformer_utils.py

This file was deleted.

4 changes: 2 additions & 2 deletions requirements-no-torch.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ seqeval==0.0.12
scikit-learn==0.22.2.post1
scipy==1.4.1
sentencepiece==0.1.86
tokenizers==0.8.1.rc2
tokenizers==0.9.4
tqdm==4.46.0
transformers==3.1.0
transformers==4.3.3
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@
"scikit-learn == 0.22.2.post1",
"scipy == 1.4.1",
"sentencepiece == 0.1.86",
"tokenizers == 0.8.1.rc2",
"tokenizers == 0.9.4",
"torch >= 1.5.0",
"tqdm == 4.46.0",
"transformers == 3.1.0",
"transformers == 4.3.3",
"torchvision == 0.6.0",
],
extras_require=extras,
Expand Down

0 comments on commit b2cfb2a

Please sign in to comment.