-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2533 from MaxDall/master
Implements a RegexpTagger
- Loading branch information
Showing
1 changed file
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import re | ||
from dataclasses import dataclass, field | ||
from typing import List, Tuple, Dict, Union | ||
|
||
from flair.data import Sentence, Span, SpanLabel, Token | ||
|
||
|
||
class RegexpTagger: | ||
@dataclass | ||
class TokenCollection: | ||
""" | ||
A utility class for RegexpTagger to hold all tokens for a given Sentence and define some functionality | ||
:param sentence: A Sentence object | ||
""" | ||
sentence: Sentence | ||
__tokens_start_pos: List[int] = field(init=False, default_factory=list) | ||
__tokens_end_pos: List[int] = field(init=False, default_factory=list) | ||
|
||
def __post_init__(self): | ||
for token in self.tokens: | ||
self.__tokens_start_pos.append(token.start_pos) | ||
self.__tokens_end_pos.append(token.end_pos) | ||
|
||
@property | ||
def tokens(self) -> List[Token]: | ||
return list(self.sentence) | ||
|
||
def get_token_span(self, span: Tuple[int, int]) -> Span: | ||
""" | ||
Given an interval specified with start and end pos as tuple, this function returns a Span object | ||
spanning the tokens included in the interval. If the interval is overlapping with a token span, a | ||
ValueError is raised | ||
:param span: Start and end pos of the requested span as tuple | ||
:return: A span object spanning the requested token interval | ||
""" | ||
span_start: int = self.__tokens_start_pos.index(span[0]) | ||
span_end: int = self.__tokens_end_pos.index(span[1]) | ||
return Span(self.tokens[span_start:span_end + 1]) | ||
|
||
def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]): | ||
""" | ||
This tagger is capable of tagging sentence objects with given regexp -> label mappings. | ||
I.e: The tuple (r'(["\'])(?:(?=(\\?))\2.)*?\1', 'QUOTE') maps every match of the regexp to | ||
a <QUOTE> labeled span and therefore labels the given sentence object with RegexpTagger.predict(). | ||
This tagger supports multilabeling so tokens can be included in multiple labeled spans. | ||
The regexp are compiled internally and an re.error will be raised if the compilation of a given regexp fails. | ||
If a match violates (in this case overlaps) a token span, an exception is raised. | ||
:param mapping: A list of tuples or a single tuple representing a mapping as regexp -> label | ||
""" | ||
self._regexp_mapping: Dict[str, re.Pattern] = {} | ||
self.register_labels(mapping=mapping) | ||
|
||
@property | ||
def registered_labels(self): | ||
return self._regexp_mapping | ||
|
||
def register_labels(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]): | ||
""" | ||
Register a regexp -> label mapping. | ||
:param mapping: A list of tuples or a single tuple representing a mapping as regexp -> label | ||
""" | ||
mapping = self._listify(mapping) | ||
|
||
for regexp, label in mapping: | ||
try: | ||
self._regexp_mapping[label] = re.compile(regexp) | ||
except re.error as err: | ||
raise re.error( | ||
f"Couldn't compile regexp '{regexp}' for label '{label}'. Aborted with error: '{err.msg}'") | ||
|
||
def remove_labels(self, labels: Union[List[str], str]): | ||
""" | ||
Remove a registered regexp -> label mapping given by label. | ||
:param labels: A list of labels or a single label as strings. | ||
""" | ||
labels = self._listify(labels) | ||
|
||
for label in labels: | ||
if not self._regexp_mapping.get(label): | ||
continue | ||
self._regexp_mapping.pop(label) | ||
|
||
@staticmethod | ||
def _listify(element: object) -> list: | ||
if not isinstance(element, list): | ||
return [element] | ||
else: | ||
return element | ||
|
||
def predict(self, sentences: Union[List[Sentence], Sentence]) -> List[Sentence]: | ||
""" | ||
Predict the given sentences according to the registered mappings. | ||
""" | ||
if not sentences: | ||
return sentences | ||
|
||
sentences = self._listify(sentences) | ||
for sentence in sentences: | ||
self._label(sentence) | ||
return sentences | ||
|
||
def _label(self, sentence: Sentence): | ||
""" | ||
This will add a complex_label to the given sentence for every match.span() for every registered_mapping. | ||
If a match span overlaps with a token span an exception is raised. | ||
""" | ||
collection = RegexpTagger.TokenCollection(sentence) | ||
|
||
for label, pattern in self._regexp_mapping.items(): | ||
for match in pattern.finditer(sentence.to_original_text()): | ||
span: Tuple[int, int] = match.span() | ||
try: | ||
token_span = collection.get_token_span(span) | ||
except ValueError: | ||
raise Exception(f"The match span {span} for label '{label}' is overlapping with a token!") | ||
sentence.add_complex_label(label, SpanLabel(token_span, label)) |