Skip to content

Commit

Permalink
feat: keep span annotations for ents
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Sep 19, 2023
1 parent f996c77 commit 8dcafc5
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 10 deletions.
80 changes: 70 additions & 10 deletions src/augmenty/span/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,26 @@
from spacy.training import Example
from spacy.util import registry

from augmenty import span

from ..augment_utilities import make_text_from_orth
from .utils import offset_range

# create entity type
ENTITY = Union[str, List[str], Span, Doc]


def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]:
def _spacing_to_str(spacing: Union[List[str], List[bool]]) -> List[str]:
def to_string(x: Union[str, bool]) -> str:
if isinstance(x, str):
return x
else:
return " " if x else ""

return [to_string(x) for x in spacing]


def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, Any]:
spacy = None
pos = None
tag = None
Expand All @@ -50,7 +63,7 @@ def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]:
)
# if not specifed use default values
if spacy is None:
spacy = [True] * len(orth)
spacy = [" "] * len(orth)
if pos is None:
pos = ["PROPN"] * len(orth)
if tag is None:
Expand All @@ -60,16 +73,47 @@ def __normalize_entity(entity: ENTITY, nlp: Language) -> Dict[str, List[Any]]:
if lemma is None:
lemma = orth

_spacy = _spacing_to_str(spacy)
str_repr = ""
for e, s in zip(orth[:-1], _spacy[:-1]):
str_repr += e + s
str_repr += orth[-1]

return {
"ORTH": orth,
"SPACY": spacy,
"POS": pos,
"TAG": tag,
"MORPH": morph,
"LEMMA": lemma,
"STR": str_repr,
}


def _update_span_annotations(
span_anno: Dict[str, list],
ent: Span,
offset: int,
entity_offset: int,
) -> Dict[str, list]:
"""Update the span annotations to be in line with the new doc."""
ent_range = (ent.start + offset, ent.end + offset)

for anno_key, spans in span_anno.items():
new_spans = []
for span_start, span_end, _, __ in spans:
span_start, span_end = offset_range(
current_range=(span_start, span_end),
inserted_range=ent_range,
offset=entity_offset,
)
new_spans.append((span_start, span_end, _, __))

span_anno[anno_key] = new_spans

return span_anno


def ent_augmenter_v1(
nlp: Language,
example: Example,
Expand All @@ -82,10 +126,14 @@ def ent_augmenter_v1(
example_dict = example.to_dict()

offset = 0
str_offset = 0

spans_anno = example_dict["doc_annotation"]["spans"]
tok_anno = example_dict["token_annotation"]
ents = example_dict["doc_annotation"]["entities"]
if example.y.has_annotation("HEAD") and resolve_dependencies:

should_update_heads = example.y.has_annotation("HEAD") and resolve_dependencies
if should_update_heads:
head = np.array(tok_anno["HEAD"])

for ent in example.y.ents:
Expand All @@ -105,10 +153,13 @@ def ent_augmenter_v1(
normalized_ent = __normalize_entity(new_ent, nlp)
new_ent = normalized_ent["ORTH"]
spacing = normalized_ent["SPACY"]
str_ent = normalized_ent["STR"]

# Handle token annotations
len_ent = len(new_ent)
i = slice(ent.start + offset, ent.end + offset)
str_len_ent = len(str_ent)
ent_range = (ent.start + offset, ent.end + offset)
i = slice(*ent_range)
tok_anno["ORTH"][i] = new_ent
tok_anno["LEMMA"][i] = normalized_ent["LEMMA"]

Expand All @@ -125,11 +176,12 @@ def ent_augmenter_v1(
spacing[-1:] = [ent[-1].whitespace_]
tok_anno["SPACY"][i] = spacing

offset_ = len_ent - (ent.end - ent.start)
if example.y.has_annotation("HEAD") and resolve_dependencies:
entity_offset = len_ent - (ent.end - ent.start)
entity_str_offset = str_len_ent - len(ent.text)
if should_update_heads:
# Handle HEAD

head[head > ent.start + offset] += offset_
head[head > ent.start + offset] += entity_offset
# keep first head correcting for changing entity size, set rest to
# refer to index of first name
head = np.concatenate(
Expand All @@ -142,7 +194,15 @@ def ent_augmenter_v1(
np.array(head[ent.end + offset :]), # after
],
)
offset += offset_

spans_anno = _update_span_annotations(
spans_anno,
ent,
str_offset,
entity_str_offset,
)
offset += entity_offset
str_offset += entity_str_offset

# Handle entities IOB tags
if len_ent == 1:
Expand All @@ -154,8 +214,8 @@ def ent_augmenter_v1(
+ ["L-" + ent.label_]
)

if example.y.has_annotation("HEAD") and resolve_dependencies:
tok_anno["HEAD"] = head.tolist()
if should_update_heads:
tok_anno["HEAD"] = head.tolist() # type: ignore
else:
tok_anno["HEAD"] = list(range(len(tok_anno["ORTH"])))

Expand Down
39 changes: 39 additions & 0 deletions src/augmenty/span/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Tuple


def offset_range(
current_range: Tuple[int, int],
inserted_range: Tuple[int, int],
offset: int,
) -> Tuple[int, int]:
"""Update current range based on inserted range and previous range.
Args:
current_range: The range you wish the indices to be updated for.
inserted_range: The range of the inserted range.
offset: The offset to apply to the current range.
"""

start, end = current_range

if offset == 0:
return current_range

is_within_range = (
inserted_range[0] <= start <= inserted_range[1]
or inserted_range[0] <= end <= inserted_range[1]
)
if is_within_range:
return start, end + offset

is_before_range = start < inserted_range[0]
if is_before_range:
return start, end

is_after_range = end > inserted_range[1]
if is_after_range:
return start + offset, end + offset

raise ValueError(
f"Current range {current_range} is not within inserted range {inserted_range}",
)
83 changes: 83 additions & 0 deletions tests/test_spans.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,93 @@
from typing import Callable

import pytest
from spacy.language import Language
from spacy.tokens import Doc

import augmenty

from .fixtures import nlp_en, nlp_en_md # noqa


@pytest.fixture
def doc(nlp_en: Language) -> Doc: # noqa
doc = Doc(
nlp_en.vocab,
words=[
"Augmenty",
"is",
"a",
"wonderful",
"tool",
"for",
"augmentation",
".",
],
spaces=[True] * 6 + [False] * 2,
ents=["B-ORG"] + ["O"] * 7,
)
return doc


@pytest.fixture
def ent_augmenter():
ent_augmenter = augmenty.load(
"ents_replace_v1", # type: ignore
level=1.00,
ent_dict={"ORG": [["SpaCy"]]},
)
return ent_augmenter


@pytest.mark.parametrize(
"nlp",
[
pytest.lazy_fixture("nlp_en"),
pytest.lazy_fixture("nlp_en_md"),
],
)
def test_ent_replace_with_span_annotations(
doc: Doc,
ent_augmenter: Callable,
nlp: Language,
):
# add span annotations
positive_noun_chunks = [doc[3:5]]
is_augmenty = [doc[0:1]]
doc.spans["positive_noun_chunks"] = positive_noun_chunks
doc.spans["is_augmenty"] = is_augmenty

docs = list(augmenty.docs([doc], augmenter=ent_augmenter, nlp=nlp))

# Check spans
doc_pos_noun_chunks = docs[0].spans["positive_noun_chunks"]
assert doc_pos_noun_chunks[0].text == "wonderful tool", "the span is not maintained"

doc_is_augmenty = docs[0].spans["is_augmenty"]
assert doc_is_augmenty[0].text == "SpaCy", "the span is not maintained"


@pytest.mark.parametrize(
"nlp",
[
pytest.lazy_fixture("nlp_en"),
pytest.lazy_fixture("nlp_en_md"),
],
)
def test_ent_replace_with_cats_annotations(
doc: Doc,
ent_augmenter: Callable,
nlp: Language,
):
# add doc annotations
doc.cats["is_positive"] = 1

# augment
docs = list(augmenty.docs([doc], augmenter=ent_augmenter, nlp=nlp))

assert docs[0].cats["is_positive"] == 1.0, "the document category is not maintained"


def test_create_ent_replace(nlp_en_md, nlp_en): # noqa F811
doc = Doc(
nlp_en.vocab,
Expand Down

0 comments on commit 8dcafc5

Please sign in to comment.