From ce1c4594c4505585cf5eeb41f9f5a9bfe21e473a Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Wed, 25 Jan 2023 16:52:49 +0100 Subject: [PATCH 1/6] Add context boundaries --- flair/data.py | 2 +- flair/embeddings/transformer.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/flair/data.py b/flair/data.py index 94ea19ff9c..28a495e60a 100644 --- a/flair/data.py +++ b/flair/data.py @@ -515,7 +515,7 @@ def idx(self) -> int: if isinstance(self._internal_index, int): return self._internal_index else: - raise ValueError + return -1 @property def text(self) -> str: diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index c77d5c0d4d..46dc75608e 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -295,6 +295,7 @@ def __init__( force_max_length: bool = False, feature_extractor: Optional[FeatureExtractionMixin] = None, needs_manual_ocr: Optional[bool] = None, + use_context_separator: bool = True, ): self.name = name super().__init__() @@ -313,6 +314,7 @@ def __init__( self.fine_tune = fine_tune self.force_max_length = force_max_length self.feature_extractor = feature_extractor + self.use_context_separator = use_context_separator tokenizer_params = list(inspect.signature(self.tokenizer.__call__).parameters.keys()) self.tokenizer_needs_ocr_boxes = "boxes" in tokenizer_params @@ -345,6 +347,7 @@ def to_args(self): "use_lang_emb": self.use_lang_emb, "force_max_length": self.force_max_length, "feature_extractor": self.feature_extractor, + "use_context_separator": self.use_context_separator, } if hasattr(self, "needs_manual_ocr"): args["needs_manual_ocr"] = self.needs_manual_ocr @@ -612,7 +615,7 @@ def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[To lengths.append(len(sentence)) return sentence_tokens, offsets, lengths - def __expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]: + def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]: expand_context = self.context_length > 0 and ( not self.training or random.randint(1, 100) > (self.context_dropout * 100) ) @@ -624,6 +627,10 @@ def __expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]: left_context = sentence.left_context(self.context_length, self.respect_document_boundaries) right_context = sentence.right_context(self.context_length, self.respect_document_boundaries) + if self.use_context_separator: + left_context = left_context + [Token(self.tokenizer.sep_token)] + right_context = [Token(self.tokenizer.sep_token)] + right_context + expanded_sentence = left_context + sentence.tokens + right_context context_length = len(left_context) @@ -926,6 +933,7 @@ def __init__( name: Optional[str] = None, force_max_length: bool = False, needs_manual_ocr: Optional[bool] = None, + use_context_separator: bool = True, **kwargs, ): self.instance_parameters = self.get_instance_parameters(locals=locals()) @@ -1038,6 +1046,8 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: self.embedding_length_internal = self._calculate_embedding_length(transformer_model) if needs_manual_ocr is not None: self.needs_manual_ocr = needs_manual_ocr + if use_context_separator is not None: + self.use_context_separator = use_context_separator super().__init__(**self.to_args()) From 347cc78ee53e2114de173e43940a105c7f95f022 Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Thu, 26 Jan 2023 12:27:50 +0100 Subject: [PATCH 2/6] Explicit document boundaries --- flair/embeddings/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 46dc75608e..8cd1bea2b3 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -609,7 +609,7 @@ def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[To sentence_tokens = [] for sentence in sentences: # flair specific pre-tokenization - tokens, offset = self.__expand_sentence_with_context(sentence) + tokens, offset = self._expand_sentence_with_context(sentence) sentence_tokens.append(tokens) offsets.append(offset) lengths.append(len(sentence)) From 4f354ce2cee99d38bc8d9ea2b5169cb5b0008eba Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Thu, 26 Jan 2023 15:57:56 +0100 Subject: [PATCH 3/6] Add option to add different context tokens --- flair/embeddings/transformer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 8cd1bea2b3..efc9bb9184 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -628,8 +628,8 @@ def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]: right_context = sentence.right_context(self.context_length, self.respect_document_boundaries) if self.use_context_separator: - left_context = left_context + [Token(self.tokenizer.sep_token)] - right_context = [Token(self.tokenizer.sep_token)] + right_context + left_context = left_context + [Token(self.use_context_separator)] + right_context = [Token(self.use_context_separator)] + right_context expanded_sentence = left_context + sentence.tokens + right_context @@ -1046,8 +1046,16 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: self.embedding_length_internal = self._calculate_embedding_length(transformer_model) if needs_manual_ocr is not None: self.needs_manual_ocr = needs_manual_ocr + + # add special context + self.use_context_separator = False if use_context_separator is not None: - self.use_context_separator = use_context_separator + if use_context_separator is True: + self.use_context_separator = self.tokenizer.sep_token + if type(use_context_separator) == str: + self.use_context_separator = use_context_separator + self.tokenizer.add_special_tokens({"additional_special_tokens": [use_context_separator]}) + transformer_model.resize_token_embeddings(len(self.tokenizer)) super().__init__(**self.to_args()) @@ -1055,6 +1063,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: self.initial_cls_token: bool = self._has_initial_cls_token() self.model = transformer_model + self.to(flair.device) # when initializing, embeddings are in eval mode by default self.eval() From c4d7d9ede46325cc4bdfc6b51cff876955e32932 Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Fri, 27 Jan 2023 14:04:45 +0100 Subject: [PATCH 4/6] Support legacy models without context separator --- flair/embeddings/transformer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index efc9bb9184..3f0d12a940 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -1047,15 +1047,11 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: if needs_manual_ocr is not None: self.needs_manual_ocr = needs_manual_ocr - # add special context - self.use_context_separator = False - if use_context_separator is not None: - if use_context_separator is True: - self.use_context_separator = self.tokenizer.sep_token - if type(use_context_separator) == str: - self.use_context_separator = use_context_separator - self.tokenizer.add_special_tokens({"additional_special_tokens": [use_context_separator]}) - transformer_model.resize_token_embeddings(len(self.tokenizer)) + # If we use a context separator, add a new special token + self.use_context_separator = use_context_separator + if use_context_separator: + self.tokenizer.add_special_tokens({"additional_special_tokens": ["[KONTEXT]"]}) + transformer_model.resize_token_embeddings(len(self.tokenizer)) super().__init__(**self.to_args()) @@ -1122,6 +1118,10 @@ def __setstate__(self, state): layer_indexes = state.pop("layer_indexes") state["layers"] = ",".join(map(str, layer_indexes)) + if "use_context_separator" not in state: + # legacy Flair <= 0.12 + state["use_context_separator"] = False + if "use_scalar_mix" in state: # legacy Flair <= 0.7 state["layer_mean"] = state.pop("use_scalar_mix") From 6da65a4f665bb333f639067bc5f868046880e63b Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Fri, 27 Jan 2023 16:46:19 +0100 Subject: [PATCH 5/6] Add correct token --- flair/embeddings/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 3f0d12a940..9159f1c04a 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -628,8 +628,8 @@ def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]: right_context = sentence.right_context(self.context_length, self.respect_document_boundaries) if self.use_context_separator: - left_context = left_context + [Token(self.use_context_separator)] - right_context = [Token(self.use_context_separator)] + right_context + left_context = left_context + [Token("[KONTEXT]")] + right_context = [Token("[KONTEXT]")] + right_context expanded_sentence = left_context + sentence.tokens + right_context From fceba71ce038fba03ca5b6e87fa5ee2d94800305 Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Mon, 30 Jan 2023 12:14:31 +0100 Subject: [PATCH 6/6] Add tests and make special marker a constant --- flair/embeddings/transformer.py | 8 ++-- .../test_transformer_word_embeddings.py | 42 +++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 9159f1c04a..036d95e845 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -37,6 +37,8 @@ register_embeddings, ) +SENTENCE_BOUNDARY_TAG: str = "[SATZ]" + @torch.jit.script_if_tracing def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tensor: @@ -628,8 +630,8 @@ def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]: right_context = sentence.right_context(self.context_length, self.respect_document_boundaries) if self.use_context_separator: - left_context = left_context + [Token("[KONTEXT]")] - right_context = [Token("[KONTEXT]")] + right_context + left_context = left_context + [Token(SENTENCE_BOUNDARY_TAG)] + right_context = [Token(SENTENCE_BOUNDARY_TAG)] + right_context expanded_sentence = left_context + sentence.tokens + right_context @@ -1050,7 +1052,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: # If we use a context separator, add a new special token self.use_context_separator = use_context_separator if use_context_separator: - self.tokenizer.add_special_tokens({"additional_special_tokens": ["[KONTEXT]"]}) + self.tokenizer.add_special_tokens({"additional_special_tokens": [SENTENCE_BOUNDARY_TAG]}) transformer_model.resize_token_embeddings(len(self.tokenizer)) super().__init__(**self.to_args()) diff --git a/tests/embeddings/test_transformer_word_embeddings.py b/tests/embeddings/test_transformer_word_embeddings.py index fe00de0903..4ed6db1df9 100644 --- a/tests/embeddings/test_transformer_word_embeddings.py +++ b/tests/embeddings/test_transformer_word_embeddings.py @@ -97,6 +97,48 @@ def forward( sentence.clear_embeddings() assert torch.isclose(jit_token_embedding, loaded_jit_token_embedding).all() + def test_transformers_context_expansion(self, results_base_path): + emb = TransformerWordEmbeddings( + "distilbert-base-uncased", use_context=True, use_context_separator=True, respect_document_boundaries=True + ) + + # previous and next sentence as context + sentence_previous = Sentence("How is it?") + sentence_next = Sentence("Then again, maybe not...") + + # test expansion for sentence without context + sentence = Sentence("This is great!") + expanded, _ = emb._expand_sentence_with_context(sentence=sentence) + assert " ".join([token.text for token in expanded]) == "[SATZ] This is great ! [SATZ]" + + # test expansion for with previous and next as context + sentence = Sentence("This is great.") + sentence._previous_sentence = sentence_previous + sentence._next_sentence = sentence_next + expanded, _ = emb._expand_sentence_with_context(sentence=sentence) + assert ( + " ".join([token.text for token in expanded]) + == "How is it ? [SATZ] This is great . [SATZ] Then again , maybe not ..." + ) + + # test expansion if first sentence is document boundary + sentence = Sentence("This is great?") + sentence_previous.is_document_boundary = True + sentence._previous_sentence = sentence_previous + sentence._next_sentence = sentence_next + expanded, _ = emb._expand_sentence_with_context(sentence=sentence) + assert ( + " ".join([token.text for token in expanded]) == "[SATZ] This is great ? [SATZ] Then again , maybe not ..." + ) + + # test expansion if we don't use context + emb.context_length = 0 + sentence = Sentence("I am here.") + sentence._previous_sentence = sentence_previous + sentence._next_sentence = sentence_next + expanded, _ = emb._expand_sentence_with_context(sentence=sentence) + assert " ".join([token.text for token in expanded]) == "I am here ." + @pytest.mark.integration def test_layoutlm_embeddings(self): sentence = Sentence(["I", "love", "Berlin"])