From 874dc065a97175195554aac53f19a3af0893fa1e Mon Sep 17 00:00:00 2001 From: Pengcheng He Date: Thu, 4 Feb 2021 18:45:26 -0500 Subject: [PATCH 1/8] Integrate DeBERTa v2(the 1.5B model surpassed human performance on SuperGLUE); Add DeBERTa v2 900M,1.5B models; --- docs/source/pretrained_models.rst | 19 +- .../models/deberta/configuration_deberta.py | 8 + .../models/deberta/gpt2_tokenizer.py | 380 ++++++++++++++ .../models/deberta/modeling_deberta.py | 377 +++++++++---- .../models/deberta/spm_tokenizer.py | 277 ++++++++++ .../models/deberta/tokenization_deberta.py | 494 ++---------------- 6 files changed, 995 insertions(+), 560 deletions(-) create mode 100644 src/transformers/models/deberta/gpt2_tokenizer.py create mode 100644 src/transformers/models/deberta/spm_tokenizer.py diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index d213267197771e..c1a71ad35d8bb5 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -443,15 +443,30 @@ For the full list, refer to `https://huggingface.co/models `__) | +--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| DeBERTa | ``microsoft/deberta-base`` | | 12-layer, 768-hidden, 12-heads, ~125M parameters | +| DeBERTa | ``microsoft/deberta-base`` | | 12-layer, 768-hidden, 12-heads, ~140M parameters | | | | | DeBERTa using the BERT-base architecture | | | | | | | | (see `details `__) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``microsoft/deberta-large`` | | 24-layer, 1024-hidden, 16-heads, ~390M parameters | +| | ``microsoft/deberta-large`` | | 24-layer, 1024-hidden, 16-heads, ~400M parameters | | | | | DeBERTa using the BERT-large architecture | | | | | | | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``microsoft/deberta-xlarge`` | | 48-layer, 1024-hidden, 16-heads, ~750M parameters | +| | | | DeBERTa XLarge with similar BERT architecture | +| | | | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``microsoft/deberta-xlarge-v2`` | | 24-layer, 1536-hidden, 24-heads, ~900M parameters | +| | | | DeBERTa XLarge V2 with similar BERT architecture | +| | | | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``microsoft/deberta-xxlarge-v2`` | | 48-layer, 1536-hidden, 24-heads, ~1.5B parameters | +| | | | DeBERTa XXLarge V2 with similar BERT architecture | +| | | | +| | | (see `details `__) | +--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | SqueezeBERT | ``squeezebert/squeezebert-uncased`` | | 12-layer, 768-hidden, 12-heads, 51M parameters, 4.3x faster than bert-base-uncased on a smartphone. | | | | | SqueezeBERT architecture pretrained from scratch on masked language model (MLM) and sentence order prediction (SOP) tasks. | diff --git a/src/transformers/models/deberta/configuration_deberta.py b/src/transformers/models/deberta/configuration_deberta.py index 25dd39cade87d4..e2f92abe1a54ad 100644 --- a/src/transformers/models/deberta/configuration_deberta.py +++ b/src/transformers/models/deberta/configuration_deberta.py @@ -23,6 +23,14 @@ DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { "microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/config.json", "microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/config.json", + "microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/config.json", + "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/config.json", + "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/config.json", + "microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/config.json", + "microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/config.json", + "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/config.json", + "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/config.json", + "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/config.json", } diff --git a/src/transformers/models/deberta/gpt2_tokenizer.py b/src/transformers/models/deberta/gpt2_tokenizer.py new file mode 100644 index 00000000000000..88e12fc07e2ec7 --- /dev/null +++ b/src/transformers/models/deberta/gpt2_tokenizer.py @@ -0,0 +1,380 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPT2 Tokenization class for model DeBERTa.""" + +import os +import unicodedata +from functools import lru_cache + + +try: + import regex as re +except ImportError: + raise ImportError("Please install regex with: pip install regex") + +___all__ = ["GPT2Tokenizer"] + +VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode + strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're + at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant + percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode + strings. And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def split_to_words(self, text): + return list(re.findall(self.pat, text)) + + def encode(self, text): + bpe_tokens = [] + for token in self.split_to_words(text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + +def get_encoder(encoder, vocab): + return Encoder( + encoder=encoder, + bpe_merges=vocab, + ) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +class GPT2Tokenizer(object): + """ + A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer + + Args: + vocab_file (:obj:`str`, optional): + The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases + `_, e.g. "bpe_encoder", default: `None`. + + If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file + is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used + in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2 + tokenizer and RoBERTa wrapped tokenizer are, + + - Special tokens, unlike `RoBERTa` which use ``, `` as the `start` token and `end` token of a + sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same + as `BERT`. + + - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, + `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 + + special_tokens (:obj:`list`, optional): + List of special tokens to be added to the end of the vocabulary. + """ + + def __init__(self, vocab_file=None, special_tokens=None, **kwargs): + import torch + + self.pad_token = "[PAD]" + self.sep_token = "[SEP]" + self.unk_token = "[UNK]" + self.cls_token = "[CLS]" + + self.symbols = [] + self.count = [] + self.indices = {} + self.pad_token_id = self.add_symbol(self.pad_token) + self.cls_token_id = self.add_symbol(self.cls_token) + self.sep_token_id = self.add_symbol(self.sep_token) + self.unk_token_id = self.add_symbol(self.unk_token) + + self.gpt2_encoder = torch.load(vocab_file) + self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"]) + for w, n in self.gpt2_encoder["dict_map"]: + self.add_symbol(w, n) + + self.mask_token = "[MASK]" + self.mask_id = self.add_symbol(self.mask_token) + self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] + if special_tokens is not None: + for t in special_tokens: + self.add_special_token(t) + + self.vocab = self.indices + self.ids_to_tokens = self.symbols + + def tokenize(self, text): + """ + Convert an input text to tokens. + + Args: + text (:obj:`str`): input text to be tokenized. + + Returns: + A list of byte tokens where each token represent the byte id in GPT2 byte dictionary + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + """ + bpe = self._encode(text) + + return [t for t in bpe.split(" ") if t] + + def convert_tokens_to_ids(self, tokens): + """ + Convert list of tokens to ids + + Args: + tokens (:obj:`list`): list of tokens + + Returns: + List of ids + """ + + return [self.vocab[t] for t in tokens] + + def convert_ids_to_tokens(self, ids): + """ + Convert list of ids to tokens + + Args: + ids (:obj:`list`): list of ids + + Returns: + List of tokens + """ + + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def split_to_words(self, text): + return self.bpe.split_to_words(text) + + def decode(self, tokens): + """ + Decode list of tokens to text strings + + Args: + tokens (:obj:`list`): list of tokens. + + Returns: + Text string corresponds to the input tokens. + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + >>> tokenizer.decode(tokens) + 'Hello world!' + """ + return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) + + def add_special_token(self, token): + """ + Adds a special token to the dictionary + + Args: + token (:obj:`str`): Tthe new token/word to be added to the vocabulary. + + Returns: + The id of new token in the vocabulary. + + """ + self.special_tokens.append(token) + return self.add_symbol(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + s = self._decode(token) + if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])): + return False + + return not s.startswith(" ") + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] + + def _encode(self, x: str) -> str: + return " ".join(map(str, self.bpe.encode(x))) + + def _decode(self, x: str) -> str: + return self.bpe.decode(map(int, x.split())) + + def add_symbol(self, word, n=1): + """ + Adds a word to the dictionary + + Args: + word (:obj:`str`): Tthe new token/word to be added to the vocabulary. + n (int, optional): The frequency of the word. + + Returns: + The id of the new word. + + """ + if word in self.indices: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def save_pretrained(self, path: str, filename_prefix: str = None): + import torch + + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + "-" + filename + full_path = os.path.join(path, filename) + torch.save(self.gpt2_encoder, full_path) + return (full_path,) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 2d38b38297aab8..574d3e7062fc60 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -16,11 +16,14 @@ import math from collections.abc import Sequence +from functools import lru_cache +import numpy as np import torch from packaging import version from torch import _softmax_backward_data, nn from torch.nn import CrossEntropyLoss +from torch.nn import LayerNorm as DebertaLayerNorm from ...activations import ACT2FN from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward @@ -44,6 +47,14 @@ DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "microsoft/deberta-base", "microsoft/deberta-large", + "microsoft/deberta-xlarge", + "microsoft/deberta-xlarge-v2", + "microsoft/deberta-xxlarge-v2", + "microsoft/deberta-base-mnli", + "microsoft/deberta-large-mnli", + "microsoft/deberta-xlarge-mnli", + "microsoft/deberta-xlarge-v2-mnli", + "microsoft/deberta-xxlarge-v2-mnli", ] @@ -214,24 +225,34 @@ def get_context(self): return self.drop_prob -class DebertaLayerNorm(nn.Module): - """LayerNorm module in the TF style (epsilon inside the square root).""" +def MaskedLayerNorm(layerNorm, input, mask=None): + """ + Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm + module. - def __init__(self, size, eps=1e-12): - super().__init__() - self.weight = nn.Parameter(torch.ones(size)) - self.bias = nn.Parameter(torch.zeros(size)) - self.variance_epsilon = eps + Args: + layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function + input (:obj:`torch.tensor`): The input tensor + mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` - def forward(self, hidden_states): - input_type = hidden_states.dtype - hidden_states = hidden_states.float() - mean = hidden_states.mean(-1, keepdim=True) - variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon) - hidden_states = hidden_states.to(input_type) - y = self.weight * hidden_states + self.bias - return y + Example:: + + # Create a tensor b x n x d + x = torch.randn([1,10,100]) + m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) + LayerNorm = DeBERTa.deberta.LayerNorm(100) + y = MaskedLayerNorm(LayerNorm, x, m) + + """ + output = layerNorm(input).to(input) + if mask is None: + return output + if mask.dim() != input.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(output.dtype) + return output * mask class DebertaSelfOutput(nn.Module): @@ -349,6 +370,29 @@ def forward( return layer_output +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = torch.nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask) + + return output_states + + class DebertaEncoder(nn.Module): """Modified BertEncoder with relative position bias support""" @@ -360,10 +404,25 @@ def __init__(self, config): self.max_relative_positions = getattr(config, "max_relative_positions", -1) if self.max_relative_positions < 1: self.max_relative_positions = config.max_position_embeddings - self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + kernel_size = getattr(config, "conv_kernel_size", 0) + self.with_conv = False + if kernel_size > 0: + self.with_conv = True + self.conv = ConvLayer(config) def get_rel_embedding(self): rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) return rel_embeddings def get_attention_mask(self, attention_mask): @@ -379,7 +438,9 @@ def get_attention_mask(self, attention_mask): def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): if self.relative_attention and relative_pos is None: q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) - relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) return relative_pos def forward( @@ -392,6 +453,10 @@ def forward( relative_pos=None, return_dict=True, ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() attention_mask = self.get_attention_mask(attention_mask) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) @@ -408,7 +473,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - hidden_states = layer_module( + output_states = layer_module( next_kv, attention_mask, output_attentions, @@ -417,29 +482,42 @@ def forward( rel_embeddings=rel_embeddings, ) if output_attentions: - hidden_states, att_m = hidden_states + output_states, att_m = output_states + + if i == 0 and self.with_conv: + output_states = self.conv(hidden_states, output_states, input_mask) if query_states is not None: - query_states = hidden_states + query_states = output_states if isinstance(hidden_states, Sequence): next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None else: - next_kv = hidden_states + next_kv = output_states if output_attentions: all_attentions = all_attentions + (att_m,) if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (output_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions ) -def build_relative_position(query_size, key_size, device): +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = np.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) + log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid + bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + return bucket_pos + + +@lru_cache(maxsize=128) +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): """ Build relative position according to the query and key @@ -450,15 +528,19 @@ def build_relative_position(query_size, key_size, device): Args: query_size (int): the length of query key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maxium allowed absolute positoin Return: :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] """ - - q_ids = torch.arange(query_size, dtype=torch.long, device=device) - k_ids = torch.arange(key_size, dtype=torch.long, device=device) - rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1) + q_ids = np.arange(0, query_size) + k_ids = np.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids @@ -498,39 +580,41 @@ def __init__(self, config): "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + _attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) - self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) - self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) - self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type self.relative_attention = getattr(config, "relative_attention", False) - self.talking_head = getattr(config, "talking_head", False) - - if self.talking_head: - self.head_logits_proj = torch.nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) - self.head_weights_proj = torch.nn.Linear( - config.num_attention_heads, config.num_attention_heads, bias=False - ) if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) self.max_relative_positions = getattr(config, "max_relative_positions", -1) if self.max_relative_positions < 1: self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + self.pos_dropout = StableDropout(config.hidden_dropout_prob) - if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: - self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False) - if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size) + if not self.share_att_key: + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = StableDropout(config.attention_probs_dropout_prob) + self._register_load_state_dict_pre_hook(self._pre_load_hook) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) def forward( self, @@ -571,51 +655,46 @@ def forward( """ if query_states is None: - qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) - query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) - else: - - def linear(w, b, x): - if b is not None: - return torch.matmul(x, w.t()) + b.t() - else: - return torch.matmul(x, w.t()) # + b.t() - - ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) - qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] - qkvb = [None] * 3 - - q = linear(qkvw[0], qkvb[0], query_states) - k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] - query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] - - query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) - value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. - scale_factor = 1 + len(self.pos_att_type) + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + if "p2p" in self.pos_att_type: + scale_factor += 1 scale = math.sqrt(query_layer.size(-1) * scale_factor) - query_layer = query_layer / scale - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) - rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) if rel_att is not None: attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) # bxhxlxd - if self.talking_head: - attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) - attention_probs = self.dropout(attention_probs) - if self.talking_head: - attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + _attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(_attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(*new_context_layer_shape) if return_att: @@ -623,61 +702,147 @@ def linear(w, b, x): else: return context_layer - def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): if relative_pos is None: q = query_layer.size(-2) - relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: relative_pos = relative_pos.unsqueeze(1) # bxhxqxk elif relative_pos.dim() != 4: - raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + raise ValueError(f"Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") - att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions) + att_span = self.pos_ebd_size relative_pos = relative_pos.long().to(query_layer.device) - rel_embeddings = rel_embeddings[ - self.max_relative_positions - att_span : self.max_relative_positions + att_span, : - ].unsqueeze(0) - if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: - pos_key_layer = self.pos_proj(rel_embeddings) - pos_key_layer = self.transpose_for_scores(pos_key_layer) - if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - pos_query_layer = self.pos_q_proj(rel_embeddings) - pos_query_layer = self.transpose_for_scores(pos_query_layer) + rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze( + 0 + ) # .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + else: + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) score = 0 # content->position if "c2p" in self.pos_att_type: - c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) + scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) - c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)) - score += c2p_att + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale # position->content if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) - if query_layer.size(-2) != key_layer.size(-2): - r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) + scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) else: r_pos = relative_pos + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) if query_layer.size(-2) != key_layer.size(-2): pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) if "p2c" in self.pos_att_type: - p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.gather( - p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), ).transpose(-1, -2) if query_layer.size(-2) != key_layer.size(-2): - p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer)) - score += p2c_att + p2c_att = torch.gather( + p2c_att, + dim=-2, + index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))), + ) + score += p2c_att / scale + + # position->position + if "p2p" in self.pos_att_type: + pos_query = pos_query_layer[:, :, att_span:, :] + p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2)) + p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:]) + if query_layer.size(-2) != key_layer.size(-2): + p2p_att = torch.gather( + p2p_att, + dim=-2, + index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))), + ) + p2p_att = torch.gather( + p2p_att, + dim=-1, + index=c2p_pos.expand( + [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] + ), + ) + score += p2p_att return score + def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + self_state = self.state_dict() + if ((prefix + "query_proj.weight") not in state_dict) and ((prefix + "in_proj.weight") in state_dict): + v1_proj = state_dict[prefix + "in_proj.weight"] + v1_proj = v1_proj.unsqueeze(0).reshape(self.num_attention_heads, -1, v1_proj.size(-1)) + q, k, v = v1_proj.chunk(3, dim=1) + state_dict[prefix + "query_proj.weight"] = q.reshape(-1, v1_proj.size(-1)) + state_dict[prefix + "key_proj.weight"] = k.reshape(-1, v1_proj.size(-1)) + state_dict[prefix + "key_proj.bias"] = self_state["key_proj.bias"] + state_dict[prefix + "value_proj.weight"] = v.reshape(-1, v1_proj.size(-1)) + v1_query_bias = state_dict[prefix + "q_bias"] + state_dict[prefix + "query_proj.bias"] = v1_query_bias + v1_value_bias = state_dict[prefix + "v_bias"] + state_dict[prefix + "value_proj.bias"] = v1_value_bias + + v1_pos_key_proj = state_dict[prefix + "pos_proj.weight"] + state_dict[prefix + "pos_key_proj.weight"] = v1_pos_key_proj + v1_pos_query_proj = state_dict[prefix + "pos_q_proj.weight"] + state_dict[prefix + "pos_query_proj.weight"] = v1_pos_query_proj + v1_pos_query_proj_bias = state_dict[prefix + "pos_q_proj.bias"] + state_dict[prefix + "pos_query_proj.bias"] = v1_pos_query_proj_bias + state_dict[prefix + "pos_key_proj.bias"] = self_state["pos_key_proj.bias"] + + del state_dict[prefix + "in_proj.weight"] + del state_dict[prefix + "q_bias"] + del state_dict[prefix + "v_bias"] + del state_dict[prefix + "pos_proj.weight"] + del state_dict[prefix + "pos_q_proj.weight"] + del state_dict[prefix + "pos_q_proj.bias"] + class DebertaEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" diff --git a/src/transformers/models/deberta/spm_tokenizer.py b/src/transformers/models/deberta/spm_tokenizer.py new file mode 100644 index 00000000000000..d7b94fa7c40928 --- /dev/null +++ b/src/transformers/models/deberta/spm_tokenizer.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for model DeBERTa.""" + +import os +import unicodedata + +import sentencepiece as sp +import six + + +VOCAB_FILES_NAMES = {"vocab_file": "spm.model"} + +__all__ = ["SPMTokenizer"] + + +class SPMTokenizer: + def __init__( + self, vocab_file, do_lower_case=False, special_tokens=None, bpe_dropout=0, split_by_punct=False, **kwargs + ): + self.split_by_punct = split_by_punct + spm = sp.SentencePieceProcessor() + assert os.path.exists(vocab_file) + spm.load(vocab_file) + bpe_vocab_size = spm.GetPieceSize() + # Token map + # 0+1 + # 1+1 + # 2+1 + self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)} + self.id_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)] + # self.vocab['[PAD]'] = 0 + # self.vocab['[CLS]'] = 1 + # self.vocab['[SEP]'] = 2 + # self.vocab['[UNK]'] = 3 + + _special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] + self.special_tokens = [] + if special_tokens is not None: + _special_tokens.extend(special_tokens) + for t in _special_tokens: + self.add_special_token(t) + + self.spm = spm + + def tokenize(self, text): + pieces = self._encode_as_pieces(text) + + def _norm(x): + if x not in self.vocab or x == "": + return "[UNK]" + else: + return x + + pieces = [_norm(p) for p in pieces] + return pieces + + def convert_tokens_to_ids(self, tokens): + return [self.vocab[t] if t in self.vocab else 1 for t in tokens] + + def convert_ids_to_tokens(self, ids): + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def decode(self, tokens, start=-1, end=-1, raw_text=None): + if raw_text is None: + return self.spm.decode_pieces([t for t in tokens if t not in self.special_tokens]) + else: + words = self.split_to_words(raw_text) + word_tokens = [self.tokenize(w) for w in words] + token2words = [0] * len(tokens) + tid = 0 + for i, w in enumerate(word_tokens): + for k, t in enumerate(w): + token2words[tid] = i + tid += 1 + word_start = token2words[start] + word_end = token2words[end] if end < len(tokens) else len(words) + text = "".join(words[word_start:word_end]) + return text + + def add_special_token(self, token): + if token not in self.special_tokens: + self.special_tokens.append(token) + if token not in self.vocab: + self.vocab[token] = len(self.vocab) + self.id_to_tokens.append(token) + return self.id(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + if ( + len(token) == 1 + and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0])) + ) or token in self.special_tokens: + return False + + word_start = b"\xe2\x96\x81".decode("utf-8") + return not token.startswith(word_start) + + def pad(self): + return "[PAD]" + + def bos(self): + return "[CLS]" + + def eos(self): + return "[SEP]" + + def unk(self): + return "[UNK]" + + def mask(self): + return "[MASK]" + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] if sym in self.vocab else 1 + + def _encode_as_pieces(self, text): + text = convert_to_unicode(text) + if self.split_by_punct: + words = self._run_split_on_punc(text) + pieces = [self.spm.encode_as_pieces(w) for w in words] + return [p for w in pieces for p in w] + else: + return self.spm.encode_as_pieces(text) + + def split_to_words(self, text): + pieces = self._encode_as_pieces(text) + word_start = b"\xe2\x96\x81".decode("utf-8") + words = [] + offset = 0 + prev_end = 0 + for i, p in enumerate(pieces): + if p.startswith(word_start): + if offset > prev_end: + words.append(text[prev_end:offset]) + prev_end = offset + w = p.replace(word_start, "") + else: + w = p + try: + s = text.index(w, offset) + pn = "" + k = i + 1 + while k < len(pieces): + pn = pieces[k].replace(word_start, "") + if len(pn) > 0: + break + k += 1 + + if len(pn) > 0 and pn in text[offset:s]: + offset = offset + 1 + else: + offset = s + len(w) + except Exception: + offset = offset + 1 + + if prev_end < offset: + words.append(text[prev_end:offset]) + + return words + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def save_pretrained(self, path: str, filename_prefix: str = None): + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + "-" + filename + full_path = os.path.join(path, filename) + with open(full_path, "wb") as fs: + fs.write(self.spm.serialized_model_proto()) + return (full_path,) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") diff --git a/src/transformers/models/deberta/tokenization_deberta.py b/src/transformers/models/deberta/tokenization_deberta.py index 4edba5fd599944..f3ed41fe3d991a 100644 --- a/src/transformers/models/deberta/tokenization_deberta.py +++ b/src/transformers/models/deberta/tokenization_deberta.py @@ -15,476 +15,61 @@ """ Tokenization class for model DeBERTa.""" import os -import pathlib -import random -import unicodedata -from functools import lru_cache from typing import Optional, Tuple -from zipfile import ZipFile - -import tqdm - -import requests +from ...file_utils import is_sentencepiece_available from ...tokenization_utils import PreTrainedTokenizer -from ...utils import logging - +from .gpt2_tokenizer import GPT2Tokenizer -try: - import regex as re -except ImportError: - raise ImportError("Please install regex with: pip install regex") +if is_sentencepiece_available(): + from .spm_tokenizer import SPMTokenizer -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/bpe_encoder.bin", "microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/bpe_encoder.bin", + "microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/bpe_encoder.bin", + "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/spm.model", + "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/spm.model", + "microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/bpe_encoder.bin", + "microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/bpe_encoder.bin", + "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/bpe_encoder.bin", + "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/spm.model", + "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/spm.model", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { "microsoft/deberta-base": 512, "microsoft/deberta-large": 512, + "microsoft/deberta-xlarge": 512, + "microsoft/deberta-xlarge-v2": 512, + "microsoft/deberta-xxlarge-v2": 512, + "microsoft/deberta-base-mnli": 512, + "microsoft/deberta-large-mnli": 512, + "microsoft/deberta-xlarge-mnli": 512, + "microsoft/deberta-xlarge-v2-mnli": 512, + "microsoft/deberta-xxlarge-v2-mnli": 512, } PRETRAINED_INIT_CONFIGURATION = { - "microsoft/deberta-base": {"do_lower_case": False}, - "microsoft/deberta-large": {"do_lower_case": False}, + "microsoft/deberta-base": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-large": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-xlarge": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-xlarge-v2": {"do_lower_case": False, "vocab_type": "spm"}, + "microsoft/deberta-xxlarge-v2": {"do_lower_case": False, "vocab_type": "spm"}, + "microsoft/deberta-base-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-large-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-xlarge-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-xlarge-v2-mnli": {"do_lower_case": False, "vocab_type": "spm"}, + "microsoft/deberta-xxlarge-v2-mnli": {"do_lower_case": False, "vocab_type": "spm"}, } __all__ = ["DebertaTokenizer"] - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode - strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're - at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant - percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode - strings. And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2 ** 8): - if b not in bs: - bs.append(b) - cs.append(2 ** 8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """ - Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length - strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -class Encoder: - def __init__(self, encoder, bpe_merges, errors="replace"): - self.encoder = encoder - self.decoder = {v: k for k, v in self.encoder.items()} - self.errors = errors # how to handle errors in decoding - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges)))) - self.cache = {} - self.random = random.Random(0) - - # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions - self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token) - pairs = get_pairs(word) - - if not pairs: - return token - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except Exception: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = " ".join(word) - self.cache[token] = word - return word - - def split_to_words(self, text): - return list(re.findall(self.pat, text)) - - def encode(self, text): - bpe_tokens = [] - for token in self.split_to_words(text): - token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) - return bpe_tokens - - def decode(self, tokens): - text = "".join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) - return text - - -def get_encoder(encoder, vocab): - return Encoder( - encoder=encoder, - bpe_merges=vocab, - ) - - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat.startswith("C"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False - - -def download_asset(name, tag=None, no_cache=False, cache_dir=None): - _tag = tag - if _tag is None: - _tag = "latest" - if not cache_dir: - cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/") - os.makedirs(cache_dir, exist_ok=True) - output = os.path.join(cache_dir, name) - if os.path.exists(output) and (not no_cache): - return output - - repo = "https://api.github.com/repos/microsoft/DeBERTa/releases" - releases = requests.get(repo).json() - if tag and tag != "latest": - release = [r for r in releases if r["name"].lower() == tag.lower()] - if len(release) != 1: - raise Exception(f"{tag} can't be found in the repository.") - else: - release = releases[0] - asset = [s for s in release["assets"] if s["name"].lower() == name.lower()] - if len(asset) != 1: - raise Exception(f"{name} can't be found in the release.") - url = asset[0]["url"] - headers = {} - headers["Accept"] = "application/octet-stream" - resp = requests.get(url, stream=True, headers=headers) - if resp.status_code != 200: - raise Exception(f"Request for {url} return {resp.status_code}, {resp.text}") - try: - with open(output, "wb") as fs: - progress = tqdm( - total=int(resp.headers["Content-Length"]) if "Content-Length" in resp.headers else -1, - ncols=80, - desc=f"Downloading {name}", - ) - for c in resp.iter_content(chunk_size=1024 * 1024): - fs.write(c) - progress.update(len(c)) - progress.close() - except Exception: - os.remove(output) - raise - - return output - - -def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None): - import torch - - if name is None: - name = "bpe_encoder" - - model_path = name - if model_path and (not os.path.exists(model_path)) and not (("/" in model_path) or ("\\" in model_path)): - _tag = tag - if _tag is None: - _tag = "latest" - if not cache_dir: - cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/") - os.makedirs(cache_dir, exist_ok=True) - out_dir = os.path.join(cache_dir, name) - model_path = os.path.join(out_dir, "bpe_encoder.bin") - if (not os.path.exists(model_path)) or no_cache: - asset = download_asset(name + ".zip", tag=tag, no_cache=no_cache, cache_dir=cache_dir) - with ZipFile(asset, "r") as zipf: - for zip_info in zipf.infolist(): - if zip_info.filename[-1] == "/": - continue - zip_info.filename = os.path.basename(zip_info.filename) - zipf.extract(zip_info, out_dir) - elif not model_path: - return None, None - - encoder_state = torch.load(model_path) - return encoder_state - - -class GPT2Tokenizer(object): - """ - A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer - - Args: - vocab_file (:obj:`str`, optional): - The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases - `_, e.g. "bpe_encoder", default: `None`. - - If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file - is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used - in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2 - tokenizer and RoBERTa wrapped tokenizer are, - - - Special tokens, unlike `RoBERTa` which use ``, `` as the `start` token and `end` token of a - sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same - as `BERT`. - - - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, - `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 - - special_tokens (:obj:`list`, optional): - List of special tokens to be added to the end of the vocabulary. - """ - - def __init__(self, vocab_file=None, special_tokens=None): - self.pad_token = "[PAD]" - self.sep_token = "[SEP]" - self.unk_token = "[UNK]" - self.cls_token = "[CLS]" - - self.symbols = [] - self.count = [] - self.indices = {} - self.pad_token_id = self.add_symbol(self.pad_token) - self.cls_token_id = self.add_symbol(self.cls_token) - self.sep_token_id = self.add_symbol(self.sep_token) - self.unk_token_id = self.add_symbol(self.unk_token) - - self.gpt2_encoder = load_vocab(vocab_file) - self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"]) - for w, n in self.gpt2_encoder["dict_map"]: - self.add_symbol(w, n) - - self.mask_token = "[MASK]" - self.mask_id = self.add_symbol(self.mask_token) - self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] - if special_tokens is not None: - for t in special_tokens: - self.add_special_token(t) - - self.vocab = self.indices - self.ids_to_tokens = self.symbols - - def tokenize(self, text): - """ - Convert an input text to tokens. - - Args: - text (:obj:`str`): input text to be tokenized. - - Returns: - A list of byte tokens where each token represent the byte id in GPT2 byte dictionary - - Example:: - >>> tokenizer = GPT2Tokenizer() - >>> text = "Hello world!" - >>> tokens = tokenizer.tokenize(text) - >>> print(tokens) - ['15496', '995', '0'] - """ - bpe = self._encode(text) - - return [t for t in bpe.split(" ") if t] - - def convert_tokens_to_ids(self, tokens): - """ - Convert list of tokens to ids - - Args: - tokens (:obj:`list`): list of tokens - - Returns: - List of ids - """ - - return [self.vocab[t] for t in tokens] - - def convert_ids_to_tokens(self, ids): - """ - Convert list of ids to tokens - - Args: - ids (:obj:`list`): list of ids - - Returns: - List of tokens - """ - - tokens = [] - for i in ids: - tokens.append(self.ids_to_tokens[i]) - return tokens - - def split_to_words(self, text): - return self.bpe.split_to_words(text) - - def decode(self, tokens): - """ - Decode list of tokens to text strings - - Args: - tokens (:obj:`list`): list of tokens. - - Returns: - Text string corresponds to the input tokens. - - Example:: - >>> tokenizer = GPT2Tokenizer() - >>> text = "Hello world!" - >>> tokens = tokenizer.tokenize(text) - >>> print(tokens) - ['15496', '995', '0'] - >>> tokenizer.decode(tokens) - 'Hello world!' - """ - return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) - - def add_special_token(self, token): - """ - Adds a special token to the dictionary - - Args: - token (:obj:`str`): Tthe new token/word to be added to the vocabulary. - - Returns: - The id of new token in the vocabulary. - - """ - self.special_tokens.append(token) - return self.add_symbol(token) - - def part_of_whole_word(self, token, is_bos=False): - if is_bos: - return True - s = self._decode(token) - if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])): - return False - - return not s.startswith(" ") - - def sym(self, id): - return self.ids_to_tokens[id] - - def id(self, sym): - return self.vocab[sym] - - def _encode(self, x: str) -> str: - return " ".join(map(str, self.bpe.encode(x))) - - def _decode(self, x: str) -> str: - return self.bpe.decode(map(int, x.split())) - - def add_symbol(self, word, n=1): - """ - Adds a word to the dictionary - - Args: - word (:obj:`str`): Tthe new token/word to be added to the vocabulary. - n (int, optional): The frequency of the word. - - Returns: - The id of the new word. - - """ - if word in self.indices: - idx = self.indices[word] - self.count[idx] = self.count[idx] + n - return idx - else: - idx = len(self.symbols) - self.indices[word] = idx - self.symbols.append(word) - self.count.append(n) - return idx - - def save_pretrained(self, path: str, filename_prefix: str = None): - import torch - - filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] - if filename_prefix is not None: - filename = filename_prefix + "-" + filename - full_path = os.path.join(path, filename) - torch.save(self.gpt2_encoder, full_path) - return (full_path,) +VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} class DebertaTokenizer(PreTrainedTokenizer): @@ -522,6 +107,7 @@ def __init__( self, vocab_file, do_lower_case=False, + vocab_type="gpt2", unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", @@ -531,6 +117,7 @@ def __init__( ): super().__init__( do_lower_case=do_lower_case, + vocab_type="gpt2", unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, @@ -545,7 +132,10 @@ def __init__( "model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) ) self.do_lower_case = do_lower_case - self.gpt2_tokenizer = GPT2Tokenizer(vocab_file) + if vocab_type.lower() == "gpt2": + self._tokenizer = GPT2Tokenizer(vocab_file, **kwargs) + else: + self._tokenizer = SPMTokenizer(vocab_file, **kwargs) @property def vocab_size(self): @@ -553,7 +143,7 @@ def vocab_size(self): @property def vocab(self): - return self.gpt2_tokenizer.vocab + return self._tokenizer.vocab def get_vocab(self): vocab = self.vocab.copy() @@ -564,7 +154,7 @@ def _tokenize(self, text): """Take as input a string and return a list of strings (tokens) for words/sub-words""" if self.do_lower_case: text = text.lower() - return self.gpt2_tokenizer.tokenize(text) + return self._tokenizer.tokenize(text) def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ @@ -572,11 +162,11 @@ def _convert_token_to_id(self, token): def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" - return self.gpt2_tokenizer.sym(index) if index < self.vocab_size else self.unk_token + return self._tokenizer.sym(index) if index < self.vocab_size else self.unk_token def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ - return self.gpt2_tokenizer.decode(tokens) + return self._tokenizer.decode(tokens) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ @@ -671,4 +261,4 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): return (text, kwargs) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - return self.gpt2_tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix) + return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix) From 75bc95a49d2531f3f6251b821430de6c0a4b8150 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 9 Feb 2021 16:42:41 +0100 Subject: [PATCH 2/8] DeBERTa-v2 --- README.md | 1 + docs/source/index.rst | 66 +- docs/source/model_doc/deberta_v2.rst | 118 ++ src/transformers/__init__.py | 22 + .../models/auto/configuration_auto.py | 4 + src/transformers/models/auto/modeling_auto.py | 14 + .../models/auto/tokenization_auto.py | 4 + .../models/deberta/configuration_deberta.py | 4 - .../models/deberta/gpt2_tokenizer.py | 380 ----- .../models/deberta/modeling_deberta.py | 434 ++--- .../models/deberta/spm_tokenizer.py | 277 --- .../models/deberta/tokenization_deberta.py | 486 +++++- .../models/deberta_v2/__init__.py | 72 + .../deberta_v2/configuration_deberta_v2.py | 138 ++ .../models/deberta_v2/modeling_deberta_v2.py | 1514 +++++++++++++++++ .../deberta_v2/tokenization_deberta_v2.py | 491 ++++++ src/transformers/utils/dummy_pt_objects.py | 57 + tests/test_modeling_deberta_v2.py | 290 ++++ tests/test_tokenization_deberta_v2.py | 192 +++ 19 files changed, 3537 insertions(+), 1027 deletions(-) create mode 100644 docs/source/model_doc/deberta_v2.rst delete mode 100644 src/transformers/models/deberta/gpt2_tokenizer.py delete mode 100644 src/transformers/models/deberta/spm_tokenizer.py create mode 100644 src/transformers/models/deberta_v2/__init__.py create mode 100644 src/transformers/models/deberta_v2/configuration_deberta_v2.py create mode 100644 src/transformers/models/deberta_v2/modeling_deberta_v2.py create mode 100644 src/transformers/models/deberta_v2/tokenization_deberta_v2.py create mode 100644 tests/test_modeling_deberta_v2.py create mode 100644 tests/test_tokenization_deberta_v2.py diff --git a/README.md b/README.md index cae90de239e3d5..f3014e8900b005 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[ConvBERT](https://huggingface.co/transformers/model_doc/convbert.html)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. 1. **[CTRL](https://huggingface.co/transformers/model_doc/ctrl.html)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 1. **[DeBERTa](https://huggingface.co/transformers/model_doc/deberta.html)** (from Microsoft Research) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. +1. **[DeBERTa-v2](https://huggingface.co/transformers/model_doc/deberta_v2.html)** (from Microsoft Research) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. 1. **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/master/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation) and a German version of DistilBERT. 1. **[DPR](https://huggingface.co/transformers/model_doc/dpr.html)** (from Facebook) released with the paper [Dense Passage Retrieval diff --git a/docs/source/index.rst b/docs/source/index.rst index 63ddcbba5c7ee4..e6cfd2beaaf582 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -117,95 +117,98 @@ and conversion utilities for the following models: 12. :doc:`DeBERTa ` (from Microsoft Research) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -13. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale +13. :doc:`DeBERTa-v2 ` (from Microsoft Research) released with the paper `DeBERTa: + Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong + Liu, Jianfeng Gao, Weizhu Chen. +14. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation `__ by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -14. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a +15. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter `__ by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into `DistilGPT2 `__, RoBERTa into `DistilRoBERTa `__, Multilingual BERT into `DistilmBERT `__ and a German version of DistilBERT. -15. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain +16. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain Question Answering `__ by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. -16. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: +17. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: Pre-training text encoders as discriminators rather than generators `__ by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. -17. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model +18. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model Pre-training for French `__ by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -18. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: +19. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing `__ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. -19. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative +20. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative Pre-Training `__ by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. -20. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask +21. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -21. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +22. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -22. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +23. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -23. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +24. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -24. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +25. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -25. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +26. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -26. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +27. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -27. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +28. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -28. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +29. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -29. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +30. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -30. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +31. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -31. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +32. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -32. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +33. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -33. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +34. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -34. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +35. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -35. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +36. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -36. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +37. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -37. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +38. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -38. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +39. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -39. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +40. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -40. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +41. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -41. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +42. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -42. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +43. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. @@ -246,6 +249,8 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | DeBERTa | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| DeBERTa-v2 | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -389,6 +394,7 @@ TensorFlow and/or Flax. model_doc/convbert model_doc/ctrl model_doc/deberta + model_doc/deberta_v2 model_doc/dialogpt model_doc/distilbert model_doc/dpr diff --git a/docs/source/model_doc/deberta_v2.rst b/docs/source/model_doc/deberta_v2.rst new file mode 100644 index 00000000000000..06d57aa4e8439a --- /dev/null +++ b/docs/source/model_doc/deberta_v2.rst @@ -0,0 +1,118 @@ +.. + Copyright 2020 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +DeBERTa-v2 +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention +`__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen It is based on Google's +BERT model released in 2018 and Facebook's RoBERTa model released in 2019. + +It builds on RoBERTa with disentangled attention and enhanced mask decoder training with half of the data used in +RoBERTa. + +The abstract from the paper is the following: + +*Recent progress in pre-trained neural language models has significantly improved the performance of many natural +language processing (NLP) tasks. In this paper we propose a new model architecture DeBERTa (Decoding-enhanced BERT with +disentangled attention) that improves the BERT and RoBERTa models using two novel techniques. The first is the +disentangled attention mechanism, where each word is represented using two vectors that encode its content and +position, respectively, and the attention weights among words are computed using disentangled matrices on their +contents and relative positions. Second, an enhanced mask decoder is used to replace the output softmax layer to +predict the masked tokens for model pretraining. We show that these two techniques significantly improve the efficiency +of model pretraining and performance of downstream tasks. Compared to RoBERTa-Large, a DeBERTa model trained on half of +the training data performs consistently better on a wide range of NLP tasks, achieving improvements on MNLI by +0.9% +(90.2% vs. 91.1%), on SQuAD v2.0 by +2.3% (88.4% vs. 90.7%) and RACE by +3.6% (83.2% vs. 86.8%). The DeBERTa code and +pre-trained models will be made publicly available at https://github.com/microsoft/DeBERTa.* + + +The following information is visible directly on the [original implementation +repository](https://github.com/microsoft/DeBERTa). DeBERTa v2 is the second version of the DeBERTa model. It includes +the 1.5B model used for the SuperGLUE single-model submission and achieving 89.9, versus human baseline 89.8. You can +find more details about this submission in the authors' +[blog](https://www.microsoft.com/en-us/research/blog/microsoft-deberta-surpasses-human-performance-on-the-superglue-benchmark/) + +New in v2: + +- **Vocabulary** In v2 the tokenizer is changed to use a new vocabulary of size 128K built from the training data. + Instead of a GPT2-based tokenizer, the tokenizer is now + [sentencepiece-based](https://github.com/google/sentencepiece) tokenizer. +- **nGiE(nGram Induced Input Encoding)** The DeBERTa-v2 model uses an additional convolution layer aside with the first + transformer layer to better learn the local dependency of input tokens. +- **Sharing position projection matrix with content projection matrix in attention layer** Based on previous + experiments, this can save parameters without affecting the performance. +- **Apply bucket to encode relative postions** The DeBERTa-v2 model uses log bucket to encode relative positions + similar to T5. +- **900M model & 1.5B model** Two additional model sizes are available: 900M and 1.5B, which significantly improves the + performance of downstream tasks. + +The original code can be found `here `__. + + +DebertaV2Config +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DebertaV2Config + :members: + + +DebertaV2Tokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DebertaV2Tokenizer + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +DebertaV2Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DebertaV2Model + :members: + + +DebertaV2PreTrainedModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DebertaV2PreTrainedModel + :members: + + +DebertaV2ForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DebertaV2ForMaskedLM + :members: + + +DebertaV2ForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DebertaV2ForSequenceClassification + :members: + + +DebertaV2ForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DebertaV2ForTokenClassification + :members: + + +DebertaV2ForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DebertaV2ForQuestionAnswering + :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 30a7b41940565d..62451be24247ce 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -157,6 +157,7 @@ "models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"], "models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], "models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"], + "models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config", "DebertaV2Tokenizer"], "models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"], "models.dpr": [ "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -515,6 +516,17 @@ "DebertaForQuestionAnswering", ] ) + _import_structure["models.deberta_v2"].extend( + [ + "DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST", + "DebertaV2ForSequenceClassification", + "DebertaV2Model", + "DebertaV2ForMaskedLM", + "DebertaV2PreTrainedModel", + "DebertaV2ForTokenClassification", + "DebertaV2ForQuestionAnswering", + ] + ) _import_structure["models.distilbert"].extend( [ "DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1287,6 +1299,7 @@ from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer + from .models.deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config, DebertaV2Tokenizer from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer from .models.dpr import ( DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -1604,6 +1617,15 @@ DebertaModel, DebertaPreTrainedModel, ) + from .models.deberta_v2 import ( + DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, + DebertaV2ForMaskedLM, + DebertaV2ForQuestionAnswering, + DebertaV2ForSequenceClassification, + DebertaV2ForTokenClassification, + DebertaV2Model, + DebertaV2PreTrainedModel, + ) from .models.distilbert import ( DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, DistilBertForMaskedLM, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 581520e98808ea..afb02fc36aaa0d 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -31,6 +31,7 @@ from ..convbert.configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig +from ..deberta_v2.configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig @@ -103,6 +104,7 @@ LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, + DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -138,6 +140,7 @@ ("reformer", ReformerConfig), ("longformer", LongformerConfig), ("roberta", RobertaConfig), + ("deberta_v2", DebertaV2Config), ("deberta", DebertaConfig), ("flaubert", FlaubertConfig), ("fsmt", FSMTConfig), @@ -200,6 +203,7 @@ ("funnel", "Funnel Transformer"), ("lxmert", "LXMERT"), ("deberta", "DeBERTa"), + ("deberta_v2", "DeBERTa-v2"), ("layoutlm", "LayoutLM"), ("dpr", "DPR"), ("rag", "RAG"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 99545895383478..64c527e1ec8abe 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -84,6 +84,13 @@ DebertaForTokenClassification, DebertaModel, ) +from ..deberta_v2.modeling_deberta_v2 import ( + DebertaV2ForMaskedLM, + DebertaV2ForQuestionAnswering, + DebertaV2ForSequenceClassification, + DebertaV2ForTokenClassification, + DebertaV2Model, +) from ..distilbert.modeling_distilbert import ( DistilBertForMaskedLM, DistilBertForMultipleChoice, @@ -254,6 +261,7 @@ ConvBertConfig, CTRLConfig, DebertaConfig, + DebertaV2Config, DistilBertConfig, DPRConfig, ElectraConfig, @@ -332,6 +340,7 @@ (LxmertConfig, LxmertModel), (BertGenerationConfig, BertGenerationEncoder), (DebertaConfig, DebertaModel), + (DebertaV2Config, DebertaV2Model), (DPRConfig, DPRQuestionEncoder), (XLMProphetNetConfig, XLMProphetNetModel), (ProphetNetConfig, ProphetNetModel), @@ -408,6 +417,7 @@ (MPNetConfig, MPNetForMaskedLM), (TapasConfig, TapasForMaskedLM), (DebertaConfig, DebertaForMaskedLM), + (DebertaV2Config, DebertaV2ForMaskedLM), ] ) @@ -465,6 +475,7 @@ (MPNetConfig, MPNetForMaskedLM), (TapasConfig, TapasForMaskedLM), (DebertaConfig, DebertaForMaskedLM), + (DebertaV2Config, DebertaV2ForMaskedLM), ] ) @@ -510,6 +521,7 @@ (ElectraConfig, ElectraForSequenceClassification), (FunnelConfig, FunnelForSequenceClassification), (DebertaConfig, DebertaForSequenceClassification), + (DebertaV2Config, DebertaV2ForSequenceClassification), (GPT2Config, GPT2ForSequenceClassification), (OpenAIGPTConfig, OpenAIGPTForSequenceClassification), (ReformerConfig, ReformerForSequenceClassification), @@ -545,6 +557,7 @@ (LxmertConfig, LxmertForQuestionAnswering), (MPNetConfig, MPNetForQuestionAnswering), (DebertaConfig, DebertaForQuestionAnswering), + (DebertaV2Config, DebertaV2ForQuestionAnswering), ] ) @@ -577,6 +590,7 @@ (FunnelConfig, FunnelForTokenClassification), (MPNetConfig, MPNetForTokenClassification), (DebertaConfig, DebertaForTokenClassification), + (DebertaV2Config, DebertaV2ForTokenClassification), ] ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 470e8ce8e860a8..3151c0e971d16c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -66,6 +66,7 @@ ConvBertConfig, CTRLConfig, DebertaConfig, + DebertaV2Config, DistilBertConfig, DPRConfig, ElectraConfig, @@ -108,6 +109,7 @@ from ..barthez.tokenization_barthez import BarthezTokenizer from ..bert_generation.tokenization_bert_generation import BertGenerationTokenizer from ..camembert.tokenization_camembert import CamembertTokenizer + from ..deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer from ..marian.tokenization_marian import MarianTokenizer from ..mbart.tokenization_mbart import MBartTokenizer from ..mt5 import MT5Tokenizer @@ -122,6 +124,7 @@ BarthezTokenizer = None BertGenerationTokenizer = None CamembertTokenizer = None + DebertaV2Tokenizer = None MarianTokenizer = None MBartTokenizer = None MT5Tokenizer = None @@ -233,6 +236,7 @@ (FSMTConfig, (FSMTTokenizer, None)), (BertGenerationConfig, (BertGenerationTokenizer, None)), (DebertaConfig, (DebertaTokenizer, None)), + (DebertaV2Config, (DebertaV2Tokenizer, None)), (RagConfig, (RagTokenizer, None)), (XLMProphetNetConfig, (XLMProphetNetTokenizer, None)), (ProphetNetConfig, (ProphetNetTokenizer, None)), diff --git a/src/transformers/models/deberta/configuration_deberta.py b/src/transformers/models/deberta/configuration_deberta.py index e2f92abe1a54ad..30a984f62005d3 100644 --- a/src/transformers/models/deberta/configuration_deberta.py +++ b/src/transformers/models/deberta/configuration_deberta.py @@ -24,13 +24,9 @@ "microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/config.json", "microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/config.json", "microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/config.json", - "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/config.json", - "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/config.json", "microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/config.json", "microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/config.json", "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/config.json", - "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/config.json", - "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/config.json", } diff --git a/src/transformers/models/deberta/gpt2_tokenizer.py b/src/transformers/models/deberta/gpt2_tokenizer.py deleted file mode 100644 index 88e12fc07e2ec7..00000000000000 --- a/src/transformers/models/deberta/gpt2_tokenizer.py +++ /dev/null @@ -1,380 +0,0 @@ -# coding=utf-8 -# Copyright 2020 Microsoft and the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" GPT2 Tokenization class for model DeBERTa.""" - -import os -import unicodedata -from functools import lru_cache - - -try: - import regex as re -except ImportError: - raise ImportError("Please install regex with: pip install regex") - -___all__ = ["GPT2Tokenizer"] - -VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode - strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're - at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant - percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode - strings. And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2 ** 8): - if b not in bs: - bs.append(b) - cs.append(2 ** 8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """ - Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length - strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -class Encoder: - def __init__(self, encoder, bpe_merges, errors="replace"): - self.encoder = encoder - self.decoder = {v: k for k, v in self.encoder.items()} - self.errors = errors # how to handle errors in decoding - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges)))) - self.cache = {} - - # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions - self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token) - pairs = get_pairs(word) - - if not pairs: - return token - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except Exception: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = " ".join(word) - self.cache[token] = word - return word - - def split_to_words(self, text): - return list(re.findall(self.pat, text)) - - def encode(self, text): - bpe_tokens = [] - for token in self.split_to_words(text): - token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) - return bpe_tokens - - def decode(self, tokens): - text = "".join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) - return text - - -def get_encoder(encoder, vocab): - return Encoder( - encoder=encoder, - bpe_merges=vocab, - ) - - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat.startswith("C"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False - - -class GPT2Tokenizer(object): - """ - A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer - - Args: - vocab_file (:obj:`str`, optional): - The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases - `_, e.g. "bpe_encoder", default: `None`. - - If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file - is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used - in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2 - tokenizer and RoBERTa wrapped tokenizer are, - - - Special tokens, unlike `RoBERTa` which use ``, `` as the `start` token and `end` token of a - sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same - as `BERT`. - - - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, - `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 - - special_tokens (:obj:`list`, optional): - List of special tokens to be added to the end of the vocabulary. - """ - - def __init__(self, vocab_file=None, special_tokens=None, **kwargs): - import torch - - self.pad_token = "[PAD]" - self.sep_token = "[SEP]" - self.unk_token = "[UNK]" - self.cls_token = "[CLS]" - - self.symbols = [] - self.count = [] - self.indices = {} - self.pad_token_id = self.add_symbol(self.pad_token) - self.cls_token_id = self.add_symbol(self.cls_token) - self.sep_token_id = self.add_symbol(self.sep_token) - self.unk_token_id = self.add_symbol(self.unk_token) - - self.gpt2_encoder = torch.load(vocab_file) - self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"]) - for w, n in self.gpt2_encoder["dict_map"]: - self.add_symbol(w, n) - - self.mask_token = "[MASK]" - self.mask_id = self.add_symbol(self.mask_token) - self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] - if special_tokens is not None: - for t in special_tokens: - self.add_special_token(t) - - self.vocab = self.indices - self.ids_to_tokens = self.symbols - - def tokenize(self, text): - """ - Convert an input text to tokens. - - Args: - text (:obj:`str`): input text to be tokenized. - - Returns: - A list of byte tokens where each token represent the byte id in GPT2 byte dictionary - - Example:: - >>> tokenizer = GPT2Tokenizer() - >>> text = "Hello world!" - >>> tokens = tokenizer.tokenize(text) - >>> print(tokens) - ['15496', '995', '0'] - """ - bpe = self._encode(text) - - return [t for t in bpe.split(" ") if t] - - def convert_tokens_to_ids(self, tokens): - """ - Convert list of tokens to ids - - Args: - tokens (:obj:`list`): list of tokens - - Returns: - List of ids - """ - - return [self.vocab[t] for t in tokens] - - def convert_ids_to_tokens(self, ids): - """ - Convert list of ids to tokens - - Args: - ids (:obj:`list`): list of ids - - Returns: - List of tokens - """ - - tokens = [] - for i in ids: - tokens.append(self.ids_to_tokens[i]) - return tokens - - def split_to_words(self, text): - return self.bpe.split_to_words(text) - - def decode(self, tokens): - """ - Decode list of tokens to text strings - - Args: - tokens (:obj:`list`): list of tokens. - - Returns: - Text string corresponds to the input tokens. - - Example:: - >>> tokenizer = GPT2Tokenizer() - >>> text = "Hello world!" - >>> tokens = tokenizer.tokenize(text) - >>> print(tokens) - ['15496', '995', '0'] - >>> tokenizer.decode(tokens) - 'Hello world!' - """ - return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) - - def add_special_token(self, token): - """ - Adds a special token to the dictionary - - Args: - token (:obj:`str`): Tthe new token/word to be added to the vocabulary. - - Returns: - The id of new token in the vocabulary. - - """ - self.special_tokens.append(token) - return self.add_symbol(token) - - def part_of_whole_word(self, token, is_bos=False): - if is_bos: - return True - s = self._decode(token) - if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])): - return False - - return not s.startswith(" ") - - def sym(self, id): - return self.ids_to_tokens[id] - - def id(self, sym): - return self.vocab[sym] - - def _encode(self, x: str) -> str: - return " ".join(map(str, self.bpe.encode(x))) - - def _decode(self, x: str) -> str: - return self.bpe.decode(map(int, x.split())) - - def add_symbol(self, word, n=1): - """ - Adds a word to the dictionary - - Args: - word (:obj:`str`): Tthe new token/word to be added to the vocabulary. - n (int, optional): The frequency of the word. - - Returns: - The id of the new word. - - """ - if word in self.indices: - idx = self.indices[word] - self.count[idx] = self.count[idx] + n - return idx - else: - idx = len(self.symbols) - self.indices[word] = idx - self.symbols.append(word) - self.count.append(n) - return idx - - def save_pretrained(self, path: str, filename_prefix: str = None): - import torch - - filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] - if filename_prefix is not None: - filename = filename_prefix + "-" + filename - full_path = os.path.join(path, filename) - torch.save(self.gpt2_encoder, full_path) - return (full_path,) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 574d3e7062fc60..60b9546379a2f1 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -16,14 +16,10 @@ import math from collections.abc import Sequence -from functools import lru_cache -import numpy as np import torch -from packaging import version from torch import _softmax_backward_data, nn from torch.nn import CrossEntropyLoss -from torch.nn import LayerNorm as DebertaLayerNorm from ...activations import ACT2FN from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward @@ -43,18 +39,15 @@ _CONFIG_FOR_DOC = "DebertaConfig" _TOKENIZER_FOR_DOC = "DebertaTokenizer" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-base" DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "microsoft/deberta-base", "microsoft/deberta-large", "microsoft/deberta-xlarge", - "microsoft/deberta-xlarge-v2", - "microsoft/deberta-xxlarge-v2", "microsoft/deberta-base-mnli", "microsoft/deberta-large-mnli", "microsoft/deberta-xlarge-mnli", - "microsoft/deberta-xlarge-v2-mnli", - "microsoft/deberta-xxlarge-v2-mnli", ] @@ -65,7 +58,7 @@ def __init__(self, config): self.dropout = StableDropout(config.pooler_dropout) self.config = config - def forward(self, hidden_states, mask=None): + def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. @@ -85,27 +78,28 @@ class XSoftmax(torch.autograd.Function): Masked Softmax which is optimized for saving memory Args: - input (:obj:`torch.tensor`): The input tensor that will apply softmax. - mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. - dim (int): The dimension that will apply softmax + input (:obj:`torch.tensor`): The input tensor that will apply softmax. + mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax Example:: - import torch - from transformers.models.deberta import XSoftmax - # Make a tensor - x = torch.randn([4,20,100]) - # Create a mask - mask = (x>0).int() - y = XSoftmax.apply(x, mask, dim=-1) + + >>> import torch + >>> from transformers.models.deberta.modeling_deberta import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4,20,100]) + + >>> # Create a mask + >>> mask = (x>0).int() + + >>> y = XSoftmax.apply(x, mask, dim=-1) """ @staticmethod def forward(self, input, mask, dim): self.dim = dim - if version.Version(torch.__version__) >= version.Version("1.2.0a"): - rmask = ~(mask.bool()) - else: - rmask = (1 - mask).byte() # This line is not supported by Onnx tracing. + rmask = ~(mask.bool()) output = input.masked_fill(rmask, float("-inf")) output = torch.softmax(output, self.dim) @@ -138,10 +132,7 @@ def get_mask(input, local_context): mask = local_context.mask if local_context.reuse_mask else None if dropout > 0 and mask is None: - if version.Version(torch.__version__) >= version.Version("1.2.0a"): - mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() - else: - mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).byte() + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() if isinstance(local_context, DropoutContext): if local_context.mask is None: @@ -177,9 +168,7 @@ class StableDropout(torch.nn.Module): Optimized dropout module for stabilizing the training Args: - drop_prob (float): the dropout probabilities - """ def __init__(self, drop_prob): @@ -194,8 +183,6 @@ def forward(self, x): Args: x (:obj:`torch.tensor`): The input tensor to apply dropout - - """ if self.training and self.drop_prob > 0: return XDropout.apply(x, self.get_context()) @@ -225,34 +212,24 @@ def get_context(self): return self.drop_prob -def MaskedLayerNorm(layerNorm, input, mask=None): - """ - Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm - module. +class DebertaLayerNorm(nn.Module): + """LayerNorm module in the TF style (epsilon inside the square root).""" - Args: - layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function - input (:obj:`torch.tensor`): The input tensor - mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` - - Example:: - - # Create a tensor b x n x d - x = torch.randn([1,10,100]) - m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) - LayerNorm = DeBERTa.deberta.LayerNorm(100) - y = MaskedLayerNorm(LayerNorm, x, m) + def __init__(self, size, eps=1e-12): + super().__init__() + self.weight = nn.Parameter(torch.ones(size)) + self.bias = nn.Parameter(torch.zeros(size)) + self.variance_epsilon = eps - """ - output = layerNorm(input).to(input) - if mask is None: - return output - if mask.dim() != input.dim(): - if mask.dim() == 4: - mask = mask.squeeze(1).squeeze(1) - mask = mask.unsqueeze(2) - mask = mask.to(output.dtype) - return output * mask + def forward(self, hidden_states): + input_type = hidden_states.dtype + hidden_states = hidden_states.float() + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.to(input_type) + y = self.weight * hidden_states + self.bias + return y class DebertaSelfOutput(nn.Module): @@ -323,7 +300,7 @@ def forward(self, hidden_states): class DebertaOutput(nn.Module): def __init__(self, config): - super(DebertaOutput, self).__init__() + super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) @@ -338,7 +315,7 @@ def forward(self, hidden_states, input_tensor): class DebertaLayer(nn.Module): def __init__(self, config): - super(DebertaLayer, self).__init__() + super().__init__() self.attention = DebertaAttention(config) self.intermediate = DebertaIntermediate(config) self.output = DebertaOutput(config) @@ -370,29 +347,6 @@ def forward( return layer_output -class ConvLayer(nn.Module): - def __init__(self, config): - super().__init__() - kernel_size = getattr(config, "conv_kernel_size", 3) - groups = getattr(config, "conv_groups", 1) - self.conv_act = getattr(config, "conv_act", "tanh") - self.conv = torch.nn.Conv1d( - config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups - ) - self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) - self.dropout = StableDropout(config.hidden_dropout_prob) - self.config = config - - def forward(self, hidden_states, residual_states, input_mask): - out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() - rmask = (1 - input_mask).bool() - out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) - out = ACT2FN[self.conv_act](self.dropout(out)) - output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask) - - return output_states - - class DebertaEncoder(nn.Module): """Modified BertEncoder with relative position bias support""" @@ -404,25 +358,10 @@ def __init__(self, config): self.max_relative_positions = getattr(config, "max_relative_positions", -1) if self.max_relative_positions < 1: self.max_relative_positions = config.max_position_embeddings - self.position_buckets = getattr(config, "position_buckets", -1) - pos_ebd_size = self.max_relative_positions * 2 - if self.position_buckets > 0: - pos_ebd_size = self.position_buckets * 2 - self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) - - self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] - if "layer_norm" in self.norm_rel_ebd: - self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) - kernel_size = getattr(config, "conv_kernel_size", 0) - self.with_conv = False - if kernel_size > 0: - self.with_conv = True - self.conv = ConvLayer(config) + self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) def get_rel_embedding(self): rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None - if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): - rel_embeddings = self.LayerNorm(rel_embeddings) return rel_embeddings def get_attention_mask(self, attention_mask): @@ -438,9 +377,7 @@ def get_attention_mask(self, attention_mask): def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): if self.relative_attention and relative_pos is None: q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) - relative_pos = build_relative_position( - q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions - ) + relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device) return relative_pos def forward( @@ -453,10 +390,6 @@ def forward( relative_pos=None, return_dict=True, ): - if attention_mask.dim() <= 2: - input_mask = attention_mask - else: - input_mask = (attention_mask.sum(-2) > 0).byte() attention_mask = self.get_attention_mask(attention_mask) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) @@ -473,7 +406,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - output_states = layer_module( + hidden_states = layer_module( next_kv, attention_mask, output_attentions, @@ -482,42 +415,29 @@ def forward( rel_embeddings=rel_embeddings, ) if output_attentions: - output_states, att_m = output_states - - if i == 0 and self.with_conv: - output_states = self.conv(hidden_states, output_states, input_mask) + hidden_states, att_m = hidden_states if query_states is not None: - query_states = output_states + query_states = hidden_states if isinstance(hidden_states, Sequence): next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None else: - next_kv = output_states + next_kv = hidden_states if output_attentions: all_attentions = all_attentions + (att_m,) if output_hidden_states: - all_hidden_states = all_hidden_states + (output_states,) + all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions ) -def make_log_bucket_position(relative_pos, bucket_size, max_position): - sign = np.sign(relative_pos) - mid = bucket_size // 2 - abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) - log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid - bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) - return bucket_pos - - -@lru_cache(maxsize=128) -def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): +def build_relative_position(query_size, key_size, device): """ Build relative position according to the query and key @@ -528,19 +448,15 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- Args: query_size (int): the length of query key_size (int): the length of key - bucket_size (int): the size of position bucket - max_position (int): the maxium allowed absolute positoin Return: :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] """ - q_ids = np.arange(0, query_size) - k_ids = np.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) - if bucket_size > 0 and max_position > 0: - rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) - rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + + q_ids = torch.arange(query_size, dtype=torch.long, device=device) + k_ids = torch.arange(key_size, dtype=torch.long, device=device) + rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids @@ -580,41 +496,39 @@ def __init__(self, config): "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads - _attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) - self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) - self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) + self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] - self.share_att_key = getattr(config, "share_att_key", False) - self.pos_att_type = config.pos_att_type self.relative_attention = getattr(config, "relative_attention", False) + self.talking_head = getattr(config, "talking_head", False) + + if self.talking_head: + self.head_logits_proj = torch.nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) + self.head_weights_proj = torch.nn.Linear( + config.num_attention_heads, config.num_attention_heads, bias=False + ) if self.relative_attention: - self.position_buckets = getattr(config, "position_buckets", -1) self.max_relative_positions = getattr(config, "max_relative_positions", -1) if self.max_relative_positions < 1: self.max_relative_positions = config.max_position_embeddings - self.pos_ebd_size = self.max_relative_positions - if self.position_buckets > 0: - self.pos_ebd_size = self.position_buckets - self.pos_dropout = StableDropout(config.hidden_dropout_prob) - if not self.share_att_key: - if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: - self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) - if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size) self.dropout = StableDropout(config.attention_probs_dropout_prob) - self._register_load_state_dict_pre_hook(self._pre_load_hook) - def transpose_for_scores(self, x, attention_heads): - new_x_shape = x.size()[:-1] + (attention_heads, -1) + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + return x.permute(0, 2, 1, 3) def forward( self, @@ -655,46 +569,51 @@ def forward( """ if query_states is None: - query_states = hidden_states - query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) - key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) - value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) + query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) + else: + + def linear(w, b, x): + if b is not None: + return torch.matmul(x, w.t()) + b.t() + else: + return torch.matmul(x, w.t()) # + b.t() + + ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) + qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] + qkvb = [None] * 3 + + q = linear(qkvw[0], qkvb[0], query_states) + k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] + query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] + + query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) + value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. - scale_factor = 1 - if "c2p" in self.pos_att_type: - scale_factor += 1 - if "p2c" in self.pos_att_type: - scale_factor += 1 - if "p2p" in self.pos_att_type: - scale_factor += 1 + scale_factor = 1 + len(self.pos_att_type) scale = math.sqrt(query_layer.size(-1) * scale_factor) - attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + query_layer = query_layer / scale + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) - rel_att = self.disentangled_attention_bias( - query_layer, key_layer, relative_pos, rel_embeddings, scale_factor - ) + rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) if rel_att is not None: attention_scores = attention_scores + rel_att - attention_scores = attention_scores - attention_scores = attention_scores.view( - -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) - ) # bxhxlxd - _attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) - attention_probs = self.dropout(_attention_probs) - context_layer = torch.bmm( - attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer - ) - context_layer = ( - context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) - .permute(0, 2, 1, 3) - .contiguous() - ) + if self.talking_head: + attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + if self.talking_head: + attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(*new_context_layer_shape) if return_att: @@ -702,147 +621,61 @@ def forward( else: return context_layer - def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): if relative_pos is None: q = query_layer.size(-2) - relative_pos = build_relative_position( - q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions - ) + relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: relative_pos = relative_pos.unsqueeze(1) # bxhxqxk elif relative_pos.dim() != 4: - raise ValueError(f"Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") - att_span = self.pos_ebd_size + att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions) relative_pos = relative_pos.long().to(query_layer.device) + rel_embeddings = rel_embeddings[ + self.max_relative_positions - att_span : self.max_relative_positions + att_span, : + ].unsqueeze(0) + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_key_layer = self.pos_proj(rel_embeddings) + pos_key_layer = self.transpose_for_scores(pos_key_layer) - rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze( - 0 - ) # .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) - if self.share_att_key: - pos_query_layer = self.transpose_for_scores( - self.query_proj(rel_embeddings), self.num_attention_heads - ).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1 - ) # .split(self.all_head_size, dim=-1) - pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1 - ) # .split(self.all_head_size, dim=-1) - else: - if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: - pos_key_layer = self.transpose_for_scores( - self.pos_key_proj(rel_embeddings), self.num_attention_heads - ).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1 - ) # .split(self.all_head_size, dim=-1) - if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - pos_query_layer = self.transpose_for_scores( - self.pos_query_proj(rel_embeddings), self.num_attention_heads - ).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1 - ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer = self.pos_q_proj(rel_embeddings) + pos_query_layer = self.transpose_for_scores(pos_query_layer) score = 0 # content->position if "c2p" in self.pos_att_type: - scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) - c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) - c2p_att = torch.gather( - c2p_att, - dim=-1, - index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), - ) - score += c2p_att / scale + c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)) + score += c2p_att # position->content if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) - if key_layer.size(-2) != query_layer.size(-2): - r_pos = build_relative_position( - key_layer.size(-2), - key_layer.size(-2), - bucket_size=self.position_buckets, - max_position=self.max_relative_positions, - ).to(query_layer.device) - r_pos = r_pos.unsqueeze(0) + pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) + if query_layer.size(-2) != key_layer.size(-2): + r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) else: r_pos = relative_pos - p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) if query_layer.size(-2) != key_layer.size(-2): pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) if "p2c" in self.pos_att_type: - p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.gather( - p2c_att, - dim=-1, - index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) ).transpose(-1, -2) if query_layer.size(-2) != key_layer.size(-2): - p2c_att = torch.gather( - p2c_att, - dim=-2, - index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))), - ) - score += p2c_att / scale - - # position->position - if "p2p" in self.pos_att_type: - pos_query = pos_query_layer[:, :, att_span:, :] - p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2)) - p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:]) - if query_layer.size(-2) != key_layer.size(-2): - p2p_att = torch.gather( - p2p_att, - dim=-2, - index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))), - ) - p2p_att = torch.gather( - p2p_att, - dim=-1, - index=c2p_pos.expand( - [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] - ), - ) - score += p2p_att + p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer)) + score += p2c_att return score - def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self_state = self.state_dict() - if ((prefix + "query_proj.weight") not in state_dict) and ((prefix + "in_proj.weight") in state_dict): - v1_proj = state_dict[prefix + "in_proj.weight"] - v1_proj = v1_proj.unsqueeze(0).reshape(self.num_attention_heads, -1, v1_proj.size(-1)) - q, k, v = v1_proj.chunk(3, dim=1) - state_dict[prefix + "query_proj.weight"] = q.reshape(-1, v1_proj.size(-1)) - state_dict[prefix + "key_proj.weight"] = k.reshape(-1, v1_proj.size(-1)) - state_dict[prefix + "key_proj.bias"] = self_state["key_proj.bias"] - state_dict[prefix + "value_proj.weight"] = v.reshape(-1, v1_proj.size(-1)) - v1_query_bias = state_dict[prefix + "q_bias"] - state_dict[prefix + "query_proj.bias"] = v1_query_bias - v1_value_bias = state_dict[prefix + "v_bias"] - state_dict[prefix + "value_proj.bias"] = v1_value_bias - - v1_pos_key_proj = state_dict[prefix + "pos_proj.weight"] - state_dict[prefix + "pos_key_proj.weight"] = v1_pos_key_proj - v1_pos_query_proj = state_dict[prefix + "pos_q_proj.weight"] - state_dict[prefix + "pos_query_proj.weight"] = v1_pos_query_proj - v1_pos_query_proj_bias = state_dict[prefix + "pos_q_proj.bias"] - state_dict[prefix + "pos_query_proj.bias"] = v1_pos_query_proj_bias - state_dict[prefix + "pos_key_proj.bias"] = self_state["pos_key_proj.bias"] - - del state_dict[prefix + "in_proj.weight"] - del state_dict[prefix + "q_bias"] - del state_dict[prefix + "v_bias"] - del state_dict[prefix + "pos_proj.weight"] - del state_dict[prefix + "pos_q_proj.weight"] - del state_dict[prefix + "pos_q_proj.bias"] - class DebertaEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -866,7 +699,6 @@ def __init__(self, config): self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) - self.output_to_half = False self.config = config # position_ids (1, len position emb) is contiguous in memory and exported when serialized @@ -928,6 +760,7 @@ class DebertaPreTrainedModel(PreTrainedModel): config_class = DebertaConfig base_model_prefix = "deberta" _keys_to_ignore_on_load_missing = ["position_ids"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] def _init_weights(self, module): """ Initialize the weights """ @@ -1032,7 +865,7 @@ class PreTrainedModel @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="microsoft/deberta-base", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1118,7 +951,6 @@ def forward( @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING) class DebertaForMaskedLM(DebertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1139,7 +971,7 @@ def set_output_embeddings(self, new_embeddings): @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="microsoft/deberta-base", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1279,7 +1111,7 @@ def set_input_embeddings(self, new_embeddings): @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="microsoft/deberta-base", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1359,7 +1191,6 @@ def forward( DEBERTA_START_DOCSTRING, ) class DebertaForTokenClassification(DebertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1375,7 +1206,7 @@ def __init__(self, config): @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="microsoft/deberta-base", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1448,7 +1279,6 @@ def forward( DEBERTA_START_DOCSTRING, ) class DebertaForQuestionAnswering(DebertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1463,7 +1293,7 @@ def __init__(self, config): @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="microsoft/deberta-base", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC, ) diff --git a/src/transformers/models/deberta/spm_tokenizer.py b/src/transformers/models/deberta/spm_tokenizer.py deleted file mode 100644 index d7b94fa7c40928..00000000000000 --- a/src/transformers/models/deberta/spm_tokenizer.py +++ /dev/null @@ -1,277 +0,0 @@ -# coding=utf-8 -# Copyright 2020 Microsoft and the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Tokenization class for model DeBERTa.""" - -import os -import unicodedata - -import sentencepiece as sp -import six - - -VOCAB_FILES_NAMES = {"vocab_file": "spm.model"} - -__all__ = ["SPMTokenizer"] - - -class SPMTokenizer: - def __init__( - self, vocab_file, do_lower_case=False, special_tokens=None, bpe_dropout=0, split_by_punct=False, **kwargs - ): - self.split_by_punct = split_by_punct - spm = sp.SentencePieceProcessor() - assert os.path.exists(vocab_file) - spm.load(vocab_file) - bpe_vocab_size = spm.GetPieceSize() - # Token map - # 0+1 - # 1+1 - # 2+1 - self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)} - self.id_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)] - # self.vocab['[PAD]'] = 0 - # self.vocab['[CLS]'] = 1 - # self.vocab['[SEP]'] = 2 - # self.vocab['[UNK]'] = 3 - - _special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] - self.special_tokens = [] - if special_tokens is not None: - _special_tokens.extend(special_tokens) - for t in _special_tokens: - self.add_special_token(t) - - self.spm = spm - - def tokenize(self, text): - pieces = self._encode_as_pieces(text) - - def _norm(x): - if x not in self.vocab or x == "": - return "[UNK]" - else: - return x - - pieces = [_norm(p) for p in pieces] - return pieces - - def convert_tokens_to_ids(self, tokens): - return [self.vocab[t] if t in self.vocab else 1 for t in tokens] - - def convert_ids_to_tokens(self, ids): - tokens = [] - for i in ids: - tokens.append(self.ids_to_tokens[i]) - return tokens - - def decode(self, tokens, start=-1, end=-1, raw_text=None): - if raw_text is None: - return self.spm.decode_pieces([t for t in tokens if t not in self.special_tokens]) - else: - words = self.split_to_words(raw_text) - word_tokens = [self.tokenize(w) for w in words] - token2words = [0] * len(tokens) - tid = 0 - for i, w in enumerate(word_tokens): - for k, t in enumerate(w): - token2words[tid] = i - tid += 1 - word_start = token2words[start] - word_end = token2words[end] if end < len(tokens) else len(words) - text = "".join(words[word_start:word_end]) - return text - - def add_special_token(self, token): - if token not in self.special_tokens: - self.special_tokens.append(token) - if token not in self.vocab: - self.vocab[token] = len(self.vocab) - self.id_to_tokens.append(token) - return self.id(token) - - def part_of_whole_word(self, token, is_bos=False): - if is_bos: - return True - if ( - len(token) == 1 - and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0])) - ) or token in self.special_tokens: - return False - - word_start = b"\xe2\x96\x81".decode("utf-8") - return not token.startswith(word_start) - - def pad(self): - return "[PAD]" - - def bos(self): - return "[CLS]" - - def eos(self): - return "[SEP]" - - def unk(self): - return "[UNK]" - - def mask(self): - return "[MASK]" - - def sym(self, id): - return self.ids_to_tokens[id] - - def id(self, sym): - return self.vocab[sym] if sym in self.vocab else 1 - - def _encode_as_pieces(self, text): - text = convert_to_unicode(text) - if self.split_by_punct: - words = self._run_split_on_punc(text) - pieces = [self.spm.encode_as_pieces(w) for w in words] - return [p for w in pieces for p in w] - else: - return self.spm.encode_as_pieces(text) - - def split_to_words(self, text): - pieces = self._encode_as_pieces(text) - word_start = b"\xe2\x96\x81".decode("utf-8") - words = [] - offset = 0 - prev_end = 0 - for i, p in enumerate(pieces): - if p.startswith(word_start): - if offset > prev_end: - words.append(text[prev_end:offset]) - prev_end = offset - w = p.replace(word_start, "") - else: - w = p - try: - s = text.index(w, offset) - pn = "" - k = i + 1 - while k < len(pieces): - pn = pieces[k].replace(word_start, "") - if len(pn) > 0: - break - k += 1 - - if len(pn) > 0 and pn in text[offset:s]: - offset = offset + 1 - else: - offset = s + len(w) - except Exception: - offset = offset + 1 - - if prev_end < offset: - words.append(text[prev_end:offset]) - - return words - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def save_pretrained(self, path: str, filename_prefix: str = None): - filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] - if filename_prefix is not None: - filename = filename_prefix + "-" + filename - full_path = os.path.join(path, filename) - with open(full_path, "wb") as fs: - fs.write(self.spm.serialized_model_proto()) - return (full_path,) - - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat.startswith("C"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False - - -def convert_to_unicode(text): - """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - else: - raise ValueError("Not running on Python2 or Python 3?") diff --git a/src/transformers/models/deberta/tokenization_deberta.py b/src/transformers/models/deberta/tokenization_deberta.py index f3ed41fe3d991a..9e8c8497408c9a 100644 --- a/src/transformers/models/deberta/tokenization_deberta.py +++ b/src/transformers/models/deberta/tokenization_deberta.py @@ -15,29 +15,39 @@ """ Tokenization class for model DeBERTa.""" import os +import pathlib +import random +import unicodedata +from functools import lru_cache from typing import Optional, Tuple +from zipfile import ZipFile + +import tqdm + +import requests -from ...file_utils import is_sentencepiece_available from ...tokenization_utils import PreTrainedTokenizer -from .gpt2_tokenizer import GPT2Tokenizer +from ...utils import logging + +try: + import regex as re +except ImportError: + raise ImportError("Please install regex with: pip install regex") -if is_sentencepiece_available(): - from .spm_tokenizer import SPMTokenizer +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/bpe_encoder.bin", "microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/bpe_encoder.bin", "microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/bpe_encoder.bin", - "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/spm.model", - "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/spm.model", "microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/bpe_encoder.bin", "microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/bpe_encoder.bin", "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/bpe_encoder.bin", - "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/spm.model", - "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/spm.model", } } @@ -45,31 +55,444 @@ "microsoft/deberta-base": 512, "microsoft/deberta-large": 512, "microsoft/deberta-xlarge": 512, - "microsoft/deberta-xlarge-v2": 512, - "microsoft/deberta-xxlarge-v2": 512, "microsoft/deberta-base-mnli": 512, "microsoft/deberta-large-mnli": 512, "microsoft/deberta-xlarge-mnli": 512, - "microsoft/deberta-xlarge-v2-mnli": 512, - "microsoft/deberta-xxlarge-v2-mnli": 512, } PRETRAINED_INIT_CONFIGURATION = { - "microsoft/deberta-base": {"do_lower_case": False, "vocab_type": "gpt2"}, - "microsoft/deberta-large": {"do_lower_case": False, "vocab_type": "gpt2"}, - "microsoft/deberta-xlarge": {"do_lower_case": False, "vocab_type": "gpt2"}, - "microsoft/deberta-xlarge-v2": {"do_lower_case": False, "vocab_type": "spm"}, - "microsoft/deberta-xxlarge-v2": {"do_lower_case": False, "vocab_type": "spm"}, - "microsoft/deberta-base-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, - "microsoft/deberta-large-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, - "microsoft/deberta-xlarge-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, - "microsoft/deberta-xlarge-v2-mnli": {"do_lower_case": False, "vocab_type": "spm"}, - "microsoft/deberta-xxlarge-v2-mnli": {"do_lower_case": False, "vocab_type": "spm"}, + "microsoft/deberta-base": {"do_lower_case": False}, + "microsoft/deberta-large": {"do_lower_case": False}, } __all__ = ["DebertaTokenizer"] -VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode + strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're + at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant + percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode + strings. And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges)))) + self.cache = {} + self.random = random.Random(0) + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def split_to_words(self, text): + return list(re.findall(self.pat, text)) + + def encode(self, text): + bpe_tokens = [] + for token in self.split_to_words(text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + +def get_encoder(encoder, vocab): + return Encoder( + encoder=encoder, + bpe_merges=vocab, + ) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def download_asset(name, tag=None, no_cache=False, cache_dir=None): + _tag = tag + if _tag is None: + _tag = "latest" + if not cache_dir: + cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/") + os.makedirs(cache_dir, exist_ok=True) + output = os.path.join(cache_dir, name) + if os.path.exists(output) and (not no_cache): + return output + + repo = "https://api.github.com/repos/microsoft/DeBERTa/releases" + releases = requests.get(repo).json() + if tag and tag != "latest": + release = [r for r in releases if r["name"].lower() == tag.lower()] + if len(release) != 1: + raise Exception(f"{tag} can't be found in the repository.") + else: + release = releases[0] + asset = [s for s in release["assets"] if s["name"].lower() == name.lower()] + if len(asset) != 1: + raise Exception(f"{name} can't be found in the release.") + url = asset[0]["url"] + headers = {} + headers["Accept"] = "application/octet-stream" + resp = requests.get(url, stream=True, headers=headers) + if resp.status_code != 200: + raise Exception(f"Request for {url} return {resp.status_code}, {resp.text}") + try: + with open(output, "wb") as fs: + progress = tqdm( + total=int(resp.headers["Content-Length"]) if "Content-Length" in resp.headers else -1, + ncols=80, + desc=f"Downloading {name}", + ) + for c in resp.iter_content(chunk_size=1024 * 1024): + fs.write(c) + progress.update(len(c)) + progress.close() + except Exception: + os.remove(output) + raise + + return output + + +def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None): + import torch + + if name is None: + name = "bpe_encoder" + + model_path = name + if model_path and (not os.path.exists(model_path)) and not (("/" in model_path) or ("\\" in model_path)): + _tag = tag + if _tag is None: + _tag = "latest" + if not cache_dir: + cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/") + os.makedirs(cache_dir, exist_ok=True) + out_dir = os.path.join(cache_dir, name) + model_path = os.path.join(out_dir, "bpe_encoder.bin") + if (not os.path.exists(model_path)) or no_cache: + asset = download_asset(name + ".zip", tag=tag, no_cache=no_cache, cache_dir=cache_dir) + with ZipFile(asset, "r") as zipf: + for zip_info in zipf.infolist(): + if zip_info.filename[-1] == "/": + continue + zip_info.filename = os.path.basename(zip_info.filename) + zipf.extract(zip_info, out_dir) + elif not model_path: + return None, None + + encoder_state = torch.load(model_path) + return encoder_state + + +class GPT2Tokenizer(object): + """ + A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer + + Args: + vocab_file (:obj:`str`, optional): + The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases + `_, e.g. "bpe_encoder", default: `None`. + + If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file + is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used + in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2 + tokenizer and RoBERTa wrapped tokenizer are, + + - Special tokens, unlike `RoBERTa` which use ``, `` as the `start` token and `end` token of a + sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same + as `BERT`. + + - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, + `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 + + special_tokens (:obj:`list`, optional): + List of special tokens to be added to the end of the vocabulary. + """ + + def __init__(self, vocab_file=None, special_tokens=None): + self.pad_token = "[PAD]" + self.sep_token = "[SEP]" + self.unk_token = "[UNK]" + self.cls_token = "[CLS]" + + self.symbols = [] + self.count = [] + self.indices = {} + self.pad_token_id = self.add_symbol(self.pad_token) + self.cls_token_id = self.add_symbol(self.cls_token) + self.sep_token_id = self.add_symbol(self.sep_token) + self.unk_token_id = self.add_symbol(self.unk_token) + + self.gpt2_encoder = load_vocab(vocab_file) + self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"]) + for w, n in self.gpt2_encoder["dict_map"]: + self.add_symbol(w, n) + + self.mask_token = "[MASK]" + self.mask_id = self.add_symbol(self.mask_token) + self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] + if special_tokens is not None: + for t in special_tokens: + self.add_special_token(t) + + self.vocab = self.indices + self.ids_to_tokens = self.symbols + + def tokenize(self, text): + """ + Convert an input text to tokens. + + Args: + text (:obj:`str`): input text to be tokenized. + + Returns: + A list of byte tokens where each token represent the byte id in GPT2 byte dictionary + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + """ + bpe = self._encode(text) + + return [t for t in bpe.split(" ") if t] + + def convert_tokens_to_ids(self, tokens): + """ + Convert list of tokens to ids + + Args: + tokens (:obj:`list`): list of tokens + + Returns: + List of ids + """ + + return [self.vocab[t] for t in tokens] + + def convert_ids_to_tokens(self, ids): + """ + Convert list of ids to tokens + + Args: + ids (:obj:`list`): list of ids + + Returns: + List of tokens + """ + + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def split_to_words(self, text): + return self.bpe.split_to_words(text) + + def decode(self, tokens): + """ + Decode list of tokens to text strings + + Args: + tokens (:obj:`list`): list of tokens. + + Returns: + Text string corresponds to the input tokens. + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + >>> tokenizer.decode(tokens) + 'Hello world!' + """ + return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) + + def add_special_token(self, token): + """ + Adds a special token to the dictionary + + Args: + token (:obj:`str`): Tthe new token/word to be added to the vocabulary. + + Returns: + The id of new token in the vocabulary. + + """ + self.special_tokens.append(token) + return self.add_symbol(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + s = self._decode(token) + if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])): + return False + + return not s.startswith(" ") + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] + + def _encode(self, x: str) -> str: + return " ".join(map(str, self.bpe.encode(x))) + + def _decode(self, x: str) -> str: + return self.bpe.decode(map(int, x.split())) + + def add_symbol(self, word, n=1): + """ + Adds a word to the dictionary + + Args: + word (:obj:`str`): Tthe new token/word to be added to the vocabulary. + n (int, optional): The frequency of the word. + + Returns: + The id of the new word. + + """ + if word in self.indices: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def save_pretrained(self, path: str, filename_prefix: str = None): + import torch + + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + "-" + filename + full_path = os.path.join(path, filename) + torch.save(self.gpt2_encoder, full_path) + return (full_path,) class DebertaTokenizer(PreTrainedTokenizer): @@ -107,7 +530,6 @@ def __init__( self, vocab_file, do_lower_case=False, - vocab_type="gpt2", unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", @@ -117,7 +539,6 @@ def __init__( ): super().__init__( do_lower_case=do_lower_case, - vocab_type="gpt2", unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, @@ -132,10 +553,7 @@ def __init__( "model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) ) self.do_lower_case = do_lower_case - if vocab_type.lower() == "gpt2": - self._tokenizer = GPT2Tokenizer(vocab_file, **kwargs) - else: - self._tokenizer = SPMTokenizer(vocab_file, **kwargs) + self.gpt2_tokenizer = GPT2Tokenizer(vocab_file) @property def vocab_size(self): @@ -143,7 +561,7 @@ def vocab_size(self): @property def vocab(self): - return self._tokenizer.vocab + return self.gpt2_tokenizer.vocab def get_vocab(self): vocab = self.vocab.copy() @@ -154,7 +572,7 @@ def _tokenize(self, text): """Take as input a string and return a list of strings (tokens) for words/sub-words""" if self.do_lower_case: text = text.lower() - return self._tokenizer.tokenize(text) + return self.gpt2_tokenizer.tokenize(text) def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ @@ -162,11 +580,11 @@ def _convert_token_to_id(self, token): def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" - return self._tokenizer.sym(index) if index < self.vocab_size else self.unk_token + return self.gpt2_tokenizer.sym(index) if index < self.vocab_size else self.unk_token def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ - return self._tokenizer.decode(tokens) + return self.gpt2_tokenizer.decode(tokens) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ @@ -261,4 +679,4 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): return (text, kwargs) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix) + return self.gpt2_tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix) diff --git a/src/transformers/models/deberta_v2/__init__.py b/src/transformers/models/deberta_v2/__init__.py new file mode 100644 index 00000000000000..b00c8f1afb8189 --- /dev/null +++ b/src/transformers/models/deberta_v2/__init__.py @@ -0,0 +1,72 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_torch_available + + +_import_structure = { + "configuration_deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], + "tokenization_deberta_v2": ["DebertaV2Tokenizer"], +} + +if is_torch_available(): + _import_structure["modeling_deberta_v2"] = [ + "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "DebertaV2ForSequenceClassification", + "DebertaV2Model", + "DebertaV2ForMaskedLM", + "DebertaV2PreTrainedModel", + "DebertaV2ForTokenClassification", + "DebertaV2ForQuestionAnswering", + ] + + +if TYPE_CHECKING: + from .configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config + from .tokenization_deberta_v2 import DebertaV2Tokenizer + + if is_torch_available(): + from .modeling_deberta_v2 import ( + DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, + DebertaV2ForMaskedLM, + DebertaV2ForQuestionAnswering, + DebertaV2ForSequenceClassification, + DebertaV2ForTokenClassification, + DebertaV2Model, + DebertaV2PreTrainedModel, + ) + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/deberta_v2/configuration_deberta_v2.py b/src/transformers/models/deberta_v2/configuration_deberta_v2.py new file mode 100644 index 00000000000000..128a8701352266 --- /dev/null +++ b/src/transformers/models/deberta_v2/configuration_deberta_v2.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# Copyright 2020, Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DeBERTa-v2 model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/config.json", + "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/config.json", + "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/config.json", + "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/config.json", +} + + +class DebertaV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.DebertaV2Model`. It is used + to instantiate a DeBERTa-v2 model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the DeBERTa + `microsoft/deberta-base `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + Arguments: + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by + the :obj:`inputs_ids` passed when calling :class:`~transformers.DebertaV2Model`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"`, :obj:`"gelu"`, :obj:`"tanh"`, :obj:`"gelu_fast"`, + :obj:`"mish"`, :obj:`"linear"`, :obj:`"sigmoid"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.DebertaModel` or + :class:`~transformers.TFDebertaModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + relative_attention (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether use relative position encoding. + max_relative_positions (:obj:`int`, `optional`, defaults to 1): + The range of relative positions :obj:`[-max_position_embeddings, max_position_embeddings]`. Use the same + value as :obj:`max_position_embeddings`. + pad_token_id (:obj:`int`, `optional`, defaults to 0): + The value used to pad input_ids. + position_biased_input (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether add absolute position embedding to content embedding. + pos_att_type (:obj:`List[str]`, `optional`): + The type of relative position attention, it can be a combination of :obj:`["p2c", "c2p", "p2p"]`, e.g. + :obj:`["p2c"]`, :obj:`["p2c", "c2p"]`, :obj:`["p2c", "c2p", 'p2p"]`. + layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + """ + model_type = "deberta" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=0, + initializer_range=0.02, + layer_norm_eps=1e-7, + relative_attention=False, + max_relative_positions=-1, + pad_token_id=0, + position_biased_input=True, + pos_att_type=None, + pooler_dropout=0, + pooler_hidden_act="gelu", + **kwargs + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.relative_attention = relative_attention + self.max_relative_positions = max_relative_positions + self.pad_token_id = pad_token_id + self.position_biased_input = position_biased_input + + # Backwards compatibility + if type(pos_att_type) == str: + pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")] + + self.pos_att_type = pos_att_type + self.vocab_size = vocab_size + self.layer_norm_eps = layer_norm_eps + + self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size) + self.pooler_dropout = pooler_dropout + self.pooler_hidden_act = pooler_hidden_act diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py new file mode 100644 index 00000000000000..834b542ba505cf --- /dev/null +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -0,0 +1,1514 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa-v2 model. """ + +import math +from collections.abc import Sequence + +import numpy as np +import torch +from torch import _softmax_backward_data, nn +from torch.nn import CrossEntropyLoss, LayerNorm + +from ...activations import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_deberta_v2 import DebertaV2Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DebertaV2Config" +_TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-xlarge-v2" + +DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/deberta-xlarge-v2", + "microsoft/deberta-xxlarge-v2", + "microsoft/deberta-xlarge-v2-mnli", + "microsoft/deberta-xxlarge-v2-mnli", +] + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (:obj:`torch.tensor`): The input tensor that will apply softmax. + mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example:: + + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4,20,100]) + + >>> # Create a mask + >>> mask = (x>0).int() + + >>> y = XSoftmax.apply(x, mask, dim=-1) + """ + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.bool()) + + output = input.masked_fill(rmask, float("-inf")) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output,) = self.saved_tensors + inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) + return inputGrad, None, None + + +# Copied from transformers.models.deberta.modeling_deberta.DropoutContext +class DropoutContext(object): + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +# Copied from transformers.models.deberta.modeling_deberta.get_mask +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.deberta.modeling_deberta.XDropout +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + +# Copied from transformers.models.deberta.modeling_deberta.StableDropout +class StableDropout(torch.nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (:obj:`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +def MaskedLayerNorm(layerNorm, input, mask=None): + """ + Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updates to the LayerNorm + module. + + Args: + layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function + input (:obj:`torch.tensor`): The input tensor + mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` + + Example:: + + # Create a tensor b x n x d + x = torch.randn([1,10,100]) + m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) + LayerNorm = DeBERTa.deberta.LayerNorm(100) + y = MaskedLayerNorm(LayerNorm, x, m) + + """ + output = layerNorm(input).to(input) + if mask is None: + return output + if mask.dim() != input.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(output.dtype) + return output * mask + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if return_att: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if return_att: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + return_att=return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if return_att: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if return_att: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = torch.nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask) + + return output_states + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + kernel_size = getattr(config, "conv_kernel_size", 0) + self.with_conv = False + if kernel_size > 0: + self.with_conv = True + self.conv = ConvLayer(config) + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + output_states = layer_module( + next_kv, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.with_conv: + output_states = self.conv(hidden_states, output_states, input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = np.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) + log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid + bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): + """ + Build relative position according to the query and key + + We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key + :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} = + P_q - P_k` + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maxium allowed absolute positoin + + Return: + :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = np.arange(0, query_size) + k_ids = np.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(torch.nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (:obj:`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + `BertConfig`, for more details, please refer :class:`~transformers.DebertaV2Config` + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (:obj:`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + `Attention(Q,K,V)` + + attention_mask (:obj:`torch.ByteTensor`): + An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum + sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j` + th token. + + return_att (:obj:`bool`, optional): + Whether return the attention matrix. + + query_states (:obj:`torch.FloatTensor`, optional): + The `Q` state in `Attention(Q,K,V)`. + + relative_pos (:obj:`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with + values ranging in [`-max_relative_positions`, `max_relative_positions`]. + + rel_embeddings (:obj:`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [:math:`2 \\times + \\text{max_relative_positions}`, `hidden_size`]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + if "p2p" in self.pos_att_type: + scale_factor += 1 + scale = math.sqrt(query_layer.size(-1) * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) + + # bxhxlxd + _attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(_attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(*new_context_layer_shape) + if return_att: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bxhxqxk + elif relative_pos.dim() != 4: + raise ValueError(f"Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.long().to(query_layer.device) + + rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) + else: + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale + + # position->content + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + if query_layer.size(-2) != key_layer.size(-2): + pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) + + if "p2c" in self.pos_att_type: + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + ).transpose(-1, -2) + if query_layer.size(-2) != key_layer.size(-2): + p2c_att = torch.gather( + p2c_att, + dim=-2, + index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))), + ) + score += p2c_att / scale + + # position->position + if "p2p" in self.pos_att_type: + pos_query = pos_query_layer[:, :, att_span:, :] + p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2)) + p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:]) + if query_layer.size(-2) != key_layer.size(-2): + p2p_att = torch.gather( + p2p_att, + dim=-2, + index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))), + ) + p2p_att = torch.gather( + p2p_att, + dim=-1, + index=c2p_pos.expand( + [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] + ), + ) + score += p2p_att + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 +class DebertaV2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = "deberta" + _keys_to_ignore_on_load_missing = ["position_ids"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention + `_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of + BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior.``` + + + Parameters: + config (:class:`~transformers.DebertaV2Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.DebertaV2Tokenizer`. See + :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + return_att=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 +class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = torch.nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # regression task + loss_fn = torch.nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1))) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = torch.nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 +class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 +class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py new file mode 100644 index 00000000000000..97b9c04150eea9 --- /dev/null +++ b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py @@ -0,0 +1,491 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for model DeBERTa.""" + +import os +import unicodedata +from typing import Optional, Tuple + +import sentencepiece as sp +import six + +from ...tokenization_utils import PreTrainedTokenizer + + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/spm.model", + "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/spm.model", + "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/spm.model", + "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/spm.model", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "microsoft/deberta-xlarge-v2": 512, + "microsoft/deberta-xxlarge-v2": 512, + "microsoft/deberta-xlarge-v2-mnli": 512, + "microsoft/deberta-xxlarge-v2-mnli": 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + "microsoft/deberta-xlarge-v2": {"do_lower_case": False}, + "microsoft/deberta-xxlarge-v2": {"do_lower_case": False}, + "microsoft/deberta-xlarge-v2-mnli": {"do_lower_case": False}, + "microsoft/deberta-xxlarge-v2-mnli": {"do_lower_case": False}, +} + +VOCAB_FILES_NAMES = {"vocab_file": "spm.model"} + + +class DebertaV2Tokenizer(PreTrainedTokenizer): + r""" + Constructs a DeBERTa-v2 tokenizer. Based on `SentencePiece `__. + + Args: + vocab_file (:obj:`str`): + `SentencePiece `__ file (generally has a `.spm` extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to lowercase the input when tokenizing. + unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + do_lower_case=False, + split_by_punct=False, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs + ): + super().__init__( + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + split_by_punct=split_by_punct, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) + ) + self.do_lower_case = do_lower_case + self.split_by_punct = split_by_punct + self._tokenizer = SPMTokenizer(vocab_file, split_by_punct=split_by_punct) + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def vocab(self): + return self._tokenizer.vocab + + def get_vocab(self): + vocab = self.vocab.copy() + vocab.update(self.get_added_vocab()) + return vocab + + def _tokenize(self, text): + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + if self.do_lower_case: + text = text.lower() + return self._tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self._tokenizer.spm.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self._tokenizer.spm.IdToPiece(index) if index < self.vocab_size else self.unk_token + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + return self._tokenizer.decode(tokens) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list( + map( + lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, + token_ids_0, + ) + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa + sequence pair mask has the following format: + + :: + + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + + If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", False) + if is_split_into_words or add_prefix_space: + text = " " + text + return (text, kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix) + + +class SPMTokenizer: + def __init__(self, vocab_file, split_by_punct=False): + self.split_by_punct = split_by_punct + self.vocab_file = vocab_file + spm = sp.SentencePieceProcessor() + assert os.path.exists(vocab_file) + spm.load(vocab_file) + bpe_vocab_size = spm.GetPieceSize() + # Token map + # 0+1 + # 1+1 + # 2+1 + self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)} + self.id_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)] + # self.vocab['[PAD]'] = 0 + # self.vocab['[CLS]'] = 1 + # self.vocab['[SEP]'] = 2 + # self.vocab['[UNK]'] = 3 + + self.spm = spm + + def __getstate__(self): + state = self.__dict__.copy() + state["spm"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.spm = sp.SentencePieceProcessor() + self.spm.Load(self.vocab_file) + + def tokenize(self, text): + pieces = self._encode_as_pieces(text) + + def _norm(x): + if x not in self.vocab or x == "": + return "[UNK]" + else: + return x + + pieces = [_norm(p) for p in pieces] + return pieces + + def convert_ids_to_tokens(self, ids): + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def decode(self, tokens, start=-1, end=-1, raw_text=None): + if raw_text is None: + return self.spm.decode_pieces([t for t in tokens]) + else: + words = self.split_to_words(raw_text) + word_tokens = [self.tokenize(w) for w in words] + token2words = [0] * len(tokens) + tid = 0 + for i, w in enumerate(word_tokens): + for k, t in enumerate(w): + token2words[tid] = i + tid += 1 + word_start = token2words[start] + word_end = token2words[end] if end < len(tokens) else len(words) + text = "".join(words[word_start:word_end]) + return text + + def add_special_token(self, token): + if token not in self.special_tokens: + self.special_tokens.append(token) + if token not in self.vocab: + self.vocab[token] = len(self.vocab) - 1 + self.id_to_tokens.append(token) + return self.id(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + if ( + len(token) == 1 + and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0])) + ) or token in self.special_tokens: + return False + + word_start = b"\xe2\x96\x81".decode("utf-8") + return not token.startswith(word_start) + + def pad(self): + return "[PAD]" + + def bos(self): + return "[CLS]" + + def eos(self): + return "[SEP]" + + def unk(self): + return "[UNK]" + + def mask(self): + return "[MASK]" + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] if sym in self.vocab else 1 + + def _encode_as_pieces(self, text): + text = convert_to_unicode(text) + if self.split_by_punct: + words = self._run_split_on_punc(text) + pieces = [self.spm.encode_as_pieces(w) for w in words] + return [p for w in pieces for p in w] + else: + return self.spm.encode_as_pieces(text) + + def split_to_words(self, text): + pieces = self._encode_as_pieces(text) + word_start = b"\xe2\x96\x81".decode("utf-8") + words = [] + offset = 0 + prev_end = 0 + for i, p in enumerate(pieces): + if p.startswith(word_start): + if offset > prev_end: + words.append(text[prev_end:offset]) + prev_end = offset + w = p.replace(word_start, "") + else: + w = p + try: + s = text.index(w, offset) + pn = "" + k = i + 1 + while k < len(pieces): + pn = pieces[k].replace(word_start, "") + if len(pn) > 0: + break + k += 1 + + if len(pn) > 0 and pn in text[offset:s]: + offset = offset + 1 + else: + offset = s + len(w) + except Exception: + offset = offset + 1 + + if prev_end < offset: + words.append(text[prev_end:offset]) + + return words + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def save_pretrained(self, path: str, filename_prefix: str = None): + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + "-" + filename + full_path = os.path.join(path, filename) + with open(full_path, "wb") as fs: + fs.write(self.spm.serialized_model_proto()) + return (full_path,) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 63dede9b285450..3a8608a98dfd56 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -883,6 +883,63 @@ def from_pretrained(self, *args, **kwargs): requires_pytorch(self) +DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class DebertaV2ForMaskedLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class DebertaV2ForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class DebertaV2ForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class DebertaV2ForTokenClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class DebertaV2Model: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class DebertaV2PreTrainedModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_deberta_v2.py b/tests/test_modeling_deberta_v2.py new file mode 100644 index 00000000000000..a67bc6b88096f3 --- /dev/null +++ b/tests/test_modeling_deberta_v2.py @@ -0,0 +1,290 @@ +# coding=utf-8 +# Copyright 2018 Microsoft Authors and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import unittest + +import numpy as np + +from transformers import is_torch_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + DebertaV2Config, + DebertaV2ForMaskedLM, + DebertaV2ForQuestionAnswering, + DebertaV2ForSequenceClassification, + DebertaV2ForTokenClassification, + DebertaV2Model, + ) + from transformers.models.deberta_v2.modeling_deberta_v2 import DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST + + +@require_torch +class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + DebertaV2Model, + DebertaV2ForMaskedLM, + DebertaV2ForSequenceClassification, + DebertaV2ForTokenClassification, + DebertaV2ForQuestionAnswering, + ) + if is_torch_available() + else () + ) + + test_torchscript = False + test_pruning = False + test_head_masking = False + is_encoder_decoder = False + + class DebertaV2ModelTester(object): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + relative_attention=False, + position_biased_input=True, + pos_att_type="None", + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.relative_attention = relative_attention + self.position_biased_input = position_biased_input + self.pos_att_type = pos_att_type + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = DebertaV2Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + relative_attention=self.relative_attention, + position_biased_input=self.position_biased_input, + pos_att_type=self.pos_att_type, + ) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def check_loss_output(self, result): + self.parent.assertListEqual(list(result.loss.size()), []) + + def create_and_check_deberta_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = DebertaV2Model(config=config) + model.to(torch_device) + model.eval() + sequence_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0] + sequence_output = model(input_ids, token_type_ids=token_type_ids)[0] + sequence_output = model(input_ids)[0] + + self.parent.assertListEqual( + list(sequence_output.size()), [self.batch_size, self.seq_length, self.hidden_size] + ) + + def create_and_check_deberta_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = DebertaV2ForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_deberta_for_sequence_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = DebertaV2ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertListEqual(list(result.logits.size()), [self.batch_size, self.num_labels]) + self.check_loss_output(result) + + def create_and_check_deberta_for_token_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = DebertaV2ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def create_and_check_deberta_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = DebertaV2ForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + def setUp(self): + self.model_tester = DebertaV2ModelTest.DebertaV2ModelTester(self) + self.config_tester = ConfigTester(self, config_class=DebertaV2Config, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deberta_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_model(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_for_sequence_classification(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_for_masked_lm(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_for_question_answering(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = DebertaV2Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class DebertaV2ModelIntegrationTest(unittest.TestCase): + @unittest.skip(reason="Model not available yet") + def test_inference_masked_lm(self): + pass + + @slow + def test_inference_no_head(self): + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + model = DebertaV2Model.from_pretrained("microsoft/deberta-xlarge-v2") + + input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + output = model(input_ids)[0] + # compare the actual values for a slice. + expected_slice = torch.tensor( + [[[-0.2913, 0.2647, 0.5627], [-0.4318, 0.1389, 0.3881], [-0.2929, -0.2489, 0.3452]]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4), f"{output[:, :3, :3]}") diff --git a/tests/test_tokenization_deberta_v2.py b/tests/test_tokenization_deberta_v2.py new file mode 100644 index 00000000000000..ee0ef37a6228ec --- /dev/null +++ b/tests/test_tokenization_deberta_v2.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2019 Hugging Face inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import unittest + +from transformers import DebertaV2Tokenizer +from transformers.testing_utils import require_sentencepiece, require_tokenizers + +from .test_tokenization_common import TokenizerTesterMixin + + +SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/spiece.model") + + +@require_sentencepiece +@require_tokenizers +class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = DebertaV2Tokenizer + rust_tokenizer_class = None + test_rust_tokenizer = False + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB) + tokenizer.save_pretrained(self.tmpdirname) + + def get_input_output_texts(self, tokenizer): + input_text = "this is a test" + output_text = "this is a test" + return input_text, output_text + + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + sequence = "I was born in 92000, and this is falsé." + + tokens = tokenizer.tokenize(sequence) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + ids = tokenizer.encode(sequence, add_special_tokens=False) + rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False) + self.assertListEqual(ids, rust_ids) + + rust_tokenizer = self.get_rust_tokenizer() + ids = tokenizer.encode(sequence) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + def test_full_tokenizer(self): + tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB, keep_accents=True) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁", "[UNK]", "his", "▁is", "▁a", "▁test"]) + + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [13, 1, 4398, 25, 21, 1289]) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + "▁", + "[UNK]", + "▁was", + "▁born", + "▁in", + "▁9", + "2000", + ",", + "▁and", + "▁this", + "▁is", + "▁fal", + "s", + "[UNK]", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(ids, [13, 1, 23, 386, 19, 561, 3050, 15, 17, 48, 25, 8256, 18, 1, 9]) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + "▁", + "", + "▁was", + "▁born", + "▁in", + "▁9", + "2000", + ",", + "▁and", + "▁this", + "▁is", + "▁fal", + "s", + "", + ".", + ], + ) + + def test_sequence_builders(self): + tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB) + + text = tokenizer.encode("sequence builders") + text_2 = tokenizer.encode("multi-sequence build") + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + text_2 + [ + tokenizer.sep_token_id + ] + + def test_tokenizer_integration(self): + tokenizer_classes = [self.tokenizer_class] + if self.test_rust_tokenizer: + tokenizer_classes.append(self.rust_tokenizer_class) + + for tokenizer_class in tokenizer_classes: + tokenizer = tokenizer_class.from_pretrained("microsoft/deberta-xlarge-v2") + + sequences = [ + [ + "DeBERTa: Decoding-enhanced BERT with Disentangled Attention", + "DeBERTa: Decoding-enhanced BERT with Disentangled Attention", + ], + [ + "Recent progress in pre-trained neural language models has significantly improved the performance of many natural language processing (NLP) tasks.", + "DeBERTa: Decoding-enhanced BERT with Disentangled Attention", + ], + [ + "In this paper we propose a new model architecture DeBERTa", + "DeBERTa: Decoding-enhanced BERT with Disentangled Attention", + ], + ] + + encoding = tokenizer(sequences, padding=True) + decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True) for seq in encoding["input_ids"]] + + # fmt: off + expected_encoding = { + 'input_ids': [ + [1, 1804, 69418, 191, 43, 117056, 18, 44596, 448, 37132, 19, 8655, 10625, 69860, 21149, 2, 1804, 69418, 191, 43, 117056, 18, 44596, 448, 37132, 19, 8655, 10625, 69860, 21149, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 9755, 1944, 11, 1053, 18, 16899, 12730, 1072, 1506, 45, 2497, 2510, 5, 610, 9, 127, 699, 1072, 2101, 36, 99388, 53, 2930, 4, 2, 1804, 69418, 191, 43, 117056, 18, 44596, 448, 37132, 19, 8655, 10625, 69860, 21149, 2], + [1, 84, 32, 778, 42, 9441, 10, 94, 735, 3372, 1804, 69418, 191, 2, 1804, 69418, 191, 43, 117056, 18, 44596, 448, 37132, 19, 8655, 10625, 69860, 21149, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + 'token_type_ids': [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + 'attention_mask': [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ] + } + + expected_decoded_sequences = [ + 'DeBERTa: Decoding-enhanced BERT with Disentangled Attention DeBERTa: Decoding-enhanced BERT with Disentangled Attention', + 'Recent progress in pre-trained neural language models has significantly improved the performance of many natural language processing (NLP) tasks. DeBERTa: Decoding-enhanced BERT with Disentangled Attention', + 'In this paper we propose a new model architecture DeBERTa DeBERTa: Decoding-enhanced BERT with Disentangled Attention' + ] + # fmt: on + + self.assertDictEqual(encoding.data, expected_encoding) + + for expected, decoded in zip(expected_decoded_sequences, decoded_sequences): + self.assertEqual(expected, decoded) From c4f57cae69195c5170e1f46dd3cc84ed0b9067a9 Mon Sep 17 00:00:00 2001 From: Pengcheng He <38195654+BigBird01@users.noreply.github.com> Date: Mon, 15 Feb 2021 02:13:04 -0800 Subject: [PATCH 3/8] Fix v2 model loading issue (#10129) --- .../models/auto/configuration_auto.py | 5 ++-- .../models/deberta/modeling_deberta.py | 16 ++++++++++++ .../deberta_v2/configuration_deberta_v2.py | 10 +++---- .../models/deberta_v2/modeling_deberta_v2.py | 26 +++++++++++++++---- .../deberta_v2/tokenization_deberta_v2.py | 24 ++++++++--------- tests/test_modeling_deberta_v2.py | 2 +- 6 files changed, 57 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index afb02fc36aaa0d..a5dcff3b3aab7c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -140,7 +140,7 @@ ("reformer", ReformerConfig), ("longformer", LongformerConfig), ("roberta", RobertaConfig), - ("deberta_v2", DebertaV2Config), + ("deberta-v2", DebertaV2Config), ("deberta", DebertaConfig), ("flaubert", FlaubertConfig), ("fsmt", FSMTConfig), @@ -202,8 +202,8 @@ ("encoder-decoder", "Encoder decoder"), ("funnel", "Funnel Transformer"), ("lxmert", "LXMERT"), + ("deberta-v2", "DeBERTa-v2"), ("deberta", "DeBERTa"), - ("deberta_v2", "DeBERTa-v2"), ("layoutlm", "LayoutLM"), ("dpr", "DPR"), ("rag", "RAG"), @@ -370,7 +370,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): {'foo': False} """ config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) - if "model_type" in config_dict: config_class = CONFIG_MAPPING[config_dict["model_type"]] return config_class.from_dict(config_dict, **kwargs) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 60b9546379a2f1..c3f48b4d60a006 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -762,6 +762,10 @@ class DebertaPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_missing = ["position_ids"] _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + def __init__(self, config): + super().__init__(config) + self._register_load_state_dict_pre_hook(self._pre_load_hook) + def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): @@ -771,6 +775,18 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + self_state = self.state_dict() + if ( + ("classifier.weight" in self_state) + and ("classifier.weight" in state_dict) + and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() + ): + logger.warning("Ignore mismatched classifer head.") + del state_dict["classifier.weight"] + if "classifier.bias" in state_dict: + del state_dict["classifier.bias"] + DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention diff --git a/src/transformers/models/deberta_v2/configuration_deberta_v2.py b/src/transformers/models/deberta_v2/configuration_deberta_v2.py index 128a8701352266..87f439e5b50aae 100644 --- a/src/transformers/models/deberta_v2/configuration_deberta_v2.py +++ b/src/transformers/models/deberta_v2/configuration_deberta_v2.py @@ -21,10 +21,10 @@ logger = logging.get_logger(__name__) DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/config.json", - "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/config.json", - "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/config.json", - "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/config.json", + "microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/config.json", + "microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/config.json", + "microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/config.json", + "microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/config.json", } @@ -83,7 +83,7 @@ class DebertaV2Config(PretrainedConfig): layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. """ - model_type = "deberta" + model_type = "deberta-v2" def __init__( self, diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 834b542ba505cf..93709dd0dc0206 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -40,13 +40,13 @@ _CONFIG_FOR_DOC = "DebertaV2Config" _TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" -_CHECKPOINT_FOR_DOC = "microsoft/deberta-xlarge-v2" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge" DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "microsoft/deberta-xlarge-v2", - "microsoft/deberta-xxlarge-v2", - "microsoft/deberta-xlarge-v2-mnli", - "microsoft/deberta-xxlarge-v2-mnli", + "microsoft/deberta-v2-xlarge", + "microsoft/deberta-v2-xxlarge", + "microsoft/deberta-v2-xlarge-mnli", + "microsoft/deberta-v2-xxlarge-mnli", ] @@ -901,6 +901,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_missing = ["position_ids"] _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + def __init__(self, config): + super().__init__(config) + self._register_load_state_dict_pre_hook(self._pre_load_hook) + def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): @@ -910,6 +914,18 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + self_state = self.state_dict() + if ( + ("classifier.weight" in self_state) + and ("classifier.weight" in state_dict) + and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() + ): + logger.warning("Ignore mismatched classifer head.") + del state_dict["classifier.weight"] + if "classifier.bias" in state_dict: + del state_dict["classifier.bias"] + DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention diff --git a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py index 97b9c04150eea9..564705fe5264f2 100644 --- a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py +++ b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py @@ -26,25 +26,25 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/spm.model", - "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/spm.model", - "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/spm.model", - "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/spm.model", + "microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/spm.model", + "microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model", + "microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model", + "microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "microsoft/deberta-xlarge-v2": 512, - "microsoft/deberta-xxlarge-v2": 512, - "microsoft/deberta-xlarge-v2-mnli": 512, - "microsoft/deberta-xxlarge-v2-mnli": 512, + "microsoft/deberta-v2-xlarge": 512, + "microsoft/deberta-v2-xxlarge": 512, + "microsoft/deberta-v2-xlarge-mnli": 512, + "microsoft/deberta-v2-xxlarge-mnli": 512, } PRETRAINED_INIT_CONFIGURATION = { - "microsoft/deberta-xlarge-v2": {"do_lower_case": False}, - "microsoft/deberta-xxlarge-v2": {"do_lower_case": False}, - "microsoft/deberta-xlarge-v2-mnli": {"do_lower_case": False}, - "microsoft/deberta-xxlarge-v2-mnli": {"do_lower_case": False}, + "microsoft/deberta-v2-xlarge": {"do_lower_case": False}, + "microsoft/deberta-v2-xxlarge": {"do_lower_case": False}, + "microsoft/deberta-v2-xlarge-mnli": {"do_lower_case": False}, + "microsoft/deberta-v2-xxlarge-mnli": {"do_lower_case": False}, } VOCAB_FILES_NAMES = {"vocab_file": "spm.model"} diff --git a/tests/test_modeling_deberta_v2.py b/tests/test_modeling_deberta_v2.py index a67bc6b88096f3..1f183aa6ec3f1b 100644 --- a/tests/test_modeling_deberta_v2.py +++ b/tests/test_modeling_deberta_v2.py @@ -279,7 +279,7 @@ def test_inference_no_head(self): np.random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed_all(0) - model = DebertaV2Model.from_pretrained("microsoft/deberta-xlarge-v2") + model = DebertaV2Model.from_pretrained("microsoft/deberta-v2-xlarge") input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) output = model(input_ids)[0] From 6d15b921ba25755b2ecdf8beac21c17efe1e4a67 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 19 Feb 2021 16:33:53 -0500 Subject: [PATCH 4/8] Doc members --- docs/source/model_doc/deberta.rst | 10 +++++----- docs/source/model_doc/deberta_v2.rst | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/model_doc/deberta.rst b/docs/source/model_doc/deberta.rst index fac01ce7edb7e0..027e5f9165ad9c 100644 --- a/docs/source/model_doc/deberta.rst +++ b/docs/source/model_doc/deberta.rst @@ -60,7 +60,7 @@ DebertaModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaModel - :members: + :members: forward DebertaPreTrainedModel @@ -74,25 +74,25 @@ DebertaForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaForMaskedLM - :members: + :members: forward DebertaForSequenceClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaForSequenceClassification - :members: + :members: forward DebertaForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaForTokenClassification - :members: + :members: forward DebertaForQuestionAnswering ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaForQuestionAnswering - :members: + :members: forward diff --git a/docs/source/model_doc/deberta_v2.rst b/docs/source/model_doc/deberta_v2.rst index 06d57aa4e8439a..45eadb4d4d7a6b 100644 --- a/docs/source/model_doc/deberta_v2.rst +++ b/docs/source/model_doc/deberta_v2.rst @@ -80,39 +80,39 @@ DebertaV2Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaV2Model - :members: + :members: forward DebertaV2PreTrainedModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaV2PreTrainedModel - :members: + :members: forward DebertaV2ForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaV2ForMaskedLM - :members: + :members: forward DebertaV2ForSequenceClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaV2ForSequenceClassification - :members: + :members: forward DebertaV2ForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaV2ForTokenClassification - :members: + :members: forward DebertaV2ForQuestionAnswering ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DebertaV2ForQuestionAnswering - :members: + :members: forward From dd5a8126a1149aae86cfa14fa8dfc4f1aff578ac Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Fri, 19 Feb 2021 22:34:21 +0100 Subject: [PATCH 5/8] Update src/transformers/models/deberta/modeling_deberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/deberta/modeling_deberta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index c3f48b4d60a006..4d6a9bd4d105f2 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -782,7 +782,7 @@ def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_key and ("classifier.weight" in state_dict) and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() ): - logger.warning("Ignore mismatched classifer head.") + logger.warning(f"The checkpoint classifier head has a shape {state_dict["classifier.weight"].size()} and this model classifer head has a shape {self_state["classifier.weight"].size()}. Ignoring the checkpoint weights. You should train your model on new data.") del state_dict["classifier.weight"] if "classifier.bias" in state_dict: del state_dict["classifier.bias"] From 0a11c2c533440a824b645116333d6bb2936b2798 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 19 Feb 2021 16:46:52 -0500 Subject: [PATCH 6/8] Address Sylvain's comments --- .../deberta_v2/configuration_deberta_v2.py | 32 ++++++++-------- .../models/deberta_v2/modeling_deberta_v2.py | 7 ++-- tests/test_tokenization_deberta_v2.py | 38 ++----------------- 3 files changed, 24 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/deberta_v2/configuration_deberta_v2.py b/src/transformers/models/deberta_v2/configuration_deberta_v2.py index 87f439e5b50aae..9870979fb8401a 100644 --- a/src/transformers/models/deberta_v2/configuration_deberta_v2.py +++ b/src/transformers/models/deberta_v2/configuration_deberta_v2.py @@ -33,22 +33,22 @@ class DebertaV2Config(PretrainedConfig): This is the configuration class to store the configuration of a :class:`~transformers.DebertaV2Model`. It is used to instantiate a DeBERTa-v2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeBERTa - `microsoft/deberta-base `__ architecture. + `microsoft/deberta-v2-xlarge `__ architecture. Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. Arguments: - vocab_size (:obj:`int`, `optional`, defaults to 30522): + vocab_size (:obj:`int`, `optional`, defaults to 128100): Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by the :obj:`inputs_ids` passed when calling :class:`~transformers.DebertaV2Model`. - hidden_size (:obj:`int`, `optional`, defaults to 768): + hidden_size (:obj:`int`, `optional`, defaults to 1536): Dimensionality of the encoder layers and the pooler layer. - num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + num_hidden_layers (:obj:`int`, `optional`, defaults to 24): Number of hidden layers in the Transformer encoder. - num_attention_heads (:obj:`int`, `optional`, defaults to 12): + num_attention_heads (:obj:`int`, `optional`, defaults to 24): Number of attention heads for each attention layer in the Transformer encoder. - intermediate_size (:obj:`int`, `optional`, defaults to 3072): + intermediate_size (:obj:`int`, `optional`, defaults to 6144): Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, @@ -61,21 +61,21 @@ class DebertaV2Config(PretrainedConfig): max_position_embeddings (:obj:`int`, `optional`, defaults to 512): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). - type_vocab_size (:obj:`int`, `optional`, defaults to 2): + type_vocab_size (:obj:`int`, `optional`, defaults to 0): The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.DebertaModel` or :class:`~transformers.TFDebertaModel`. initializer_range (:obj:`float`, `optional`, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-7): The epsilon used by the layer normalization layers. - relative_attention (:obj:`bool`, `optional`, defaults to :obj:`False`): + relative_attention (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether use relative position encoding. - max_relative_positions (:obj:`int`, `optional`, defaults to 1): + max_relative_positions (:obj:`int`, `optional`, defaults to -1): The range of relative positions :obj:`[-max_position_embeddings, max_position_embeddings]`. Use the same value as :obj:`max_position_embeddings`. pad_token_id (:obj:`int`, `optional`, defaults to 0): The value used to pad input_ids. - position_biased_input (:obj:`bool`, `optional`, defaults to :obj:`True`): + position_biased_input (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether add absolute position embedding to content embedding. pos_att_type (:obj:`List[str]`, `optional`): The type of relative position attention, it can be a combination of :obj:`["p2c", "c2p", "p2p"]`, e.g. @@ -87,11 +87,11 @@ class DebertaV2Config(PretrainedConfig): def __init__( self, - vocab_size=50265, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, + vocab_size=128100, + hidden_size=1536, + num_hidden_layers=24, + num_attention_heads=24, + intermediate_size=6144, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 93709dd0dc0206..f7ab8deecb82fb 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -223,9 +223,10 @@ def MaskedLayerNorm(layerNorm, input, mask=None): module. Args: - layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function - input (:obj:`torch.tensor`): The input tensor - mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` + layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function + input (:obj:`torch.tensor`): The input tensor + mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicates the + output of that element will be ignored, i.e. set to `0` Example:: diff --git a/tests/test_tokenization_deberta_v2.py b/tests/test_tokenization_deberta_v2.py index ee0ef37a6228ec..2fdf74d003c49e 100644 --- a/tests/test_tokenization_deberta_v2.py +++ b/tests/test_tokenization_deberta_v2.py @@ -77,25 +77,10 @@ def test_full_tokenizer(self): self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [13, 1, 4398, 25, 21, 1289]) tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + # fmt: off self.assertListEqual( tokens, - [ - "▁", - "[UNK]", - "▁was", - "▁born", - "▁in", - "▁9", - "2000", - ",", - "▁and", - "▁this", - "▁is", - "▁fal", - "s", - "[UNK]", - ".", - ], + ["▁", "[UNK]", "▁was", "▁born", "▁in", "▁9", "2000", ",", "▁and", "▁this", "▁is", "▁fal", "s", "[UNK]", "."], ) ids = tokenizer.convert_tokens_to_ids(tokens) self.assertListEqual(ids, [13, 1, 23, 386, 19, 561, 3050, 15, 17, 48, 25, 8256, 18, 1, 9]) @@ -103,24 +88,9 @@ def test_full_tokenizer(self): back_tokens = tokenizer.convert_ids_to_tokens(ids) self.assertListEqual( back_tokens, - [ - "▁", - "", - "▁was", - "▁born", - "▁in", - "▁9", - "2000", - ",", - "▁and", - "▁this", - "▁is", - "▁fal", - "s", - "", - ".", - ], + ["▁", "", "▁was", "▁born", "▁in", "▁9", "2000", ",", "▁and", "▁this", "▁is", "▁fal", "s", "", "."], ) + # fmt: on def test_sequence_builders(self): tokenizer = DebertaV2Tokenizer(SAMPLE_VOCAB) From 8fae0216e65309dbd632b6b97118282f5b593707 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 19 Feb 2021 18:00:03 -0500 Subject: [PATCH 7/8] Address Patrick's comments Co-authored-by: Patrick von Platen --- .../models/deberta/modeling_deberta.py | 17 +++-- .../models/deberta_v2/modeling_deberta_v2.py | 76 ++++++++----------- .../deberta_v2/tokenization_deberta_v2.py | 2 +- 3 files changed, 44 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 4d6a9bd4d105f2..0f90ffa7b02fac 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -360,9 +360,7 @@ def __init__(self, config): self.max_relative_positions = config.max_position_embeddings self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) - def get_rel_embedding(self): - rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None - return rel_embeddings + self.rel_embeddings_weights = self.rel_embeddings.weight if self.relative_attention else None def get_attention_mask(self, attention_mask): if attention_mask.dim() <= 2: @@ -400,7 +398,7 @@ def forward( next_kv = hidden_states[0] else: next_kv = hidden_states - rel_embeddings = self.get_rel_embedding() + rel_embeddings = self.rel_embeddings_weights for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -776,13 +774,20 @@ def _init_weights(self, module): module.bias.data.zero_() def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + """ + Removes the classifier if it doesn't have the correct number of labels. + """ self_state = self.state_dict() if ( ("classifier.weight" in self_state) and ("classifier.weight" in state_dict) and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() ): - logger.warning(f"The checkpoint classifier head has a shape {state_dict["classifier.weight"].size()} and this model classifer head has a shape {self_state["classifier.weight"].size()}. Ignoring the checkpoint weights. You should train your model on new data.") + logger.warning( + f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model " + f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint " + f"weights. You should train your model on new data." + ) del state_dict["classifier.weight"] if "classifier.bias" in state_dict: del state_dict["classifier.bias"] @@ -939,7 +944,7 @@ def forward( hidden_states = encoded_layers[-2] layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] query_states = encoded_layers[-1] - rel_embeddings = self.encoder.get_rel_embedding() + rel_embeddings = self.encoder.rel_embeddings_weights attention_mask = self.encoder.get_attention_mask(attention_mask) rel_pos = self.encoder.get_rel_pos(embedding_output) for layer in layers[1:]: diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index f7ab8deecb82fb..a24b359ea3f709 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -217,37 +217,6 @@ def get_context(self): return self.drop_prob -def MaskedLayerNorm(layerNorm, input, mask=None): - """ - Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updates to the LayerNorm - module. - - Args: - layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function - input (:obj:`torch.tensor`): The input tensor - mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicates the - output of that element will be ignored, i.e. set to `0` - - Example:: - - # Create a tensor b x n x d - x = torch.randn([1,10,100]) - m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) - LayerNorm = DeBERTa.deberta.LayerNorm(100) - y = MaskedLayerNorm(LayerNorm, x, m) - - """ - output = layerNorm(input).to(input) - if mask is None: - return output - if mask.dim() != input.dim(): - if mask.dim() == 4: - mask = mask.squeeze(1).squeeze(1) - mask = mask.unsqueeze(2) - mask = mask.to(output.dtype) - return output * mask - - # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm class DebertaV2SelfOutput(nn.Module): def __init__(self, config): @@ -385,7 +354,20 @@ def forward(self, hidden_states, residual_states, input_mask): rmask = (1 - input_mask).bool() out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) out = ACT2FN[self.conv_act](self.dropout(out)) - output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask return output_states @@ -417,11 +399,10 @@ def __init__(self, config): if "layer_norm" in self.norm_rel_ebd: self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) - kernel_size = getattr(config, "conv_kernel_size", 0) - self.with_conv = False - if kernel_size > 0: - self.with_conv = True - self.conv = ConvLayer(config) + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + self.rel_embeddings_weights = self.rel_embeddings.weight if self.relative_attention else None + if self.rel_embeddings_weights is not None and ("layer_norm" in self.norm_rel_ebd): + self.rel_embeddings_weights = nn.Parameter(self.LayerNorm(self.rel_embeddings_weights)) def get_rel_embedding(self): rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None @@ -488,7 +469,7 @@ def forward( if output_attentions: output_states, att_m = output_states - if i == 0 and self.with_conv: + if i == 0 and self.conv is not None: output_states = self.conv(hidden_states, output_states, input_mask) if query_states is not None: @@ -586,7 +567,7 @@ def __init__(self, config): "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads - _attention_head_size = int(config.hidden_size / config.num_attention_heads) + _attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) @@ -689,9 +670,9 @@ def forward( -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) ) - # bxhxlxd - _attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) - attention_probs = self.dropout(_attention_probs) + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) context_layer = torch.bmm( attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer ) @@ -717,7 +698,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: relative_pos = relative_pos.unsqueeze(1) - # bxhxqxk + # bsz x height x query x key elif relative_pos.dim() != 4: raise ValueError(f"Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") @@ -916,13 +897,20 @@ def _init_weights(self, module): module.bias.data.zero_() def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + """ + Removes the classifier if it doesn't have the correct number of labels. + """ self_state = self.state_dict() if ( ("classifier.weight" in self_state) and ("classifier.weight" in state_dict) and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() ): - logger.warning("Ignore mismatched classifer head.") + logger.warning( + f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model " + f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint " + f"weights. You should train your model on new data." + ) del state_dict["classifier.weight"] if "classifier.bias" in state_dict: del state_dict["classifier.bias"] diff --git a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py index 564705fe5264f2..c7edc10111ac53 100644 --- a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py +++ b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py @@ -108,7 +108,7 @@ def __init__( if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " - "model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) + "model use `tokenizer = DebertaV2Tokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) ) self.do_lower_case = do_lower_case self.split_by_punct = split_by_punct From c9c557ea570475064fb5305209bf270ae980689c Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 19 Feb 2021 18:02:13 -0500 Subject: [PATCH 8/8] Style --- src/transformers/models/deberta/modeling_deberta.py | 8 +++++--- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 3 --- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 0f90ffa7b02fac..d83a5959467c00 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -360,7 +360,9 @@ def __init__(self, config): self.max_relative_positions = config.max_position_embeddings self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) - self.rel_embeddings_weights = self.rel_embeddings.weight if self.relative_attention else None + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + return rel_embeddings def get_attention_mask(self, attention_mask): if attention_mask.dim() <= 2: @@ -398,7 +400,7 @@ def forward( next_kv = hidden_states[0] else: next_kv = hidden_states - rel_embeddings = self.rel_embeddings_weights + rel_embeddings = self.get_rel_embedding() for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -944,7 +946,7 @@ def forward( hidden_states = encoded_layers[-2] layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] query_states = encoded_layers[-1] - rel_embeddings = self.encoder.rel_embeddings_weights + rel_embeddings = self.encoder.get_rel_embedding() attention_mask = self.encoder.get_attention_mask(attention_mask) rel_pos = self.encoder.get_rel_pos(embedding_output) for layer in layers[1:]: diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index a24b359ea3f709..29f495b4816cb3 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -400,9 +400,6 @@ def __init__(self, config): self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None - self.rel_embeddings_weights = self.rel_embeddings.weight if self.relative_attention else None - if self.rel_embeddings_weights is not None and ("layer_norm" in self.norm_rel_ebd): - self.rel_embeddings_weights = nn.Parameter(self.LayerNorm(self.rel_embeddings_weights)) def get_rel_embedding(self): rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None