Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix EL failure with sentence-crossing entities #12398

Merged
merged 16 commits into from
Mar 14, 2023
14 changes: 10 additions & 4 deletions spacy/pipeline/entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,18 +474,24 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:

# Looping through each entity in batch (TODO: rewrite)
for j, ent in enumerate(ent_batch):
sent_index = sentences.index(ent.sent)
assert sent_index >= 0
assert hasattr(ent, "sents")
sents = list(ent.sents)
sent_indices = (
sentences.index(sents[0]),
sentences.index(sents[-1]),
)
assert sent_indices[1] >= sent_indices[0] >= 0

if self.incl_context:
# get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents)
start_sentence = max(0, sent_indices[0] - self.n_sents)
end_sentence = min(
len(sentences) - 1, sent_index + self.n_sents
len(sentences) - 1, sent_indices[1] + self.n_sents
)
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()

# currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T
Expand Down
50 changes: 19 additions & 31 deletions spacy/tests/pipeline/test_entity_linker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Callable, Iterable, Dict, Any
from typing import Callable, Iterable, Dict, Any, Tuple

import pytest
from numpy.testing import assert_equal

from spacy import registry, util
from spacy import registry, util, Language
from spacy.attrs import ENT_KB_ID
from spacy.compat import pickle
from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase
Expand Down Expand Up @@ -108,18 +108,23 @@ def test_issue7065():


@pytest.mark.issue(7065)
def test_issue7065_b():
@pytest.mark.parametrize("entity_in_first_sentence", [True, False])
def test_sentence_crossing_ents(entity_in_first_sentence: bool):
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
entity.
entity_in_prior_sentence (bool): Whether to include an entity in the first sentence associated with the
sentence-crossing entity.
"""
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
nlp = English()
vector_length = 3
nlp.add_pipe("sentencizer")
text = "Mahler 's Symphony No. 8 was beautiful."
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
links = {
(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
}
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
entities = [(10, 24, "WORK")]
links = {(10, 24): {"Q7304": 0.0, "Q270853": 1.0}}
if entity_in_first_sentence:
entities.append((0, 6, "PERSON"))
links[(0, 6)] = {"Q7304": 1.0, "Q270853": 0.0}
sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0]
doc = nlp(text)
example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
Expand All @@ -145,31 +150,14 @@ def create_kb(vocab):

# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb)
entity_linker.set_kb(create_kb) # type: ignore
# train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
nlp.update(train_examples, sgd=optimizer)

# Add a custom rule-based component to mimick NER
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
{
"label": "WORK",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
# test the trained model - this should not throw E148
doc = nlp(text)
assert doc
# This shouldn't crash.
entity_linker.predict([example.reference]) # type: ignore


def test_no_entities():
Expand Down