From fcfdd979af364b389c695a6afcd812e6e1059559 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Wed, 7 Dec 2022 21:38:42 +0100 Subject: [PATCH 01/23] Implement functional encoding strategies --- flair/models/relation_classifier_model.py | 68 ++++++++++++++++------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 20bac08866..956f5d0316 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -1,5 +1,6 @@ import collections as co import itertools +from abc import ABC from operator import itemgetter from typing import ( Any, @@ -15,6 +16,8 @@ Tuple, Union, cast, + Literal, + Callable, ) import torch @@ -48,6 +51,40 @@ class _Entity(NamedTuple): label: Label +class EncodingStrategy(ABC): + @staticmethod + def entity_mask(entity: _Entity, role: Literal["H", "T"]) -> str: + return f"[{role}-{entity.label.value}]" + + @staticmethod + def entity_marker(entity: _Entity, role: Literal["H", "T"]) -> str: + space_tokenized_text: str = " ".join(token.text for token in entity.span) + return f"[{role}] {space_tokenized_text} [/{role}]" + + @staticmethod + def entity_marker_punctual(entity: _Entity, role: Literal["H", "T"]) -> str: + space_tokenized_text: str = " ".join(token.text for token in entity.span) + if role == "H": + return f"@ {space_tokenized_text} @" + if role == "T": + return f"# {space_tokenized_text} #" + raise ValueError() # TODO + + @staticmethod + def typed_entity_marker(entity: _Entity, role: Literal["H", "T"]) -> str: + space_tokenized_text: str = " ".join(token.text for token in entity.span) + return f"[{role}-{entity.label.value}] {space_tokenized_text} [/{role}-{entity.label.value}]" + + @staticmethod + def typed_entity_marker_punctual(entity: _Entity, role: Literal["H", "T"]) -> str: + space_tokenized_text: str = " ".join(token.text for token in entity.span) + if role == "H": + return f"@ * {entity.label.value} * {space_tokenized_text} @" + if role == "T": + return f"# ^ {entity.label.value} ^ {space_tokenized_text} #" + raise ValueError() # TODO + + # TODO: This closely shadows the RelationExtractor name. Maybe we need a better name here. # - MaskedRelationClassifier ? # This depends if this relation classification architecture should replace or offer as an alternative. @@ -84,10 +121,10 @@ def __init__( entity_label_types: Union[str, Sequence[str], Dict[str, Optional[Set[str]]]], entity_pair_labels: Optional[Set[Tuple[str, str]]] = None, entity_threshold: Optional[float] = None, + cross_augmentation: bool = True, + encoding_strategy: Callable[[_Entity, str], str] = EncodingStrategy.entity_marker, zero_tag_value: str = "O", allow_unk_tag: bool = True, - cross_augmentation: bool = True, - mask_type: str = "mark", **classifierargs, ) -> None: """ @@ -149,9 +186,9 @@ def __init__( self.entity_pair_labels = entity_pair_labels - self.cross_augmentation = cross_augmentation - self.mask_type = mask_type self.entity_threshold = entity_threshold + self.cross_augmentation = cross_augmentation + self.encoding_strategy = encoding_strategy # Auto-spawn on GPU, if available self.to(flair.device) @@ -222,17 +259,6 @@ def _entity_pair_permutations( yield head, tail, gold_label - def _mask(self, entity: _Entity, role: str) -> str: - if self.mask_type == "label-aware": - return f"[{role}-{entity.label.value}]" - if self.mask_type == "entity": - return f"[{role}-ENTITY]" - if self.mask_type == "mark": - return f"[[{role}-{entity.span.text}]]" - - # by default, use "mark" masking - return f"[[{role}-{entity.span.text}]]" - def _create_masked_sentence( self, head: _Entity, @@ -269,10 +295,10 @@ def _create_masked_sentence( for token in original_sentence: if token is head.span[0]: - masked_sentence_tokens.append(self._mask(entity=head, role="H")) + masked_sentence_tokens.append(self.encoding_strategy(head, "H")) elif token is tail.span[0]: - masked_sentence_tokens.append(self._mask(entity=tail, role="T")) + masked_sentence_tokens.append(self.encoding_strategy(tail, "T")) elif all( token is not non_leading_entity_token @@ -514,10 +540,10 @@ def _get_state_dict(self) -> Dict[str, Any]: "entity_label_types": self.entity_label_types, "entity_pair_labels": self.entity_pair_labels, "entity_threshold": self.entity_threshold, + "cross_augmentation": self.cross_augmentation, + "encoding_strategy": self.encoding_strategy, "zero_tag_value": self.zero_tag_value, "allow_unk_tag": self.allow_unk_tag, - "cross_augmentation": self.cross_augmentation, - "mask_type": self.mask_type, } return model_state @@ -531,10 +557,10 @@ def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): entity_label_types=state["entity_label_types"], entity_pair_labels=state["entity_pair_labels"], entity_threshold=state["entity_threshold"], + cross_augmentation=state["cross_augmentation"], + encoding_strategy=state["encoding_strategy"], zero_tag_value=state["zero_tag_value"], allow_unk_tag=state["allow_unk_tag"], - cross_augmentation=state["cross_augmentation"], - mask_type=state["mask_type"], **kwargs, ) From 52d65f3886939ef93c65bbb14dbff605eb0afc2e Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Thu, 8 Dec 2022 02:14:53 +0100 Subject: [PATCH 02/23] Implement object-oriented encoding strategies --- flair/models/relation_classifier_model.py | 170 ++++++++++++++++------ 1 file changed, 128 insertions(+), 42 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 956f5d0316..3231b2f1bb 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -16,8 +16,8 @@ Tuple, Union, cast, - Literal, - Callable, + Protocol, + runtime_checkable, ) import torch @@ -26,7 +26,7 @@ import flair from flair.data import Corpus, Dictionary, Label, Relation, Sentence, Span, Token from flair.datasets import DataLoader, FlairDatapointDataset -from flair.embeddings import DocumentEmbeddings +from flair.embeddings import DocumentEmbeddings, TransformerDocumentEmbeddings from flair.tokenization import SpaceTokenizer @@ -41,6 +41,125 @@ class EncodedSentence(Sentence): pass +@runtime_checkable +class EncodingStrategy(Protocol): + """ + TODO: Write documentation + """ + + special_tokens: Set[str] + + def encode_head(self, head_span: Span, label: Label) -> str: + """ + TODO: Write documentation + """ + ... + + def encode_tail(self, tail_span: Span, label: Label) -> str: + """ + TODO: Write documentation + """ + ... + + +class EntityMask(EncodingStrategy): + """ + TODO: Write documentation + """ + + def __init__(self, special_tokens: Optional[Set[str]] = None) -> None: + self.special_tokens = {"[HEAD]", "[TAIL]"} if special_tokens is None else special_tokens + + def encode_head(self, head_span: Span, label: Label) -> str: + return "[HEAD]" + + def encode_tail(self, tail_span: Span, label: Label) -> str: + return "[TAIL]" + + +class TypedEntityMask(EncodingStrategy): + """ + TODO: Write documentation + """ + + def __init__(self) -> None: + self.special_tokens: Set[str] = set() + + def encode_head(self, head: Span, label: Label) -> str: + return f"[HEAD-{label.value}]" + + def encode_tail(self, tail: Span, label: Label) -> str: + return f"[TAIL-{label.value}]" + + +class EntityMarker(EncodingStrategy): + """ + TODO: Write documentation + """ + + def __init__(self, special_tokens: Optional[Set[str]] = None) -> None: + self.special_tokens = {"[HEAD]", "[\HEAD]", "[TAIL]", "[\TAIL]"} if special_tokens is None else special_tokens + + def encode_head(self, head: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in head) + return f"[HEAD] {space_tokenized_text} [/HEAD]" + + def encode_tail(self, tail: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in tail) + return f"[TAIL] {space_tokenized_text} [/TAIL]" + + +class TypedEntityMarker(EncodingStrategy): + """ + TODO: Write documentation + """ + + def __init__(self) -> None: + self.special_tokens: Set[str] = set() + + def encode_head(self, head: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in head) + return f"[HEAD-{label.value}] {space_tokenized_text} [/HEAD-{label.value}]" + + def encode_tail(self, tail: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in tail) + return f"[TAIL-{label.value}] {space_tokenized_text} [/TAIL-{label.value}]" + + +class EntityMarkerPunct(EncodingStrategy): + """ + TODO: Write documentation + """ + + def __init__(self) -> None: + self.special_tokens = set() + + def encode_head(self, head: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in head) + return f"@ {space_tokenized_text} @" + + def encode_tail(self, tail: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in tail) + return f"# {space_tokenized_text} #" + + +class TypedEntityMarkerPunct(EncodingStrategy): + """ + TODO: Write documentation + """ + + def __init__(self) -> None: + self.special_tokens = set() + + def encode_head(self, head: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in head) + return f"@ * {label.value} * {space_tokenized_text} @" + + def encode_tail(self, tail: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in tail) + return f"# ^ {label.value} ^ {space_tokenized_text} #" + + class _Entity(NamedTuple): """ A `_Entity` encapsulates either a relation's head or a tail span, including its label. @@ -51,40 +170,6 @@ class _Entity(NamedTuple): label: Label -class EncodingStrategy(ABC): - @staticmethod - def entity_mask(entity: _Entity, role: Literal["H", "T"]) -> str: - return f"[{role}-{entity.label.value}]" - - @staticmethod - def entity_marker(entity: _Entity, role: Literal["H", "T"]) -> str: - space_tokenized_text: str = " ".join(token.text for token in entity.span) - return f"[{role}] {space_tokenized_text} [/{role}]" - - @staticmethod - def entity_marker_punctual(entity: _Entity, role: Literal["H", "T"]) -> str: - space_tokenized_text: str = " ".join(token.text for token in entity.span) - if role == "H": - return f"@ {space_tokenized_text} @" - if role == "T": - return f"# {space_tokenized_text} #" - raise ValueError() # TODO - - @staticmethod - def typed_entity_marker(entity: _Entity, role: Literal["H", "T"]) -> str: - space_tokenized_text: str = " ".join(token.text for token in entity.span) - return f"[{role}-{entity.label.value}] {space_tokenized_text} [/{role}-{entity.label.value}]" - - @staticmethod - def typed_entity_marker_punctual(entity: _Entity, role: Literal["H", "T"]) -> str: - space_tokenized_text: str = " ".join(token.text for token in entity.span) - if role == "H": - return f"@ * {entity.label.value} * {space_tokenized_text} @" - if role == "T": - return f"# ^ {entity.label.value} ^ {space_tokenized_text} #" - raise ValueError() # TODO - - # TODO: This closely shadows the RelationExtractor name. Maybe we need a better name here. # - MaskedRelationClassifier ? # This depends if this relation classification architecture should replace or offer as an alternative. @@ -122,7 +207,7 @@ def __init__( entity_pair_labels: Optional[Set[Tuple[str, str]]] = None, entity_threshold: Optional[float] = None, cross_augmentation: bool = True, - encoding_strategy: Callable[[_Entity, str], str] = EncodingStrategy.entity_marker, + encoding_strategy: EncodingStrategy = TypedEntityMarker(), zero_tag_value: str = "O", allow_unk_tag: bool = True, **classifierargs, @@ -148,14 +233,15 @@ def __init__( i.e. the model classifies the relation for each entity pair in the cross product of *all* entity pairs (inefficient). :param entity_threshold: Only pre-labelled entities above this threshold are taken into account by the model. - :param zero_tag_value: The label to use for out-of-class relations - :param allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. :param cross_augmentation: If `True`, use cross augmentation to transform `Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus`, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence. + :param encoding_strategy: TODO Write documentation + :param zero_tag_value: The label to use for out-of-class relations + :param allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. :param classifierargs: The remaining parameters passed to the underlying `DefaultClassifier` """ # Set label type and prepare label dictionary @@ -295,10 +381,10 @@ def _create_masked_sentence( for token in original_sentence: if token is head.span[0]: - masked_sentence_tokens.append(self.encoding_strategy(head, "H")) + masked_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label)) elif token is tail.span[0]: - masked_sentence_tokens.append(self.encoding_strategy(tail, "T")) + masked_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label)) elif all( token is not non_leading_entity_token From 9bb142e863dca5b531aab6a97e56a6f4bc5e2c92 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Thu, 8 Dec 2022 02:15:25 +0100 Subject: [PATCH 03/23] Add special tokens to transformer embeddings if specified --- flair/models/relation_classifier_model.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 3231b2f1bb..988c060f81 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -1,6 +1,6 @@ import collections as co import itertools -from abc import ABC +import logging from operator import itemgetter from typing import ( Any, @@ -29,6 +29,8 @@ from flair.embeddings import DocumentEmbeddings, TransformerDocumentEmbeddings from flair.tokenization import SpaceTokenizer +logger: logging.Logger = logging.getLogger("flair") + class EncodedSentence(Sentence): """ @@ -276,6 +278,18 @@ def __init__( self.cross_augmentation = cross_augmentation self.encoding_strategy = encoding_strategy + # Add the special tokens from the encoding strategy + if self.encoding_strategy.special_tokens and isinstance(self.embeddings, TransformerDocumentEmbeddings): + special_tokens: List[str] = list(self.encoding_strategy.special_tokens) + tokenizer = self.embeddings.tokenizer + tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) + self.embeddings.model.resize_token_embeddings(len(tokenizer)) + + logger.info( + f"{self.__class__.__name__}: " + f"Added {', '.join(special_tokens)} as additional special tokens to {self.embeddings.name}" + ) + # Auto-spawn on GPU, if available self.to(flair.device) From 456c0bfe6cfc9cb68429cc43fbcbeb70096757e5 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Thu, 8 Dec 2022 02:27:18 +0100 Subject: [PATCH 04/23] Refactor encoding strategy special tokens to boolean flag instead of being fully configurable --- flair/models/relation_classifier_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 988c060f81..5adab9f510 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -69,8 +69,8 @@ class EntityMask(EncodingStrategy): TODO: Write documentation """ - def __init__(self, special_tokens: Optional[Set[str]] = None) -> None: - self.special_tokens = {"[HEAD]", "[TAIL]"} if special_tokens is None else special_tokens + def __init__(self, add_special_tokens: bool = False) -> None: + self.special_tokens = {"[HEAD]", "[TAIL]"} if add_special_tokens else set() def encode_head(self, head_span: Span, label: Label) -> str: return "[HEAD]" @@ -99,8 +99,8 @@ class EntityMarker(EncodingStrategy): TODO: Write documentation """ - def __init__(self, special_tokens: Optional[Set[str]] = None) -> None: - self.special_tokens = {"[HEAD]", "[\HEAD]", "[TAIL]", "[\TAIL]"} if special_tokens is None else special_tokens + def __init__(self, add_special_tokens: bool = False) -> None: + self.special_tokens = {"[HEAD]", "[\HEAD]", "[TAIL]", "[\TAIL]"} if add_special_tokens else set() def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) From 1587d855011eb94de9758c49eafe5ca72d676696 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 00:44:29 +0100 Subject: [PATCH 05/23] Add examples for the selection strategies --- flair/models/relation_classifier_model.py | 69 ++++++++++++++++++++--- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 5adab9f510..64432a0fac 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -46,27 +46,37 @@ class EncodedSentence(Sentence): @runtime_checkable class EncodingStrategy(Protocol): """ - TODO: Write documentation + The :class:`EncodingStrategy` protocol defines + the encoding of the head and tail entities in a sentence with a relation annotation. """ special_tokens: Set[str] def encode_head(self, head_span: Span, label: Label) -> str: """ - TODO: Write documentation + Returns the encoded string representation of the head span. + Multi-token head encodings tokens are separated by a space. """ ... def encode_tail(self, tail_span: Span, label: Label) -> str: """ - TODO: Write documentation + Returns the encoded string representation of the tail span. + Multi-token tail encodings tokens are separated by a space. """ ... class EntityMask(EncodingStrategy): """ - TODO: Write documentation + An `class`:EncodingStrategy: that masks the head and tail relation entities. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "[TAIL] and Sergey Brin founded [HEAD]" -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and [TAIL] founded [HEAD]" -> Relation(head='Google', tail='Sergey Brin'). """ def __init__(self, add_special_tokens: bool = False) -> None: @@ -81,7 +91,14 @@ def encode_tail(self, tail_span: Span, label: Label) -> str: class TypedEntityMask(EncodingStrategy): """ - TODO: Write documentation + An `class`:EncodingStrategy: that masks the head and tail relation entities with their label. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "[TAIL-PER] and Sergey Brin founded [HEAD-ORG]" -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and [TAIL-PER] founded [HEAD-ORG]" -> Relation(head='Google', tail='Sergey Brin'). """ def __init__(self) -> None: @@ -96,7 +113,16 @@ def encode_tail(self, tail: Span, label: Label) -> str: class EntityMarker(EncodingStrategy): """ - TODO: Write documentation + An `class`:EncodingStrategy: that marks the head and tail relation entities. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "[HEAD] Larry Page [/HEAD] and Sergey Brin founded [TAIL] Google [\TAIL]" + -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and [HEAD] Sergey Brin [/HEAD] founded [TAIL] Google [\TAIL]" + -> Relation(head='Google', tail='Sergey Brin'). """ def __init__(self, add_special_tokens: bool = False) -> None: @@ -113,7 +139,16 @@ def encode_tail(self, tail: Span, label: Label) -> str: class TypedEntityMarker(EncodingStrategy): """ - TODO: Write documentation + An `class`:EncodingStrategy: that marks the head and tail relation entities with their label. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "[HEAD-PER] Larry Page [/HEAD-PER] and Sergey Brin founded [TAIL-ORG] Google [\TAIL-ORG]" + -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and [HEAD-PER] Sergey Brin [/HEAD-PER] founded [TAIL-ORG] Google [\TAIL-ORG]" + -> Relation(head='Google', tail='Sergey Brin'). """ def __init__(self) -> None: @@ -130,7 +165,14 @@ def encode_tail(self, tail: Span, label: Label) -> str: class EntityMarkerPunct(EncodingStrategy): """ - TODO: Write documentation + An alternate version of `class`:EntityMarker: with punctuations as control tokens. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "@ Larry Page @ and Sergey Brin founded # Google #" -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and @ Sergey Brin @ founded # Google #" -> Relation(head='Google', tail='Sergey Brin'). """ def __init__(self) -> None: @@ -147,7 +189,16 @@ def encode_tail(self, tail: Span, label: Label) -> str: class TypedEntityMarkerPunct(EncodingStrategy): """ - TODO: Write documentation + An alternate version of `class`:TypedEntityMarker: with punctuations as control tokens. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "@ * PER * Larry Page @ and Sergey Brin founded # * ORG * Google #" + -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and @ * PER * Sergey Brin @ founded # * ORG * Google #" + -> Relation(head='Google', tail='Sergey Brin'). """ def __init__(self) -> None: From 1a300ad53223e6c96de3ab7fc6608a27aec17c5e Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 00:49:36 +0100 Subject: [PATCH 06/23] - Rename `_create_masked_sentence` -> `_encoded_sentence` - Rename mentions of `masked_sentence` to `encoded_sentence` --- flair/models/relation_classifier_model.py | 36 ++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 64432a0fac..1d4c038eb5 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -410,7 +410,7 @@ def _entity_pair_permutations( yield head, tail, gold_label - def _create_masked_sentence( + def _encode_sentence( self, head: _Entity, tail: _Entity, @@ -439,38 +439,41 @@ def _create_masked_sentence( non_leading_head_tokens: List[Token] = head.span.tokens[1:] non_leading_tail_tokens: List[Token] = tail.span.tokens[1:] - # We can not use the plaintext of the head/tail span in the sentence as the mask + # We can not use the plaintext of the head/tail span in the sentence as the mask/marker # since there may be multiple occurrences of the same entity mentioned in the sentence. # Therefore, we use the span's position in the sentence. - masked_sentence_tokens: List[str] = [] + encoded_sentence_tokens: List[str] = [] for token in original_sentence: if token is head.span[0]: - masked_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label)) + encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label)) elif token is tail.span[0]: - masked_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label)) + encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label)) elif all( token is not non_leading_entity_token for non_leading_entity_token in itertools.chain(non_leading_head_tokens, non_leading_tail_tokens) ): - masked_sentence_tokens.append(token.text) + encoded_sentence_tokens.append(token.text) # Create masked sentence - masked_sentence: EncodedSentence = EncodedSentence( - " ".join(masked_sentence_tokens), use_tokenizer=SpaceTokenizer() + encoded_sentence: EncodedSentence = EncodedSentence( + " ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer() ) if gold_label is not None: # Add gold relation annotation as sentence label # Using the sentence label instead of annotating a separate `Relation` object is easier to manage since, # during prediction, the forward pass does not need any knowledge about the entities in the sentence. - masked_sentence.add_label(typename=self.label_type, value=gold_label, score=1.0) + encoded_sentence.add_label(typename=self.label_type, value=gold_label, score=1.0) - return masked_sentence + return encoded_sentence - def _encode_sentence_for_inference(self, sentence: Sentence) -> Iterator[Tuple[EncodedSentence, Relation]]: + def _encode_sentence_for_inference( + self, + sentence: Sentence, + ) -> Iterator[Tuple[EncodedSentence, Relation]]: """ Yields masked entity pair sentences annotated with their gold relation for all valid entity pair permutations. The created masked sentences are newly created sentences with no reference to the passed sentence. @@ -479,19 +482,12 @@ def _encode_sentence_for_inference(self, sentence: Sentence) -> Iterator[Tuple[E **exactly** one induced relation annotation, the gold annotation or `self.zero_tag_value`. - The created relations have head and tail spans from the original passed sentence. - Example: - For the `founded_by` relation from `ORG` to `PER` and - the sentence "Larry Page and Sergey Brin founded Google .", - the masked sentences and relations are - - "[T-PER] and Sergey Brin founded [H-ORG]" -> Relation(head='Google', tail='Larry Page') and - - "Larry Page and [T-PER] founded [H-ORG]" -> Relation(head='Google', tail='Sergey Brin'). - :param sentence: A flair `Sentence` object with entity annotations :return: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence """ for head, tail, gold_label in self._entity_pair_permutations(sentence): - masked_sentence: EncodedSentence = self._create_masked_sentence( + masked_sentence: EncodedSentence = self._encode_sentence( head=head, tail=tail, gold_label=gold_label if gold_label is not None else self.zero_tag_value, @@ -513,7 +509,7 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS else: continue # Skip generated data points that do not express an originally annotated relation - masked_sentence: EncodedSentence = self._create_masked_sentence( + masked_sentence: EncodedSentence = self._encode_sentence( head=head, tail=tail, gold_label=gold_label, From 46b3e0f089f37f2a0f2d6b76ae1e26714112fc3a Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 00:51:50 +0100 Subject: [PATCH 07/23] Adjusts docstrings to fit the encoding strategies --- flair/models/relation_classifier_model.py | 38 ++++++++++------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 1d4c038eb5..73588c2fdb 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -243,8 +243,8 @@ class RelationClassifier(flair.nn.DefaultClassifier[EncodedSentence, EncodedSent The Relation Classifier Model builds upon a text classifier. The model generates an encoded sentence for each entity pair in the cross product of all entities in the original sentence. - In the encoded representation, the entities in the current entity pair are masked with special control tokens. - (For an example, see the docstring of the `_encode_sentence_for_training` function.) + In the encoded representation, the entities in the current entity pair are masked/marked with control tokens. + (For an example, see the docstrings of different encoding strategies, e.g. :class:`TypedEntityMarker`.) Then, for each encoded sentence, the model takes its document embedding and puts the resulting text representation(s) through a linear layer to get the class relation label. @@ -292,7 +292,7 @@ def __init__( in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence. - :param encoding_strategy: TODO Write documentation + :param encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol :param zero_tag_value: The label to use for out-of-class relations :param allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. :param classifierargs: The remaining parameters passed to the underlying `DefaultClassifier` @@ -417,19 +417,13 @@ def _encode_sentence( gold_label: Optional[str] = None, ) -> EncodedSentence: """ - Returns a new `Sentence` object with masked head and tail spans. - The label-aware mask is constructed from the head/tail span labels. - If provided, the masked sentence also has the corresponding gold label annotation in `self.label_type`. - - Example: - For the `head=Google`, `tail=Larry Page` and - the sentence "Larry Page and Sergey Brin founded Google .", - the masked sentence is "[T-PER] and Sergey Brin founded [H-ORG]". + Returns a new `Sentence` object with masked/marked head and tail spans according to the encoding strategy. + If provided, the encoded sentence also has the corresponding gold label annotation from `self.label_type`. :param head: The head `_Entity` :param tail: The tail `_Entity` :param gold_label: An optional gold label of the induced relation by the head and tail entity - :return: The masked sentence (with gold annotations) + :return: The `EncodedSentence` (with gold annotations) """ # Some sanity checks original_sentence: Sentence = head.span.sentence @@ -475,10 +469,12 @@ def _encode_sentence_for_inference( sentence: Sentence, ) -> Iterator[Tuple[EncodedSentence, Relation]]: """ - Yields masked entity pair sentences annotated with their gold relation for all valid entity pair permutations. - The created masked sentences are newly created sentences with no reference to the passed sentence. + Yields encoded sentences annotated with their gold relation and + the corresponding relation object in the original sentence for all valid entity pair permutations. + The created encoded sentences are newly created sentences with no reference to the passed sentence. + Important properties: - - Every sentence has exactly one masked head and tail entity token. Therefore, every encoded sentence has + - Every sentence has exactly one encoded head and tail entity token. Therefore, every encoded sentence has **exactly** one induced relation annotation, the gold annotation or `self.zero_tag_value`. - The created relations have head and tail spans from the original passed sentence. @@ -521,8 +517,8 @@ def transform_sentence(self, sentences: Union[Sentence, List[Sentence]]) -> List """ Transforms sentences into encoded sentences specific to the `RelationClassifier`. For more information on the internal sentence transformation procedure, - see the `RelationClassifier` architecture docstring and - the `_encode_sentence_for_training` and `_encode_sentence_for_inference` docstrings. + see the :class:`RelationClassifier` architecture and + the different :class:`EncodingStrategy` variants docstrings. :param sentences: A (list) of sentence(s) to transform :return: A list of encoded sentences specific to the `RelationClassifier` @@ -541,8 +537,8 @@ def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset Transforms a dataset into a dataset containing encoded sentences specific to the `RelationClassifier`. The returned dataset is stored in memory. For more information on the internal sentence transformation procedure, - see the `RelationClassifier` architecture docstring and - the `_encode_sentence_for_training` and `_encode_sentence_for_inference` docstrings. + see the :class:`RelationClassifier` architecture and + the different :class:`EncodingStrategy` variants docstrings. :param dataset: A dataset of sentences to transform :return: A dataset of encoded sentences specific to the `RelationClassifier` @@ -556,8 +552,8 @@ def transform_corpus(self, corpus: Corpus[Sentence]) -> Corpus[EncodedSentence]: Transforms a corpus into a corpus containing encoded sentences specific to the `RelationClassifier`. The splits of the returned corpus are stored in memory. For more information on the internal sentence transformation procedure, - see the `RelationClassifier` architecture docstring and - the `_encode_sentence_for_training` and `_encode_sentence_for_inference` docstrings. + see the :class:`RelationClassifier` architecture and + the different :class:`EncodingStrategy` variants docstrings. :param corpus: A corpus of sentences to transform :return: A corpus of encoded sentences specific to the `RelationClassifier` From 468470a591a80e30e584e4f6d5f6d189d353d0fd Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 01:00:58 +0100 Subject: [PATCH 08/23] Cite relevant source paper --- flair/models/relation_classifier_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 73588c2fdb..490d788710 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -248,6 +248,8 @@ class RelationClassifier(flair.nn.DefaultClassifier[EncodedSentence, EncodedSent Then, for each encoded sentence, the model takes its document embedding and puts the resulting text representation(s) through a linear layer to get the class relation label. + The implemented encoding strategies are taken from this paper by Zhou et al.: https://arxiv.org/abs/2102.01373 + Note: Currently, the model has no multi-label support. """ From e8681e85fbfcef67ab733d4f40cf237ee869e97a Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 01:13:28 +0100 Subject: [PATCH 09/23] Fix mypy --- flair/models/relation_classifier_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 490d788710..c4debb40a9 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -176,7 +176,7 @@ class EntityMarkerPunct(EncodingStrategy): """ def __init__(self) -> None: - self.special_tokens = set() + self.special_tokens: Set[str] = set() def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) @@ -202,7 +202,7 @@ class TypedEntityMarkerPunct(EncodingStrategy): """ def __init__(self) -> None: - self.special_tokens = set() + self.special_tokens: Set[str] = set() def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) From 86100efb81ec76ed9a8ac9d607050506ad820985 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 02:05:00 +0100 Subject: [PATCH 10/23] Use raw strings for markers --- flair/models/relation_classifier_model.py | 56 +++++++++++++---------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index c4debb40a9..71dfff2518 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -112,7 +112,7 @@ def encode_tail(self, tail: Span, label: Label) -> str: class EntityMarker(EncodingStrategy): - """ + r""" An `class`:EncodingStrategy: that marks the head and tail relation entities. Example: @@ -126,19 +126,19 @@ class EntityMarker(EncodingStrategy): """ def __init__(self, add_special_tokens: bool = False) -> None: - self.special_tokens = {"[HEAD]", "[\HEAD]", "[TAIL]", "[\TAIL]"} if add_special_tokens else set() + self.special_tokens = {"[HEAD]", r"[\HEAD]", "[TAIL]", r"[\TAIL]"} if add_special_tokens else set() def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) - return f"[HEAD] {space_tokenized_text} [/HEAD]" + return rf"[HEAD] {space_tokenized_text} [/HEAD]" def encode_tail(self, tail: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in tail) - return f"[TAIL] {space_tokenized_text} [/TAIL]" + return rf"[TAIL] {space_tokenized_text} [/TAIL]" class TypedEntityMarker(EncodingStrategy): - """ + r""" An `class`:EncodingStrategy: that marks the head and tail relation entities with their label. Example: @@ -156,11 +156,11 @@ def __init__(self) -> None: def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) - return f"[HEAD-{label.value}] {space_tokenized_text} [/HEAD-{label.value}]" + return rf"[HEAD-{label.value}] {space_tokenized_text} [/HEAD-{label.value}]" def encode_tail(self, tail: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in tail) - return f"[TAIL-{label.value}] {space_tokenized_text} [/TAIL-{label.value}]" + return rf"[TAIL-{label.value}] {space_tokenized_text} [/TAIL-{label.value}]" class EntityMarkerPunct(EncodingStrategy): @@ -725,13 +725,17 @@ def allow_unk_tag(self) -> bool: def inspect_relations( corpus: Corpus[Sentence], relation_label_type: str, - entity_label_types: Optional[Union[Sequence[str], str]] = None, -) -> DefaultDict[str, Counter[Tuple[str, str]]]: - if entity_label_types is not None and not isinstance(entity_label_types, Sequence): - entity_label_types = [entity_label_types] + entity_label_types: Optional[Union[Set[Union[None, str]], str]] = None, +) -> Tuple[DefaultDict[str, Counter[Tuple[str, str]]], Counter[str]]: + + if entity_label_types is None: + entity_label_types = {None} + elif not isinstance(entity_label_types, Set): + entity_label_types = {entity_label_types} # Dictionary of [, ] relations: DefaultDict[str, Counter[Tuple[str, str]]] = co.defaultdict(co.Counter) + entity_word_counter = co.Counter() data_loader: DataLoader = DataLoader( ConcatDataset(split for split in [corpus.train, corpus.dev, corpus.test] if split is not None), @@ -742,19 +746,23 @@ def inspect_relations( for relation in sentence.get_relations(relation_label_type): entity_counter = relations[relation.get_label(relation_label_type).value] - head_relation_label: str - tail_relation_label: str - if entity_label_types is None: - head_relation_label = relation.first.get_label().value - tail_relation_label = relation.second.get_label().value - else: - head_relation_label = next( - relation.first.get_label(label_type).value for label_type in entity_label_types - ) - tail_relation_label = next( - relation.second.get_label(label_type).value for label_type in entity_label_types - ) + head_relation_label = next(relation.first.get_label(label_type).value for label_type in entity_label_types) + tail_relation_label = next(relation.second.get_label(label_type).value for label_type in entity_label_types) entity_counter.update([(head_relation_label, tail_relation_label)]) + entity_word_counter.update([(relation.first.text, relation.second.text)]) + + return relations, entity_word_counter - return relations + +def infer_entity_pair_labels( + corpus: Corpus[Sentence], + relation_label_type: str, + entity_label_types: Optional[Union[Set[str], str]] = None, +) -> Set[Tuple[str, str]]: + + relations, _ = inspect_relations(corpus, relation_label_type, entity_label_types) + entity_pair_labels: Set[Tuple[str, str]] = { + entity_pair for entity_pair_counter in relations.values() for entity_pair in entity_pair_counter.keys() + } + return entity_pair_labels From 11415024ce8b91854f12e0ac6696be3c979ccd58 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 02:29:28 +0100 Subject: [PATCH 11/23] - Parameterize transformation test with and without cross augmentation into one test - Add transformation test for entity-mask and typed-entity-mask --- tests/models/test_relation_classifier.py | 112 +++++++++++++++-------- 1 file changed, 74 insertions(+), 38 deletions(-) diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index 2e0edb0ae8..a149618037 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -8,7 +8,12 @@ from flair.datasets import ColumnCorpus, DataLoader from flair.embeddings import TransformerDocumentEmbeddings from flair.models import RelationClassifier -from flair.models.relation_classifier_model import EncodedSentence +from flair.models.relation_classifier_model import ( + EncodedSentence, + EntityMask, + TypedEntityMask, + EncodingStrategy, +) from tests.model_test_utils import BaseModelTest @@ -118,59 +123,90 @@ def check_transformation_correctness( for sentence in map(itemgetter(0), data_loader) } == ground_truth - def test_transform_corpus_with_cross_augmentation(self, corpus: ColumnCorpus, embeddings) -> None: + @pytest.mark.parametrize( + "cross_augmentation", [True, False], ids=["with_cross_augmentation", "without_cross_augmentation"] + ) + @pytest.mark.parametrize( + "encoding_strategy, encoded_entity_pairs", + [ + (EntityMask(), [("[HEAD]", "[TAIL]") for _ in range(7)]), + ( + TypedEntityMask(), + [ + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ], + ), + # (EntityMarker(), []), + # (TypedEntityMarker(), []), + # (EntityMarkerPunct(), []), + # (TypedEntityMarker(), []), + ], + ids=[ + c.__name__ + for c in ( + EntityMask, + TypedEntityMask, + # EntityMarker, + # TypedEntityMarker, + # EntityMarkerPunct, + # TypedEntityMarkerPunct, + ) + ], + ) + def test_transform_corpus( + self, + corpus: ColumnCorpus, + embeddings: TransformerDocumentEmbeddings, + cross_augmentation: bool, + encoding_strategy: EncodingStrategy, + encoded_entity_pairs: List[Tuple[str, str]], + ) -> None: label_dictionary = corpus.make_label_dictionary("relation") - model: RelationClassifier = self.build_model(embeddings, label_dictionary, cross_augmentation=True) + model: RelationClassifier = self.build_model( + embeddings, label_dictionary, cross_augmentation=cross_augmentation, encoding_strategy=encoding_strategy + ) transformed_corpus = model.transform_corpus(corpus) # Check sentence masking and relation label annotation on - # training, validation and test dataset (in this test they are the same) + # training, validation and test dataset (in this test the splits are the same) ground_truth: Set[Tuple[str, Tuple[str, ...]]] = { # Entity pair permutations of: "Larry Page and Sergey Brin founded Google ." - ("[[T-Larry Page]] and Sergey Brin founded [[H-Google]] .", ("founded_by",)), - ("Larry Page and [[T-Sergey Brin]] founded [[H-Google]] .", ("founded_by",)), + (f"{encoded_entity_pairs[0][1]} and Sergey Brin founded {encoded_entity_pairs[0][0]} .", ("founded_by",)), + (f"Larry Page and {encoded_entity_pairs[1][1]} founded {encoded_entity_pairs[1][0]} .", ("founded_by",)), # Entity pair permutations of: "Microsoft was founded by Bill Gates ." - ("[[H-Microsoft]] was founded by [[T-Bill Gates]] .", ("founded_by",)), + (f"{encoded_entity_pairs[2][0]} was founded by {encoded_entity_pairs[2][1]} .", ("founded_by",)), # Entity pair permutations of: "Konrad Zuse was born in Berlin on 22 June 1910 ." - ("[[T-Konrad Zuse]] was born in [[H-Berlin]] on 22 June 1910 .", ("place_of_birth",)), - # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany." - ("[[T-Joseph Weizenbaum]] , a professor at [[H-MIT]] , was born in Berlin , Germany .", ("O",)), ( - "[[T-Joseph Weizenbaum]] , a professor at MIT , was born in [[H-Berlin]] , Germany .", + f"{encoded_entity_pairs[3][1]} was born in {encoded_entity_pairs[3][0]} on 22 June 1910 .", ("place_of_birth",), ), - ( - "[[T-Joseph Weizenbaum]] , a professor at MIT , was born in Berlin , [[H-Germany]] .", - ("place_of_birth",), + # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany." + # This sentence is only included if we transform the corpus with cross augmentation + *( + [ + ( + f"{encoded_entity_pairs[4][1]} , a professor at {encoded_entity_pairs[4][0]} , " + f"was born in Berlin , Germany .", + ("O",), + ) + ] + if cross_augmentation + else [] ), - } - for split in (transformed_corpus.train, transformed_corpus.dev, transformed_corpus.test): - self.check_transformation_correctness(split, ground_truth) - - def test_transform_corpus_without_cross_augmentation(self, corpus: ColumnCorpus, embeddings) -> None: - label_dictionary = corpus.make_label_dictionary("relation") - embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased", layers="-1", fine_tune=True) - model: RelationClassifier = self.build_model(embeddings, label_dictionary, cross_augmentation=False) - - transformed_corpus = model.transform_corpus(corpus) - - # Check sentence masking and relation label annotation on - # training, validation and test dataset (in this test they are the same) - ground_truth: Set[Tuple[str, Tuple[str, ...]]] = { - # Entity pair permutations of: "Larry Page and Sergey Brin founded Google ." - ("[[T-Larry Page]] and Sergey Brin founded [[H-Google]] .", ("founded_by",)), - ("Larry Page and [[T-Sergey Brin]] founded [[H-Google]] .", ("founded_by",)), - # Entity pair permutations of: "Microsoft was founded by Bill Gates ." - ("[[H-Microsoft]] was founded by [[T-Bill Gates]] .", ("founded_by",)), - # Entity pair permutations of: "Konrad Zuse was born in Berlin on 22 June 1910 ." - ("[[T-Konrad Zuse]] was born in [[H-Berlin]] on 22 June 1910 .", ("place_of_birth",)), - # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany ." ( - "[[T-Joseph Weizenbaum]] , a professor at MIT , was born in [[H-Berlin]] , Germany .", + f"{encoded_entity_pairs[5][1]} , a professor at MIT , " + f"was born in {encoded_entity_pairs[5][0]} , Germany .", ("place_of_birth",), ), ( - "[[T-Joseph Weizenbaum]] , a professor at MIT , was born in Berlin , [[H-Germany]] .", + f"{encoded_entity_pairs[6][1]} , a professor at MIT , " + f"was born in Berlin , {encoded_entity_pairs[6][0]} .", ("place_of_birth",), ), } From 746879bf3ee7a320886473c0934849aea5b8659d Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 02:34:37 +0100 Subject: [PATCH 12/23] Remove dead functions --- flair/models/relation_classifier_model.py | 46 ----------------------- 1 file changed, 46 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 71dfff2518..0a783c18f4 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -720,49 +720,3 @@ def zero_tag_value(self) -> str: @property def allow_unk_tag(self) -> bool: return self._allow_unk_tag - - -def inspect_relations( - corpus: Corpus[Sentence], - relation_label_type: str, - entity_label_types: Optional[Union[Set[Union[None, str]], str]] = None, -) -> Tuple[DefaultDict[str, Counter[Tuple[str, str]]], Counter[str]]: - - if entity_label_types is None: - entity_label_types = {None} - elif not isinstance(entity_label_types, Set): - entity_label_types = {entity_label_types} - - # Dictionary of [, ] - relations: DefaultDict[str, Counter[Tuple[str, str]]] = co.defaultdict(co.Counter) - entity_word_counter = co.Counter() - - data_loader: DataLoader = DataLoader( - ConcatDataset(split for split in [corpus.train, corpus.dev, corpus.test] if split is not None), - batch_size=1, - num_workers=0, - ) - for sentence in map(itemgetter(0), data_loader): - for relation in sentence.get_relations(relation_label_type): - entity_counter = relations[relation.get_label(relation_label_type).value] - - head_relation_label = next(relation.first.get_label(label_type).value for label_type in entity_label_types) - tail_relation_label = next(relation.second.get_label(label_type).value for label_type in entity_label_types) - - entity_counter.update([(head_relation_label, tail_relation_label)]) - entity_word_counter.update([(relation.first.text, relation.second.text)]) - - return relations, entity_word_counter - - -def infer_entity_pair_labels( - corpus: Corpus[Sentence], - relation_label_type: str, - entity_label_types: Optional[Union[Set[str], str]] = None, -) -> Set[Tuple[str, str]]: - - relations, _ = inspect_relations(corpus, relation_label_type, entity_label_types) - entity_pair_labels: Set[Tuple[str, str]] = { - entity_pair for entity_pair_counter in relations.values() for entity_pair in entity_pair_counter.keys() - } - return entity_pair_labels From 13c5f382ff67179230e443d0b9b742d613d6ea9c Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 02:35:48 +0100 Subject: [PATCH 13/23] Isort --- flair/models/relation_classifier_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 0a783c18f4..52a2c7a13c 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -11,12 +11,12 @@ List, NamedTuple, Optional, + Protocol, Sequence, Set, Tuple, Union, cast, - Protocol, runtime_checkable, ) From ff1e40eca52c23a9a8a1a1d6a33385b9dc0fcacc Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Sun, 11 Dec 2022 02:54:00 +0100 Subject: [PATCH 14/23] Isort --- tests/models/test_relation_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index a149618037..752ec6136a 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -10,9 +10,9 @@ from flair.models import RelationClassifier from flair.models.relation_classifier_model import ( EncodedSentence, + EncodingStrategy, EntityMask, TypedEntityMask, - EncodingStrategy, ) from tests.model_test_utils import BaseModelTest From f9b71555807e32dbeb413bd2c0fb5598498ed183 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Mon, 12 Dec 2022 02:33:42 +0100 Subject: [PATCH 15/23] Correct encoding strategy examples --- flair/models/relation_classifier_model.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 52a2c7a13c..21108ba31e 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -112,42 +112,42 @@ def encode_tail(self, tail: Span, label: Label) -> str: class EntityMarker(EncodingStrategy): - r""" + """ An `class`:EncodingStrategy: that marks the head and tail relation entities. Example: For the `founded_by` relation from `ORG` to `PER` and the sentence "Larry Page and Sergey Brin founded Google .", the encoded sentences and relations are - - "[HEAD] Larry Page [/HEAD] and Sergey Brin founded [TAIL] Google [\TAIL]" + - "[HEAD] Larry Page [/HEAD] and Sergey Brin founded [TAIL] Google [/TAIL]" -> Relation(head='Google', tail='Larry Page') and - - "Larry Page and [HEAD] Sergey Brin [/HEAD] founded [TAIL] Google [\TAIL]" + - "Larry Page and [HEAD] Sergey Brin [/HEAD] founded [TAIL] Google [/TAIL]" -> Relation(head='Google', tail='Sergey Brin'). """ def __init__(self, add_special_tokens: bool = False) -> None: - self.special_tokens = {"[HEAD]", r"[\HEAD]", "[TAIL]", r"[\TAIL]"} if add_special_tokens else set() + self.special_tokens = {"[HEAD]", "[\HEAD]", "[TAIL]", "[\TAIL]"} if add_special_tokens else set() def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) - return rf"[HEAD] {space_tokenized_text} [/HEAD]" + return f"[HEAD] {space_tokenized_text} [/HEAD]" def encode_tail(self, tail: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in tail) - return rf"[TAIL] {space_tokenized_text} [/TAIL]" + return f"[TAIL] {space_tokenized_text} [/TAIL]" class TypedEntityMarker(EncodingStrategy): - r""" + """ An `class`:EncodingStrategy: that marks the head and tail relation entities with their label. Example: For the `founded_by` relation from `ORG` to `PER` and the sentence "Larry Page and Sergey Brin founded Google .", the encoded sentences and relations are - - "[HEAD-PER] Larry Page [/HEAD-PER] and Sergey Brin founded [TAIL-ORG] Google [\TAIL-ORG]" + - "[HEAD-PER] Larry Page [/HEAD-PER] and Sergey Brin founded [TAIL-ORG] Google [/TAIL-ORG]" -> Relation(head='Google', tail='Larry Page') and - - "Larry Page and [HEAD-PER] Sergey Brin [/HEAD-PER] founded [TAIL-ORG] Google [\TAIL-ORG]" + - "Larry Page and [HEAD-PER] Sergey Brin [/HEAD-PER] founded [TAIL-ORG] Google [/TAIL-ORG]" -> Relation(head='Google', tail='Sergey Brin'). """ @@ -156,11 +156,11 @@ def __init__(self) -> None: def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) - return rf"[HEAD-{label.value}] {space_tokenized_text} [/HEAD-{label.value}]" + return f"[HEAD-{label.value}] {space_tokenized_text} [/HEAD-{label.value}]" def encode_tail(self, tail: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in tail) - return rf"[TAIL-{label.value}] {space_tokenized_text} [/TAIL-{label.value}]" + return f"[TAIL-{label.value}] {space_tokenized_text} [/TAIL-{label.value}]" class EntityMarkerPunct(EncodingStrategy): From 0934b3b2faa01eb204de2d9f9b87279507b55183 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Mon, 12 Dec 2022 02:34:18 +0100 Subject: [PATCH 16/23] Add the remaining tests for the encoding strategy --- tests/models/test_relation_classifier.py | 64 +++++++++++++++++++++--- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index 752ec6136a..d128a1e221 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -13,6 +13,10 @@ EncodingStrategy, EntityMask, TypedEntityMask, + EntityMarker, + TypedEntityMarker, + TypedEntityMarkerPunct, + EntityMarkerPunct, ) from tests.model_test_utils import BaseModelTest @@ -142,20 +146,64 @@ def check_transformation_correctness( ("[HEAD-LOC]", "[TAIL-PER]"), ], ), - # (EntityMarker(), []), - # (TypedEntityMarker(), []), - # (EntityMarkerPunct(), []), - # (TypedEntityMarker(), []), + ( + EntityMarker(), + [ + ("[HEAD] Google [/HEAD]", "[TAIL] Larry Page [/TAIL]"), + ("[HEAD] Google [/HEAD]", "[TAIL] Sergey Brin [/TAIL]"), + ("[HEAD] Microsoft [/HEAD]", "[TAIL] Bill Gates [/TAIL]"), + ("[HEAD] Berlin [/HEAD]", "[TAIL] Konrad Zuse [/TAIL]"), + ("[HEAD] MIT [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ("[HEAD] Berlin [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ("[HEAD] Germany [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ], + ), + ( + TypedEntityMarker(), + [ + ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Larry Page [/TAIL-PER]"), + ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Sergey Brin [/TAIL-PER]"), + ("[HEAD-ORG] Microsoft [/HEAD-ORG]", "[TAIL-PER] Bill Gates [/TAIL-PER]"), + ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Konrad Zuse [/TAIL-PER]"), + ("[HEAD-ORG] MIT [/HEAD-ORG]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ("[HEAD-LOC] Germany [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ], + ), + ( + EntityMarkerPunct(), + [ + ("@ Google @", "# Larry Page #"), + ("@ Google @", "# Sergey Brin #"), + ("@ Microsoft @", "# Bill Gates #"), + ("@ Berlin @", "# Konrad Zuse #"), + ("@ MIT @", "# Joseph Weizenbaum #"), + ("@ Berlin @", "# Joseph Weizenbaum #"), + ("@ Germany @", "# Joseph Weizenbaum #"), + ], + ), + ( + TypedEntityMarkerPunct(), + [ + ("@ * ORG * Google @", "# ^ PER ^ Larry Page #"), + ("@ * ORG * Google @", "# ^ PER ^ Sergey Brin #"), + ("@ * ORG * Microsoft @", "# ^ PER ^ Bill Gates #"), + ("@ * LOC * Berlin @", "# ^ PER ^ Konrad Zuse #"), + ("@ * ORG * MIT @", "# ^ PER ^ Joseph Weizenbaum #"), + ("@ * LOC * Berlin @", "# ^ PER ^ Joseph Weizenbaum #"), + ("@ * LOC * Germany @", "# ^ PER ^ Joseph Weizenbaum #"), + ], + ), ], ids=[ c.__name__ for c in ( EntityMask, TypedEntityMask, - # EntityMarker, - # TypedEntityMarker, - # EntityMarkerPunct, - # TypedEntityMarkerPunct, + EntityMarker, + TypedEntityMarker, + EntityMarkerPunct, + TypedEntityMarkerPunct, ) ], ) From b7991c896c2a09ba2c12cd1ca4898e201d682747 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Mon, 12 Dec 2022 02:50:36 +0100 Subject: [PATCH 17/23] Fix special tokens --- flair/models/relation_classifier_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 21108ba31e..1c887f516d 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -126,7 +126,7 @@ class EntityMarker(EncodingStrategy): """ def __init__(self, add_special_tokens: bool = False) -> None: - self.special_tokens = {"[HEAD]", "[\HEAD]", "[TAIL]", "[\TAIL]"} if add_special_tokens else set() + self.special_tokens = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"} if add_special_tokens else set() def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) From 56063dc4459d82b20962c94e52e6aed16548fbe6 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Mon, 12 Dec 2022 02:53:09 +0100 Subject: [PATCH 18/23] Outsource encoding strategy test templates --- tests/models/test_relation_classifier.py | 128 ++++++++++------------- 1 file changed, 53 insertions(+), 75 deletions(-) diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index d128a1e221..67bd14f0f8 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -1,5 +1,5 @@ from operator import itemgetter -from typing import List, Optional, Set, Tuple +from typing import List, Optional, Set, Tuple, Dict import pytest from torch.utils.data import Dataset @@ -21,6 +21,56 @@ from tests.model_test_utils import BaseModelTest +encoding_strategies: Dict[EncodingStrategy, List[Tuple[str, str]]] = { + EntityMask(): [("[HEAD]", "[TAIL]") for _ in range(7)], + TypedEntityMask(): [ + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ], + EntityMarker(): [ + ("[HEAD] Google [/HEAD]", "[TAIL] Larry Page [/TAIL]"), + ("[HEAD] Google [/HEAD]", "[TAIL] Sergey Brin [/TAIL]"), + ("[HEAD] Microsoft [/HEAD]", "[TAIL] Bill Gates [/TAIL]"), + ("[HEAD] Berlin [/HEAD]", "[TAIL] Konrad Zuse [/TAIL]"), + ("[HEAD] MIT [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ("[HEAD] Berlin [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ("[HEAD] Germany [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ], + TypedEntityMarker(): [ + ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Larry Page [/TAIL-PER]"), + ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Sergey Brin [/TAIL-PER]"), + ("[HEAD-ORG] Microsoft [/HEAD-ORG]", "[TAIL-PER] Bill Gates [/TAIL-PER]"), + ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Konrad Zuse [/TAIL-PER]"), + ("[HEAD-ORG] MIT [/HEAD-ORG]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ("[HEAD-LOC] Germany [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ], + EntityMarkerPunct(): [ + ("@ Google @", "# Larry Page #"), + ("@ Google @", "# Sergey Brin #"), + ("@ Microsoft @", "# Bill Gates #"), + ("@ Berlin @", "# Konrad Zuse #"), + ("@ MIT @", "# Joseph Weizenbaum #"), + ("@ Berlin @", "# Joseph Weizenbaum #"), + ("@ Germany @", "# Joseph Weizenbaum #"), + ], + TypedEntityMarkerPunct(): [ + ("@ * ORG * Google @", "# ^ PER ^ Larry Page #"), + ("@ * ORG * Google @", "# ^ PER ^ Sergey Brin #"), + ("@ * ORG * Microsoft @", "# ^ PER ^ Bill Gates #"), + ("@ * LOC * Berlin @", "# ^ PER ^ Konrad Zuse #"), + ("@ * ORG * MIT @", "# ^ PER ^ Joseph Weizenbaum #"), + ("@ * LOC * Berlin @", "# ^ PER ^ Joseph Weizenbaum #"), + ("@ * LOC * Germany @", "# ^ PER ^ Joseph Weizenbaum #"), + ], +} + + class TestRelationClassifier(BaseModelTest): model_cls = RelationClassifier train_label_type = "relation" @@ -132,80 +182,8 @@ def check_transformation_correctness( ) @pytest.mark.parametrize( "encoding_strategy, encoded_entity_pairs", - [ - (EntityMask(), [("[HEAD]", "[TAIL]") for _ in range(7)]), - ( - TypedEntityMask(), - [ - ("[HEAD-ORG]", "[TAIL-PER]"), - ("[HEAD-ORG]", "[TAIL-PER]"), - ("[HEAD-ORG]", "[TAIL-PER]"), - ("[HEAD-LOC]", "[TAIL-PER]"), - ("[HEAD-ORG]", "[TAIL-PER]"), - ("[HEAD-LOC]", "[TAIL-PER]"), - ("[HEAD-LOC]", "[TAIL-PER]"), - ], - ), - ( - EntityMarker(), - [ - ("[HEAD] Google [/HEAD]", "[TAIL] Larry Page [/TAIL]"), - ("[HEAD] Google [/HEAD]", "[TAIL] Sergey Brin [/TAIL]"), - ("[HEAD] Microsoft [/HEAD]", "[TAIL] Bill Gates [/TAIL]"), - ("[HEAD] Berlin [/HEAD]", "[TAIL] Konrad Zuse [/TAIL]"), - ("[HEAD] MIT [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), - ("[HEAD] Berlin [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), - ("[HEAD] Germany [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), - ], - ), - ( - TypedEntityMarker(), - [ - ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Larry Page [/TAIL-PER]"), - ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Sergey Brin [/TAIL-PER]"), - ("[HEAD-ORG] Microsoft [/HEAD-ORG]", "[TAIL-PER] Bill Gates [/TAIL-PER]"), - ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Konrad Zuse [/TAIL-PER]"), - ("[HEAD-ORG] MIT [/HEAD-ORG]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), - ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), - ("[HEAD-LOC] Germany [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), - ], - ), - ( - EntityMarkerPunct(), - [ - ("@ Google @", "# Larry Page #"), - ("@ Google @", "# Sergey Brin #"), - ("@ Microsoft @", "# Bill Gates #"), - ("@ Berlin @", "# Konrad Zuse #"), - ("@ MIT @", "# Joseph Weizenbaum #"), - ("@ Berlin @", "# Joseph Weizenbaum #"), - ("@ Germany @", "# Joseph Weizenbaum #"), - ], - ), - ( - TypedEntityMarkerPunct(), - [ - ("@ * ORG * Google @", "# ^ PER ^ Larry Page #"), - ("@ * ORG * Google @", "# ^ PER ^ Sergey Brin #"), - ("@ * ORG * Microsoft @", "# ^ PER ^ Bill Gates #"), - ("@ * LOC * Berlin @", "# ^ PER ^ Konrad Zuse #"), - ("@ * ORG * MIT @", "# ^ PER ^ Joseph Weizenbaum #"), - ("@ * LOC * Berlin @", "# ^ PER ^ Joseph Weizenbaum #"), - ("@ * LOC * Germany @", "# ^ PER ^ Joseph Weizenbaum #"), - ], - ), - ], - ids=[ - c.__name__ - for c in ( - EntityMask, - TypedEntityMask, - EntityMarker, - TypedEntityMarker, - EntityMarkerPunct, - TypedEntityMarkerPunct, - ) - ], + encoding_strategies.items(), + ids=[type(encoding_strategy).__name__ for encoding_strategy in encoding_strategies], ) def test_transform_corpus( self, From da3a564077c271837eb03035f7dd590dc3eccad1 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Mon, 12 Dec 2022 02:54:07 +0100 Subject: [PATCH 19/23] Use protocol from typing extensions to support python 3.7 --- flair/models/relation_classifier_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 1c887f516d..576e3176e5 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -11,17 +11,16 @@ List, NamedTuple, Optional, - Protocol, Sequence, Set, Tuple, Union, cast, - runtime_checkable, ) import torch from torch.utils.data.dataset import ConcatDataset, Dataset +from typing_extensions import Protocol, runtime_checkable import flair from flair.data import Corpus, Dictionary, Label, Relation, Sentence, Span, Token From d114d3caaed2f2ce580c627b6c3d45789e0f9bc0 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Mon, 12 Dec 2022 02:54:51 +0100 Subject: [PATCH 20/23] Isort --- tests/models/test_relation_classifier.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index 67bd14f0f8..a720167484 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -1,5 +1,5 @@ from operator import itemgetter -from typing import List, Optional, Set, Tuple, Dict +from typing import Dict, List, Optional, Set, Tuple import pytest from torch.utils.data import Dataset @@ -11,16 +11,15 @@ from flair.models.relation_classifier_model import ( EncodedSentence, EncodingStrategy, - EntityMask, - TypedEntityMask, EntityMarker, + EntityMarkerPunct, + EntityMask, TypedEntityMarker, TypedEntityMarkerPunct, - EntityMarkerPunct, + TypedEntityMask, ) from tests.model_test_utils import BaseModelTest - encoding_strategies: Dict[EncodingStrategy, List[Tuple[str, str]]] = { EntityMask(): [("[HEAD]", "[TAIL]") for _ in range(7)], TypedEntityMask(): [ From 18f781671bb957191beecfb496f2470d1c0c9fb2 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Mon, 12 Dec 2022 13:11:04 +0100 Subject: [PATCH 21/23] Refactor tests readability --- tests/models/test_relation_classifier.py | 41 ++++++++++++------------ 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index a720167484..d03e3fc37d 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -27,45 +27,45 @@ ("[HEAD-ORG]", "[TAIL-PER]"), ("[HEAD-ORG]", "[TAIL-PER]"), ("[HEAD-LOC]", "[TAIL-PER]"), - ("[HEAD-ORG]", "[TAIL-PER]"), ("[HEAD-LOC]", "[TAIL-PER]"), ("[HEAD-LOC]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), ], EntityMarker(): [ ("[HEAD] Google [/HEAD]", "[TAIL] Larry Page [/TAIL]"), ("[HEAD] Google [/HEAD]", "[TAIL] Sergey Brin [/TAIL]"), ("[HEAD] Microsoft [/HEAD]", "[TAIL] Bill Gates [/TAIL]"), ("[HEAD] Berlin [/HEAD]", "[TAIL] Konrad Zuse [/TAIL]"), - ("[HEAD] MIT [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), ("[HEAD] Berlin [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), ("[HEAD] Germany [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ("[HEAD] MIT [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), ], TypedEntityMarker(): [ ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Larry Page [/TAIL-PER]"), ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Sergey Brin [/TAIL-PER]"), ("[HEAD-ORG] Microsoft [/HEAD-ORG]", "[TAIL-PER] Bill Gates [/TAIL-PER]"), ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Konrad Zuse [/TAIL-PER]"), - ("[HEAD-ORG] MIT [/HEAD-ORG]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), ("[HEAD-LOC] Germany [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ("[HEAD-ORG] MIT [/HEAD-ORG]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), ], EntityMarkerPunct(): [ ("@ Google @", "# Larry Page #"), ("@ Google @", "# Sergey Brin #"), ("@ Microsoft @", "# Bill Gates #"), ("@ Berlin @", "# Konrad Zuse #"), - ("@ MIT @", "# Joseph Weizenbaum #"), ("@ Berlin @", "# Joseph Weizenbaum #"), ("@ Germany @", "# Joseph Weizenbaum #"), + ("@ MIT @", "# Joseph Weizenbaum #"), ], TypedEntityMarkerPunct(): [ ("@ * ORG * Google @", "# ^ PER ^ Larry Page #"), ("@ * ORG * Google @", "# ^ PER ^ Sergey Brin #"), ("@ * ORG * Microsoft @", "# ^ PER ^ Bill Gates #"), ("@ * LOC * Berlin @", "# ^ PER ^ Konrad Zuse #"), - ("@ * ORG * MIT @", "# ^ PER ^ Joseph Weizenbaum #"), ("@ * LOC * Berlin @", "# ^ PER ^ Joseph Weizenbaum #"), ("@ * LOC * Germany @", "# ^ PER ^ Joseph Weizenbaum #"), + ("@ * ORG * MIT @", "# ^ PER ^ Joseph Weizenbaum #"), ], } @@ -212,28 +212,27 @@ def test_transform_corpus( ("place_of_birth",), ), # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany." - # This sentence is only included if we transform the corpus with cross augmentation - *( - [ - ( - f"{encoded_entity_pairs[4][1]} , a professor at {encoded_entity_pairs[4][0]} , " - f"was born in Berlin , Germany .", - ("O",), - ) - ] - if cross_augmentation - else [] - ), ( - f"{encoded_entity_pairs[5][1]} , a professor at MIT , " - f"was born in {encoded_entity_pairs[5][0]} , Germany .", + f"{encoded_entity_pairs[4][1]} , a professor at MIT , " + f"was born in {encoded_entity_pairs[4][0]} , Germany .", ("place_of_birth",), ), ( - f"{encoded_entity_pairs[6][1]} , a professor at MIT , " - f"was born in Berlin , {encoded_entity_pairs[6][0]} .", + f"{encoded_entity_pairs[5][1]} , a professor at MIT , " + f"was born in Berlin , {encoded_entity_pairs[5][0]} .", ("place_of_birth",), ), } + + if cross_augmentation: + # This sentence is only included if we transform the corpus with cross augmentation + ground_truth.add( + ( + f"{encoded_entity_pairs[6][1]} , a professor at {encoded_entity_pairs[6][0]} , " + f"was born in Berlin , Germany .", + ("O",), + ) + ) + for split in (transformed_corpus.train, transformed_corpus.dev, transformed_corpus.test): self.check_transformation_correctness(split, ground_truth) From 2a2b49eecb6103b6a9fe230559f91498cc5e0673 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Mon, 12 Dec 2022 13:11:27 +0100 Subject: [PATCH 22/23] Make encoding strategy abstract --- flair/models/relation_classifier_model.py | 36 ++++++++++------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 576e3176e5..770eeefa12 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -1,6 +1,7 @@ import collections as co import itertools import logging +from abc import ABC, abstractmethod from operator import itemgetter from typing import ( Any, @@ -20,7 +21,6 @@ import torch from torch.utils.data.dataset import ConcatDataset, Dataset -from typing_extensions import Protocol, runtime_checkable import flair from flair.data import Corpus, Dictionary, Label, Relation, Sentence, Span, Token @@ -42,15 +42,18 @@ class EncodedSentence(Sentence): pass -@runtime_checkable -class EncodingStrategy(Protocol): +class EncodingStrategy(ABC): """ The :class:`EncodingStrategy` protocol defines the encoding of the head and tail entities in a sentence with a relation annotation. """ - special_tokens: Set[str] + special_tokens: Set[str] = set() + def __init__(self, add_special_tokens: bool = False) -> None: + self.add_special_tokens = add_special_tokens + + @abstractmethod def encode_head(self, head_span: Span, label: Label) -> str: """ Returns the encoded string representation of the head span. @@ -58,6 +61,7 @@ def encode_head(self, head_span: Span, label: Label) -> str: """ ... + @abstractmethod def encode_tail(self, tail_span: Span, label: Label) -> str: """ Returns the encoded string representation of the tail span. @@ -78,8 +82,7 @@ class EntityMask(EncodingStrategy): - "Larry Page and [TAIL] founded [HEAD]" -> Relation(head='Google', tail='Sergey Brin'). """ - def __init__(self, add_special_tokens: bool = False) -> None: - self.special_tokens = {"[HEAD]", "[TAIL]"} if add_special_tokens else set() + special_tokens: Set[str] = {"[HEAD]", "[TAIL]"} def encode_head(self, head_span: Span, label: Label) -> str: return "[HEAD]" @@ -100,9 +103,6 @@ class TypedEntityMask(EncodingStrategy): - "Larry Page and [TAIL-PER] founded [HEAD-ORG]" -> Relation(head='Google', tail='Sergey Brin'). """ - def __init__(self) -> None: - self.special_tokens: Set[str] = set() - def encode_head(self, head: Span, label: Label) -> str: return f"[HEAD-{label.value}]" @@ -124,8 +124,7 @@ class EntityMarker(EncodingStrategy): -> Relation(head='Google', tail='Sergey Brin'). """ - def __init__(self, add_special_tokens: bool = False) -> None: - self.special_tokens = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"} if add_special_tokens else set() + special_tokens: Set[str] = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"} def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) @@ -150,9 +149,6 @@ class TypedEntityMarker(EncodingStrategy): -> Relation(head='Google', tail='Sergey Brin'). """ - def __init__(self) -> None: - self.special_tokens: Set[str] = set() - def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) return f"[HEAD-{label.value}] {space_tokenized_text} [/HEAD-{label.value}]" @@ -174,9 +170,6 @@ class EntityMarkerPunct(EncodingStrategy): - "Larry Page and @ Sergey Brin @ founded # Google #" -> Relation(head='Google', tail='Sergey Brin'). """ - def __init__(self) -> None: - self.special_tokens: Set[str] = set() - def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) return f"@ {space_tokenized_text} @" @@ -200,9 +193,6 @@ class TypedEntityMarkerPunct(EncodingStrategy): -> Relation(head='Google', tail='Sergey Brin'). """ - def __init__(self) -> None: - self.special_tokens: Set[str] = set() - def encode_head(self, head: Span, label: Label) -> str: space_tokenized_text: str = " ".join(token.text for token in head) return f"@ * {label.value} * {space_tokenized_text} @" @@ -331,7 +321,11 @@ def __init__( self.encoding_strategy = encoding_strategy # Add the special tokens from the encoding strategy - if self.encoding_strategy.special_tokens and isinstance(self.embeddings, TransformerDocumentEmbeddings): + if ( + self.encoding_strategy.add_special_tokens + and self.encoding_strategy.special_tokens + and isinstance(self.embeddings, TransformerDocumentEmbeddings) + ): special_tokens: List[str] = list(self.encoding_strategy.special_tokens) tokenizer = self.embeddings.tokenizer tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) From a3430f745f7e9a3be12a4361894f863831909189 Mon Sep 17 00:00:00 2001 From: Conrad Dobberstein <29147025+dobbersc@users.noreply.github.com> Date: Tue, 13 Dec 2022 13:33:15 +0100 Subject: [PATCH 23/23] Remove unused imports --- flair/models/relation_classifier_model.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 770eeefa12..cebdcf0e72 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -1,12 +1,8 @@ -import collections as co import itertools import logging from abc import ABC, abstractmethod -from operator import itemgetter from typing import ( Any, - Counter, - DefaultDict, Dict, Iterator, List, @@ -20,7 +16,7 @@ ) import torch -from torch.utils.data.dataset import ConcatDataset, Dataset +from torch.utils.data.dataset import Dataset import flair from flair.data import Corpus, Dictionary, Label, Relation, Sentence, Span, Token