Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masked Relation Classifier #2748

Merged
merged 70 commits into from
Nov 6, 2022
Merged

Masked Relation Classifier #2748

merged 70 commits into from
Nov 6, 2022

Conversation

dobbersc
Copy link
Collaborator

@dobbersc dobbersc commented May 2, 2022

This PR implements a new or alternate relation classifier.

Relation Classification (RC) is the task of identifying the semantic relation between two entities in a text.
In contrast to (end-to-end) Relation Extraction (RE), RC requires pre-labelled entities.

Example: For the founded_by relation from ORG (head) to PER (tail) and the sentence "Larry Page and Sergey Brin founded Google .", we extract the relations

  • founded_by(head='Google', tail='Larry Page') and
  • founded_by(head='Google', tail='Sergey Brin').

The Relation Classifier Model builds upon a text classifier. The model generates an encoded sentence for each entity pair in the cross product of all entities in the original sentence. In the encoded representation, the entities in the current entity pair are masked with special control tokens. (For an example, see the docstring of the _encode_sentence_for_training function.) Then, for each encoded sentence, the model takes its document embedding and puts the resulting text representation(s) through a linear layer to get the class relation label.

In the following, I leave some results of the masked relation classifier vs. the current relation extractor on CONLL04. I have not optimized their hyperparameters to the fullest. Nevertheless, the difference is quite clear.

Current Relation Extractor

Training Script
from pathlib import Path

import torch

import flair
from flair.data import Sentence
from flair.datasets import RE_ENGLISH_CONLL04
from flair.embeddings import TransformerWordEmbeddings
from flair.models import RelationExtractor
from flair.trainers import ModelTrainer

flair.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

out_dir: Path = Path("conll04")


def train(
    transformer: str = "distilbert-base-uncased",
    max_epochs: int = 10,
    learning_rate: float = 5e-5,
    mini_batch_size: int = 32,
    seed: int = 42,
) -> None:
    flair.set_seed(seed)

    # Step 1: Create the training data
    # The relation extractor is *not* trained end-to-end.
    # A corpus for training the relation extractor requires annotated entities and relations.
    corpus: RE_ENGLISH_CONLL04 = RE_ENGLISH_CONLL04()

    # Print example
    sentence: Sentence = corpus.test[0]
    print(f"Example {sentence}")  # 'ner' is the entity label type, 'relation' is the relation label type

    # Step 2: Make the label dictionary from the corpus
    label_dictionary = corpus.make_label_dictionary("relation")
    label_dictionary.add_item("O")

    # Step 3: Initialize fine-tunable transformer embeddings
    embeddings = TransformerWordEmbeddings(model=transformer, layers="-1", fine_tune=True)

    # Step 4: Initialize relation classifier
    model: RelationExtractor = RelationExtractor(
        embeddings=embeddings,
        label_dictionary=label_dictionary,
        label_type="relation",
        entity_label_type="ner",
        entity_pair_filters=[  # Define valid entity pair combinations, used as relation candidates
            ("Loc", "Loc"),
            ("Peop", "Loc"),
            ("Peop", "Org"),
            ("Org", "Loc"),
            ("Peop", "Peop"),
        ],
    )

    # Step 5: Initialize trainer
    trainer: ModelTrainer = ModelTrainer(model, corpus)

    # Step 6: Run fine-tuning
    trainer.fine_tune(
        out_dir,
        max_epochs=max_epochs,
        learning_rate=learning_rate,
        mini_batch_size=mini_batch_size,
        main_evaluation_metric=("macro avg", "f1-score"),
    )


def predict_example() -> None:
    # Step 1: Load trained relation extraction model
    model: RelationExtractor = RelationExtractor.load(out_dir / "final-model.pt")

    # Step 2: Create sentences with entity annotations (as these are required by the relation extraction model)
    # In production, use another sequence tagger model to tag the relevant entities.
    sentence: Sentence = Sentence(
        "On April 14, while attending a play at the Ford Theatre in Washington, "
        "Lincoln was shot in the head by actor John Wilkes Booth."
    )
    sentence[10:12].add_label(typename="ner", value="Loc", score=1.0)  # Ford Theatre -> Loc
    sentence[13:14].add_label(typename="ner", value="Loc", score=1.0)  # Washington -> Loc
    sentence[15:16].add_label(typename="ner", value="Peop", score=1.0)  # Lincoln -> Peop
    sentence[23:26].add_label(typename="ner", value="Peop", score=1.0)  # John Wilkes Booth -> Peop

    # Step 3: Predict
    model.predict(sentence)
    print(sentence)


if __name__ == "__main__":
    train()
    predict_example()
Results:
- F-score (micro) 0.7129
- F-score (macro) 0.7353
- Accuracy 0.5592

By class:
              precision    recall  f1-score   support

     Live_In     0.6907    0.6700    0.6802       100
 OrgBased_In     0.7473    0.6476    0.6939       105
  Located_In     0.6750    0.5745    0.6207        94
    Work_For     0.7778    0.7368    0.7568        76
        Kill     0.9348    0.9149    0.9247        47

   micro avg     0.7461    0.6825    0.7129       422
   macro avg     0.7651    0.7088    0.7353       422
weighted avg     0.7441    0.6825    0.7114       422

Masked Relation Classifier

Training Script
from pathlib import Path

import torch

import flair
from flair.data import Sentence
from flair.datasets import RE_ENGLISH_CONLL04
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import RelationClassifier
from flair.trainers import ModelTrainer

flair.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

out_dir: Path = Path("conll04-masked")


def train(
    transformer: str = "distilbert-base-uncased",
    max_epochs: int = 10,
    learning_rate: float = 5e-5,
    mini_batch_size: int = 32,
    seed: int = 42,
) -> None:
    flair.set_seed(seed)

    # Step 1: Create the training data
    # The relation extractor is *not* trained end-to-end.
    # A corpus for training the relation extractor requires annotated entities and relations.
    corpus: RE_ENGLISH_CONLL04 = RE_ENGLISH_CONLL04()

    # Print example
    sentence: Sentence = corpus.test[0]
    print(f"Example {sentence}")  # 'ner' is the entity label type, 'relation' is the relation label type

    # Step 2: Make the label dictionary from the corpus
    label_dictionary = corpus.make_label_dictionary("relation")

    # Step 3: Initialize fine-tunable transformer embedding
    embeddings = TransformerDocumentEmbeddings(model=transformer, layers="-1", fine_tune=True)

    # Step 4: Initialize relation classifier
    model: RelationClassifier = RelationClassifier(
        document_embeddings=embeddings,
        label_dictionary=label_dictionary,
        label_type="relation",
        entity_label_types="ner",
        entity_pair_labels={  # Define valid entity pair combinations, used as relation candidates
            ("Loc", "Loc"),  # Located_In
            ("Peop", "Loc"),  # Live_In
            ("Peop", "Org"),  # Work_For
            ("Org", "Loc"),  # OrgBased_In
            ("Peop", "Peop"),  # Kill
        },
        allow_unk_tag=False,
        mask_remainder=True,
        cross_augmentation=True,
    )

    # Step 5: Initialize trainer on transformed corpus
    trainer: ModelTrainer = ModelTrainer(model=model, corpus=model.transform_corpus(corpus))

    # Step 6: Run fine-tuning
    trainer.fine_tune(
        out_dir,
        max_epochs=max_epochs,
        learning_rate=learning_rate,
        mini_batch_size=mini_batch_size,
        main_evaluation_metric=("macro avg", "f1-score"),
    )


def predict_example() -> None:
    # Step 1: Load trained relation extraction model
    model: RelationClassifier = RelationClassifier.load(out_dir / "final-model.pt")

    # Step 2: Create sentences with entity annotations (as these are required by the relation extraction model)
    # In production, use another sequence tagger model to tag the relevant entities.
    sentence: Sentence = Sentence(
        "On April 14, while attending a play at the Ford Theatre in Washington, "
        "Lincoln was shot in the head by actor John Wilkes Booth."
    )
    sentence[10:12].add_label(typename="ner", value="Loc", score=1.0)  # Ford Theatre -> Loc
    sentence[13:14].add_label(typename="ner", value="Loc", score=1.0)  # Washington -> Loc
    sentence[15:16].add_label(typename="ner", value="Peop", score=1.0)  # Lincoln -> Peop
    sentence[23:26].add_label(typename="ner", value="Peop", score=1.0)  # John Wilkes Booth -> Peop

    # Step 3: Predict
    model.predict(sentence)
    print(sentence)


if __name__ == "__main__":
    train()
    predict_example()
Results:
- F-score (micro) 0.8045
- F-score (macro) 0.8191
- Accuracy 0.9184

By class:
              precision    recall  f1-score   support

     Live_In     0.8100    0.8100    0.8100       100
 OrgBased_In     0.8171    0.6381    0.7166       105
  Located_In     0.8788    0.6170    0.7250        94
    Work_For     0.9079    0.9079    0.9079        76
        Kill     0.9362    0.9362    0.9362        47

   micro avg     0.8598    0.7559    0.8045       422
   macro avg     0.8700    0.7818    0.8191       422
weighted avg     0.8588    0.7559    0.7995       422

dobbersc added 30 commits April 8, 2022 22:25
- Correct docstring relation direction
- Add option do remove the `<unk>` tag from the passed label dictionary
@dobbersc
Copy link
Collaborator Author

dobbersc commented Jun 17, 2022

I think this PR is ready for review now. I have some more ideas to incorporate but they are beyond the basic functionality and may be added later within smaller PRs that are easier to review. For example:

  • Automatic detection of the entity_pair_labels similar to the make_label_dictionary of the corpus.
  • Better integration of the corpus transform functions and support in_memory=False. Maybe a general transformation method/parameter for dataset objects.
  • Multi-label support for the relation classifier (properly very niche).
  • Add a tutorial similar to NER + RE #2726
  • Add pre-trained models similar to the existing relation extractor

I could also add some more benchmarks if desired.

@helpmefindaname
Copy link
Collaborator

hi @dobbersc as you introduce some kind of special tokens, have you tried adding them specifically to the vocabulary of the transformer embeddings? You could do this by adding something like:

if isinstance(self.document_embeddings, TransformerDocumentEmbeddings) :
    special_tokens_dict = {'additional_special_tokens': ['[T-ORG]','[H-ORG]','[R-ORG]','[T-LOC]','[H-LOC]','[R-LOC]']} # cross product of "T-", "H-", "R-" and all entity types
    num_added_toks = self.document_embeddings.tokenizer.add_special_tokens(special_tokens_dict)  # already handles possible duplication of tokens, so no need to check if they were added already.
    self.document_embeddings.model.resize_token_embeddings(len(tokenizer)  # keeps the previous embeddings and adds random initialisation otherwise

in the __init__ of the MaskedRelationClassifier.

I would be interested if that yields even some more improvements

# Conflicts:
#	tests/test_relation_classifier.py
@dobbersc
Copy link
Collaborator Author

dobbersc commented Jul 6, 2022

Hey @helpmefindaname, thank you for your advice. These are the results I got on CONLL04 (same hyperparameters and training script as in the PR description):

Masks Additional Special Tokens F1-Score (micro) F1-Score (macro) Accuracy
with label
([H-PER], [T-PER], [R-PER], ...)
no 0.8010 0.8140 0.9179
with label
([H-PER], [T-PER], [R-PER], ...)
yes 0.7760 0.7923 0.9016
without label
([HEAD], [TAIL], [REMAINDER])
no 0.7306 0.7509 0.8947
without label
([HEAD], [TAIL], [REMAINDER])
yes 0.7211 0.7423 0.8963

Unfortunately, I get worse scores with added special tokens.

Initially, when I tested the first two configurations (with label masks), I suspected that distilbert associates a meaning to the labels PER, ORG, LOC, etc., that is useful for this task. Since for added special tokens, distilbert initializes its embedding layer's weights with random values I assumed that this information is now lost and has to be re-learned. But after the last two configurations (without label masks), I don't get why the scores are decreasing, when I add the special tokens. Do you have any ideas here?

Code I added to the `__init__` (only works for CONLL04)...
# Add the cross-product of "H-", "T-", "R-" and all entity types
if isinstance(self.document_embeddings, TransformerDocumentEmbeddings):
    special_tokens: List[str] = [
        mask_func(label)
        for mask_func, label in itertools.product(
            [self._label_aware_head_mask, self._label_aware_tail_mask, self._label_aware_remainder_mask],
            ["Loc", "Peop", "Org"],  # TODO: Retrieve these dynamically
        )
    ]

    tokenizer = self.document_embeddings.tokenizer
    num_added_tokens = tokenizer.add_special_tokens(
        {"additional_special_tokens": special_tokens}
    )
    self.document_embeddings.model.resize_token_embeddings(len(tokenizer))

    log.info(
        f"{self.__class__.__name__}: "
        f"Added {num_added_tokens} {special_tokens} additional special tokens to {self.document_embeddings.name}"
    )

@helpmefindaname
Copy link
Collaborator

Hi @dobbersc
interesting and surprising results.
looking at the tokenization without special tokens:

'[',
 'h',
 '-',
 'lo',
 '##c',
 ']',
 'and',
 '[',
 't',
 '-',
 'lo',
 '##c',
 ']',

we see that there are many overlapping tokens.

Maybe it goes in the direction that it uses tokens like [ ] - to specifically learn information that represents a token, while the token type has an extra encoding and the label also a different one. In comparison, the special tokens would need to learn that representation independently of each other, which is, of course, harder.

Respectively, I suppose that if we have [Head] [TAIL] and [Reminder] will have a shared token representation if [ ] learn specific token representations.

That said, I would have the theory (I can check them myself, if you give me a week), that it might be beneficial to introduce special tokens that represent only a part. Let's say we encode the tokens as [TOKEN-{HEAD|TAIL|REMAINDER}-{Per|Peop|Org}] and add the following special tokens: [TOKEN -HEAD -TAIL -REMAINDER -Per] -Peop] -Org] it might perform better as now there would be dedicated tokens for each type of learning allowing shared representations but also would be less confused by random [ and ] tokens (I just assume that this might be an issue, I didn't test it).

Another completely unrelated idea: As you mentioned that the Labels might be useful for the task: You could add a mechanism to rename the labels e.g. LOC to Location, PER to Person as the latter might have more semantic meaning. That trick was used at the TARS paper and as we see a good increase by label information, the replacement might improve it even a bit more.

@alanakbik
Copy link
Collaborator

Hello @dobbersc @helpmefindaname very interesting results and discussion!

Regarding renaming labels: With the new corpus logic, renaming of any label is now easy and requires only to define a label_name_map. See this example of a sentence with and without mapped labels:

### EXAMPLE 1: load WITHOUT label map
corpus = RE_ENGLISH_CONLL04()

# get example sentence
example_sentence: Sentence = corpus.train[1]

# print sentence text, its NER labels, and its relations
print(example_sentence.text)
for entity in example_sentence.get_labels('ner'):
    print(entity)
for relation in example_sentence.get_labels('relation'):
    print(relation)


### EXAMPLE 2: load WITH label map
corpus = RE_ENGLISH_CONLL04(label_name_map={'Loc': 'Location',
                                            'Peop': 'Person',
                                            'Org': 'Organization',
                                            'Live_In': 'Lives in'})
# get example sentence
example_sentence: Sentence = corpus.train[1]

# print sentence text, its NER labels, and its relations
print(example_sentence.text)
for entity in example_sentence.get_labels('ner'):
    print(entity)
for relation in example_sentence.get_labels('relation'):
    print(relation)

@helpmefindaname
Copy link
Collaborator

Hi again,

I did some testing and basically all my ideas lead to an decrease of scores, here are my runs, all with some adjustments in the tokens :

with label [H-PER], [T-PER], [R-PER], .... no special token:

0.7985
0.8143
0.9153

with label [H-PERSON], [T-PERSON], [R-PERSON], .... no special token:
(rename labels for better naming)

0.7889
0.8037
0.9121

with label [TOKEN-HEAD-PER], [TOKEN-TAIL-PER], [TOKEN-RREMAINDER-PER], .... special tokens ("[TOKEN", "-HEAD-", "-TAIL-", "-REMAINDER-", "PER]", "LOC]", "ORG]",:

0.792
0.8078
0.9121

with label [H-PER], [T-PER], [R-PER], .... special tokens ("PER", "LOC", "ORG"):

0.7839
0.7998
0.9095

with label [TOKEN-HEAD-PER], [TOKEN-TAIL-PER], [TOKEN-RREMAINDER-PER], .... special tokens ("[TOKEN", "-HEAD-", "-TAIL-", "-REMAINDER-",:

0.7854
0.7982
0.9105

@alanakbik
Copy link
Collaborator

Thanks for sharing these interesting results @helpmefindaname!

I wonder if the masking of the "remainder" NER tags could be a problem since they share many subtokens with HEAD and TAIL tags - and anyway only the head and tail NER are really relevant and what the algorithm should be focusing on.

@dobbersc could you do some training runs to evaluate the impact on accuracy of whether mask_remainder is set to True or False?

@dobbersc
Copy link
Collaborator Author

dobbersc commented Jul 12, 2022

@dobbersc could you do some training runs to evaluate the impact on the accuracy of whether mask_remainder is set to True or False?

I don't have any concrete numbers saved from my runs while experimenting with the model on the fly. But in general, with mask_remainder set to true, it actually improved the performance. I'll get back to you with some concrete results.

@dobbersc
Copy link
Collaborator Author

dobbersc commented Jul 17, 2022

Sorry that I only post now. I did some more masking experimentation similar to @helpmefindaname but nothing that got better scores. Here are some scores for the mask_remainder setting:

Mask Pattern mask_remainder Additional Special Tokens F1-Score (micro) F1-Score (macro) Accuracy
[H/T/R-\<Label>] True None 0.801 0.814 0.9179
[H/T-\<Label>] False None 0.7921 0.8091 0.9116
[H/T/R-\<Expanded Label>] ([H-PERSON], etc.) True None 0.7869 0.8 0.9111
[H/T-\<Expanded Label>] ([H-PERSON], etc.) False None 0.794 0.8097 0.9126
[ENTITY-HEAD/TAIL/REMAINDER-\<Label>]
Example: [ENTITY-HEAD-Peop]
True [ENTITY, -HEAD-, -TAIL-, -REMAINDER- 0.7897 0.8039 0.9095
[ENTITY-HEAD/TAIL-\<Label>]
Example: [ENTITY-HEAD-Peop]
False [ENTITY, -HEAD-, -TAIL- 0.798 0.8114 0.9153
[ENTITY-HEAD/TAIL/REMAINDER-\<Label>]
Example: [ENTITY-HEAD-Peop]
True [ENTITY, -HEAD-, -TAIL-, -REMAINDER-, Peop], Loc], Org], Other] 0.7928 0.8064 0.9153
[ENTITY-HEAD/TAIL-\<Label>]
Example: [ENTITY-HEAD-Peop]
False [ENTITY, -HEAD-, -TAIL-, Peop], Loc], Org], Other] 0.798 0.8125 0.9153
[HEAD/TAIL/REMAINDER-\<LABEL>-ENTITY]
Example: [HEAD-Peop-ENTITY]
True [HEAD-, [TAIL-, [REMAINDER-, -ENTITY] 0.7823 0.7977 0.9068
[HEAD/TAIL-\<LABEL>-ENTITY]
Example: [HEAD-Peop-ENTITY]
False [HEAD-, [TAIL-, -ENTITY] 0.7977 0.8069 0.9111

For the original masking strategy ([H-PER], etc.) the run with mask_remainder=True yields better results. But for the other training runs where we experimented with adding special tokens, mask_remainder=False scores a bit better.
Since the difference in scores is very little, would it be helpful to evaluate the model on some more benchmarks?

@alanakbik
Copy link
Collaborator

@dobbersc thanks a lot for this great implementation of masked relation extraction! From my experiments, it looks to significantly outperform our prior relation extractor. I'll probably train one or two models to include with the next Flair release.

@alanakbik alanakbik merged commit 8830564 into master Nov 6, 2022
@alanakbik alanakbik deleted the masked-relation-classifier branch November 6, 2022 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Improving of an existing feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants