From 26c07e15436e1cad08ea3e8abd3dc24ba3198bd3 Mon Sep 17 00:00:00 2001 From: Pengcheng He Date: Thu, 4 Feb 2021 18:45:26 -0500 Subject: [PATCH] Integrate DeBERTa v2(the 1.5B model surpassed human performance on SuperGLUE); Add DeBERTa v2 900M,1.5B models; --- src/transformers/models/deberta/__init__.py | 1 + .../models/deberta/configuration_deberta.py | 8 + .../models/deberta/gpt2_tokenizer.py | 380 ++++++++++++++ .../models/deberta/modeling_deberta.py | 377 ++++++++++---- .../models/deberta/spm_tokenizer.py | 277 ++++++++++ .../models/deberta/tokenization_deberta.py | 493 ++---------------- 6 files changed, 976 insertions(+), 560 deletions(-) create mode 100644 src/transformers/models/deberta/gpt2_tokenizer.py create mode 100644 src/transformers/models/deberta/spm_tokenizer.py diff --git a/src/transformers/models/deberta/__init__.py b/src/transformers/models/deberta/__init__.py index 2a489b124033cd..7d1ebe77510ce0 100644 --- a/src/transformers/models/deberta/__init__.py +++ b/src/transformers/models/deberta/__init__.py @@ -24,6 +24,7 @@ _import_structure = { "configuration_deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig"], "tokenization_deberta": ["DebertaTokenizer"], + "tokenization_debertav2": ["DebertaTokenizerV2"], } if is_torch_available(): diff --git a/src/transformers/models/deberta/configuration_deberta.py b/src/transformers/models/deberta/configuration_deberta.py index 25dd39cade87d4..e2f92abe1a54ad 100644 --- a/src/transformers/models/deberta/configuration_deberta.py +++ b/src/transformers/models/deberta/configuration_deberta.py @@ -23,6 +23,14 @@ DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { "microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/config.json", "microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/config.json", + "microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/config.json", + "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/config.json", + "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/config.json", + "microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/config.json", + "microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/config.json", + "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/config.json", + "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/config.json", + "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/config.json", } diff --git a/src/transformers/models/deberta/gpt2_tokenizer.py b/src/transformers/models/deberta/gpt2_tokenizer.py new file mode 100644 index 00000000000000..88e12fc07e2ec7 --- /dev/null +++ b/src/transformers/models/deberta/gpt2_tokenizer.py @@ -0,0 +1,380 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPT2 Tokenization class for model DeBERTa.""" + +import os +import unicodedata +from functools import lru_cache + + +try: + import regex as re +except ImportError: + raise ImportError("Please install regex with: pip install regex") + +___all__ = ["GPT2Tokenizer"] + +VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode + strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're + at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant + percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode + strings. And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def split_to_words(self, text): + return list(re.findall(self.pat, text)) + + def encode(self, text): + bpe_tokens = [] + for token in self.split_to_words(text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + +def get_encoder(encoder, vocab): + return Encoder( + encoder=encoder, + bpe_merges=vocab, + ) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +class GPT2Tokenizer(object): + """ + A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer + + Args: + vocab_file (:obj:`str`, optional): + The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases + `_, e.g. "bpe_encoder", default: `None`. + + If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file + is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used + in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2 + tokenizer and RoBERTa wrapped tokenizer are, + + - Special tokens, unlike `RoBERTa` which use ``, `` as the `start` token and `end` token of a + sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same + as `BERT`. + + - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, + `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 + + special_tokens (:obj:`list`, optional): + List of special tokens to be added to the end of the vocabulary. + """ + + def __init__(self, vocab_file=None, special_tokens=None, **kwargs): + import torch + + self.pad_token = "[PAD]" + self.sep_token = "[SEP]" + self.unk_token = "[UNK]" + self.cls_token = "[CLS]" + + self.symbols = [] + self.count = [] + self.indices = {} + self.pad_token_id = self.add_symbol(self.pad_token) + self.cls_token_id = self.add_symbol(self.cls_token) + self.sep_token_id = self.add_symbol(self.sep_token) + self.unk_token_id = self.add_symbol(self.unk_token) + + self.gpt2_encoder = torch.load(vocab_file) + self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"]) + for w, n in self.gpt2_encoder["dict_map"]: + self.add_symbol(w, n) + + self.mask_token = "[MASK]" + self.mask_id = self.add_symbol(self.mask_token) + self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] + if special_tokens is not None: + for t in special_tokens: + self.add_special_token(t) + + self.vocab = self.indices + self.ids_to_tokens = self.symbols + + def tokenize(self, text): + """ + Convert an input text to tokens. + + Args: + text (:obj:`str`): input text to be tokenized. + + Returns: + A list of byte tokens where each token represent the byte id in GPT2 byte dictionary + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + """ + bpe = self._encode(text) + + return [t for t in bpe.split(" ") if t] + + def convert_tokens_to_ids(self, tokens): + """ + Convert list of tokens to ids + + Args: + tokens (:obj:`list`): list of tokens + + Returns: + List of ids + """ + + return [self.vocab[t] for t in tokens] + + def convert_ids_to_tokens(self, ids): + """ + Convert list of ids to tokens + + Args: + ids (:obj:`list`): list of ids + + Returns: + List of tokens + """ + + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def split_to_words(self, text): + return self.bpe.split_to_words(text) + + def decode(self, tokens): + """ + Decode list of tokens to text strings + + Args: + tokens (:obj:`list`): list of tokens. + + Returns: + Text string corresponds to the input tokens. + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + >>> tokenizer.decode(tokens) + 'Hello world!' + """ + return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) + + def add_special_token(self, token): + """ + Adds a special token to the dictionary + + Args: + token (:obj:`str`): Tthe new token/word to be added to the vocabulary. + + Returns: + The id of new token in the vocabulary. + + """ + self.special_tokens.append(token) + return self.add_symbol(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + s = self._decode(token) + if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])): + return False + + return not s.startswith(" ") + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] + + def _encode(self, x: str) -> str: + return " ".join(map(str, self.bpe.encode(x))) + + def _decode(self, x: str) -> str: + return self.bpe.decode(map(int, x.split())) + + def add_symbol(self, word, n=1): + """ + Adds a word to the dictionary + + Args: + word (:obj:`str`): Tthe new token/word to be added to the vocabulary. + n (int, optional): The frequency of the word. + + Returns: + The id of the new word. + + """ + if word in self.indices: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def save_pretrained(self, path: str, filename_prefix: str = None): + import torch + + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + "-" + filename + full_path = os.path.join(path, filename) + torch.save(self.gpt2_encoder, full_path) + return (full_path,) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 2d38b38297aab8..574d3e7062fc60 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -16,11 +16,14 @@ import math from collections.abc import Sequence +from functools import lru_cache +import numpy as np import torch from packaging import version from torch import _softmax_backward_data, nn from torch.nn import CrossEntropyLoss +from torch.nn import LayerNorm as DebertaLayerNorm from ...activations import ACT2FN from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward @@ -44,6 +47,14 @@ DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "microsoft/deberta-base", "microsoft/deberta-large", + "microsoft/deberta-xlarge", + "microsoft/deberta-xlarge-v2", + "microsoft/deberta-xxlarge-v2", + "microsoft/deberta-base-mnli", + "microsoft/deberta-large-mnli", + "microsoft/deberta-xlarge-mnli", + "microsoft/deberta-xlarge-v2-mnli", + "microsoft/deberta-xxlarge-v2-mnli", ] @@ -214,24 +225,34 @@ def get_context(self): return self.drop_prob -class DebertaLayerNorm(nn.Module): - """LayerNorm module in the TF style (epsilon inside the square root).""" +def MaskedLayerNorm(layerNorm, input, mask=None): + """ + Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm + module. - def __init__(self, size, eps=1e-12): - super().__init__() - self.weight = nn.Parameter(torch.ones(size)) - self.bias = nn.Parameter(torch.zeros(size)) - self.variance_epsilon = eps + Args: + layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function + input (:obj:`torch.tensor`): The input tensor + mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` - def forward(self, hidden_states): - input_type = hidden_states.dtype - hidden_states = hidden_states.float() - mean = hidden_states.mean(-1, keepdim=True) - variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon) - hidden_states = hidden_states.to(input_type) - y = self.weight * hidden_states + self.bias - return y + Example:: + + # Create a tensor b x n x d + x = torch.randn([1,10,100]) + m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) + LayerNorm = DeBERTa.deberta.LayerNorm(100) + y = MaskedLayerNorm(LayerNorm, x, m) + + """ + output = layerNorm(input).to(input) + if mask is None: + return output + if mask.dim() != input.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(output.dtype) + return output * mask class DebertaSelfOutput(nn.Module): @@ -349,6 +370,29 @@ def forward( return layer_output +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = torch.nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask) + + return output_states + + class DebertaEncoder(nn.Module): """Modified BertEncoder with relative position bias support""" @@ -360,10 +404,25 @@ def __init__(self, config): self.max_relative_positions = getattr(config, "max_relative_positions", -1) if self.max_relative_positions < 1: self.max_relative_positions = config.max_position_embeddings - self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + kernel_size = getattr(config, "conv_kernel_size", 0) + self.with_conv = False + if kernel_size > 0: + self.with_conv = True + self.conv = ConvLayer(config) def get_rel_embedding(self): rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) return rel_embeddings def get_attention_mask(self, attention_mask): @@ -379,7 +438,9 @@ def get_attention_mask(self, attention_mask): def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): if self.relative_attention and relative_pos is None: q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) - relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) return relative_pos def forward( @@ -392,6 +453,10 @@ def forward( relative_pos=None, return_dict=True, ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() attention_mask = self.get_attention_mask(attention_mask) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) @@ -408,7 +473,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - hidden_states = layer_module( + output_states = layer_module( next_kv, attention_mask, output_attentions, @@ -417,29 +482,42 @@ def forward( rel_embeddings=rel_embeddings, ) if output_attentions: - hidden_states, att_m = hidden_states + output_states, att_m = output_states + + if i == 0 and self.with_conv: + output_states = self.conv(hidden_states, output_states, input_mask) if query_states is not None: - query_states = hidden_states + query_states = output_states if isinstance(hidden_states, Sequence): next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None else: - next_kv = hidden_states + next_kv = output_states if output_attentions: all_attentions = all_attentions + (att_m,) if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (output_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions ) -def build_relative_position(query_size, key_size, device): +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = np.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) + log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid + bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + return bucket_pos + + +@lru_cache(maxsize=128) +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): """ Build relative position according to the query and key @@ -450,15 +528,19 @@ def build_relative_position(query_size, key_size, device): Args: query_size (int): the length of query key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maxium allowed absolute positoin Return: :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] """ - - q_ids = torch.arange(query_size, dtype=torch.long, device=device) - k_ids = torch.arange(key_size, dtype=torch.long, device=device) - rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1) + q_ids = np.arange(0, query_size) + k_ids = np.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids @@ -498,39 +580,41 @@ def __init__(self, config): "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + _attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) - self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) - self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) - self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type self.relative_attention = getattr(config, "relative_attention", False) - self.talking_head = getattr(config, "talking_head", False) - - if self.talking_head: - self.head_logits_proj = torch.nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) - self.head_weights_proj = torch.nn.Linear( - config.num_attention_heads, config.num_attention_heads, bias=False - ) if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) self.max_relative_positions = getattr(config, "max_relative_positions", -1) if self.max_relative_positions < 1: self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + self.pos_dropout = StableDropout(config.hidden_dropout_prob) - if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: - self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False) - if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size) + if not self.share_att_key: + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = StableDropout(config.attention_probs_dropout_prob) + self._register_load_state_dict_pre_hook(self._pre_load_hook) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) def forward( self, @@ -571,51 +655,46 @@ def forward( """ if query_states is None: - qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) - query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) - else: - - def linear(w, b, x): - if b is not None: - return torch.matmul(x, w.t()) + b.t() - else: - return torch.matmul(x, w.t()) # + b.t() - - ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) - qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] - qkvb = [None] * 3 - - q = linear(qkvw[0], qkvb[0], query_states) - k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] - query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] - - query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) - value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. - scale_factor = 1 + len(self.pos_att_type) + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + if "p2p" in self.pos_att_type: + scale_factor += 1 scale = math.sqrt(query_layer.size(-1) * scale_factor) - query_layer = query_layer / scale - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) - rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) if rel_att is not None: attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) # bxhxlxd - if self.talking_head: - attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) - attention_probs = self.dropout(attention_probs) - if self.talking_head: - attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + _attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(_attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(*new_context_layer_shape) if return_att: @@ -623,61 +702,147 @@ def linear(w, b, x): else: return context_layer - def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): if relative_pos is None: q = query_layer.size(-2) - relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: relative_pos = relative_pos.unsqueeze(1) # bxhxqxk elif relative_pos.dim() != 4: - raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + raise ValueError(f"Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") - att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions) + att_span = self.pos_ebd_size relative_pos = relative_pos.long().to(query_layer.device) - rel_embeddings = rel_embeddings[ - self.max_relative_positions - att_span : self.max_relative_positions + att_span, : - ].unsqueeze(0) - if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: - pos_key_layer = self.pos_proj(rel_embeddings) - pos_key_layer = self.transpose_for_scores(pos_key_layer) - if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - pos_query_layer = self.pos_q_proj(rel_embeddings) - pos_query_layer = self.transpose_for_scores(pos_query_layer) + rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze( + 0 + ) # .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + else: + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) score = 0 # content->position if "c2p" in self.pos_att_type: - c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) + scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) - c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)) - score += c2p_att + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale # position->content if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: - pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) - if query_layer.size(-2) != key_layer.size(-2): - r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) + scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) else: r_pos = relative_pos + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) if query_layer.size(-2) != key_layer.size(-2): pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) if "p2c" in self.pos_att_type: - p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.gather( - p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), ).transpose(-1, -2) if query_layer.size(-2) != key_layer.size(-2): - p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer)) - score += p2c_att + p2c_att = torch.gather( + p2c_att, + dim=-2, + index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))), + ) + score += p2c_att / scale + + # position->position + if "p2p" in self.pos_att_type: + pos_query = pos_query_layer[:, :, att_span:, :] + p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2)) + p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:]) + if query_layer.size(-2) != key_layer.size(-2): + p2p_att = torch.gather( + p2p_att, + dim=-2, + index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))), + ) + p2p_att = torch.gather( + p2p_att, + dim=-1, + index=c2p_pos.expand( + [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] + ), + ) + score += p2p_att return score + def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + self_state = self.state_dict() + if ((prefix + "query_proj.weight") not in state_dict) and ((prefix + "in_proj.weight") in state_dict): + v1_proj = state_dict[prefix + "in_proj.weight"] + v1_proj = v1_proj.unsqueeze(0).reshape(self.num_attention_heads, -1, v1_proj.size(-1)) + q, k, v = v1_proj.chunk(3, dim=1) + state_dict[prefix + "query_proj.weight"] = q.reshape(-1, v1_proj.size(-1)) + state_dict[prefix + "key_proj.weight"] = k.reshape(-1, v1_proj.size(-1)) + state_dict[prefix + "key_proj.bias"] = self_state["key_proj.bias"] + state_dict[prefix + "value_proj.weight"] = v.reshape(-1, v1_proj.size(-1)) + v1_query_bias = state_dict[prefix + "q_bias"] + state_dict[prefix + "query_proj.bias"] = v1_query_bias + v1_value_bias = state_dict[prefix + "v_bias"] + state_dict[prefix + "value_proj.bias"] = v1_value_bias + + v1_pos_key_proj = state_dict[prefix + "pos_proj.weight"] + state_dict[prefix + "pos_key_proj.weight"] = v1_pos_key_proj + v1_pos_query_proj = state_dict[prefix + "pos_q_proj.weight"] + state_dict[prefix + "pos_query_proj.weight"] = v1_pos_query_proj + v1_pos_query_proj_bias = state_dict[prefix + "pos_q_proj.bias"] + state_dict[prefix + "pos_query_proj.bias"] = v1_pos_query_proj_bias + state_dict[prefix + "pos_key_proj.bias"] = self_state["pos_key_proj.bias"] + + del state_dict[prefix + "in_proj.weight"] + del state_dict[prefix + "q_bias"] + del state_dict[prefix + "v_bias"] + del state_dict[prefix + "pos_proj.weight"] + del state_dict[prefix + "pos_q_proj.weight"] + del state_dict[prefix + "pos_q_proj.bias"] + class DebertaEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" diff --git a/src/transformers/models/deberta/spm_tokenizer.py b/src/transformers/models/deberta/spm_tokenizer.py new file mode 100644 index 00000000000000..d7b94fa7c40928 --- /dev/null +++ b/src/transformers/models/deberta/spm_tokenizer.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for model DeBERTa.""" + +import os +import unicodedata + +import sentencepiece as sp +import six + + +VOCAB_FILES_NAMES = {"vocab_file": "spm.model"} + +__all__ = ["SPMTokenizer"] + + +class SPMTokenizer: + def __init__( + self, vocab_file, do_lower_case=False, special_tokens=None, bpe_dropout=0, split_by_punct=False, **kwargs + ): + self.split_by_punct = split_by_punct + spm = sp.SentencePieceProcessor() + assert os.path.exists(vocab_file) + spm.load(vocab_file) + bpe_vocab_size = spm.GetPieceSize() + # Token map + # 0+1 + # 1+1 + # 2+1 + self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)} + self.id_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)] + # self.vocab['[PAD]'] = 0 + # self.vocab['[CLS]'] = 1 + # self.vocab['[SEP]'] = 2 + # self.vocab['[UNK]'] = 3 + + _special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] + self.special_tokens = [] + if special_tokens is not None: + _special_tokens.extend(special_tokens) + for t in _special_tokens: + self.add_special_token(t) + + self.spm = spm + + def tokenize(self, text): + pieces = self._encode_as_pieces(text) + + def _norm(x): + if x not in self.vocab or x == "": + return "[UNK]" + else: + return x + + pieces = [_norm(p) for p in pieces] + return pieces + + def convert_tokens_to_ids(self, tokens): + return [self.vocab[t] if t in self.vocab else 1 for t in tokens] + + def convert_ids_to_tokens(self, ids): + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def decode(self, tokens, start=-1, end=-1, raw_text=None): + if raw_text is None: + return self.spm.decode_pieces([t for t in tokens if t not in self.special_tokens]) + else: + words = self.split_to_words(raw_text) + word_tokens = [self.tokenize(w) for w in words] + token2words = [0] * len(tokens) + tid = 0 + for i, w in enumerate(word_tokens): + for k, t in enumerate(w): + token2words[tid] = i + tid += 1 + word_start = token2words[start] + word_end = token2words[end] if end < len(tokens) else len(words) + text = "".join(words[word_start:word_end]) + return text + + def add_special_token(self, token): + if token not in self.special_tokens: + self.special_tokens.append(token) + if token not in self.vocab: + self.vocab[token] = len(self.vocab) + self.id_to_tokens.append(token) + return self.id(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + if ( + len(token) == 1 + and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0])) + ) or token in self.special_tokens: + return False + + word_start = b"\xe2\x96\x81".decode("utf-8") + return not token.startswith(word_start) + + def pad(self): + return "[PAD]" + + def bos(self): + return "[CLS]" + + def eos(self): + return "[SEP]" + + def unk(self): + return "[UNK]" + + def mask(self): + return "[MASK]" + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] if sym in self.vocab else 1 + + def _encode_as_pieces(self, text): + text = convert_to_unicode(text) + if self.split_by_punct: + words = self._run_split_on_punc(text) + pieces = [self.spm.encode_as_pieces(w) for w in words] + return [p for w in pieces for p in w] + else: + return self.spm.encode_as_pieces(text) + + def split_to_words(self, text): + pieces = self._encode_as_pieces(text) + word_start = b"\xe2\x96\x81".decode("utf-8") + words = [] + offset = 0 + prev_end = 0 + for i, p in enumerate(pieces): + if p.startswith(word_start): + if offset > prev_end: + words.append(text[prev_end:offset]) + prev_end = offset + w = p.replace(word_start, "") + else: + w = p + try: + s = text.index(w, offset) + pn = "" + k = i + 1 + while k < len(pieces): + pn = pieces[k].replace(word_start, "") + if len(pn) > 0: + break + k += 1 + + if len(pn) > 0 and pn in text[offset:s]: + offset = offset + 1 + else: + offset = s + len(w) + except Exception: + offset = offset + 1 + + if prev_end < offset: + words.append(text[prev_end:offset]) + + return words + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def save_pretrained(self, path: str, filename_prefix: str = None): + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + "-" + filename + full_path = os.path.join(path, filename) + with open(full_path, "wb") as fs: + fs.write(self.spm.serialized_model_proto()) + return (full_path,) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") diff --git a/src/transformers/models/deberta/tokenization_deberta.py b/src/transformers/models/deberta/tokenization_deberta.py index 4edba5fd599944..22c5007e4f5231 100644 --- a/src/transformers/models/deberta/tokenization_deberta.py +++ b/src/transformers/models/deberta/tokenization_deberta.py @@ -15,476 +15,57 @@ """ Tokenization class for model DeBERTa.""" import os -import pathlib -import random -import unicodedata -from functools import lru_cache from typing import Optional, Tuple -from zipfile import ZipFile - -import tqdm - -import requests from ...tokenization_utils import PreTrainedTokenizer -from ...utils import logging - - -try: - import regex as re -except ImportError: - raise ImportError("Please install regex with: pip install regex") - +from .gpt2_tokenizer import GPT2Tokenizer +from .spm_tokenizer import SPMTokenizer -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/bpe_encoder.bin", "microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/bpe_encoder.bin", + "microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/bpe_encoder.bin", + "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/spm.model", + "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/spm.model", + "microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/bpe_encoder.bin", + "microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/bpe_encoder.bin", + "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/bpe_encoder.bin", + "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/spm.model", + "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/spm.model", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { "microsoft/deberta-base": 512, "microsoft/deberta-large": 512, + "microsoft/deberta-xlarge": 512, + "microsoft/deberta-xlarge-v2": 512, + "microsoft/deberta-xxlarge-v2": 512, + "microsoft/deberta-base-mnli": 512, + "microsoft/deberta-large-mnli": 512, + "microsoft/deberta-xlarge-mnli": 512, + "microsoft/deberta-xlarge-v2-mnli": 512, + "microsoft/deberta-xxlarge-v2-mnli": 512, } PRETRAINED_INIT_CONFIGURATION = { - "microsoft/deberta-base": {"do_lower_case": False}, - "microsoft/deberta-large": {"do_lower_case": False}, + "microsoft/deberta-base": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-large": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-xlarge": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-xlarge-v2": {"do_lower_case": False, "vocab_type": "spm"}, + "microsoft/deberta-xxlarge-v2": {"do_lower_case": False, "vocab_type": "spm"}, + "microsoft/deberta-base-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-large-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-xlarge-mnli": {"do_lower_case": False, "vocab_type": "gpt2"}, + "microsoft/deberta-xlarge-v2-mnli": {"do_lower_case": False, "vocab_type": "spm"}, + "microsoft/deberta-xxlarge-v2-mnli": {"do_lower_case": False, "vocab_type": "spm"}, } __all__ = ["DebertaTokenizer"] - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode - strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're - at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant - percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode - strings. And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2 ** 8): - if b not in bs: - bs.append(b) - cs.append(2 ** 8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """ - Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length - strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -class Encoder: - def __init__(self, encoder, bpe_merges, errors="replace"): - self.encoder = encoder - self.decoder = {v: k for k, v in self.encoder.items()} - self.errors = errors # how to handle errors in decoding - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges)))) - self.cache = {} - self.random = random.Random(0) - - # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions - self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token) - pairs = get_pairs(word) - - if not pairs: - return token - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except Exception: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = " ".join(word) - self.cache[token] = word - return word - - def split_to_words(self, text): - return list(re.findall(self.pat, text)) - - def encode(self, text): - bpe_tokens = [] - for token in self.split_to_words(text): - token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) - return bpe_tokens - - def decode(self, tokens): - text = "".join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) - return text - - -def get_encoder(encoder, vocab): - return Encoder( - encoder=encoder, - bpe_merges=vocab, - ) - - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat.startswith("C"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False - - -def download_asset(name, tag=None, no_cache=False, cache_dir=None): - _tag = tag - if _tag is None: - _tag = "latest" - if not cache_dir: - cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/") - os.makedirs(cache_dir, exist_ok=True) - output = os.path.join(cache_dir, name) - if os.path.exists(output) and (not no_cache): - return output - - repo = "https://api.github.com/repos/microsoft/DeBERTa/releases" - releases = requests.get(repo).json() - if tag and tag != "latest": - release = [r for r in releases if r["name"].lower() == tag.lower()] - if len(release) != 1: - raise Exception(f"{tag} can't be found in the repository.") - else: - release = releases[0] - asset = [s for s in release["assets"] if s["name"].lower() == name.lower()] - if len(asset) != 1: - raise Exception(f"{name} can't be found in the release.") - url = asset[0]["url"] - headers = {} - headers["Accept"] = "application/octet-stream" - resp = requests.get(url, stream=True, headers=headers) - if resp.status_code != 200: - raise Exception(f"Request for {url} return {resp.status_code}, {resp.text}") - try: - with open(output, "wb") as fs: - progress = tqdm( - total=int(resp.headers["Content-Length"]) if "Content-Length" in resp.headers else -1, - ncols=80, - desc=f"Downloading {name}", - ) - for c in resp.iter_content(chunk_size=1024 * 1024): - fs.write(c) - progress.update(len(c)) - progress.close() - except Exception: - os.remove(output) - raise - - return output - - -def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None): - import torch - - if name is None: - name = "bpe_encoder" - - model_path = name - if model_path and (not os.path.exists(model_path)) and not (("/" in model_path) or ("\\" in model_path)): - _tag = tag - if _tag is None: - _tag = "latest" - if not cache_dir: - cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/") - os.makedirs(cache_dir, exist_ok=True) - out_dir = os.path.join(cache_dir, name) - model_path = os.path.join(out_dir, "bpe_encoder.bin") - if (not os.path.exists(model_path)) or no_cache: - asset = download_asset(name + ".zip", tag=tag, no_cache=no_cache, cache_dir=cache_dir) - with ZipFile(asset, "r") as zipf: - for zip_info in zipf.infolist(): - if zip_info.filename[-1] == "/": - continue - zip_info.filename = os.path.basename(zip_info.filename) - zipf.extract(zip_info, out_dir) - elif not model_path: - return None, None - - encoder_state = torch.load(model_path) - return encoder_state - - -class GPT2Tokenizer(object): - """ - A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer - - Args: - vocab_file (:obj:`str`, optional): - The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases - `_, e.g. "bpe_encoder", default: `None`. - - If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file - is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used - in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2 - tokenizer and RoBERTa wrapped tokenizer are, - - - Special tokens, unlike `RoBERTa` which use ``, `` as the `start` token and `end` token of a - sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same - as `BERT`. - - - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, - `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 - - special_tokens (:obj:`list`, optional): - List of special tokens to be added to the end of the vocabulary. - """ - - def __init__(self, vocab_file=None, special_tokens=None): - self.pad_token = "[PAD]" - self.sep_token = "[SEP]" - self.unk_token = "[UNK]" - self.cls_token = "[CLS]" - - self.symbols = [] - self.count = [] - self.indices = {} - self.pad_token_id = self.add_symbol(self.pad_token) - self.cls_token_id = self.add_symbol(self.cls_token) - self.sep_token_id = self.add_symbol(self.sep_token) - self.unk_token_id = self.add_symbol(self.unk_token) - - self.gpt2_encoder = load_vocab(vocab_file) - self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"]) - for w, n in self.gpt2_encoder["dict_map"]: - self.add_symbol(w, n) - - self.mask_token = "[MASK]" - self.mask_id = self.add_symbol(self.mask_token) - self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] - if special_tokens is not None: - for t in special_tokens: - self.add_special_token(t) - - self.vocab = self.indices - self.ids_to_tokens = self.symbols - - def tokenize(self, text): - """ - Convert an input text to tokens. - - Args: - text (:obj:`str`): input text to be tokenized. - - Returns: - A list of byte tokens where each token represent the byte id in GPT2 byte dictionary - - Example:: - >>> tokenizer = GPT2Tokenizer() - >>> text = "Hello world!" - >>> tokens = tokenizer.tokenize(text) - >>> print(tokens) - ['15496', '995', '0'] - """ - bpe = self._encode(text) - - return [t for t in bpe.split(" ") if t] - - def convert_tokens_to_ids(self, tokens): - """ - Convert list of tokens to ids - - Args: - tokens (:obj:`list`): list of tokens - - Returns: - List of ids - """ - - return [self.vocab[t] for t in tokens] - - def convert_ids_to_tokens(self, ids): - """ - Convert list of ids to tokens - - Args: - ids (:obj:`list`): list of ids - - Returns: - List of tokens - """ - - tokens = [] - for i in ids: - tokens.append(self.ids_to_tokens[i]) - return tokens - - def split_to_words(self, text): - return self.bpe.split_to_words(text) - - def decode(self, tokens): - """ - Decode list of tokens to text strings - - Args: - tokens (:obj:`list`): list of tokens. - - Returns: - Text string corresponds to the input tokens. - - Example:: - >>> tokenizer = GPT2Tokenizer() - >>> text = "Hello world!" - >>> tokens = tokenizer.tokenize(text) - >>> print(tokens) - ['15496', '995', '0'] - >>> tokenizer.decode(tokens) - 'Hello world!' - """ - return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) - - def add_special_token(self, token): - """ - Adds a special token to the dictionary - - Args: - token (:obj:`str`): Tthe new token/word to be added to the vocabulary. - - Returns: - The id of new token in the vocabulary. - - """ - self.special_tokens.append(token) - return self.add_symbol(token) - - def part_of_whole_word(self, token, is_bos=False): - if is_bos: - return True - s = self._decode(token) - if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])): - return False - - return not s.startswith(" ") - - def sym(self, id): - return self.ids_to_tokens[id] - - def id(self, sym): - return self.vocab[sym] - - def _encode(self, x: str) -> str: - return " ".join(map(str, self.bpe.encode(x))) - - def _decode(self, x: str) -> str: - return self.bpe.decode(map(int, x.split())) - - def add_symbol(self, word, n=1): - """ - Adds a word to the dictionary - - Args: - word (:obj:`str`): Tthe new token/word to be added to the vocabulary. - n (int, optional): The frequency of the word. - - Returns: - The id of the new word. - - """ - if word in self.indices: - idx = self.indices[word] - self.count[idx] = self.count[idx] + n - return idx - else: - idx = len(self.symbols) - self.indices[word] = idx - self.symbols.append(word) - self.count.append(n) - return idx - - def save_pretrained(self, path: str, filename_prefix: str = None): - import torch - - filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] - if filename_prefix is not None: - filename = filename_prefix + "-" + filename - full_path = os.path.join(path, filename) - torch.save(self.gpt2_encoder, full_path) - return (full_path,) +VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} class DebertaTokenizer(PreTrainedTokenizer): @@ -522,6 +103,7 @@ def __init__( self, vocab_file, do_lower_case=False, + vocab_type="gpt2", unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", @@ -545,7 +127,10 @@ def __init__( "model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) ) self.do_lower_case = do_lower_case - self.gpt2_tokenizer = GPT2Tokenizer(vocab_file) + if vocab_type.lower() == "gpt2": + self._tokenizer = GPT2Tokenizer(vocab_file, **kwargs) + else: + self._tokenizer = SPMTokenizer(vocab_file, **kwargs) @property def vocab_size(self): @@ -553,7 +138,7 @@ def vocab_size(self): @property def vocab(self): - return self.gpt2_tokenizer.vocab + return self._tokenizer.vocab def get_vocab(self): vocab = self.vocab.copy() @@ -564,7 +149,7 @@ def _tokenize(self, text): """Take as input a string and return a list of strings (tokens) for words/sub-words""" if self.do_lower_case: text = text.lower() - return self.gpt2_tokenizer.tokenize(text) + return self._tokenizer.tokenize(text) def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ @@ -572,11 +157,11 @@ def _convert_token_to_id(self, token): def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" - return self.gpt2_tokenizer.sym(index) if index < self.vocab_size else self.unk_token + return self._tokenizer.sym(index) if index < self.vocab_size else self.unk_token def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ - return self.gpt2_tokenizer.decode(tokens) + return self._tokenizer.decode(tokens) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ @@ -671,4 +256,4 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): return (text, kwargs) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - return self.gpt2_tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix) + return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix)