From 87ca7024d16100ddc93d1804cfbf45e97960721b Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Thu, 16 Dec 2021 20:48:39 +0100 Subject: [PATCH 01/19] combine transformer embeddings --- flair/data.py | 89 +++-- flair/embeddings/base.py | 612 ++++++++++++++++++++++++++++++- flair/embeddings/document.py | 270 ++------------ flair/embeddings/token.py | 676 ++++------------------------------- 4 files changed, 759 insertions(+), 888 deletions(-) diff --git a/flair/data.py b/flair/data.py index 0196bb7637..a5f57fb515 100644 --- a/flair/data.py +++ b/flair/data.py @@ -3,6 +3,7 @@ import typing from abc import ABC, abstractmethod from collections import Counter, defaultdict +from functools import lru_cache from operator import itemgetter from pathlib import Path from typing import Callable, Dict, List, Optional, Union, cast @@ -289,10 +290,10 @@ def __len__(self): def __eq__(self, other): return ( - self.value == other.value - and self.score == other.score - and self.head.id_text == other.head.id_text - and self.tail.id_text == other.tail.id_text + self.value == other.value + and self.score == other.score + and self.head.id_text == other.head.id_text + and self.tail.id_text == other.tail.id_text ) @property @@ -402,12 +403,12 @@ class Token(DataPoint): """ def __init__( - self, - text: str, - idx: int = None, - head_id: int = None, - whitespace_after: bool = True, - start_position: int = None, + self, + text: str, + idx: int = None, + head_id: int = None, + whitespace_after: bool = True, + start_position: int = None, ): super().__init__() @@ -609,11 +610,11 @@ class Sentence(DataPoint): """ def __init__( - self, - text: Union[str, List[str]] = [], - use_tokenizer: Union[bool, Tokenizer, Callable] = True, - language_code: str = None, - start_position: int = None, + self, + text: Union[str, List[str]] = [], + use_tokenizer: Union[bool, Tokenizer, Callable] = True, + language_code: str = None, + start_position: int = None, ): """ Class to hold all meta related to a text (tokens, predictions, language code, ...) @@ -831,6 +832,42 @@ def clear_embeddings(self, embedding_names: List[str] = None): for token in self: token.clear_embeddings(embedding_names) + @lru_cache(maxsize=1) # cache last context, as training repeats calls + def left_context(self, context_length: int, respect_document_boundaries: bool = True): + sentence = self + left_context = [] + while True: + sentence = sentence.previous_sentence() + if sentence is None: + break + + if respect_document_boundaries and sentence.is_document_boundary: + break + + left_context = [t.text for t in sentence.tokens] + left_context + if len(left_context) > context_length: + left_context = left_context[-context_length:] + break + return left_context + + @lru_cache(maxsize=1) # cache last context, as training repeats calls + def right_context(self, context_length: int, respect_document_boundaries: bool = True): + sentence = self + right_context = [] + while True: + sentence = sentence.next_sentence() + if sentence is None: + break + if respect_document_boundaries and sentence.is_document_boundary: + break + + right_context += [t.text for t in sentence.tokens] + if len(right_context) > context_length: + right_context = right_context[: context_length] + break + return right_context + + def to_tagged_string(self, main_tag=None) -> str: list = [] for token in self.tokens: @@ -1140,12 +1177,12 @@ def is_in_memory(self) -> bool: class Corpus: def __init__( - self, - train: Dataset = None, - dev: Dataset = None, - test: Dataset = None, - name: str = "corpus", - sample_missing_splits: Union[bool, str] = True, + self, + train: Dataset = None, + dev: Dataset = None, + test: Dataset = None, + name: str = "corpus", + sample_missing_splits: Union[bool, str] = True, ): # set name self.name: str = name @@ -1184,11 +1221,11 @@ def test(self) -> Optional[Dataset]: return self._test def downsample( - self, - percentage: float = 0.1, - downsample_train=True, - downsample_dev=True, - downsample_test=True, + self, + percentage: float = 0.1, + downsample_train=True, + downsample_dev=True, + downsample_test=True, ): if downsample_train and self._train is not None: diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 7962bb87ed..8883087429 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -1,13 +1,29 @@ import inspect import logging +import os +import re +import tempfile +import zipfile from abc import abstractmethod -from typing import Dict, Generic, List, Sequence, Union +from io import BytesIO +import random +from typing import Dict, Generic, List, Sequence, Union, Optional import torch from torch.nn import Parameter, ParameterList +from transformers import ( + PreTrainedTokenizer, + AutoTokenizer, + PretrainedConfig, + AutoConfig, + AutoModel, + TransfoXLModel, + XLNetModel, + CONFIG_MAPPING, +) import flair -from flair.data import DT +from flair.data import DT, Sentence, Token log = logging.getLogger("flair") @@ -149,3 +165,595 @@ def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: for weight, tensor in zip(normed_weights, tensors): pieces.append(weight * tensor) return self.gamma * sum(pieces) + + +class TransformerEmbedding(Embeddings[Sentence]): + NO_MAX_SEQ_LENGTH_MODELS = (XLNetModel, TransfoXLModel) + + def __init__( + self, + model: str = "bert-base-uncased", + fine_tune: bool = True, + layers: str = "all", + layer_mean: bool = True, + subtoken_pooling: str = "first", + cls_pooling: str = "cls", + is_token_embedding: bool = True, + is_document_embedding: bool = True, + allow_long_sentences: bool = False, + use_context: Union[bool, int] = False, + respect_document_boundaries: bool = True, + context_dropout: float = 0.5, + saved_config: Optional[PretrainedConfig] = None, + tokenizer_data: Optional[BytesIO] = None, + name: Optional[str] = None, + **kwargs, + ): + self.instance_parameters = self.get_instance_parameters(locals=locals()) + super().__init__() + # temporary fix to disable tokenizer parallelism warning + # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # do not print transformer warnings as these are confusing in this case + from transformers import logging + + logging.set_verbosity_error() + + if tokenizer_data is None: + # load tokenizer and transformer model + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs) + else: + # load tokenizer from inmemory zip-file + self.tokenizer = self._tokenizer_from_bytes(tokenizer_data) + + if saved_config is None: + config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) + self.model = AutoModel.from_pretrained(model, config=config) + else: + self.model = AutoModel.from_config(saved_config) + + self.truncate = True + + if isinstance(self.model, self.NO_MAX_SEQ_LENGTH_MODELS): + allow_long_sentences = False + self.truncate = False + + self.stride = self.tokenizer.model_max_length // 2 if allow_long_sentences else 0 + self.allow_long_sentences = allow_long_sentences + self.use_lang_emb = hasattr(self.model, "use_lang_emb") and self.model.use_lang_emb + + # model name + if name is None: + self.name = "transformer-" + str(model) + else: + self.name = name + self.base_model_name = str(model) + + self.token_embedding = is_token_embedding + self.document_embedding = is_document_embedding + + if not self.token_embedding and not self.document_embedding: + raise ValueError("either 'is_token_embedding' or 'is_document_embedding' needs to be set.") + + if self.token_embedding and self.document_embedding: + log.info("Using TransformerEmbedding for both token embeddings and document embeddings is experimental") + + if self.document_embedding and cls_pooling not in ["cls", "max", "mean"]: + raise ValueError(f"Document Pooling operation `{cls_pooling}` is not defined for TransformerEmbedding") + + if self.token_embedding and subtoken_pooling not in ["first", "last", "first_last", "mean"]: + raise ValueError(f"Subtoken Pooling operation `{subtoken_pooling}` is not defined for TransformerEmbedding") + + if self.document_embedding and cls_pooling == "cls" and allow_long_sentences: + log.warning( + "Using long sentences for Document embeddings is only beneficial for cls_pooling types 'mean' and 'max " + ) + + if isinstance(use_context, bool): + self.context_length: int = 64 if use_context else 0 + else: + self.context_length = use_context + + self.context_dropout = context_dropout + self.respect_document_boundaries = respect_document_boundaries + + self.to(flair.device) + + # embedding parameters + if layers == "all": + # send mini-token through to check how many layers the model has + hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] + self.layer_indexes = [int(x) for x in range(len(hidden_states))] + else: + self.layer_indexes = [int(x) for x in layers.split(",")] + + self.cls_pooling = cls_pooling + self.subtoken_pooling = subtoken_pooling + + self.layer_mean = layer_mean + self.fine_tune = fine_tune + self.static_embeddings = not self.fine_tune + + # return length + self.embedding_length_internal = self._calculate_embedding_length() + + self.special_tokens = [] + # check if special tokens exist to circumvent error message + if self.tokenizer._bos_token: + self.special_tokens.append(self.tokenizer.bos_token) + if self.tokenizer._cls_token: + self.special_tokens.append(self.tokenizer.cls_token) + + # most models have an initial BOS token, except for XLNet, T5 and GPT2 + self.begin_offset = self._get_begin_offset_of_tokenizer() + self.initial_cls_token: bool = self._has_initial_cls_token() + + # when initializing, embeddings are in eval mode by default + self.eval() + + @property + def embedding_length(self) -> int: + if not hasattr(self, "embedding_length_internal"): + self.embedding_length_internal = self._calculate_embedding_length() + + return self.embedding_length_internal + + def _has_initial_cls_token(self) -> bool: + # most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial + tokens = self.tokenizer.encode("a") + return tokens[0] == self.tokenizer.cls_token_id + + def _get_begin_offset_of_tokenizer(self) -> int: + test_string = "a" + tokens = self.tokenizer.encode(test_string) + begin_offset = 0 + + for begin_offset, token in enumerate(tokens): + if ( + self.tokenizer.decode([token]) == test_string + or self.tokenizer.decode([token]) == self.tokenizer.unk_token + ): + break + return begin_offset + + def _calculate_embedding_length(self) -> int: + if not self.layer_mean: + length = len(self.layer_indexes) * self.model.config.hidden_size + else: + length = self.model.config.hidden_size + + # in case of doubt: token embedding has higher priority than document embedding + if self.token_embedding and self.pooling_operation == "first_last": + length *= 2 + if self.document_embedding: + log.warning( + "Token embedding length and Document embedding length vary, due to `first_last` subtoken pooling, this might not be supported" + ) + + @property + def embedding_type(self) -> str: + # in case of doubt: token embedding has higher priority than document embedding + return "word-level" if self.token_embedding else "sentence-level" + + def _tokenizer_from_bytes(self, zip_data: BytesIO) -> PreTrainedTokenizer: + zip_obj = zipfile.ZipFile(zip_data) + with tempfile.TemporaryDirectory() as temp_dir: + zip_obj.extractall(temp_dir) + return AutoTokenizer.from_pretrained(temp_dir) + + def _tokenizer_bytes(self): + with tempfile.TemporaryDirectory() as temp_dir: + files = self.tokenizer.save_pretrained(temp_dir) + zip_data = BytesIO() + zip_obj = zipfile.ZipFile(zip_data, "w+") + for f in files: + zip_obj.write(f, os.path.relpath(f, temp_dir)) + + zip_data.seek(0) + return zip_data + + @staticmethod + def _remove_special_markup(text: str): + # remove special markup + text = re.sub("^Ġ", "", text) # RoBERTa models + text = re.sub("^##", "", text) # BERT models + text = re.sub("^▁", "", text) # XLNet models + text = re.sub("$", "", text) # XLM models + return text + + def _get_processed_token_text(self, token: Token) -> str: + pieces = self.tokenizer.tokenize(token.text) + token_text = "" + for piece in pieces: + token_text += self._remove_special_markup(piece) + token_text = token_text.lower() + return token_text + + def __getstate__(self): + config_dict = self.model.config.to_dict() + + # not necessary when loaded via model, but we keep it for now. + model_state_dict = self.model.state_dict() + + tokenizer_data = self._tokenizer_bytes() + + model_state = { + "model": self.base_model_name, + "fine_tune": self.fine_tune, + "layers": ",".join(map(str, self.layer_indexes)), + "layer_mean": self.layer_mean, + "subtoken_pooling": self.subtoken_pooling, + "cls_pooling": self.cls_pooling, + "is_token_embedding": self.token_embedding, + "is_document_embedding": self.document_embedding, + "allow_long_sentences": self.allow_long_sentences, + "config_state_dict": config_dict, + "tokenizer_data": tokenizer_data, + "name": self.name, + "model_state_dict": model_state_dict, + "context_length": self.context_length, + "respect_document_boundaries": self.respect_document_boundaries, + "context_dropout": self.context_dropout, + } + + return model_state + + def __setstate__(self, state): + config_state_dict = state.pop("config_state_dict", None) + model_state_dict = state.pop("model_state_dict", None) + + if "base_model_name" in state: + state["model"] = state.pop("base_model_name") + + state["use_context"] = state.pop("context_length", False) + + if "layer_indexes" in state: + layer_indexes = state.pop("layer_indexes") + state["layers"] = ",".join(map(str, layer_indexes)) + + if "use_scalar_mix" in state: + # legacy Flair <= 0.7 + state["layer_mean"] = state.pop("use_scalar_mix") + + if "is_token_embedding" not in state: + # legacy TransformerTokenEmbedding + state["is_token_embedding"] = "pooling_operation" in state + + if "is_document_embedding" not in state: + # Legacy TransformerDocumentEmbedding + state["is_document_embedding"] = "pooling" in state + + if "pooling_operation" in state: + # legacy TransformerTokenEmbedding + state["subtoken_pooling"] = state.pop("pooling_operation") + + if "cls_operation" in state: + # legacy TransformerDocumentEmbedding + state["cls_pooling"] = state.pop("pooling") + + if config_state_dict: + model_type = config_state_dict.get("model_type", "bert") + config_class = CONFIG_MAPPING[model_type] + config = config_class.from_dict(config_state_dict) + else: + config = None + + if "embedding_length_internal" in state: + del state["embedding_length_internal"] + + embedding = self.create_from_state(saved_config=config, **state) + + # copy values from new embedding + for key in embedding.__dict__.keys(): + self.__dict__[key] = embedding.__dict__[key] + + if model_state_dict: + self.load_state_dict(model_state_dict) + + @classmethod + def create_from_state(cls, **state): + return cls(**state) + + def _reconstruct_tokens_from_subtokens(self, sentence, subtokens): + word_iterator = iter(sentence) + token = next(word_iterator) + token_text = self._get_processed_token_text(token) + token_subtoken_lengths = [] + reconstructed_token = "" + subtoken_count = 0 + # iterate over subtokens and reconstruct tokens + for subtoken_id, subtoken in enumerate(subtokens): + + # remove special markup + subtoken = self._remove_special_markup(subtoken) + + # TODO check if this is necessary is this method is called before prepare_for_model + # check if reconstructed token is special begin token ([CLS] or similar) + if subtoken in self.special_tokens and subtoken_id == 0: + continue + + # some BERT tokenizers somehow omit words - in such cases skip to next token + if subtoken_count == 0 and not token_text.startswith(subtoken.lower()): + + while True: + token_subtoken_lengths.append(0) + token = next(word_iterator) + token_text = self._get_processed_token_text(token) + if token_text.startswith(subtoken.lower()): + break + + subtoken_count += 1 + + # append subtoken to reconstruct token + reconstructed_token = reconstructed_token + subtoken + + # check if reconstructed token is the same as current token + if reconstructed_token.lower() == token_text: + + # if so, add subtoken count + token_subtoken_lengths.append(subtoken_count) + + # reset subtoken count and reconstructed token + reconstructed_token = "" + subtoken_count = 0 + + # break from loop if all tokens are accounted for + if len(token_subtoken_lengths) < len(sentence): + token = next(word_iterator) + token_text = self._get_processed_token_text(token) + else: + break + + # if tokens are unaccounted for + while len(token_subtoken_lengths) < len(sentence) and len(token.text) == 1: + token_subtoken_lengths.append(0) + if len(token_subtoken_lengths) == len(sentence): + break + token = next(word_iterator) + + # check if all tokens were matched to subtokens + if token != sentence[-1]: + log.error(f"Tokenization MISMATCH in sentence '{sentence.to_tokenized_string()}'") + log.error(f"Last matched: '{token}'") + log.error(f"Last sentence: '{sentence[-1]}'") + log.error(f"subtokenized: '{subtokens}'") + return token_subtoken_lengths + + def _gather_tokenized_strings(self, sentences: List[Sentence]): + tokenized_sentences = [] + all_token_subtoken_lengths = [] + for sentence in sentences: + + # subtokenize the sentence + tokenized_string = sentence.to_tokenized_string() + + # transformer specific tokenization + subtokenized_sentence = self.tokenizer.tokenize(tokenized_string) + + # set zero embeddings for empty sentences and exclude + if len(subtokenized_sentence) == 0: + if self.token_embedding: + for token in sentence: + token.set_embedding(self.name, torch.zeros(self.embedding_length)) + if self.document_embedding: + sentence.set_embedding(self.name, torch.zeros(self.embedding_length)) + continue + + if self.token_embedding: + # determine into how many subtokens each token is split + all_token_subtoken_lengths.append( + self._reconstruct_tokens_from_subtokens(sentence, subtokenized_sentence) + ) + + # remember tokenized sentences and their subtokenization + tokenized_sentences.append(tokenized_string) + return tokenized_sentences, all_token_subtoken_lengths + + def _build_transformer_model_inputs(self, batch_encoding, tokenized_sentences): + model_kwargs = {} + input_ids = batch_encoding["input_ids"].to(flair.device) + + # Models such as FNet do not have an attention_mask + if "attention_mask" in batch_encoding: + model_kwargs["attention_mask"] = batch_encoding["attention_mask"].to(flair.device) + + # set language IDs for XLM-style transformers + if self.use_lang_emb: + model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype) + for s_id, sentence_text in enumerate(tokenized_sentences): + sequence_length = len(sentence_text) + lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0) # type: ignore + model_kwargs["langs"][s_id][:sequence_length] = lang_id + return input_ids, model_kwargs + + def _combine_strided_sentences( + self, hidden_states: torch.Tensor, sentence_parts_lengths: torch.Tensor + ) -> List[torch.Tensor]: + sentence_idx_offset = 0 + sentence_hidden_states = [] + for nr_sentence_parts in sentence_parts_lengths: + sentence_hidden_state = hidden_states[:, sentence_idx_offset, ...] + sentence_idx_offset += 1 + + for i in range(1, nr_sentence_parts): + remainder_sentence_hidden_state = hidden_states[:, sentence_idx_offset, ...] + sentence_idx_offset += 1 + sentence_hidden_state = torch.cat( + ( + sentence_hidden_state[:, : -1 - self.stride // 2, :], + remainder_sentence_hidden_state[:, 1 + self.stride // 2:, :], + ), + 1, + ) + sentence_hidden_states.append(sentence_hidden_state) + return sentence_hidden_states + + def _try_document_embedding_shortcut(self, hidden_states, sentences): + # cls first pooling can be done without recreating sentence hidden states + if ( + self.document_embedding + and not self.token_embedding + and self.cls_pooling == "cls" + and self.initial_cls_token + ): + embeddings_all_document_layers = hidden_states[:, :, 0] + if self.layer_mean: + document_embs = torch.mean(embeddings_all_document_layers, dim=0) + else: + document_embs = embeddings_all_document_layers.view(-1, embeddings_all_document_layers.size[-1]) + for (document_emb, sentence) in zip(document_embs, sentences): + sentence.set_embedding(self.name, document_emb) + return True + return False + + def _extract_document_embeddings(self, sentence_hidden_states, sentences): + for sentence_hidden_state, sentence in zip(sentence_hidden_states, sentences): + if self.cls_pooling == "cls": + index_of_cls_token = 0 if self.initial_cls_token else -1 + embedding_all_document_layers = sentence_hidden_state[:, index_of_cls_token, :] + elif self.cls_pooling == "mean": + embedding_all_document_layers = sentence_hidden_state.mean(dim=2) + elif self.cls_pooling == "max": + _, embedding_all_document_layers = sentence_hidden_state.max(dim=2) + else: + raise ValueError(f"cls pooling method: `{self.cls_pooling}` is not implemented") + if self.layer_mean: + document_emb = torch.mean(embedding_all_document_layers, dim=0) + else: + document_emb = embedding_all_document_layers.view(-1) + sentence.set_embedding(self.name, document_emb) + + def _extract_token_embeddings(self, sentence_hidden_states, sentences, all_token_subtoken_lengths): + for sentence_hidden_state, sentence, subtoken_lengths in zip( + sentence_hidden_states, sentences, all_token_subtoken_lengths + ): + subword_start_idx = self.begin_offset + + for token, n_subtokens in zip(sentence, subtoken_lengths): + if n_subtokens == 0: + token.set_embedding(self.name, torch.zeros(self.embedding_length)) + continue + subword_end_idx = subword_start_idx + n_subtokens + + current_embeddings = sentence_hidden_state[:, subword_start_idx:subword_end_idx] + subword_start_idx = subword_end_idx + + if self.subtoken_pooling == "first": + final_embedding = current_embeddings[:, 0] + elif self.subtoken_pooling == "last": + final_embedding = current_embeddings[:, -1] + elif self.subtoken_pooling == "first_last": + final_embedding = torch.cat([current_embeddings[:, 0], current_embeddings[:, -1]], dim=1) + elif self.subtoken_pooling == "mean": + final_embedding = current_embeddings.mean(dim=1) + else: + raise ValueError(f"subtoken pooling method: `{self.subtoken_pooling}` is not implemented") + + token.set_embedding(self.name, final_embedding) + + def _add_embeddings_to_sentences(self, sentences: List[Sentence]): + tokenized_sentences, all_token_subtoken_lengths = self._gather_tokenized_strings(sentences) + + # encode inputs + batch_encoding = self.tokenizer( + tokenized_sentences, + stride=self.stride, + return_overflowing_tokens=self.allow_long_sentences, + truncation=self.truncate, + padding=True, + return_tensors="pt", + ) + + input_ids, model_kwargs = self._build_transformer_model_inputs(batch_encoding, tokenized_sentences) + + hidden_states = self.model(input_ids, **model_kwargs)[-1] + + # make the tuple a tensor; makes working with it easier. + hidden_states = torch.stack(hidden_states) + + # only use layers that will be outputted + hidden_states = hidden_states[self.layer_indexes, :, :] + + gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() + + with gradient_context: + + if self._try_document_embedding_shortcut(hidden_states, sentences): + return + + if self.allow_long_sentences: + sentence_hidden_states = self._combine_strided_sentences( + hidden_states, + sentence_parts_lengths=torch.unique( + batch_encoding["overflow_to_sample_mapping"], + return_counts=True, + sorted=True, + )[1].tolist(), + ) + else: + sentence_hidden_states = hidden_states.tolist() + + # remove padding tokens + sentence_hidden_states = [ + sentence_hidden_state[:, : len(subtokens), :] + for (subtokens, sentence_hidden_state) in zip(tokenized_sentences, sentence_hidden_states) + ] + + if self.document_embedding: + self._extract_document_embedding(sentence_hidden_states, sentences) + + if self.token_embedding: + self._extract_token_embeddings(sentence_hidden_states, sentences, all_token_subtoken_lengths) + + def _expand_sentence_with_context(self, sentence): + expand_context = not self.training or random.randint(1, 100) > (self.context_dropout * 100) + + left_context = [] + right_context = [] + + if expand_context: + left_context = sentence.left_context(self.context_length, self.respect_document_boundaries) + right_context = sentence.right_context(self.context_length, self.respect_document_boundaries) + + expanded_sentence = Sentence(left_context + [t.text for t in sentence.tokens] + right_context) + + context_length = len(left_context) + return expanded_sentence, context_length + + def _add_embeddings_internal(self, sentences: List[Sentence]): + expanded_sentences = [] + context_offsets = [] + + if self.context_length > 0: + # set context if not set already + previous_sentence = None + for sentence in sentences: + if sentence.is_context_set(): + continue + sentence._previous_sentence = previous_sentence + sentence._next_sentence = None + if previous_sentence: + previous_sentence._next_sentence = sentence + previous_sentence = sentence + + for sentence in sentences: + # create expanded sentence and remember context offsets + expanded_sentence, context_offset = self._expand_sentence_with_context(sentence) + expanded_sentences.append(expanded_sentence) + context_offsets.append(context_offset) + else: + expanded_sentences.extend(sentences) + + self._add_embeddings_to_sentences(expanded_sentences) + + # move embeddings from context back to original sentence (if using context) + if self.context_length > 0: + for original_sentence, expanded_sentence, context_offset in zip( + sentences, expanded_sentences, context_offsets + ): + for token_idx, token in enumerate(original_sentence): + token.set_embedding( + self.name, + expanded_sentence[token_idx + context_offset].get_embedding(self.name), + ) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 15fed61a6d..3bfcfbb2cb 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -15,7 +15,7 @@ import flair from flair.data import Sentence -from flair.embeddings.base import Embeddings, ScalarMix +from flair.embeddings.base import Embeddings, ScalarMix, TransformerEmbedding from flair.embeddings.token import FlairEmbeddings, StackedEmbeddings, TokenEmbeddings from flair.nn import LockedDropout, WordDropout @@ -30,260 +30,40 @@ def embedding_type(self) -> str: return "sentence-level" -class TransformerDocumentEmbeddings(DocumentEmbeddings): +class TransformerDocumentEmbeddings(DocumentEmbeddings, TransformerEmbedding): + def __init__( - self, - model: str = "bert-base-uncased", - fine_tune: bool = True, - layers: str = "-1", - layer_mean: bool = False, - pooling: str = "cls", - **kwargs, + self, + model: str = "bert-base-uncased", # set parameters with different default values + layers: str = "-1", + layer_mean: bool = False, + is_token_embedding: bool = False, + **kwargs, ): """ Bidirectional transformer embeddings of words from various transformer architectures. :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for options) - :param fine_tune: If True, allows transformers to be fine-tuned during training - :param batch_size: How many sentence to push through transformer at once. Set to 1 by default since transformer - models tend to be huge. :param layers: string indicating which layers to take for embedding (-1 is topmost layer) + :param cls_pooling: Pooling strategy for combining token level embeddings. options are 'cls', 'max', 'mean'. :param layer_mean: If True, uses a scalar mix of layers as embedding - :param pooling: Pooling strategy for combining token level embeddings. options are 'cls', 'max', 'mean'. + :param fine_tune: If True, allows transformers to be fine-tuned during training """ super().__init__() - - if pooling not in ["cls", "max", "mean"]: - raise ValueError(f"Pooling operation `{pooling}` is not defined for TransformerDocumentEmbeddings") - - # temporary fix to disable tokenizer parallelism warning - # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning) - import os - - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # do not print transformer warnings as these are confusing in this case - from transformers import logging - - logging.set_verbosity_error() - - # load tokenizer and transformer model - self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs) - if self.tokenizer.model_max_length > 1000000000: - self.tokenizer.model_max_length = 512 - log.info( - "No model_max_length in Tokenizer's config.json - setting it to 512. " - "Specify desired model_max_length by passing it as attribute to embedding instance." - ) - if "config" not in kwargs: - config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) - self.model = AutoModel.from_pretrained(model, config=config) - else: - self.model = AutoModel.from_pretrained(None, **kwargs) - - logging.set_verbosity_warning() - - # model name - self.name = "transformer-document-" + str(model) - self.base_model_name = str(model) - - # when initializing, embeddings are in eval mode by default - self.model.eval() - self.model.to(flair.device) - - # embedding parameters - if layers == "all": - # send mini-token through to check how many layers the model has - hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] - self.layer_indexes = [int(x) for x in range(len(hidden_states))] - else: - self.layer_indexes = [int(x) for x in layers.split(",")] - - self.layer_mean = layer_mean - self.fine_tune = fine_tune - self.static_embeddings = not self.fine_tune - self.pooling = pooling - - # check whether CLS is at beginning or end - self.initial_cls_token: bool = self._has_initial_cls_token(tokenizer=self.tokenizer) - - @staticmethod - def _has_initial_cls_token(tokenizer: PreTrainedTokenizer) -> bool: - # most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial - tokens = tokenizer.encode("a") - initial_cls_token: bool = False - if tokens[0] == tokenizer.cls_token_id: - initial_cls_token = True - return initial_cls_token - - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: - """Add embeddings to all words in a list of sentences.""" - - # gradients are enabled if fine-tuning is enabled - gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() - - with gradient_context: - - # first, subtokenize each sentence and find out into how many subtokens each token was divided - subtokenized_sentences = [] - - # subtokenize sentences - for sentence in sentences: - # tokenize and truncate to max subtokens (TODO: check better truncation strategies) - subtokenized_sentence = self.tokenizer.encode( - sentence.to_tokenized_string(), - add_special_tokens=True, - max_length=self.tokenizer.model_max_length, - truncation=True, - ) - - subtokenized_sentences.append( - torch.tensor(subtokenized_sentence, dtype=torch.long, device=flair.device) - ) - - # find longest sentence in batch - longest_sequence_in_batch: int = len(max(subtokenized_sentences, key=len)) - - # initialize batch tensors and mask - input_ids = torch.zeros( - [len(sentences), longest_sequence_in_batch], - dtype=torch.long, - device=flair.device, - ) - mask = torch.zeros( - [len(sentences), longest_sequence_in_batch], - dtype=torch.long, - device=flair.device, - ) - for s_id, sentence_embedding in enumerate(subtokenized_sentences): - sequence_length = len(sentence_embedding) - input_ids[s_id][:sequence_length] = sentence_embedding - mask[s_id][:sequence_length] = torch.ones(sequence_length) - - # put encoded batch through transformer model to get all hidden states of all encoder layers - hidden_states = ( - self.model(input_ids, attention_mask=mask)[-1] if len(sentences) > 1 else self.model(input_ids)[-1] - ) - - # iterate over all subtokenized sentences - for sentence_idx, (sentence, subtokens) in enumerate(zip(sentences, subtokenized_sentences)): - - if self.pooling == "cls": - index_of_CLS_token = 0 if self.initial_cls_token else len(subtokens) - 1 - - cls_embeddings_all_layers: List[torch.Tensor] = [ - hidden_states[layer][sentence_idx][index_of_CLS_token] for layer in self.layer_indexes - ] - - embeddings_all_layers = cls_embeddings_all_layers - - elif self.pooling == "mean": - mean_embeddings_all_layers: List[torch.Tensor] = [ - torch.mean( - hidden_states[layer][sentence_idx][: len(subtokens), :], - dim=0, - ) - for layer in self.layer_indexes - ] - - embeddings_all_layers = mean_embeddings_all_layers - - elif self.pooling == "max": - max_embeddings_all_layers: List[torch.Tensor] = [ - torch.max( - hidden_states[layer][sentence_idx][: len(subtokens), :], - dim=0, - )[0] - for layer in self.layer_indexes - ] - - embeddings_all_layers = max_embeddings_all_layers - - # use scalar mix of embeddings if so selected - if self.layer_mean: - sm = ScalarMix(mixture_size=len(embeddings_all_layers)) - sm_embeddings = sm(embeddings_all_layers) - - embeddings_all_layers = [sm_embeddings] - - # set the extracted embedding for the token - sentence.set_embedding(self.name, torch.cat(embeddings_all_layers)) - - return sentences - - @property - def embedding_length(self) -> int: - """Returns the length of the embedding vector.""" - return ( - len(self.layer_indexes) * self.model.config.hidden_size - if not self.layer_mean - else self.model.config.hidden_size - ) - - def __getstate__(self): - # special handling for serializing transformer models - config_state_dict = self.model.config.__dict__ - model_state_dict = self.model.state_dict() - - if not hasattr(self, "base_model_name"): - self.base_model_name = self.name.split("transformer-document-")[-1] - - # serialize the transformer models and the constructor arguments (but nothing else) - model_state = { - "config_state_dict": config_state_dict, - "model_state_dict": model_state_dict, - "embedding_length_internal": self.embedding_length, - "base_model_name": self.base_model_name, - "fine_tune": self.fine_tune, - "layer_indexes": self.layer_indexes, - "layer_mean": self.layer_mean, - "pooling": self.pooling, - } - - return model_state - - def __setstate__(self, d): - self.__dict__ = d - - # necessary for reverse compatibility with Flair <= 0.7 - if "use_scalar_mix" in self.__dict__.keys(): - self.__dict__["layer_mean"] = d["use_scalar_mix"] - - # special handling for deserializing transformer models - if "config_state_dict" in d: - - # load transformer model - model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert" - config_class = CONFIG_MAPPING[model_type] - loaded_config = config_class.from_dict(d["config_state_dict"]) - - # constructor arguments - layers = ",".join([str(idx) for idx in self.__dict__["layer_indexes"]]) - - # re-initialize transformer word embeddings with constructor arguments - embedding = TransformerDocumentEmbeddings( - model=self.__dict__["base_model_name"], - fine_tune=self.__dict__["fine_tune"], - layers=layers, - layer_mean=self.__dict__["layer_mean"], - config=loaded_config, - state_dict=d["model_state_dict"], - pooling=self.__dict__["pooling"] if "pooling" in self.__dict__ else "cls", - # for backward compatibility with previous models - ) - - # I have no idea why this is necessary, but otherwise it doesn't work - for key in embedding.__dict__.keys(): - self.__dict__[key] = embedding.__dict__[key] - - else: - model_name = self.__dict__["name"].split("transformer-document-")[-1] - # reload tokenizer to get around serialization issues - try: - tokenizer = AutoTokenizer.from_pretrained(model_name) - except: # noqa: E722 TODO: figure out possible exceptions - pass - self.tokenizer = tokenizer + TransformerEmbedding.__init__( + self, + model=model, + layers=layers, + layer_mean=layer_mean, + is_token_embedding=is_token_embedding, + is_document_embedding=True, + **kwargs) + + @classmethod + def create_from_state(cls, **state): + # this parameter is fixed + del state["is_document_embedding"] + return cls(**state) class DocumentPoolEmbeddings(DocumentEmbeddings): diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 0d81be0c13..f50f7b1b50 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -3,6 +3,7 @@ import os import re from collections import Counter +from io import BytesIO from pathlib import Path from typing import Dict, List, Optional, Sequence, Union @@ -19,12 +20,12 @@ AutoTokenizer, PreTrainedTokenizer, TransfoXLModel, - XLNetModel, + XLNetModel, PretrainedConfig, ) import flair from flair.data import Corpus, Dictionary, Sentence, Token, _iter_dataset -from flair.embeddings.base import Embeddings +from flair.embeddings.base import Embeddings, TransformerEmbedding from flair.file_utils import cached_path, instance_lru_cache, open_inside_zip log = logging.getLogger("flair") @@ -116,12 +117,12 @@ class WordEmbeddings(TokenEmbeddings): """Standard static word embeddings, such as GloVe or FastText.""" def __init__( - self, - embeddings: str, - field: str = None, - fine_tune: bool = False, - force_cpu: bool = True, - stable: bool = False, + self, + embeddings: str, + field: str = None, + fine_tune: bool = False, + force_cpu: bool = True, + stable: bool = False, ): """ Initializes classic word embeddings. Constructor downloads required files if not there. @@ -374,10 +375,10 @@ class CharacterEmbeddings(TokenEmbeddings): """Character embeddings of words, as proposed in Lample et al., 2016.""" def __init__( - self, - path_to_char_dict: str = None, - char_embedding_dim: int = 25, - hidden_size_char: int = 25, + self, + path_to_char_dict: str = None, + char_embedding_dim: int = 25, + hidden_size_char: int = 25, ): """Uses the default character dictionary if none provided.""" @@ -473,13 +474,13 @@ class FlairEmbeddings(TokenEmbeddings): """Contextual string embeddings of words, as proposed in Akbik et al., 2018.""" def __init__( - self, - model, - fine_tune: bool = False, - chars_per_chunk: int = 512, - with_whitespace: bool = True, - tokenized_lm: bool = True, - is_lower: bool = False, + self, + model, + fine_tune: bool = False, + chars_per_chunk: int = 512, + with_whitespace: bool = True, + tokenized_lm: bool = True, + is_lower: bool = False, ): """ initializes contextual string embeddings using a character-level language model. @@ -788,11 +789,11 @@ def __str__(self): class PooledFlairEmbeddings(TokenEmbeddings): def __init__( - self, - contextual_embeddings: Union[str, FlairEmbeddings], - pooling: str = "min", - only_capitalized: bool = False, - **kwargs, + self, + contextual_embeddings: Union[str, FlairEmbeddings], + pooling: str = "min", + only_capitalized: bool = False, + **kwargs, ): super().__init__() @@ -893,22 +894,13 @@ def __setstate__(self, d): self.word_embeddings[key] = self.word_embeddings[key].cpu() -class TransformerWordEmbeddings(TokenEmbeddings): - NO_MAX_SEQ_LENGTH_MODELS = [XLNetModel, TransfoXLModel] +class TransformerWordEmbeddings(TokenEmbeddings, TransformerEmbedding): def __init__( - self, - model: str = "bert-base-uncased", - layers: str = "all", - subtoken_pooling: str = "first", - layer_mean: bool = True, - fine_tune: bool = False, - allow_long_sentences: bool = True, - use_context: Union[bool, int] = False, - memory_effective_training: bool = True, - respect_document_boundaries: bool = True, - context_dropout: float = 0.5, - **kwargs, + self, + model: str = "bert-base-uncased", + is_document_embedding: bool = False, + **kwargs, ): """ Bidirectional transformer embeddings of words from various transformer architectures. @@ -921,564 +913,18 @@ def __init__( :param fine_tune: If True, allows transformers to be fine-tuned during training """ super().__init__() - self.instance_parameters = self.get_instance_parameters(locals=locals()) - - # temporary fix to disable tokenizer parallelism warning - # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning) - import os - - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # do not print transformer warnings as these are confusing in this case - from transformers import logging - - logging.set_verbosity_error() - - # load tokenizer and transformer model - self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs) - if self.tokenizer.model_max_length > 1000000000: - self.tokenizer.model_max_length = 512 - log.info( - "No model_max_length in Tokenizer's config.json - setting it to 512. " - "Specify desired model_max_length by passing it as attribute to embedding instance." - ) - if "config" not in kwargs: - config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) - self.model = AutoModel.from_pretrained(model, config=config) - else: - self.model = AutoModel.from_pretrained(None, **kwargs) - - logging.set_verbosity_warning() - - if type(self.model) not in self.NO_MAX_SEQ_LENGTH_MODELS: - self.allow_long_sentences = allow_long_sentences - self.truncate = True - self.stride = self.tokenizer.model_max_length // 2 if allow_long_sentences else 0 - else: - # in the end, these models don't need this configuration - self.allow_long_sentences = False - self.truncate = False - self.stride = 0 - - self.use_lang_emb = hasattr(self.model, "use_lang_emb") and self.model.use_lang_emb - - # model name - self.name = "transformer-word-" + str(model) - self.base_model = str(model) - - # whether to detach gradients on overlong sentences - self.memory_effective_training = memory_effective_training - - # store whether to use context (and how much) - if isinstance(use_context, bool): - self.context_length: int = 64 if use_context else 0 - else: - self.context_length = use_context - - # dropout contexts - self.context_dropout = context_dropout - - # if using context, can we cross document boundaries? - self.respect_document_boundaries = respect_document_boundaries - - # send self to flair-device - self.to(flair.device) - - # embedding parameters - if layers == "all": - # send mini-token through to check how many layers the model has - hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] - self.layer_indexes = [int(x) for x in range(len(hidden_states))] - else: - self.layer_indexes = [int(x) for x in layers.split(",")] - - self.pooling_operation = subtoken_pooling - self.layer_mean = layer_mean - self.fine_tune = fine_tune - self.static_embeddings = not self.fine_tune - - # calculate embedding length - if not self.layer_mean: - length = len(self.layer_indexes) * self.model.config.hidden_size - else: - length = self.model.config.hidden_size - if self.pooling_operation == "first_last": - length *= 2 - - # return length - self.embedding_length_internal = length - - self.special_tokens = [] - # check if special tokens exist to circumvent error message - if self.tokenizer._bos_token: - self.special_tokens.append(self.tokenizer.bos_token) - if self.tokenizer._cls_token: - self.special_tokens.append(self.tokenizer.cls_token) - - # most models have an intial BOS token, except for XLNet, T5 and GPT2 - self.begin_offset = self._get_begin_offset_of_tokenizer(tokenizer=self.tokenizer) - - # when initializing, embeddings are in eval mode by default - self.eval() - - @staticmethod - def _get_begin_offset_of_tokenizer(tokenizer: PreTrainedTokenizer) -> int: - test_string = "a" - tokens = tokenizer.encode(test_string) - - for begin_offset, token in enumerate(tokens): - if tokenizer.decode([token]) == test_string or tokenizer.decode([token]) == tokenizer.unk_token: - break - return begin_offset - - @staticmethod - def _remove_special_markup(text: str): - # remove special markup - text = re.sub("^Ġ", "", text) # RoBERTa models - text = re.sub("^##", "", text) # BERT models - text = re.sub("^▁", "", text) # XLNet models - text = re.sub("$", "", text) # XLM models - return text - - def _get_processed_token_text(self, token: Token) -> str: - pieces = self.tokenizer.tokenize(token.text) - token_text = "" - for piece in pieces: - token_text += self._remove_special_markup(piece) - token_text = token_text.lower() - return token_text - - def _add_embeddings_internal(self, sentences: List[Sentence]): - - # we require encoded subtokenized sentences, the mapping to original tokens and the number of - # parts that each sentence produces - all_token_subtoken_lengths = [] - - # if we also use context, first expand sentence to include context - if self.context_length > 0: - - # set context if not set already - previous_sentence = None - for sentence in sentences: - if sentence.is_context_set(): - continue - sentence._previous_sentence = previous_sentence - sentence._next_sentence = None - if previous_sentence: - previous_sentence._next_sentence = sentence - previous_sentence = sentence - - original_sentences = [] - expanded_sentences = [] - context_offsets = [] - - for sentence in sentences: - # in case of contextualization, we must remember non-expanded sentence - original_sentence = sentence - original_sentences.append(original_sentence) - - # create expanded sentence and remember context offsets - expanded_sentence, context_offset = self._expand_sentence_with_context(sentence) - expanded_sentences.append(expanded_sentence) - context_offsets.append(context_offset) - - # overwrite sentence with expanded sentence - sentence = expanded_sentence - - sentences = expanded_sentences - - tokenized_sentences = [] - for sentence in sentences: - - # subtokenize the sentence - tokenized_string = sentence.to_tokenized_string() - - # transformer specific tokenization - subtokenized_sentence = self.tokenizer.tokenize(tokenized_string) - - # set zero embeddings for empty sentences and exclude - if len(subtokenized_sentence) == 0: - for token in sentence: - token.set_embedding(self.name, torch.zeros(self.embedding_length)) - continue - - # determine into how many subtokens each token is split - token_subtoken_lengths = self.reconstruct_tokens_from_subtokens(sentence, subtokenized_sentence) - - # remember tokenized sentences and their subtokenization - tokenized_sentences.append(tokenized_string) - all_token_subtoken_lengths.append(token_subtoken_lengths) - - # encode inputs - batch_encoding = self.tokenizer( - tokenized_sentences, - stride=self.stride, - return_overflowing_tokens=self.allow_long_sentences, - truncation=self.truncate, - padding=True, - return_tensors="pt", - ) - - model_kwargs = {} - input_ids = batch_encoding["input_ids"].to(flair.device) - - # Models such as FNet do not have an attention_mask - if "attention_mask" in batch_encoding: - model_kwargs["attention_mask"] = batch_encoding["attention_mask"].to(flair.device) - - # determine which sentence was split into how many parts - sentence_parts_lengths = ( - torch.ones(len(tokenized_sentences), dtype=torch.int) - if not self.allow_long_sentences - else torch.unique( - batch_encoding["overflow_to_sample_mapping"], - return_counts=True, - sorted=True, - )[1].tolist() - ) - - # set language IDs for XLM-style transformers - if self.use_lang_emb: - model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype) - for s_id, sentence_text in enumerate(tokenized_sentences): - sequence_length = len(sentence_text) - lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0) # type: ignore - model_kwargs["langs"][s_id][:sequence_length] = lang_id - - - - # gradients are enabled if fine-tuning is enabled - gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() + TransformerEmbedding.__init__( + self, + model=model, + is_token_embedding=True, + is_document_embedding=is_document_embedding, + **kwargs) - with gradient_context: - # put encoded batch through transformer model to get all hidden states of all encoder layers - hidden_states = self.model(input_ids, **model_kwargs)[-1] - # make the tuple a tensor; makes working with it easier. - hidden_states = torch.stack(hidden_states) - - sentence_idx_offset = 0 - # iterate over all subtokenized sentences - for sentence_idx, ( - sentence, - subtoken_lengths, - nr_sentence_parts, - ) in enumerate(zip(sentences, all_token_subtoken_lengths, sentence_parts_lengths)): - - sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...] - - for i in range(1, nr_sentence_parts): - sentence_idx_offset += 1 - remainder_sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...] - # remove stride_size//2 at end of sentence_hidden_state, and half at beginning of remainder, - # in order to get some context into the embeddings of these words. - # also don't include the embedding of the extra [CLS] and [SEP] tokens. - sentence_hidden_state = torch.cat( - ( - sentence_hidden_state[:, : -1 - self.stride // 2, :], - remainder_sentence_hidden_state[:, 1 + self.stride // 2 :, :], - ), - 1, - ) - - subword_start_idx = self.begin_offset - - # for each token, get embedding - for token_idx, (token, number_of_subtokens) in enumerate(zip(sentence, subtoken_lengths)): - - # some tokens have no subtokens at all (if omitted by BERT tokenizer) so return zero vector - if number_of_subtokens == 0: - token.set_embedding(self.name, torch.zeros(self.embedding_length)) - continue - - subword_end_idx = subword_start_idx + number_of_subtokens - - subtoken_embeddings: List[torch.Tensor] = [] - - # get states from all selected layers, aggregate with pooling operation - for layer in self.layer_indexes: - current_embeddings = sentence_hidden_state[layer][subword_start_idx:subword_end_idx] - - if self.pooling_operation == "first": - final_embedding: torch.Tensor = current_embeddings[0] - - if self.pooling_operation == "last": - final_embedding = current_embeddings[-1] - - if self.pooling_operation == "first_last": - final_embedding = torch.cat([current_embeddings[0], current_embeddings[-1]]) - - if self.pooling_operation == "mean": - all_embeddings = [embedding.unsqueeze(0) for embedding in current_embeddings] - final_embedding = torch.mean(torch.cat(all_embeddings, dim=0), dim=0) - - subtoken_embeddings.append(final_embedding) - - # use layer mean of embeddings if so selected - if self.layer_mean and len(self.layer_indexes) > 1: - sm_embeddings = torch.mean(torch.stack(subtoken_embeddings, dim=1), dim=1) - subtoken_embeddings = [sm_embeddings] - - # set the extracted embedding for the token - token.set_embedding(self.name, torch.cat(subtoken_embeddings)) - - subword_start_idx += number_of_subtokens - - # move embeddings from context back to original sentence (if using context) - if self.context_length > 0: - for original_sentence, expanded_sentence, context_offset in zip( - original_sentences, sentences, context_offsets - ): - for token_idx, token in enumerate(original_sentence): - token.set_embedding( - self.name, - expanded_sentence[token_idx + context_offset].get_embedding(self.name), - ) - sentence = original_sentence - - def _expand_sentence_with_context(self, sentence): - - # remember original sentence - original_sentence = sentence - - import random - - expand_context = False if self.training and random.randint(1, 100) <= (self.context_dropout * 100) else True - - left_context = "" - right_context = "" - - if expand_context: - - # get left context - while True: - sentence = sentence.previous_sentence() - if sentence is None: - break - - if self.respect_document_boundaries and sentence.is_document_boundary: - break - - left_context = sentence.to_tokenized_string() + " " + left_context - left_context = left_context.strip() - if len(left_context.split(" ")) > self.context_length: - left_context = " ".join(left_context.split(" ")[-self.context_length :]) - break - original_sentence.left_context = left_context - - sentence = original_sentence - - # get right context - while True: - sentence = sentence.next_sentence() - if sentence is None: - break - if self.respect_document_boundaries and sentence.is_document_boundary: - break - - right_context += " " + sentence.to_tokenized_string() - right_context = right_context.strip() - if len(right_context.split(" ")) > self.context_length: - right_context = " ".join(right_context.split(" ")[: self.context_length]) - break - - original_sentence.right_context = right_context - - left_context_split = left_context.split(" ") - right_context_split = right_context.split(" ") - - # empty contexts should not introduce whitespace tokens - if left_context_split == [""]: - left_context_split = [] - if right_context_split == [""]: - right_context_split = [] - - # make expanded sentence - expanded_sentence = Sentence() - expanded_sentence.tokens = [ - Token(token) - for token in left_context_split + original_sentence.to_tokenized_string().split(" ") + right_context_split - ] - - context_length = len(left_context_split) - return expanded_sentence, context_length - - def reconstruct_tokens_from_subtokens(self, sentence, subtokens): - word_iterator = iter(sentence) - token = next(word_iterator) - token_text = self._get_processed_token_text(token) - token_subtoken_lengths = [] - reconstructed_token = "" - subtoken_count = 0 - # iterate over subtokens and reconstruct tokens - for subtoken_id, subtoken in enumerate(subtokens): - - # remove special markup - subtoken = self._remove_special_markup(subtoken) - - # TODO check if this is necessary is this method is called before prepare_for_model - # check if reconstructed token is special begin token ([CLS] or similar) - if subtoken in self.special_tokens and subtoken_id == 0: - continue - - # some BERT tokenizers somehow omit words - in such cases skip to next token - if subtoken_count == 0 and not token_text.startswith(subtoken.lower()): - - while True: - token_subtoken_lengths.append(0) - token = next(word_iterator) - token_text = self._get_processed_token_text(token) - if token_text.startswith(subtoken.lower()): - break - - subtoken_count += 1 - - # append subtoken to reconstruct token - reconstructed_token = reconstructed_token + subtoken - - # check if reconstructed token is the same as current token - if reconstructed_token.lower() == token_text: - - # if so, add subtoken count - token_subtoken_lengths.append(subtoken_count) - - # reset subtoken count and reconstructed token - reconstructed_token = "" - subtoken_count = 0 - - # break from loop if all tokens are accounted for - if len(token_subtoken_lengths) < len(sentence): - token = next(word_iterator) - token_text = self._get_processed_token_text(token) - else: - break - - # if tokens are unaccounted for - while len(token_subtoken_lengths) < len(sentence) and len(token.text) == 1: - token_subtoken_lengths.append(0) - if len(token_subtoken_lengths) == len(sentence): - break - token = next(word_iterator) - - # check if all tokens were matched to subtokens - if token != sentence[-1]: - log.error(f"Tokenization MISMATCH in sentence '{sentence.to_tokenized_string()}'") - log.error(f"Last matched: '{token}'") - log.error(f"Last sentence: '{sentence[-1]}'") - log.error(f"subtokenized: '{subtokens}'") - return token_subtoken_lengths - - @property - def embedding_length(self) -> int: - - if "embedding_length_internal" in self.__dict__.keys(): - return self.embedding_length_internal - - # """Returns the length of the embedding vector.""" - if not self.layer_mean: - length = len(self.layer_indexes) * self.model.config.hidden_size - else: - length = self.model.config.hidden_size - - if self.pooling_operation == "first_last": - length *= 2 - - self.__embedding_length = length - - return length - - def __getstate__(self): - # special handling for serializing transformer models - config_state_dict = self.model.config.__dict__ - model_state_dict = self.model.state_dict() - - if not hasattr(self, "base_model_name"): - self.base_model_name = self.name.split("transformer-word-")[-1] - - # serialize the transformer models and the constructor arguments (but nothing else) - model_state = { - "config_state_dict": config_state_dict, - "model_state_dict": model_state_dict, - "embedding_length_internal": self.embedding_length, - "base_model_name": self.base_model_name, - "name": self.name, - "layer_indexes": self.layer_indexes, - "subtoken_pooling": self.pooling_operation, - "context_length": self.context_length, - "layer_mean": self.layer_mean, - "fine_tune": self.fine_tune, - "allow_long_sentences": self.allow_long_sentences, - "memory_effective_training": self.memory_effective_training, - "respect_document_boundaries": self.respect_document_boundaries, - "context_dropout": self.context_dropout, - } - - return model_state - - def __setstate__(self, d): - self.__dict__ = d - - # necessary for reverse compatibility with Flair <= 0.7 - if "use_scalar_mix" in self.__dict__.keys(): - self.__dict__["layer_mean"] = d["use_scalar_mix"] - if "memory_effective_training" not in self.__dict__.keys(): - self.__dict__["memory_effective_training"] = True - if "pooling_operation" in self.__dict__.keys(): - self.__dict__["subtoken_pooling"] = d["pooling_operation"] - if "context_length" not in self.__dict__.keys(): - self.__dict__["context_length"] = 0 - if "use_context" in self.__dict__.keys(): - self.__dict__["context_length"] = 64 if self.__dict__["use_context"] else 0 - - if "context_dropout" not in self.__dict__.keys(): - self.__dict__["context_dropout"] = 0.5 - if "respect_document_boundaries" not in self.__dict__.keys(): - self.__dict__["respect_document_boundaries"] = True - if "memory_effective_training" not in self.__dict__.keys(): - self.__dict__["memory_effective_training"] = True - if "base_model_name" not in self.__dict__.keys(): - self.__dict__["base_model_name"] = self.__dict__["name"].split("transformer-word-")[-1] - - # special handling for deserializing transformer models - if "config_state_dict" in d: - - # load transformer model - model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert" - config_class = CONFIG_MAPPING[model_type] - loaded_config = config_class.from_dict(d["config_state_dict"]) - - # constructor arguments - layers = ",".join([str(idx) for idx in self.__dict__["layer_indexes"]]) - - # re-initialize transformer word embeddings with constructor arguments - embedding = TransformerWordEmbeddings( - model=self.__dict__["base_model_name"], - layers=layers, - subtoken_pooling=self.__dict__["subtoken_pooling"], - use_context=self.__dict__["context_length"], - layer_mean=self.__dict__["layer_mean"], - fine_tune=self.__dict__["fine_tune"], - allow_long_sentences=self.__dict__["allow_long_sentences"], - respect_document_boundaries=self.__dict__["respect_document_boundaries"], - memory_effective_training=self.__dict__["memory_effective_training"], - context_dropout=self.__dict__["context_dropout"], - config=loaded_config, - state_dict=d["model_state_dict"], - ) - - # I have no idea why this is necessary, but otherwise it doesn't work - for key in embedding.__dict__.keys(): - self.__dict__[key] = embedding.__dict__[key] - - else: - - # reload tokenizer to get around serialization issues - model_name = self.__dict__["name"].split("transformer-word-")[-1] - try: - tokenizer = AutoTokenizer.from_pretrained(model_name) - except: # noqa: E722 TODO: specify exceptions - pass - - self.tokenizer = tokenizer + @classmethod + def create_from_state(cls, **state): + # this parameter is fixed + del state["is_token_embedding"] + return cls(**state) class FastTextEmbeddings(TokenEmbeddings): @@ -1561,11 +1007,11 @@ class OneHotEmbeddings(TokenEmbeddings): """One-hot encoded embeddings.""" def __init__( - self, - vocab_dictionary: Dictionary, - field: str = "text", - embedding_length: int = 300, - stable: bool = False, + self, + vocab_dictionary: Dictionary, + field: str = "text", + embedding_length: int = 300, + stable: bool = False, ): """ Initializes one-hot encoded word embeddings and a trainable embedding layer @@ -1700,7 +1146,7 @@ def __str__(self): class MuseCrosslingualEmbeddings(TokenEmbeddings): def __init__( - self, + self, ): self.name: str = "muse-crosslingual" self.static_embeddings = True @@ -1826,14 +1272,14 @@ def __setstate__(self, state): class BytePairEmbeddings(TokenEmbeddings): def __init__( - self, - language: str = None, - dim: int = 50, - syllables: int = 100000, - cache_dir=None, - model_file_path: Path = None, - embedding_file_path: Path = None, - **kwargs, + self, + language: str = None, + dim: int = 50, + syllables: int = 100000, + cache_dir=None, + model_file_path: Path = None, + embedding_file_path: Path = None, + **kwargs, ): """ Initializes BP embeddings. Constructor downloads required files if not there. @@ -1846,7 +1292,7 @@ def __init__( self.name: str = f"bpe-{language}-{syllables}-{dim}" else: assert ( - model_file_path is not None and embedding_file_path is not None + model_file_path is not None and embedding_file_path is not None ), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)" dim = None # type: ignore @@ -1906,11 +1352,11 @@ class ELMoEmbeddings(TokenEmbeddings): Default is to concatene the top 3 layers in the LM.""" def __init__( - self, - model: str = "original", - options_file: str = None, - weight_file: str = None, - embedding_mode: str = "all", + self, + model: str = "original", + options_file: str = None, + weight_file: str = None, + embedding_mode: str = "all", ): super().__init__() From 004de59d19a8ebf5ddb6abdc4177d343ac028bc7 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Thu, 16 Dec 2021 20:56:49 +0100 Subject: [PATCH 02/19] combine transformer embeddings --- flair/data.py | 59 +++++++++---------- flair/embeddings/base.py | 73 +++++++++++------------ flair/embeddings/document.py | 25 +++----- flair/embeddings/token.py | 110 +++++++++++++++-------------------- 4 files changed, 123 insertions(+), 144 deletions(-) diff --git a/flair/data.py b/flair/data.py index a5f57fb515..da2bc53e76 100644 --- a/flair/data.py +++ b/flair/data.py @@ -290,10 +290,10 @@ def __len__(self): def __eq__(self, other): return ( - self.value == other.value - and self.score == other.score - and self.head.id_text == other.head.id_text - and self.tail.id_text == other.tail.id_text + self.value == other.value + and self.score == other.score + and self.head.id_text == other.head.id_text + and self.tail.id_text == other.tail.id_text ) @property @@ -403,12 +403,12 @@ class Token(DataPoint): """ def __init__( - self, - text: str, - idx: int = None, - head_id: int = None, - whitespace_after: bool = True, - start_position: int = None, + self, + text: str, + idx: int = None, + head_id: int = None, + whitespace_after: bool = True, + start_position: int = None, ): super().__init__() @@ -610,11 +610,11 @@ class Sentence(DataPoint): """ def __init__( - self, - text: Union[str, List[str]] = [], - use_tokenizer: Union[bool, Tokenizer, Callable] = True, - language_code: str = None, - start_position: int = None, + self, + text: Union[str, List[str]] = [], + use_tokenizer: Union[bool, Tokenizer, Callable] = True, + language_code: str = None, + start_position: int = None, ): """ Class to hold all meta related to a text (tokens, predictions, language code, ...) @@ -835,7 +835,7 @@ def clear_embeddings(self, embedding_names: List[str] = None): @lru_cache(maxsize=1) # cache last context, as training repeats calls def left_context(self, context_length: int, respect_document_boundaries: bool = True): sentence = self - left_context = [] + left_context: List[str] = [] while True: sentence = sentence.previous_sentence() if sentence is None: @@ -853,7 +853,7 @@ def left_context(self, context_length: int, respect_document_boundaries: bool = @lru_cache(maxsize=1) # cache last context, as training repeats calls def right_context(self, context_length: int, respect_document_boundaries: bool = True): sentence = self - right_context = [] + right_context: List[str] = [] while True: sentence = sentence.next_sentence() if sentence is None: @@ -863,11 +863,10 @@ def right_context(self, context_length: int, respect_document_boundaries: bool = right_context += [t.text for t in sentence.tokens] if len(right_context) > context_length: - right_context = right_context[: context_length] + right_context = right_context[:context_length] break return right_context - def to_tagged_string(self, main_tag=None) -> str: list = [] for token in self.tokens: @@ -1177,12 +1176,12 @@ def is_in_memory(self) -> bool: class Corpus: def __init__( - self, - train: Dataset = None, - dev: Dataset = None, - test: Dataset = None, - name: str = "corpus", - sample_missing_splits: Union[bool, str] = True, + self, + train: Dataset = None, + dev: Dataset = None, + test: Dataset = None, + name: str = "corpus", + sample_missing_splits: Union[bool, str] = True, ): # set name self.name: str = name @@ -1221,11 +1220,11 @@ def test(self) -> Optional[Dataset]: return self._test def downsample( - self, - percentage: float = 0.1, - downsample_train=True, - downsample_dev=True, - downsample_test=True, + self, + percentage: float = 0.1, + downsample_train=True, + downsample_dev=True, + downsample_test=True, ): if downsample_train and self._train is not None: diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 8883087429..2ad0283ca6 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -1,25 +1,25 @@ import inspect import logging import os +import random import re import tempfile import zipfile from abc import abstractmethod from io import BytesIO -import random -from typing import Dict, Generic, List, Sequence, Union, Optional +from typing import Dict, Generic, List, Optional, Sequence, Union import torch from torch.nn import Parameter, ParameterList from transformers import ( - PreTrainedTokenizer, - AutoTokenizer, - PretrainedConfig, + CONFIG_MAPPING, AutoConfig, AutoModel, + AutoTokenizer, + PretrainedConfig, + PreTrainedTokenizer, TransfoXLModel, XLNetModel, - CONFIG_MAPPING, ) import flair @@ -171,23 +171,23 @@ class TransformerEmbedding(Embeddings[Sentence]): NO_MAX_SEQ_LENGTH_MODELS = (XLNetModel, TransfoXLModel) def __init__( - self, - model: str = "bert-base-uncased", - fine_tune: bool = True, - layers: str = "all", - layer_mean: bool = True, - subtoken_pooling: str = "first", - cls_pooling: str = "cls", - is_token_embedding: bool = True, - is_document_embedding: bool = True, - allow_long_sentences: bool = False, - use_context: Union[bool, int] = False, - respect_document_boundaries: bool = True, - context_dropout: float = 0.5, - saved_config: Optional[PretrainedConfig] = None, - tokenizer_data: Optional[BytesIO] = None, - name: Optional[str] = None, - **kwargs, + self, + model: str = "bert-base-uncased", + fine_tune: bool = True, + layers: str = "all", + layer_mean: bool = True, + subtoken_pooling: str = "first", + cls_pooling: str = "cls", + is_token_embedding: bool = True, + is_document_embedding: bool = True, + allow_long_sentences: bool = False, + use_context: Union[bool, int] = False, + respect_document_boundaries: bool = True, + context_dropout: float = 0.5, + saved_config: Optional[PretrainedConfig] = None, + tokenizer_data: Optional[BytesIO] = None, + name: Optional[str] = None, + **kwargs, ): self.instance_parameters = self.get_instance_parameters(locals=locals()) super().__init__() @@ -311,8 +311,8 @@ def _get_begin_offset_of_tokenizer(self) -> int: for begin_offset, token in enumerate(tokens): if ( - self.tokenizer.decode([token]) == test_string - or self.tokenizer.decode([token]) == self.tokenizer.unk_token + self.tokenizer.decode([token]) == test_string + or self.tokenizer.decode([token]) == self.tokenizer.unk_token ): break return begin_offset @@ -330,6 +330,7 @@ def _calculate_embedding_length(self) -> int: log.warning( "Token embedding length and Document embedding length vary, due to `first_last` subtoken pooling, this might not be supported" ) + return length @property def embedding_type(self) -> str: @@ -550,7 +551,7 @@ def _gather_tokenized_strings(self, sentences: List[Sentence]): tokenized_sentences.append(tokenized_string) return tokenized_sentences, all_token_subtoken_lengths - def _build_transformer_model_inputs(self, batch_encoding, tokenized_sentences): + def _build_transformer_model_inputs(self, batch_encoding, tokenized_sentences, sentences): model_kwargs = {} input_ids = batch_encoding["input_ids"].to(flair.device) @@ -568,7 +569,7 @@ def _build_transformer_model_inputs(self, batch_encoding, tokenized_sentences): return input_ids, model_kwargs def _combine_strided_sentences( - self, hidden_states: torch.Tensor, sentence_parts_lengths: torch.Tensor + self, hidden_states: torch.Tensor, sentence_parts_lengths: torch.Tensor ) -> List[torch.Tensor]: sentence_idx_offset = 0 sentence_hidden_states = [] @@ -582,7 +583,7 @@ def _combine_strided_sentences( sentence_hidden_state = torch.cat( ( sentence_hidden_state[:, : -1 - self.stride // 2, :], - remainder_sentence_hidden_state[:, 1 + self.stride // 2:, :], + remainder_sentence_hidden_state[:, 1 + self.stride // 2 :, :], ), 1, ) @@ -592,10 +593,10 @@ def _combine_strided_sentences( def _try_document_embedding_shortcut(self, hidden_states, sentences): # cls first pooling can be done without recreating sentence hidden states if ( - self.document_embedding - and not self.token_embedding - and self.cls_pooling == "cls" - and self.initial_cls_token + self.document_embedding + and not self.token_embedding + and self.cls_pooling == "cls" + and self.initial_cls_token ): embeddings_all_document_layers = hidden_states[:, :, 0] if self.layer_mean: @@ -626,7 +627,7 @@ def _extract_document_embeddings(self, sentence_hidden_states, sentences): def _extract_token_embeddings(self, sentence_hidden_states, sentences, all_token_subtoken_lengths): for sentence_hidden_state, sentence, subtoken_lengths in zip( - sentence_hidden_states, sentences, all_token_subtoken_lengths + sentence_hidden_states, sentences, all_token_subtoken_lengths ): subword_start_idx = self.begin_offset @@ -665,7 +666,7 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]): return_tensors="pt", ) - input_ids, model_kwargs = self._build_transformer_model_inputs(batch_encoding, tokenized_sentences) + input_ids, model_kwargs = self._build_transformer_model_inputs(batch_encoding, tokenized_sentences, sentences) hidden_states = self.model(input_ids, **model_kwargs)[-1] @@ -701,7 +702,7 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]): ] if self.document_embedding: - self._extract_document_embedding(sentence_hidden_states, sentences) + self._extract_document_embeddings(sentence_hidden_states, sentences) if self.token_embedding: self._extract_token_embeddings(sentence_hidden_states, sentences, all_token_subtoken_lengths) @@ -750,7 +751,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): # move embeddings from context back to original sentence (if using context) if self.context_length > 0: for original_sentence, expanded_sentence, context_offset in zip( - sentences, expanded_sentences, context_offsets + sentences, expanded_sentences, context_offsets ): for token_idx, token in enumerate(original_sentence): token.set_embedding( diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 3bfcfbb2cb..c47052085f 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -5,17 +5,10 @@ from sklearn.feature_extraction.text import TfidfVectorizer from torch.nn import RNNBase from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence -from transformers import ( - CONFIG_MAPPING, - AutoConfig, - AutoModel, - AutoTokenizer, - PreTrainedTokenizer, -) import flair from flair.data import Sentence -from flair.embeddings.base import Embeddings, ScalarMix, TransformerEmbedding +from flair.embeddings.base import Embeddings, TransformerEmbedding from flair.embeddings.token import FlairEmbeddings, StackedEmbeddings, TokenEmbeddings from flair.nn import LockedDropout, WordDropout @@ -31,14 +24,13 @@ def embedding_type(self) -> str: class TransformerDocumentEmbeddings(DocumentEmbeddings, TransformerEmbedding): - def __init__( - self, - model: str = "bert-base-uncased", # set parameters with different default values - layers: str = "-1", - layer_mean: bool = False, - is_token_embedding: bool = False, - **kwargs, + self, + model: str = "bert-base-uncased", # set parameters with different default values + layers: str = "-1", + layer_mean: bool = False, + is_token_embedding: bool = False, + **kwargs, ): """ Bidirectional transformer embeddings of words from various transformer architectures. @@ -57,7 +49,8 @@ def __init__( layer_mean=layer_mean, is_token_embedding=is_token_embedding, is_document_embedding=True, - **kwargs) + **kwargs, + ) @classmethod def create_from_state(cls, **state): diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index f50f7b1b50..aca594a437 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -3,7 +3,6 @@ import os import re from collections import Counter -from io import BytesIO from pathlib import Path from typing import Dict, List, Optional, Sequence, Union @@ -13,15 +12,6 @@ from bpemb import BPEmb from gensim.models import KeyedVectors from torch import nn -from transformers import ( - CONFIG_MAPPING, - AutoConfig, - AutoModel, - AutoTokenizer, - PreTrainedTokenizer, - TransfoXLModel, - XLNetModel, PretrainedConfig, -) import flair from flair.data import Corpus, Dictionary, Sentence, Token, _iter_dataset @@ -117,12 +107,12 @@ class WordEmbeddings(TokenEmbeddings): """Standard static word embeddings, such as GloVe or FastText.""" def __init__( - self, - embeddings: str, - field: str = None, - fine_tune: bool = False, - force_cpu: bool = True, - stable: bool = False, + self, + embeddings: str, + field: str = None, + fine_tune: bool = False, + force_cpu: bool = True, + stable: bool = False, ): """ Initializes classic word embeddings. Constructor downloads required files if not there. @@ -375,10 +365,10 @@ class CharacterEmbeddings(TokenEmbeddings): """Character embeddings of words, as proposed in Lample et al., 2016.""" def __init__( - self, - path_to_char_dict: str = None, - char_embedding_dim: int = 25, - hidden_size_char: int = 25, + self, + path_to_char_dict: str = None, + char_embedding_dim: int = 25, + hidden_size_char: int = 25, ): """Uses the default character dictionary if none provided.""" @@ -474,13 +464,13 @@ class FlairEmbeddings(TokenEmbeddings): """Contextual string embeddings of words, as proposed in Akbik et al., 2018.""" def __init__( - self, - model, - fine_tune: bool = False, - chars_per_chunk: int = 512, - with_whitespace: bool = True, - tokenized_lm: bool = True, - is_lower: bool = False, + self, + model, + fine_tune: bool = False, + chars_per_chunk: int = 512, + with_whitespace: bool = True, + tokenized_lm: bool = True, + is_lower: bool = False, ): """ initializes contextual string embeddings using a character-level language model. @@ -789,11 +779,11 @@ def __str__(self): class PooledFlairEmbeddings(TokenEmbeddings): def __init__( - self, - contextual_embeddings: Union[str, FlairEmbeddings], - pooling: str = "min", - only_capitalized: bool = False, - **kwargs, + self, + contextual_embeddings: Union[str, FlairEmbeddings], + pooling: str = "min", + only_capitalized: bool = False, + **kwargs, ): super().__init__() @@ -895,12 +885,11 @@ def __setstate__(self, d): class TransformerWordEmbeddings(TokenEmbeddings, TransformerEmbedding): - def __init__( - self, - model: str = "bert-base-uncased", - is_document_embedding: bool = False, - **kwargs, + self, + model: str = "bert-base-uncased", + is_document_embedding: bool = False, + **kwargs, ): """ Bidirectional transformer embeddings of words from various transformer architectures. @@ -914,11 +903,8 @@ def __init__( """ super().__init__() TransformerEmbedding.__init__( - self, - model=model, - is_token_embedding=True, - is_document_embedding=is_document_embedding, - **kwargs) + self, model=model, is_token_embedding=True, is_document_embedding=is_document_embedding, **kwargs + ) @classmethod def create_from_state(cls, **state): @@ -1007,11 +993,11 @@ class OneHotEmbeddings(TokenEmbeddings): """One-hot encoded embeddings.""" def __init__( - self, - vocab_dictionary: Dictionary, - field: str = "text", - embedding_length: int = 300, - stable: bool = False, + self, + vocab_dictionary: Dictionary, + field: str = "text", + embedding_length: int = 300, + stable: bool = False, ): """ Initializes one-hot encoded word embeddings and a trainable embedding layer @@ -1146,7 +1132,7 @@ def __str__(self): class MuseCrosslingualEmbeddings(TokenEmbeddings): def __init__( - self, + self, ): self.name: str = "muse-crosslingual" self.static_embeddings = True @@ -1272,14 +1258,14 @@ def __setstate__(self, state): class BytePairEmbeddings(TokenEmbeddings): def __init__( - self, - language: str = None, - dim: int = 50, - syllables: int = 100000, - cache_dir=None, - model_file_path: Path = None, - embedding_file_path: Path = None, - **kwargs, + self, + language: str = None, + dim: int = 50, + syllables: int = 100000, + cache_dir=None, + model_file_path: Path = None, + embedding_file_path: Path = None, + **kwargs, ): """ Initializes BP embeddings. Constructor downloads required files if not there. @@ -1292,7 +1278,7 @@ def __init__( self.name: str = f"bpe-{language}-{syllables}-{dim}" else: assert ( - model_file_path is not None and embedding_file_path is not None + model_file_path is not None and embedding_file_path is not None ), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)" dim = None # type: ignore @@ -1352,11 +1338,11 @@ class ELMoEmbeddings(TokenEmbeddings): Default is to concatene the top 3 layers in the LM.""" def __init__( - self, - model: str = "original", - options_file: str = None, - weight_file: str = None, - embedding_mode: str = "all", + self, + model: str = "original", + options_file: str = None, + weight_file: str = None, + embedding_mode: str = "all", ): super().__init__() From f059bbc29c9c384e0ab51673e6714f49c4ed2fce Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Thu, 16 Dec 2021 23:51:50 +0100 Subject: [PATCH 03/19] fix pooling_operation in TransformerEmbeddings --- flair/embeddings/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 2ad0283ca6..258a6c6e64 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -324,7 +324,7 @@ def _calculate_embedding_length(self) -> int: length = self.model.config.hidden_size # in case of doubt: token embedding has higher priority than document embedding - if self.token_embedding and self.pooling_operation == "first_last": + if self.token_embedding and self.subtoken_pooling == "first_last": length *= 2 if self.document_embedding: log.warning( From cd3e89314c90d0a7e580139b092a168cd70c9a7c Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 17 Dec 2021 00:18:48 +0100 Subject: [PATCH 04/19] fix loading state_dict --- flair/embeddings/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 258a6c6e64..a11594ebf0 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -450,7 +450,7 @@ def __setstate__(self, state): self.__dict__[key] = embedding.__dict__[key] if model_state_dict: - self.load_state_dict(model_state_dict) + self.model.load_state_dict(model_state_dict) @classmethod def create_from_state(cls, **state): From 9b1b333c6cbdc4e9e5a63ffed6fb39f174f6d5a1 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 17 Dec 2021 00:42:00 +0100 Subject: [PATCH 05/19] fix size call --- flair/embeddings/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index a11594ebf0..6540aa1c36 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -602,7 +602,7 @@ def _try_document_embedding_shortcut(self, hidden_states, sentences): if self.layer_mean: document_embs = torch.mean(embeddings_all_document_layers, dim=0) else: - document_embs = embeddings_all_document_layers.view(-1, embeddings_all_document_layers.size[-1]) + document_embs = embeddings_all_document_layers.view(-1, embeddings_all_document_layers.size()[-1]) for (document_emb, sentence) in zip(document_embs, sentences): sentence.set_embedding(self.name, document_emb) return True From fb515e8ab2bd62bcb4a402f5cad41f5724aad39d Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 17 Dec 2021 15:22:20 +0100 Subject: [PATCH 06/19] fix transformer embeddings tests --- flair/embeddings/base.py | 43 ++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 6540aa1c36..865446a1c3 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -270,7 +270,6 @@ def __init__( self.cls_pooling = cls_pooling self.subtoken_pooling = subtoken_pooling - self.layer_mean = layer_mean self.fine_tune = fine_tune self.static_embeddings = not self.fine_tune @@ -347,9 +346,11 @@ def _tokenizer_bytes(self): with tempfile.TemporaryDirectory() as temp_dir: files = self.tokenizer.save_pretrained(temp_dir) zip_data = BytesIO() - zip_obj = zipfile.ZipFile(zip_data, "w+") + zip_obj = zipfile.ZipFile(zip_data, "w") for f in files: - zip_obj.write(f, os.path.relpath(f, temp_dir)) + # transformers returns the "added_tokens.json" even if it doesn't create it + if os.path.exists(f): + zip_obj.write(f, os.path.relpath(f, temp_dir)) zip_data.seek(0) return zip_data @@ -560,12 +561,24 @@ def _build_transformer_model_inputs(self, batch_encoding, tokenized_sentences, s model_kwargs["attention_mask"] = batch_encoding["attention_mask"].to(flair.device) # set language IDs for XLM-style transformers - if self.use_lang_emb: + if self.use_lang_emb and self.tokenizer.lang2id is not None: model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype) - for s_id, sentence_text in enumerate(tokenized_sentences): - sequence_length = len(sentence_text) - lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0) # type: ignore - model_kwargs["langs"][s_id][:sequence_length] = lang_id + if not self.allow_long_sentences: + for s_id, sentence_text in enumerate(tokenized_sentences): + lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0) + model_kwargs["langs"][s_id] = lang_id + else: + sentence_part_lengths = torch.unique( + batch_encoding["overflow_to_sample_mapping"], + return_counts=True, + sorted=True, + )[1].tolist() + sentence_idx = 0 + for sentence, part_length in zip(sentences, sentence_part_lengths): + lang_id = self.tokenizer.lang2id.get(sentence.get_language_code(), 0) + model_kwargs["langs"][sentence_idx : sentence_idx + part_length] = lang_id + sentence_idx += part_length + return input_ids, model_kwargs def _combine_strided_sentences( @@ -602,7 +615,7 @@ def _try_document_embedding_shortcut(self, hidden_states, sentences): if self.layer_mean: document_embs = torch.mean(embeddings_all_document_layers, dim=0) else: - document_embs = embeddings_all_document_layers.view(-1, embeddings_all_document_layers.size()[-1]) + document_embs = torch.flatten(embeddings_all_document_layers.permute((1, 0, 2)), 1) for (document_emb, sentence) in zip(document_embs, sentences): sentence.set_embedding(self.name, document_emb) return True @@ -636,10 +649,9 @@ def _extract_token_embeddings(self, sentence_hidden_states, sentences, all_token token.set_embedding(self.name, torch.zeros(self.embedding_length)) continue subword_end_idx = subword_start_idx + n_subtokens - + assert subword_start_idx < subword_end_idx <= sentence_hidden_state.size()[1] current_embeddings = sentence_hidden_state[:, subword_start_idx:subword_end_idx] subword_start_idx = subword_end_idx - if self.subtoken_pooling == "first": final_embedding = current_embeddings[:, 0] elif self.subtoken_pooling == "last": @@ -651,6 +663,11 @@ def _extract_token_embeddings(self, sentence_hidden_states, sentences, all_token else: raise ValueError(f"subtoken pooling method: `{self.subtoken_pooling}` is not implemented") + if self.layer_mean: + final_embedding = final_embedding.mean(dim=0) + else: + final_embedding = torch.flatten(final_embedding) + token.set_embedding(self.name, final_embedding) def _add_embeddings_to_sentences(self, sentences: List[Sentence]): @@ -693,11 +710,11 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]): )[1].tolist(), ) else: - sentence_hidden_states = hidden_states.tolist() + sentence_hidden_states = list(hidden_states.permute((1, 0, 2, 3))) # remove padding tokens sentence_hidden_states = [ - sentence_hidden_state[:, : len(subtokens), :] + sentence_hidden_state[:, : len(subtokens) + 1, :] for (subtokens, sentence_hidden_state) in zip(tokenized_sentences, sentence_hidden_states) ] From 862fbd3e70055902db7e26519272f5bfdde41e84 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Dec 2021 08:15:02 +0100 Subject: [PATCH 07/19] load state dict directly --- flair/embeddings/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 865446a1c3..9caca08ec6 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -444,15 +444,12 @@ def __setstate__(self, state): if "embedding_length_internal" in state: del state["embedding_length_internal"] - embedding = self.create_from_state(saved_config=config, **state) + embedding = self.create_from_state(saved_config=config, **state, state_dict=model_state_dict) # copy values from new embedding for key in embedding.__dict__.keys(): self.__dict__[key] = embedding.__dict__[key] - if model_state_dict: - self.model.load_state_dict(model_state_dict) - @classmethod def create_from_state(cls, **state): return cls(**state) From d1197299f7ab6c291f5466f4591d2dd88748109e Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Dec 2021 08:41:55 +0100 Subject: [PATCH 08/19] don't save model state dict twice --- flair/embeddings/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 9caca08ec6..ebaf3334fb 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -393,7 +393,6 @@ def __getstate__(self): "config_state_dict": config_dict, "tokenizer_data": tokenizer_data, "name": self.name, - "model_state_dict": model_state_dict, "context_length": self.context_length, "respect_document_boundaries": self.respect_document_boundaries, "context_dropout": self.context_dropout, From a338821d34454e561d2c2c54ef3b7296b92c6a97 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Dec 2021 09:23:43 +0100 Subject: [PATCH 09/19] also load model with kwargs --- flair/embeddings/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index ebaf3334fb..97ec59a892 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -211,7 +211,7 @@ def __init__( config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) self.model = AutoModel.from_pretrained(model, config=config) else: - self.model = AutoModel.from_config(saved_config) + self.model = AutoModel.from_config(saved_config, **kwargs) self.truncate = True @@ -442,8 +442,10 @@ def __setstate__(self, state): if "embedding_length_internal" in state: del state["embedding_length_internal"] - - embedding = self.create_from_state(saved_config=config, **state, state_dict=model_state_dict) + if model_state_dict: + embedding = self.create_from_state(saved_config=config, **state, state_dict=model_state_dict) + else: + embedding = self.create_from_state(saved_config=config, **state) # copy values from new embedding for key in embedding.__dict__.keys(): From c5b7879c0fdda54788b1eb666ce10fbae6336d9c Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Dec 2021 10:18:40 +0100 Subject: [PATCH 10/19] don't save config and tokenizer data as instance parameters --- flair/embeddings/base.py | 2 ++ flair/embeddings/document.py | 1 - flair/embeddings/token.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 97ec59a892..0933c7df0d 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -190,6 +190,8 @@ def __init__( **kwargs, ): self.instance_parameters = self.get_instance_parameters(locals=locals()) + del self.instance_parameters["saved_config"] + del self.instance_parameters["tokenizer_data"] super().__init__() # temporary fix to disable tokenizer parallelism warning # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index c47052085f..30539e2f61 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -41,7 +41,6 @@ def __init__( :param layer_mean: If True, uses a scalar mix of layers as embedding :param fine_tune: If True, allows transformers to be fine-tuned during training """ - super().__init__() TransformerEmbedding.__init__( self, model=model, diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index aca594a437..fb78f29380 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -901,7 +901,6 @@ def __init__( :param layer_mean: If True, uses a scalar mix of layers as embedding :param fine_tune: If True, allows transformers to be fine-tuned during training """ - super().__init__() TransformerEmbedding.__init__( self, model=model, is_token_embedding=True, is_document_embedding=is_document_embedding, **kwargs ) From 63cd8711b86b6af0c67a06f7c03f66f782439962 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Dec 2021 11:09:57 +0100 Subject: [PATCH 11/19] stop batch size from loading --- flair/embeddings/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 0933c7df0d..04943296c7 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -406,6 +406,9 @@ def __setstate__(self, state): config_state_dict = state.pop("config_state_dict", None) model_state_dict = state.pop("model_state_dict", None) + # legacy TransformerDocumentEmbedding + state.pop("batch_size", None) + if "base_model_name" in state: state["model"] = state.pop("base_model_name") From a8dba68dd73a7fee6d681f78b42648219c7a842f Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Dec 2021 11:28:56 +0100 Subject: [PATCH 12/19] load state dict afterwards if provided --- flair/embeddings/base.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 04943296c7..2588d218be 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -266,9 +266,9 @@ def __init__( if layers == "all": # send mini-token through to check how many layers the model has hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] - self.layer_indexes = [int(x) for x in range(len(hidden_states))] + self.layer_indexes = list(range(len(hidden_states))) else: - self.layer_indexes = [int(x) for x in layers.split(",")] + self.layer_indexes = list(map(int,layers.split(","))) self.cls_pooling = cls_pooling self.subtoken_pooling = subtoken_pooling @@ -377,9 +377,6 @@ def _get_processed_token_text(self, token: Token) -> str: def __getstate__(self): config_dict = self.model.config.to_dict() - # not necessary when loaded via model, but we keep it for now. - model_state_dict = self.model.state_dict() - tokenizer_data = self._tokenizer_bytes() model_state = { @@ -445,17 +442,15 @@ def __setstate__(self, state): else: config = None - if "embedding_length_internal" in state: - del state["embedding_length_internal"] - if model_state_dict: - embedding = self.create_from_state(saved_config=config, **state, state_dict=model_state_dict) - else: - embedding = self.create_from_state(saved_config=config, **state) + embedding = self.create_from_state(saved_config=config, **state) # copy values from new embedding for key in embedding.__dict__.keys(): self.__dict__[key] = embedding.__dict__[key] + if model_state_dict: + self.load_state_dict(model_state_dict) + @classmethod def create_from_state(cls, **state): return cls(**state) From ab1086bd26d1dfca766abdc037ed9ee84dc2d3ea Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Dec 2021 13:32:01 +0100 Subject: [PATCH 13/19] load model state dict --- flair/embeddings/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 2588d218be..0af530d6a5 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -405,6 +405,7 @@ def __setstate__(self, state): # legacy TransformerDocumentEmbedding state.pop("batch_size", None) + state.pop("embedding_length_internal", None) if "base_model_name" in state: state["model"] = state.pop("base_model_name") @@ -449,7 +450,7 @@ def __setstate__(self, state): self.__dict__[key] = embedding.__dict__[key] if model_state_dict: - self.load_state_dict(model_state_dict) + self.model.load_state_dict(model_state_dict) @classmethod def create_from_state(cls, **state): From a7edfbc31d05a99b543f765e569c99599654c857 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Dec 2021 13:56:45 +0100 Subject: [PATCH 14/19] dummy commit to retrigger github actions --- flair/embeddings/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 0af530d6a5..af0a57b925 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -436,12 +436,12 @@ def __setstate__(self, state): # legacy TransformerDocumentEmbedding state["cls_pooling"] = state.pop("pooling") + config = None + if config_state_dict: model_type = config_state_dict.get("model_type", "bert") config_class = CONFIG_MAPPING[model_type] config = config_class.from_dict(config_state_dict) - else: - config = None embedding = self.create_from_state(saved_config=config, **state) From 8452bc95d374f9a3b025bf01173ea107e328be95 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Wed, 29 Dec 2021 11:59:05 +0100 Subject: [PATCH 15/19] remove legacy flag --- flair/embeddings/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index af0a57b925..7cfa97db72 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -406,6 +406,8 @@ def __setstate__(self, state): # legacy TransformerDocumentEmbedding state.pop("batch_size", None) state.pop("embedding_length_internal", None) + # legacy TransformerTokenEmbedding + state.pop("memory_effective_training", None) if "base_model_name" in state: state["model"] = state.pop("base_model_name") From 3ac3231f597e8d46ea35e343bcf6a61bfc396e45 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Wed, 29 Dec 2021 11:59:56 +0100 Subject: [PATCH 16/19] don't compute gradients if not finetune mode --- flair/embeddings/base.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 7cfa97db72..2e3c972344 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -686,17 +686,16 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]): input_ids, model_kwargs = self._build_transformer_model_inputs(batch_encoding, tokenized_sentences, sentences) - hidden_states = self.model(input_ids, **model_kwargs)[-1] - - # make the tuple a tensor; makes working with it easier. - hidden_states = torch.stack(hidden_states) - - # only use layers that will be outputted - hidden_states = hidden_states[self.layer_indexes, :, :] - gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() with gradient_context: + hidden_states = self.model(input_ids, **model_kwargs)[-1] + + # make the tuple a tensor; makes working with it easier. + hidden_states = torch.stack(hidden_states) + + # only use layers that will be outputted + hidden_states = hidden_states[self.layer_indexes, :, :] if self._try_document_embedding_shortcut(hidden_states, sentences): return From cc0750e31bc70498f2bc0df54f983a3dfa3136a6 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Wed, 29 Dec 2021 12:28:43 +0100 Subject: [PATCH 17/19] context embedding also for TransformerDocumentEmbeddings --- flair/embeddings/base.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 2e3c972344..8b29a63227 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -770,8 +770,11 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): for original_sentence, expanded_sentence, context_offset in zip( sentences, expanded_sentences, context_offsets ): - for token_idx, token in enumerate(original_sentence): - token.set_embedding( - self.name, - expanded_sentence[token_idx + context_offset].get_embedding(self.name), - ) + if self.token_embedding: + for token_idx, token in enumerate(original_sentence): + token.set_embedding( + self.name, + expanded_sentence[token_idx + context_offset].get_embedding(self.name), + ) + if self.document_embedding: + original_sentence.set_embedding(self.name, expanded_sentence.get_embedding(self.name)) From 12f78cbb21770903bb96920621c131a70980719a Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Wed, 29 Dec 2021 12:46:36 +0100 Subject: [PATCH 18/19] fix document embedding extraction --- flair/embeddings/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 8b29a63227..fd299240d3 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -628,9 +628,9 @@ def _extract_document_embeddings(self, sentence_hidden_states, sentences): index_of_cls_token = 0 if self.initial_cls_token else -1 embedding_all_document_layers = sentence_hidden_state[:, index_of_cls_token, :] elif self.cls_pooling == "mean": - embedding_all_document_layers = sentence_hidden_state.mean(dim=2) + embedding_all_document_layers = sentence_hidden_state.mean(dim=1) elif self.cls_pooling == "max": - _, embedding_all_document_layers = sentence_hidden_state.max(dim=2) + embedding_all_document_layers, _ = sentence_hidden_state.max(dim=1) else: raise ValueError(f"cls pooling method: `{self.cls_pooling}` is not implemented") if self.layer_mean: From bf8f757491c5b33ef8d8136f7df34c1d4cc158ce Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Thu, 30 Dec 2021 10:42:45 +0100 Subject: [PATCH 19/19] fix padding removal --- flair/embeddings/base.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index fd299240d3..8b436109ec 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -238,9 +238,6 @@ def __init__( if not self.token_embedding and not self.document_embedding: raise ValueError("either 'is_token_embedding' or 'is_document_embedding' needs to be set.") - if self.token_embedding and self.document_embedding: - log.info("Using TransformerEmbedding for both token embeddings and document embeddings is experimental") - if self.document_embedding and cls_pooling not in ["cls", "max", "mean"]: raise ValueError(f"Document Pooling operation `{cls_pooling}` is not defined for TransformerEmbedding") @@ -268,7 +265,7 @@ def __init__( hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] self.layer_indexes = list(range(len(hidden_states))) else: - self.layer_indexes = list(map(int,layers.split(","))) + self.layer_indexes = list(map(int, layers.split(","))) self.cls_pooling = cls_pooling self.subtoken_pooling = subtoken_pooling @@ -526,6 +523,7 @@ def _reconstruct_tokens_from_subtokens(self, sentence, subtokens): def _gather_tokenized_strings(self, sentences: List[Sentence]): tokenized_sentences = [] all_token_subtoken_lengths = [] + subtoken_lengths = [] for sentence in sentences: # subtokenize the sentence @@ -548,10 +546,11 @@ def _gather_tokenized_strings(self, sentences: List[Sentence]): all_token_subtoken_lengths.append( self._reconstruct_tokens_from_subtokens(sentence, subtokenized_sentence) ) + subtoken_lengths.append(len(subtokenized_sentence)) # remember tokenized sentences and their subtokenization tokenized_sentences.append(tokenized_string) - return tokenized_sentences, all_token_subtoken_lengths + return tokenized_sentences, all_token_subtoken_lengths, subtoken_lengths def _build_transformer_model_inputs(self, batch_encoding, tokenized_sentences, sentences): model_kwargs = {} @@ -672,7 +671,7 @@ def _extract_token_embeddings(self, sentence_hidden_states, sentences, all_token token.set_embedding(self.name, final_embedding) def _add_embeddings_to_sentences(self, sentences: List[Sentence]): - tokenized_sentences, all_token_subtoken_lengths = self._gather_tokenized_strings(sentences) + tokenized_sentences, all_token_subtoken_lengths, subtoken_lengths = self._gather_tokenized_strings(sentences) # encode inputs batch_encoding = self.tokenizer( @@ -714,8 +713,8 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]): # remove padding tokens sentence_hidden_states = [ - sentence_hidden_state[:, : len(subtokens) + 1, :] - for (subtokens, sentence_hidden_state) in zip(tokenized_sentences, sentence_hidden_states) + sentence_hidden_state[:, : subtoken_length + 1, :] + for (subtoken_length, sentence_hidden_state) in zip(subtoken_lengths, sentence_hidden_states) ] if self.document_embedding: