diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 20bac08866..cebdcf0e72 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -1,10 +1,8 @@ -import collections as co import itertools -from operator import itemgetter +import logging +from abc import ABC, abstractmethod from typing import ( Any, - Counter, - DefaultDict, Dict, Iterator, List, @@ -18,14 +16,16 @@ ) import torch -from torch.utils.data.dataset import ConcatDataset, Dataset +from torch.utils.data.dataset import Dataset import flair from flair.data import Corpus, Dictionary, Label, Relation, Sentence, Span, Token from flair.datasets import DataLoader, FlairDatapointDataset -from flair.embeddings import DocumentEmbeddings +from flair.embeddings import DocumentEmbeddings, TransformerDocumentEmbeddings from flair.tokenization import SpaceTokenizer +logger: logging.Logger = logging.getLogger("flair") + class EncodedSentence(Sentence): """ @@ -38,6 +38,166 @@ class EncodedSentence(Sentence): pass +class EncodingStrategy(ABC): + """ + The :class:`EncodingStrategy` protocol defines + the encoding of the head and tail entities in a sentence with a relation annotation. + """ + + special_tokens: Set[str] = set() + + def __init__(self, add_special_tokens: bool = False) -> None: + self.add_special_tokens = add_special_tokens + + @abstractmethod + def encode_head(self, head_span: Span, label: Label) -> str: + """ + Returns the encoded string representation of the head span. + Multi-token head encodings tokens are separated by a space. + """ + ... + + @abstractmethod + def encode_tail(self, tail_span: Span, label: Label) -> str: + """ + Returns the encoded string representation of the tail span. + Multi-token tail encodings tokens are separated by a space. + """ + ... + + +class EntityMask(EncodingStrategy): + """ + An `class`:EncodingStrategy: that masks the head and tail relation entities. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "[TAIL] and Sergey Brin founded [HEAD]" -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and [TAIL] founded [HEAD]" -> Relation(head='Google', tail='Sergey Brin'). + """ + + special_tokens: Set[str] = {"[HEAD]", "[TAIL]"} + + def encode_head(self, head_span: Span, label: Label) -> str: + return "[HEAD]" + + def encode_tail(self, tail_span: Span, label: Label) -> str: + return "[TAIL]" + + +class TypedEntityMask(EncodingStrategy): + """ + An `class`:EncodingStrategy: that masks the head and tail relation entities with their label. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "[TAIL-PER] and Sergey Brin founded [HEAD-ORG]" -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and [TAIL-PER] founded [HEAD-ORG]" -> Relation(head='Google', tail='Sergey Brin'). + """ + + def encode_head(self, head: Span, label: Label) -> str: + return f"[HEAD-{label.value}]" + + def encode_tail(self, tail: Span, label: Label) -> str: + return f"[TAIL-{label.value}]" + + +class EntityMarker(EncodingStrategy): + """ + An `class`:EncodingStrategy: that marks the head and tail relation entities. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "[HEAD] Larry Page [/HEAD] and Sergey Brin founded [TAIL] Google [/TAIL]" + -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and [HEAD] Sergey Brin [/HEAD] founded [TAIL] Google [/TAIL]" + -> Relation(head='Google', tail='Sergey Brin'). + """ + + special_tokens: Set[str] = {"[HEAD]", "[/HEAD]", "[TAIL]", "[/TAIL]"} + + def encode_head(self, head: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in head) + return f"[HEAD] {space_tokenized_text} [/HEAD]" + + def encode_tail(self, tail: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in tail) + return f"[TAIL] {space_tokenized_text} [/TAIL]" + + +class TypedEntityMarker(EncodingStrategy): + """ + An `class`:EncodingStrategy: that marks the head and tail relation entities with their label. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "[HEAD-PER] Larry Page [/HEAD-PER] and Sergey Brin founded [TAIL-ORG] Google [/TAIL-ORG]" + -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and [HEAD-PER] Sergey Brin [/HEAD-PER] founded [TAIL-ORG] Google [/TAIL-ORG]" + -> Relation(head='Google', tail='Sergey Brin'). + """ + + def encode_head(self, head: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in head) + return f"[HEAD-{label.value}] {space_tokenized_text} [/HEAD-{label.value}]" + + def encode_tail(self, tail: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in tail) + return f"[TAIL-{label.value}] {space_tokenized_text} [/TAIL-{label.value}]" + + +class EntityMarkerPunct(EncodingStrategy): + """ + An alternate version of `class`:EntityMarker: with punctuations as control tokens. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "@ Larry Page @ and Sergey Brin founded # Google #" -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and @ Sergey Brin @ founded # Google #" -> Relation(head='Google', tail='Sergey Brin'). + """ + + def encode_head(self, head: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in head) + return f"@ {space_tokenized_text} @" + + def encode_tail(self, tail: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in tail) + return f"# {space_tokenized_text} #" + + +class TypedEntityMarkerPunct(EncodingStrategy): + """ + An alternate version of `class`:TypedEntityMarker: with punctuations as control tokens. + + Example: + For the `founded_by` relation from `ORG` to `PER` and + the sentence "Larry Page and Sergey Brin founded Google .", + the encoded sentences and relations are + - "@ * PER * Larry Page @ and Sergey Brin founded # * ORG * Google #" + -> Relation(head='Google', tail='Larry Page') and + - "Larry Page and @ * PER * Sergey Brin @ founded # * ORG * Google #" + -> Relation(head='Google', tail='Sergey Brin'). + """ + + def encode_head(self, head: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in head) + return f"@ * {label.value} * {space_tokenized_text} @" + + def encode_tail(self, tail: Span, label: Label) -> str: + space_tokenized_text: str = " ".join(token.text for token in tail) + return f"# ^ {label.value} ^ {space_tokenized_text} #" + + class _Entity(NamedTuple): """ A `_Entity` encapsulates either a relation's head or a tail span, including its label. @@ -68,11 +228,13 @@ class RelationClassifier(flair.nn.DefaultClassifier[EncodedSentence, EncodedSent The Relation Classifier Model builds upon a text classifier. The model generates an encoded sentence for each entity pair in the cross product of all entities in the original sentence. - In the encoded representation, the entities in the current entity pair are masked with special control tokens. - (For an example, see the docstring of the `_encode_sentence_for_training` function.) + In the encoded representation, the entities in the current entity pair are masked/marked with control tokens. + (For an example, see the docstrings of different encoding strategies, e.g. :class:`TypedEntityMarker`.) Then, for each encoded sentence, the model takes its document embedding and puts the resulting text representation(s) through a linear layer to get the class relation label. + The implemented encoding strategies are taken from this paper by Zhou et al.: https://arxiv.org/abs/2102.01373 + Note: Currently, the model has no multi-label support. """ @@ -84,10 +246,10 @@ def __init__( entity_label_types: Union[str, Sequence[str], Dict[str, Optional[Set[str]]]], entity_pair_labels: Optional[Set[Tuple[str, str]]] = None, entity_threshold: Optional[float] = None, + cross_augmentation: bool = True, + encoding_strategy: EncodingStrategy = TypedEntityMarker(), zero_tag_value: str = "O", allow_unk_tag: bool = True, - cross_augmentation: bool = True, - mask_type: str = "mark", **classifierargs, ) -> None: """ @@ -111,14 +273,15 @@ def __init__( i.e. the model classifies the relation for each entity pair in the cross product of *all* entity pairs (inefficient). :param entity_threshold: Only pre-labelled entities above this threshold are taken into account by the model. - :param zero_tag_value: The label to use for out-of-class relations - :param allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. :param cross_augmentation: If `True`, use cross augmentation to transform `Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus`, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence. + :param encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol + :param zero_tag_value: The label to use for out-of-class relations + :param allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. :param classifierargs: The remaining parameters passed to the underlying `DefaultClassifier` """ # Set label type and prepare label dictionary @@ -149,9 +312,25 @@ def __init__( self.entity_pair_labels = entity_pair_labels - self.cross_augmentation = cross_augmentation - self.mask_type = mask_type self.entity_threshold = entity_threshold + self.cross_augmentation = cross_augmentation + self.encoding_strategy = encoding_strategy + + # Add the special tokens from the encoding strategy + if ( + self.encoding_strategy.add_special_tokens + and self.encoding_strategy.special_tokens + and isinstance(self.embeddings, TransformerDocumentEmbeddings) + ): + special_tokens: List[str] = list(self.encoding_strategy.special_tokens) + tokenizer = self.embeddings.tokenizer + tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) + self.embeddings.model.resize_token_embeddings(len(tokenizer)) + + logger.info( + f"{self.__class__.__name__}: " + f"Added {', '.join(special_tokens)} as additional special tokens to {self.embeddings.name}" + ) # Auto-spawn on GPU, if available self.to(flair.device) @@ -222,37 +401,20 @@ def _entity_pair_permutations( yield head, tail, gold_label - def _mask(self, entity: _Entity, role: str) -> str: - if self.mask_type == "label-aware": - return f"[{role}-{entity.label.value}]" - if self.mask_type == "entity": - return f"[{role}-ENTITY]" - if self.mask_type == "mark": - return f"[[{role}-{entity.span.text}]]" - - # by default, use "mark" masking - return f"[[{role}-{entity.span.text}]]" - - def _create_masked_sentence( + def _encode_sentence( self, head: _Entity, tail: _Entity, gold_label: Optional[str] = None, ) -> EncodedSentence: """ - Returns a new `Sentence` object with masked head and tail spans. - The label-aware mask is constructed from the head/tail span labels. - If provided, the masked sentence also has the corresponding gold label annotation in `self.label_type`. - - Example: - For the `head=Google`, `tail=Larry Page` and - the sentence "Larry Page and Sergey Brin founded Google .", - the masked sentence is "[T-PER] and Sergey Brin founded [H-ORG]". + Returns a new `Sentence` object with masked/marked head and tail spans according to the encoding strategy. + If provided, the encoded sentence also has the corresponding gold label annotation from `self.label_type`. :param head: The head `_Entity` :param tail: The tail `_Entity` :param gold_label: An optional gold label of the induced relation by the head and tail entity - :return: The masked sentence (with gold annotations) + :return: The `EncodedSentence` (with gold annotations) """ # Some sanity checks original_sentence: Sentence = head.span.sentence @@ -262,59 +424,57 @@ def _create_masked_sentence( non_leading_head_tokens: List[Token] = head.span.tokens[1:] non_leading_tail_tokens: List[Token] = tail.span.tokens[1:] - # We can not use the plaintext of the head/tail span in the sentence as the mask + # We can not use the plaintext of the head/tail span in the sentence as the mask/marker # since there may be multiple occurrences of the same entity mentioned in the sentence. # Therefore, we use the span's position in the sentence. - masked_sentence_tokens: List[str] = [] + encoded_sentence_tokens: List[str] = [] for token in original_sentence: if token is head.span[0]: - masked_sentence_tokens.append(self._mask(entity=head, role="H")) + encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label)) elif token is tail.span[0]: - masked_sentence_tokens.append(self._mask(entity=tail, role="T")) + encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label)) elif all( token is not non_leading_entity_token for non_leading_entity_token in itertools.chain(non_leading_head_tokens, non_leading_tail_tokens) ): - masked_sentence_tokens.append(token.text) + encoded_sentence_tokens.append(token.text) # Create masked sentence - masked_sentence: EncodedSentence = EncodedSentence( - " ".join(masked_sentence_tokens), use_tokenizer=SpaceTokenizer() + encoded_sentence: EncodedSentence = EncodedSentence( + " ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer() ) if gold_label is not None: # Add gold relation annotation as sentence label # Using the sentence label instead of annotating a separate `Relation` object is easier to manage since, # during prediction, the forward pass does not need any knowledge about the entities in the sentence. - masked_sentence.add_label(typename=self.label_type, value=gold_label, score=1.0) + encoded_sentence.add_label(typename=self.label_type, value=gold_label, score=1.0) - return masked_sentence + return encoded_sentence - def _encode_sentence_for_inference(self, sentence: Sentence) -> Iterator[Tuple[EncodedSentence, Relation]]: + def _encode_sentence_for_inference( + self, + sentence: Sentence, + ) -> Iterator[Tuple[EncodedSentence, Relation]]: """ - Yields masked entity pair sentences annotated with their gold relation for all valid entity pair permutations. - The created masked sentences are newly created sentences with no reference to the passed sentence. + Yields encoded sentences annotated with their gold relation and + the corresponding relation object in the original sentence for all valid entity pair permutations. + The created encoded sentences are newly created sentences with no reference to the passed sentence. + Important properties: - - Every sentence has exactly one masked head and tail entity token. Therefore, every encoded sentence has + - Every sentence has exactly one encoded head and tail entity token. Therefore, every encoded sentence has **exactly** one induced relation annotation, the gold annotation or `self.zero_tag_value`. - The created relations have head and tail spans from the original passed sentence. - Example: - For the `founded_by` relation from `ORG` to `PER` and - the sentence "Larry Page and Sergey Brin founded Google .", - the masked sentences and relations are - - "[T-PER] and Sergey Brin founded [H-ORG]" -> Relation(head='Google', tail='Larry Page') and - - "Larry Page and [T-PER] founded [H-ORG]" -> Relation(head='Google', tail='Sergey Brin'). - :param sentence: A flair `Sentence` object with entity annotations :return: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence """ for head, tail, gold_label in self._entity_pair_permutations(sentence): - masked_sentence: EncodedSentence = self._create_masked_sentence( + masked_sentence: EncodedSentence = self._encode_sentence( head=head, tail=tail, gold_label=gold_label if gold_label is not None else self.zero_tag_value, @@ -336,7 +496,7 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS else: continue # Skip generated data points that do not express an originally annotated relation - masked_sentence: EncodedSentence = self._create_masked_sentence( + masked_sentence: EncodedSentence = self._encode_sentence( head=head, tail=tail, gold_label=gold_label, @@ -348,8 +508,8 @@ def transform_sentence(self, sentences: Union[Sentence, List[Sentence]]) -> List """ Transforms sentences into encoded sentences specific to the `RelationClassifier`. For more information on the internal sentence transformation procedure, - see the `RelationClassifier` architecture docstring and - the `_encode_sentence_for_training` and `_encode_sentence_for_inference` docstrings. + see the :class:`RelationClassifier` architecture and + the different :class:`EncodingStrategy` variants docstrings. :param sentences: A (list) of sentence(s) to transform :return: A list of encoded sentences specific to the `RelationClassifier` @@ -368,8 +528,8 @@ def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset Transforms a dataset into a dataset containing encoded sentences specific to the `RelationClassifier`. The returned dataset is stored in memory. For more information on the internal sentence transformation procedure, - see the `RelationClassifier` architecture docstring and - the `_encode_sentence_for_training` and `_encode_sentence_for_inference` docstrings. + see the :class:`RelationClassifier` architecture and + the different :class:`EncodingStrategy` variants docstrings. :param dataset: A dataset of sentences to transform :return: A dataset of encoded sentences specific to the `RelationClassifier` @@ -383,8 +543,8 @@ def transform_corpus(self, corpus: Corpus[Sentence]) -> Corpus[EncodedSentence]: Transforms a corpus into a corpus containing encoded sentences specific to the `RelationClassifier`. The splits of the returned corpus are stored in memory. For more information on the internal sentence transformation procedure, - see the `RelationClassifier` architecture docstring and - the `_encode_sentence_for_training` and `_encode_sentence_for_inference` docstrings. + see the :class:`RelationClassifier` architecture and + the different :class:`EncodingStrategy` variants docstrings. :param corpus: A corpus of sentences to transform :return: A corpus of encoded sentences specific to the `RelationClassifier` @@ -514,10 +674,10 @@ def _get_state_dict(self) -> Dict[str, Any]: "entity_label_types": self.entity_label_types, "entity_pair_labels": self.entity_pair_labels, "entity_threshold": self.entity_threshold, + "cross_augmentation": self.cross_augmentation, + "encoding_strategy": self.encoding_strategy, "zero_tag_value": self.zero_tag_value, "allow_unk_tag": self.allow_unk_tag, - "cross_augmentation": self.cross_augmentation, - "mask_type": self.mask_type, } return model_state @@ -531,10 +691,10 @@ def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): entity_label_types=state["entity_label_types"], entity_pair_labels=state["entity_pair_labels"], entity_threshold=state["entity_threshold"], + cross_augmentation=state["cross_augmentation"], + encoding_strategy=state["encoding_strategy"], zero_tag_value=state["zero_tag_value"], allow_unk_tag=state["allow_unk_tag"], - cross_augmentation=state["cross_augmentation"], - mask_type=state["mask_type"], **kwargs, ) @@ -549,41 +709,3 @@ def zero_tag_value(self) -> str: @property def allow_unk_tag(self) -> bool: return self._allow_unk_tag - - -def inspect_relations( - corpus: Corpus[Sentence], - relation_label_type: str, - entity_label_types: Optional[Union[Sequence[str], str]] = None, -) -> DefaultDict[str, Counter[Tuple[str, str]]]: - if entity_label_types is not None and not isinstance(entity_label_types, Sequence): - entity_label_types = [entity_label_types] - - # Dictionary of [, ] - relations: DefaultDict[str, Counter[Tuple[str, str]]] = co.defaultdict(co.Counter) - - data_loader: DataLoader = DataLoader( - ConcatDataset(split for split in [corpus.train, corpus.dev, corpus.test] if split is not None), - batch_size=1, - num_workers=0, - ) - for sentence in map(itemgetter(0), data_loader): - for relation in sentence.get_relations(relation_label_type): - entity_counter = relations[relation.get_label(relation_label_type).value] - - head_relation_label: str - tail_relation_label: str - if entity_label_types is None: - head_relation_label = relation.first.get_label().value - tail_relation_label = relation.second.get_label().value - else: - head_relation_label = next( - relation.first.get_label(label_type).value for label_type in entity_label_types - ) - tail_relation_label = next( - relation.second.get_label(label_type).value for label_type in entity_label_types - ) - - entity_counter.update([(head_relation_label, tail_relation_label)]) - - return relations diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index 2e0edb0ae8..d03e3fc37d 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -1,5 +1,5 @@ from operator import itemgetter -from typing import List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import pytest from torch.utils.data import Dataset @@ -8,9 +8,67 @@ from flair.datasets import ColumnCorpus, DataLoader from flair.embeddings import TransformerDocumentEmbeddings from flair.models import RelationClassifier -from flair.models.relation_classifier_model import EncodedSentence +from flair.models.relation_classifier_model import ( + EncodedSentence, + EncodingStrategy, + EntityMarker, + EntityMarkerPunct, + EntityMask, + TypedEntityMarker, + TypedEntityMarkerPunct, + TypedEntityMask, +) from tests.model_test_utils import BaseModelTest +encoding_strategies: Dict[EncodingStrategy, List[Tuple[str, str]]] = { + EntityMask(): [("[HEAD]", "[TAIL]") for _ in range(7)], + TypedEntityMask(): [ + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), + ("[HEAD-ORG]", "[TAIL-PER]"), + ], + EntityMarker(): [ + ("[HEAD] Google [/HEAD]", "[TAIL] Larry Page [/TAIL]"), + ("[HEAD] Google [/HEAD]", "[TAIL] Sergey Brin [/TAIL]"), + ("[HEAD] Microsoft [/HEAD]", "[TAIL] Bill Gates [/TAIL]"), + ("[HEAD] Berlin [/HEAD]", "[TAIL] Konrad Zuse [/TAIL]"), + ("[HEAD] Berlin [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ("[HEAD] Germany [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ("[HEAD] MIT [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ], + TypedEntityMarker(): [ + ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Larry Page [/TAIL-PER]"), + ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Sergey Brin [/TAIL-PER]"), + ("[HEAD-ORG] Microsoft [/HEAD-ORG]", "[TAIL-PER] Bill Gates [/TAIL-PER]"), + ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Konrad Zuse [/TAIL-PER]"), + ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ("[HEAD-LOC] Germany [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ("[HEAD-ORG] MIT [/HEAD-ORG]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ], + EntityMarkerPunct(): [ + ("@ Google @", "# Larry Page #"), + ("@ Google @", "# Sergey Brin #"), + ("@ Microsoft @", "# Bill Gates #"), + ("@ Berlin @", "# Konrad Zuse #"), + ("@ Berlin @", "# Joseph Weizenbaum #"), + ("@ Germany @", "# Joseph Weizenbaum #"), + ("@ MIT @", "# Joseph Weizenbaum #"), + ], + TypedEntityMarkerPunct(): [ + ("@ * ORG * Google @", "# ^ PER ^ Larry Page #"), + ("@ * ORG * Google @", "# ^ PER ^ Sergey Brin #"), + ("@ * ORG * Microsoft @", "# ^ PER ^ Bill Gates #"), + ("@ * LOC * Berlin @", "# ^ PER ^ Konrad Zuse #"), + ("@ * LOC * Berlin @", "# ^ PER ^ Joseph Weizenbaum #"), + ("@ * LOC * Germany @", "# ^ PER ^ Joseph Weizenbaum #"), + ("@ * ORG * MIT @", "# ^ PER ^ Joseph Weizenbaum #"), + ], +} + class TestRelationClassifier(BaseModelTest): model_cls = RelationClassifier @@ -118,61 +176,63 @@ def check_transformation_correctness( for sentence in map(itemgetter(0), data_loader) } == ground_truth - def test_transform_corpus_with_cross_augmentation(self, corpus: ColumnCorpus, embeddings) -> None: + @pytest.mark.parametrize( + "cross_augmentation", [True, False], ids=["with_cross_augmentation", "without_cross_augmentation"] + ) + @pytest.mark.parametrize( + "encoding_strategy, encoded_entity_pairs", + encoding_strategies.items(), + ids=[type(encoding_strategy).__name__ for encoding_strategy in encoding_strategies], + ) + def test_transform_corpus( + self, + corpus: ColumnCorpus, + embeddings: TransformerDocumentEmbeddings, + cross_augmentation: bool, + encoding_strategy: EncodingStrategy, + encoded_entity_pairs: List[Tuple[str, str]], + ) -> None: label_dictionary = corpus.make_label_dictionary("relation") - model: RelationClassifier = self.build_model(embeddings, label_dictionary, cross_augmentation=True) + model: RelationClassifier = self.build_model( + embeddings, label_dictionary, cross_augmentation=cross_augmentation, encoding_strategy=encoding_strategy + ) transformed_corpus = model.transform_corpus(corpus) # Check sentence masking and relation label annotation on - # training, validation and test dataset (in this test they are the same) + # training, validation and test dataset (in this test the splits are the same) ground_truth: Set[Tuple[str, Tuple[str, ...]]] = { # Entity pair permutations of: "Larry Page and Sergey Brin founded Google ." - ("[[T-Larry Page]] and Sergey Brin founded [[H-Google]] .", ("founded_by",)), - ("Larry Page and [[T-Sergey Brin]] founded [[H-Google]] .", ("founded_by",)), + (f"{encoded_entity_pairs[0][1]} and Sergey Brin founded {encoded_entity_pairs[0][0]} .", ("founded_by",)), + (f"Larry Page and {encoded_entity_pairs[1][1]} founded {encoded_entity_pairs[1][0]} .", ("founded_by",)), # Entity pair permutations of: "Microsoft was founded by Bill Gates ." - ("[[H-Microsoft]] was founded by [[T-Bill Gates]] .", ("founded_by",)), + (f"{encoded_entity_pairs[2][0]} was founded by {encoded_entity_pairs[2][1]} .", ("founded_by",)), # Entity pair permutations of: "Konrad Zuse was born in Berlin on 22 June 1910 ." - ("[[T-Konrad Zuse]] was born in [[H-Berlin]] on 22 June 1910 .", ("place_of_birth",)), - # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany." - ("[[T-Joseph Weizenbaum]] , a professor at [[H-MIT]] , was born in Berlin , Germany .", ("O",)), ( - "[[T-Joseph Weizenbaum]] , a professor at MIT , was born in [[H-Berlin]] , Germany .", + f"{encoded_entity_pairs[3][1]} was born in {encoded_entity_pairs[3][0]} on 22 June 1910 .", ("place_of_birth",), ), + # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany." ( - "[[T-Joseph Weizenbaum]] , a professor at MIT , was born in Berlin , [[H-Germany]] .", - ("place_of_birth",), - ), - } - for split in (transformed_corpus.train, transformed_corpus.dev, transformed_corpus.test): - self.check_transformation_correctness(split, ground_truth) - - def test_transform_corpus_without_cross_augmentation(self, corpus: ColumnCorpus, embeddings) -> None: - label_dictionary = corpus.make_label_dictionary("relation") - embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased", layers="-1", fine_tune=True) - model: RelationClassifier = self.build_model(embeddings, label_dictionary, cross_augmentation=False) - - transformed_corpus = model.transform_corpus(corpus) - - # Check sentence masking and relation label annotation on - # training, validation and test dataset (in this test they are the same) - ground_truth: Set[Tuple[str, Tuple[str, ...]]] = { - # Entity pair permutations of: "Larry Page and Sergey Brin founded Google ." - ("[[T-Larry Page]] and Sergey Brin founded [[H-Google]] .", ("founded_by",)), - ("Larry Page and [[T-Sergey Brin]] founded [[H-Google]] .", ("founded_by",)), - # Entity pair permutations of: "Microsoft was founded by Bill Gates ." - ("[[H-Microsoft]] was founded by [[T-Bill Gates]] .", ("founded_by",)), - # Entity pair permutations of: "Konrad Zuse was born in Berlin on 22 June 1910 ." - ("[[T-Konrad Zuse]] was born in [[H-Berlin]] on 22 June 1910 .", ("place_of_birth",)), - # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany ." - ( - "[[T-Joseph Weizenbaum]] , a professor at MIT , was born in [[H-Berlin]] , Germany .", + f"{encoded_entity_pairs[4][1]} , a professor at MIT , " + f"was born in {encoded_entity_pairs[4][0]} , Germany .", ("place_of_birth",), ), ( - "[[T-Joseph Weizenbaum]] , a professor at MIT , was born in Berlin , [[H-Germany]] .", + f"{encoded_entity_pairs[5][1]} , a professor at MIT , " + f"was born in Berlin , {encoded_entity_pairs[5][0]} .", ("place_of_birth",), ), } + + if cross_augmentation: + # This sentence is only included if we transform the corpus with cross augmentation + ground_truth.add( + ( + f"{encoded_entity_pairs[6][1]} , a professor at {encoded_entity_pairs[6][0]} , " + f"was born in Berlin , Germany .", + ("O",), + ) + ) + for split in (transformed_corpus.train, transformed_corpus.dev, transformed_corpus.test): self.check_transformation_correctness(split, ground_truth)