Skip to content

Commit

Permalink
fix: Entity augmenter now allows for passing in spans or string
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed May 1, 2023
1 parent 80e85cb commit faf2abc
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 21 deletions.
103 changes: 84 additions & 19 deletions src/augmenty/span/entities.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,91 @@
import random
from functools import partial
from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Union,
)

import numpy as np
import spacy
from spacy.language import Language
from spacy.tokens import Token
from spacy.tokens import Span, Token
from spacy.training import Example
from spacy.util import registry

from ..augment_utilities import make_text_from_orth

# create entity type
ENTITY = Union[str, List[str], Span]


def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]:
spacy = None
pos = None
tag = None
morph = None
lemma = None

if isinstance(entity, str):
ent_doc = nlp(entity)
orth = [tok.text for tok in ent_doc]
spacy = [tok.whitespace_ for tok in ent_doc]
elif isinstance(entity, list):
orth = entity
elif isinstance(entity, Span):
orth = [tok.text for tok in entity]
spacy = [tok.whitespace_ for tok in entity]
pos = [tok.pos_ for tok in entity]
tag = [tok.tag_ for tok in entity]
morph = [tok.morph for tok in entity]
lemma = [tok.lemma_ for tok in entity]
else:
raise ValueError(
f"entity must be of type str, List[str] or Span, not {type(entity)}",
)
# if not specifed use default values
if spacy is None:
spacy = [True] * len(orth)
if pos is None:
pos = ["PROPN"] * len(orth)
if tag is None:
tag = ["PROPN"] * len(orth)
if morph is None:
morph = [""] * len(orth)
if lemma is None:
lemma = orth

return {
"ORTH": orth,
"SPACY": spacy,
"POS": pos,
"TAG": tag,
"MORPH": morph,
"LEMMA": lemma,
}


def ent_augmenter_v1(
nlp: Language,
example: Example,
level: float,
ent_dict: Dict[str, Iterable[List[str]]],
ent_dict: Dict[str, Iterable[ENTITY]],
replace_consistency: bool,
resolve_dependencies: bool,
) -> Iterator[Example]:
replaced_ents = {} # type: Dict[str, List[str]]
replaced_ents: Dict[str, ENTITY] = {}
example_dict = example.to_dict()

offset = 0

tok_anno = example_dict["token_annotation"]
ents = example_dict["doc_annotation"]["entities"]
if example.y.has_annotation("HEAD"):
if example.y.has_annotation("HEAD") and resolve_dependencies:
head = np.array(tok_anno["HEAD"])

for ent in example.y.ents:
Expand All @@ -43,24 +102,28 @@ def ent_augmenter_v1(
if replace_consistency:
replaced_ents[ent.text] = new_ent

normalized_ent = __normalize_entity(new_ent, nlp)
new_ent = normalized_ent["ORTH"]
spacing = normalized_ent["SPACY"]

# Handle token annotations
len_ent = len(new_ent)
i = slice(ent.start + offset, ent.end + offset)
tok_anno["ORTH"][i] = new_ent
tok_anno["LEMMA"][i] = new_ent
tok_anno["LEMMA"][i] = normalized_ent["LEMMA"]

tok_anno["TAG"][i] = ["PROPN"] * len_ent
tok_anno["POS"][i] = ["PROPN"] * len_ent
tok_anno["POS"][i] = normalized_ent["POS"]
tok_anno["TAG"][i] = normalized_ent["TAG"]

tok_anno["MORPH"][i] = [""] * len_ent
tok_anno["MORPH"][i] = normalized_ent["MORPH"]
tok_anno["DEP"][i] = [ent[0].dep_] + ["flat"] * (len_ent - 1)

# Set sentence start based on first token in previous entity
tok_anno["SENT_START"][i] = [ent[0].sent_start] + [0] * (len_ent - 1)

# set spacing to be whitespace for all tokens except the last one
# which is set based on the original entity
tok_anno["SPACY"][i] = [True] * (len_ent - 1) + [bool(ent[-1].whitespace_)]
# set the last spacing to be equal to the last token spacing in the previous entity
spacing[-1:] = [ent[-1].whitespace_]
tok_anno["SPACY"][i] = spacing

offset_ = len_ent - (ent.end - ent.start)
if example.y.has_annotation("HEAD") and resolve_dependencies:
Expand Down Expand Up @@ -102,10 +165,10 @@ def ent_augmenter_v1(
yield Example.from_dict(doc, example_dict)


@spacy.registry.augmenters("ents_replace_v1")
@registry.augmenters("ents_replace_v1")
def create_ent_augmenter_v1(
level: float,
ent_dict: Dict[str, Iterable[List[str]]],
ent_dict: Dict[str, Iterable[ENTITY]],
replace_consistency: bool = True,
resolve_dependencies: bool = True,
) -> Callable[[Language, Example], Iterator[Example]]:
Expand All @@ -116,8 +179,10 @@ def create_ent_augmenter_v1(
level: the percentage of entities to be augmented.
ent_dict: A dictionary with keys corresponding
the the entity type you wish to replace (e.g. "PER") and a itarable of the
replacements. A replacement is a list of string of the desired entity
replacement ["Kenneth", "Enevoldsen"].
replacements entities. A replacement can be either 1) a list of string of the desired entity
i.e. ["Kenneth", "Enevoldsen"], 2) a string of the desired entity i.e. "Kenneth Enevoldsen", this
will be split using the tokenizer of the nlp pipeline, or 3) Span object with the desired entity, here all information will be passed
on except for the dependency tree.
replace_consistency: Should an entity always be replaced with
the same entity? Defaults to True.
resolve_dependencies: Attempts to resolve the dependency tree
Expand Down Expand Up @@ -160,7 +225,7 @@ def generator_from_name_dict(
]


@spacy.registry.augmenters("per_replace_v1")
@registry.augmenters("per_replace_v1")
def create_per_replace_augmenter_v1(
names: Dict[
str,
Expand Down Expand Up @@ -256,7 +321,7 @@ def ent_format_augmenter_v1(
yield Example.from_dict(doc, example_dict)


@spacy.registry.augmenters("ents_format_v1")
@registry.augmenters("ents_format_v1")
def create_ent_format_augmenter_v1(
reordering: List[Union[int, None]],
formatter: List[Union[Callable[[Token], str], None]],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_issue_170.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ def example_doc(nlp) -> Doc:
def test_entity_with_no_dep(nlp, example_doc: Doc):
level = 1.0
docs = [example_doc]
ents_as_str = ["Melvin R. Brown"]
augmenter = augmenty.load(
"ents_replace_v1",
level=level,
ent_dict={"pers": [[s] for s in ents_as_str]},
ent_dict={"pers": ["Melvin R. Brown"]},
replace_consistency=True,
resolve_dependencies=True,
)
Expand All @@ -72,3 +71,4 @@ def test_entity_with_no_dep(nlp, example_doc: Doc):
aug_doc.text
== "Melvin R. Brown and Melvin R. Brown (concussion protocol) are each progressing. SS Melvin R. Brown"
)
assert aug_doc[0].text == "Melvin"

0 comments on commit faf2abc

Please sign in to comment.