Skip to content

Commit

Permalink
Add tests and make special marker a constant
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Jan 30, 2023
1 parent 6da65a4 commit fceba71
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
8 changes: 5 additions & 3 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
register_embeddings,
)

SENTENCE_BOUNDARY_TAG: str = "[SATZ]"


@torch.jit.script_if_tracing
def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -628,8 +630,8 @@ def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]:
right_context = sentence.right_context(self.context_length, self.respect_document_boundaries)

if self.use_context_separator:
left_context = left_context + [Token("[KONTEXT]")]
right_context = [Token("[KONTEXT]")] + right_context
left_context = left_context + [Token(SENTENCE_BOUNDARY_TAG)]
right_context = [Token(SENTENCE_BOUNDARY_TAG)] + right_context

expanded_sentence = left_context + sentence.tokens + right_context

Expand Down Expand Up @@ -1050,7 +1052,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
# If we use a context separator, add a new special token
self.use_context_separator = use_context_separator
if use_context_separator:
self.tokenizer.add_special_tokens({"additional_special_tokens": ["[KONTEXT]"]})
self.tokenizer.add_special_tokens({"additional_special_tokens": [SENTENCE_BOUNDARY_TAG]})
transformer_model.resize_token_embeddings(len(self.tokenizer))

super().__init__(**self.to_args())
Expand Down
42 changes: 42 additions & 0 deletions tests/embeddings/test_transformer_word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,48 @@ def forward(
sentence.clear_embeddings()
assert torch.isclose(jit_token_embedding, loaded_jit_token_embedding).all()

def test_transformers_context_expansion(self, results_base_path):
emb = TransformerWordEmbeddings(
"distilbert-base-uncased", use_context=True, use_context_separator=True, respect_document_boundaries=True
)

# previous and next sentence as context
sentence_previous = Sentence("How is it?")
sentence_next = Sentence("Then again, maybe not...")

# test expansion for sentence without context
sentence = Sentence("This is great!")
expanded, _ = emb._expand_sentence_with_context(sentence=sentence)
assert " ".join([token.text for token in expanded]) == "[SATZ] This is great ! [SATZ]"

# test expansion for with previous and next as context
sentence = Sentence("This is great.")
sentence._previous_sentence = sentence_previous
sentence._next_sentence = sentence_next
expanded, _ = emb._expand_sentence_with_context(sentence=sentence)
assert (
" ".join([token.text for token in expanded])
== "How is it ? [SATZ] This is great . [SATZ] Then again , maybe not ..."
)

# test expansion if first sentence is document boundary
sentence = Sentence("This is great?")
sentence_previous.is_document_boundary = True
sentence._previous_sentence = sentence_previous
sentence._next_sentence = sentence_next
expanded, _ = emb._expand_sentence_with_context(sentence=sentence)
assert (
" ".join([token.text for token in expanded]) == "[SATZ] This is great ? [SATZ] Then again , maybe not ..."
)

# test expansion if we don't use context
emb.context_length = 0
sentence = Sentence("I am here.")
sentence._previous_sentence = sentence_previous
sentence._next_sentence = sentence_next
expanded, _ = emb._expand_sentence_with_context(sentence=sentence)
assert " ".join([token.text for token in expanded]) == "I am here ."

@pytest.mark.integration
def test_layoutlm_embeddings(self):
sentence = Sentence(["I", "love", "Berlin"])
Expand Down

0 comments on commit fceba71

Please sign in to comment.