Skip to content

Commit

Permalink
Merge pull request #3073 from flairNLP/context_boundaries
Browse files Browse the repository at this point in the history
Explicit context boundaries in Transformer embeddings
  • Loading branch information
alanakbik authored Jan 30, 2023
2 parents 2f017ea + fceba71 commit f99ad60
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 3 deletions.
2 changes: 1 addition & 1 deletion flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def idx(self) -> int:
if isinstance(self._internal_index, int):
return self._internal_index
else:
raise ValueError
return -1

@property
def text(self) -> str:
Expand Down
25 changes: 23 additions & 2 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
register_embeddings,
)

SENTENCE_BOUNDARY_TAG: str = "[SATZ]"


@torch.jit.script_if_tracing
def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -295,6 +297,7 @@ def __init__(
force_max_length: bool = False,
feature_extractor: Optional[FeatureExtractionMixin] = None,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
):
self.name = name
super().__init__()
Expand All @@ -313,6 +316,7 @@ def __init__(
self.fine_tune = fine_tune
self.force_max_length = force_max_length
self.feature_extractor = feature_extractor
self.use_context_separator = use_context_separator

tokenizer_params = list(inspect.signature(self.tokenizer.__call__).parameters.keys())
self.tokenizer_needs_ocr_boxes = "boxes" in tokenizer_params
Expand Down Expand Up @@ -345,6 +349,7 @@ def to_args(self):
"use_lang_emb": self.use_lang_emb,
"force_max_length": self.force_max_length,
"feature_extractor": self.feature_extractor,
"use_context_separator": self.use_context_separator,
}
if hasattr(self, "needs_manual_ocr"):
args["needs_manual_ocr"] = self.needs_manual_ocr
Expand Down Expand Up @@ -606,13 +611,13 @@ def __gather_flair_tokens(self, sentences: List[Sentence]) -> Tuple[List[List[To
sentence_tokens = []
for sentence in sentences:
# flair specific pre-tokenization
tokens, offset = self.__expand_sentence_with_context(sentence)
tokens, offset = self._expand_sentence_with_context(sentence)
sentence_tokens.append(tokens)
offsets.append(offset)
lengths.append(len(sentence))
return sentence_tokens, offsets, lengths

def __expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]:
def _expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]:
expand_context = self.context_length > 0 and (
not self.training or random.randint(1, 100) > (self.context_dropout * 100)
)
Expand All @@ -624,6 +629,10 @@ def __expand_sentence_with_context(self, sentence) -> Tuple[List[Token], int]:
left_context = sentence.left_context(self.context_length, self.respect_document_boundaries)
right_context = sentence.right_context(self.context_length, self.respect_document_boundaries)

if self.use_context_separator:
left_context = left_context + [Token(SENTENCE_BOUNDARY_TAG)]
right_context = [Token(SENTENCE_BOUNDARY_TAG)] + right_context

expanded_sentence = left_context + sentence.tokens + right_context

context_length = len(left_context)
Expand Down Expand Up @@ -926,6 +935,7 @@ def __init__(
name: Optional[str] = None,
force_max_length: bool = False,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
**kwargs,
):
self.instance_parameters = self.get_instance_parameters(locals=locals())
Expand Down Expand Up @@ -1039,12 +1049,19 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
if needs_manual_ocr is not None:
self.needs_manual_ocr = needs_manual_ocr

# If we use a context separator, add a new special token
self.use_context_separator = use_context_separator
if use_context_separator:
self.tokenizer.add_special_tokens({"additional_special_tokens": [SENTENCE_BOUNDARY_TAG]})
transformer_model.resize_token_embeddings(len(self.tokenizer))

super().__init__(**self.to_args())

# most models have an initial BOS token, except for XLNet, T5 and GPT2
self.initial_cls_token: bool = self._has_initial_cls_token()

self.model = transformer_model

self.to(flair.device)
# when initializing, embeddings are in eval mode by default
self.eval()
Expand Down Expand Up @@ -1103,6 +1120,10 @@ def __setstate__(self, state):
layer_indexes = state.pop("layer_indexes")
state["layers"] = ",".join(map(str, layer_indexes))

if "use_context_separator" not in state:
# legacy Flair <= 0.12
state["use_context_separator"] = False

if "use_scalar_mix" in state:
# legacy Flair <= 0.7
state["layer_mean"] = state.pop("use_scalar_mix")
Expand Down
42 changes: 42 additions & 0 deletions tests/embeddings/test_transformer_word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,48 @@ def forward(
sentence.clear_embeddings()
assert torch.isclose(jit_token_embedding, loaded_jit_token_embedding).all()

def test_transformers_context_expansion(self, results_base_path):
emb = TransformerWordEmbeddings(
"distilbert-base-uncased", use_context=True, use_context_separator=True, respect_document_boundaries=True
)

# previous and next sentence as context
sentence_previous = Sentence("How is it?")
sentence_next = Sentence("Then again, maybe not...")

# test expansion for sentence without context
sentence = Sentence("This is great!")
expanded, _ = emb._expand_sentence_with_context(sentence=sentence)
assert " ".join([token.text for token in expanded]) == "[SATZ] This is great ! [SATZ]"

# test expansion for with previous and next as context
sentence = Sentence("This is great.")
sentence._previous_sentence = sentence_previous
sentence._next_sentence = sentence_next
expanded, _ = emb._expand_sentence_with_context(sentence=sentence)
assert (
" ".join([token.text for token in expanded])
== "How is it ? [SATZ] This is great . [SATZ] Then again , maybe not ..."
)

# test expansion if first sentence is document boundary
sentence = Sentence("This is great?")
sentence_previous.is_document_boundary = True
sentence._previous_sentence = sentence_previous
sentence._next_sentence = sentence_next
expanded, _ = emb._expand_sentence_with_context(sentence=sentence)
assert (
" ".join([token.text for token in expanded]) == "[SATZ] This is great ? [SATZ] Then again , maybe not ..."
)

# test expansion if we don't use context
emb.context_length = 0
sentence = Sentence("I am here.")
sentence._previous_sentence = sentence_previous
sentence._next_sentence = sentence_next
expanded, _ = emb._expand_sentence_with_context(sentence=sentence)
assert " ".join([token.text for token in expanded]) == "I am here ."

@pytest.mark.integration
def test_layoutlm_embeddings(self):
sentence = Sentence(["I", "love", "Berlin"])
Expand Down

0 comments on commit f99ad60

Please sign in to comment.