From faf2abcfd05b3efe1a954a4d5e1b09b7c65ae50b Mon Sep 17 00:00:00 2001 From: Kenneth Enevoldsen Date: Mon, 1 May 2023 15:28:17 -0700 Subject: [PATCH] fix: Entity augmenter now allows for passing in spans or string --- src/augmenty/span/entities.py | 103 +++++++++++++++++++++++++++------- tests/test_issue_170.py | 4 +- 2 files changed, 86 insertions(+), 21 deletions(-) diff --git a/src/augmenty/span/entities.py b/src/augmenty/span/entities.py index 2bfc73b3..70b9317f 100644 --- a/src/augmenty/span/entities.py +++ b/src/augmenty/span/entities.py @@ -1,32 +1,91 @@ import random from functools import partial -from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + Optional, + Union, +) import numpy as np -import spacy from spacy.language import Language -from spacy.tokens import Token +from spacy.tokens import Span, Token from spacy.training import Example +from spacy.util import registry from ..augment_utilities import make_text_from_orth +# create entity type +ENTITY = Union[str, List[str], Span] + + +def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]: + spacy = None + pos = None + tag = None + morph = None + lemma = None + + if isinstance(entity, str): + ent_doc = nlp(entity) + orth = [tok.text for tok in ent_doc] + spacy = [tok.whitespace_ for tok in ent_doc] + elif isinstance(entity, list): + orth = entity + elif isinstance(entity, Span): + orth = [tok.text for tok in entity] + spacy = [tok.whitespace_ for tok in entity] + pos = [tok.pos_ for tok in entity] + tag = [tok.tag_ for tok in entity] + morph = [tok.morph for tok in entity] + lemma = [tok.lemma_ for tok in entity] + else: + raise ValueError( + f"entity must be of type str, List[str] or Span, not {type(entity)}", + ) + # if not specifed use default values + if spacy is None: + spacy = [True] * len(orth) + if pos is None: + pos = ["PROPN"] * len(orth) + if tag is None: + tag = ["PROPN"] * len(orth) + if morph is None: + morph = [""] * len(orth) + if lemma is None: + lemma = orth + + return { + "ORTH": orth, + "SPACY": spacy, + "POS": pos, + "TAG": tag, + "MORPH": morph, + "LEMMA": lemma, + } + def ent_augmenter_v1( nlp: Language, example: Example, level: float, - ent_dict: Dict[str, Iterable[List[str]]], + ent_dict: Dict[str, Iterable[ENTITY]], replace_consistency: bool, resolve_dependencies: bool, ) -> Iterator[Example]: - replaced_ents = {} # type: Dict[str, List[str]] + replaced_ents: Dict[str, ENTITY] = {} example_dict = example.to_dict() offset = 0 tok_anno = example_dict["token_annotation"] ents = example_dict["doc_annotation"]["entities"] - if example.y.has_annotation("HEAD"): + if example.y.has_annotation("HEAD") and resolve_dependencies: head = np.array(tok_anno["HEAD"]) for ent in example.y.ents: @@ -43,24 +102,28 @@ def ent_augmenter_v1( if replace_consistency: replaced_ents[ent.text] = new_ent + normalized_ent = __normalize_entity(new_ent, nlp) + new_ent = normalized_ent["ORTH"] + spacing = normalized_ent["SPACY"] + # Handle token annotations len_ent = len(new_ent) i = slice(ent.start + offset, ent.end + offset) tok_anno["ORTH"][i] = new_ent - tok_anno["LEMMA"][i] = new_ent + tok_anno["LEMMA"][i] = normalized_ent["LEMMA"] - tok_anno["TAG"][i] = ["PROPN"] * len_ent - tok_anno["POS"][i] = ["PROPN"] * len_ent + tok_anno["POS"][i] = normalized_ent["POS"] + tok_anno["TAG"][i] = normalized_ent["TAG"] - tok_anno["MORPH"][i] = [""] * len_ent + tok_anno["MORPH"][i] = normalized_ent["MORPH"] tok_anno["DEP"][i] = [ent[0].dep_] + ["flat"] * (len_ent - 1) # Set sentence start based on first token in previous entity tok_anno["SENT_START"][i] = [ent[0].sent_start] + [0] * (len_ent - 1) - # set spacing to be whitespace for all tokens except the last one - # which is set based on the original entity - tok_anno["SPACY"][i] = [True] * (len_ent - 1) + [bool(ent[-1].whitespace_)] + # set the last spacing to be equal to the last token spacing in the previous entity + spacing[-1:] = [ent[-1].whitespace_] + tok_anno["SPACY"][i] = spacing offset_ = len_ent - (ent.end - ent.start) if example.y.has_annotation("HEAD") and resolve_dependencies: @@ -102,10 +165,10 @@ def ent_augmenter_v1( yield Example.from_dict(doc, example_dict) -@spacy.registry.augmenters("ents_replace_v1") +@registry.augmenters("ents_replace_v1") def create_ent_augmenter_v1( level: float, - ent_dict: Dict[str, Iterable[List[str]]], + ent_dict: Dict[str, Iterable[ENTITY]], replace_consistency: bool = True, resolve_dependencies: bool = True, ) -> Callable[[Language, Example], Iterator[Example]]: @@ -116,8 +179,10 @@ def create_ent_augmenter_v1( level: the percentage of entities to be augmented. ent_dict: A dictionary with keys corresponding the the entity type you wish to replace (e.g. "PER") and a itarable of the - replacements. A replacement is a list of string of the desired entity - replacement ["Kenneth", "Enevoldsen"]. + replacements entities. A replacement can be either 1) a list of string of the desired entity + i.e. ["Kenneth", "Enevoldsen"], 2) a string of the desired entity i.e. "Kenneth Enevoldsen", this + will be split using the tokenizer of the nlp pipeline, or 3) Span object with the desired entity, here all information will be passed + on except for the dependency tree. replace_consistency: Should an entity always be replaced with the same entity? Defaults to True. resolve_dependencies: Attempts to resolve the dependency tree @@ -160,7 +225,7 @@ def generator_from_name_dict( ] -@spacy.registry.augmenters("per_replace_v1") +@registry.augmenters("per_replace_v1") def create_per_replace_augmenter_v1( names: Dict[ str, @@ -256,7 +321,7 @@ def ent_format_augmenter_v1( yield Example.from_dict(doc, example_dict) -@spacy.registry.augmenters("ents_format_v1") +@registry.augmenters("ents_format_v1") def create_ent_format_augmenter_v1( reordering: List[Union[int, None]], formatter: List[Union[Callable[[Token], str], None]], diff --git a/tests/test_issue_170.py b/tests/test_issue_170.py index 47e7b409..56a76be7 100644 --- a/tests/test_issue_170.py +++ b/tests/test_issue_170.py @@ -58,11 +58,10 @@ def example_doc(nlp) -> Doc: def test_entity_with_no_dep(nlp, example_doc: Doc): level = 1.0 docs = [example_doc] - ents_as_str = ["Melvin R. Brown"] augmenter = augmenty.load( "ents_replace_v1", level=level, - ent_dict={"pers": [[s] for s in ents_as_str]}, + ent_dict={"pers": ["Melvin R. Brown"]}, replace_consistency=True, resolve_dependencies=True, ) @@ -72,3 +71,4 @@ def test_entity_with_no_dep(nlp, example_doc: Doc): aug_doc.text == "Melvin R. Brown and Melvin R. Brown (concussion protocol) are each progressing. SS Melvin R. Brown" ) + assert aug_doc[0].text == "Melvin"