Skip to content

Commit

Permalink
feat: added context string selectors
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed May 19, 2024
1 parent b756307 commit 78cc968
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 59 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Added a `context_getter` SpanGetter argument to the `eds.matcher` class to only retrieve entities inside the spans returned by the getter
- Added a `filter_expr` parameter to scorers to filter the documents to score
- Added a new `required` field to `eds.contextual_matcher` assign patterns to only match if the required field has been found, and an `include` parameter (similar to `exclude`) to search for required patterns without assigning them to the entity
- Added context strings (e.g., "words[0:5] | sent[0:1]") to the `eds.contextual_matcher` component to allow for more complex patterns in the selection of the window around the trigger spans

### Changed

Expand Down
20 changes: 3 additions & 17 deletions edsnlp/pipes/core/contextual_matcher/contextual_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,23 +252,15 @@ def filter_one(self, span: Span) -> Span:
source = span.label_
to_keep = True
for exclude in self.patterns[source].exclude:
snippet = get_window(
doclike=span,
window=exclude.window,
limit_to_sentence=exclude.limit_to_sentence,
)
snippet = exclude.window(span)

if next(exclude.matcher(snippet, as_spans=True), None) is not None:
to_keep = False
logger.trace(f"Entity {span} was filtered out")
break

for include in self.patterns[source].include:
snippet = get_window(
doclike=span,
window=include.window,
limit_to_sentence=include.limit_to_sentence,
)
snippet = include.window(span)

if next(include.matcher(snippet, as_spans=True), None) is None:
to_keep = False
Expand Down Expand Up @@ -308,13 +300,7 @@ def assign_one(self, span: Span) -> Span:
for assign in self.patterns[source].assign:
assign: SingleAssignModel
window = assign.window
limit_to_sentence = assign.limit_to_sentence

snippet = get_window(
doclike=span,
window=window,
limit_to_sentence=limit_to_sentence,
)
snippet = window(span)

matcher: RegexMatcher = assign.matcher
if matcher is not None:
Expand Down
95 changes: 57 additions & 38 deletions edsnlp/pipes/core/contextual_matcher/models.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,14 @@
import re
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union

import regex
from pydantic import BaseModel, Extra, validator

from edsnlp.matchers.utils import ListOrStr
from edsnlp.utils.span_getters import SpanGetterArg
from edsnlp.utils.span_getters import Context, SentenceContext, SpanGetterArg
from edsnlp.utils.typing import AsList

Flags = Union[re.RegexFlag, int]
Window = Union[
Tuple[int, int],
List[int],
int,
]


def normalize_window(cls, v):
if v is None:
return v
if isinstance(v, list):
assert (
len(v) == 2
), "`window` should be a tuple/list of two integer, or a single integer"
v = tuple(v)
if isinstance(v, int):
assert v != 0, "The provided `window` should not be 0"
if v < 0:
return (v, 0)
if v > 0:
return (0, v)
assert v[0] < v[1], "The provided `window` should contain at least 1 token"
return v


class AssignDict(dict):
Expand Down Expand Up @@ -101,9 +78,10 @@ class SingleExcludeModel(BaseModel):
----------
regex: ListOrStr
A single Regex or a list of Regexes
window: Optional[Window]
window: Optional[Context]
Size of the context to use (in number of words). You can provide the window as:
- A [context string][context-string]
- A positive integer, in this case the used context will be taken **after**
the extraction
- A negative integer, in this case the used context will be taken **before**
Expand All @@ -121,8 +99,8 @@ class SingleExcludeModel(BaseModel):
"""

regex: ListOrStr = []
window: Optional[Window] = None
limit_to_sentence: Optional[bool] = True
limit_to_sentence: Optional[bool] = None
window: Optional[Context] = None
regex_flags: Optional[Flags] = None
regex_attr: Optional[str] = None
matcher: Optional[Any] = None
Expand All @@ -133,7 +111,20 @@ def exclude_regex_validation(cls, v):
v = [v]
return v

_normalize_window = validator("window", allow_reuse=True)(normalize_window)
@validator("limit_to_sentence", pre=True, always=True)
def backward_compat_auto_limit_to_sentence(cls, v, values):
if (
isinstance(values.get("window"), (type(None), int, tuple, list))
and v is None
):
v = True
return v

@validator("window", always=True)
def backward_compat_intersect_sentence(cls, v, values):
if values.get("limit_to_sentence"):
v = v & SentenceContext(0, 0)
return v


class SingleIncludeModel(BaseModel):
Expand All @@ -146,9 +137,10 @@ class SingleIncludeModel(BaseModel):
----------
regex: ListOrStr
A single Regex or a list of Regexes
window: Optional[Window]
window: Optional[Context]
Size of the context to use (in number of words). You can provide the window as:
- A [context string][context-string]
- A positive integer, in this case the used context will be taken **after**
the extraction
- A negative integer, in this case the used context will be taken **before**
Expand All @@ -166,8 +158,8 @@ class SingleIncludeModel(BaseModel):
"""

regex: ListOrStr = []
window: Optional[Window] = None
limit_to_sentence: Optional[bool] = True
limit_to_sentence: Optional[bool] = None
window: Optional[Context] = None
regex_flags: Optional[Flags] = None
regex_attr: Optional[str] = None
matcher: Optional[Any] = None
Expand All @@ -178,7 +170,20 @@ def exclude_regex_validation(cls, v):
v = [v]
return v

_normalize_window = validator("window", allow_reuse=True)(normalize_window)
@validator("limit_to_sentence", pre=True, always=True)
def backward_compat_auto_limit_to_sentence(cls, v, values):
if (
isinstance(values.get("window"), (type(None), int, tuple, list))
and v is None
):
v = True
return v

@validator("window", always=True)
def backward_compat_intersect_sentence(cls, v, values):
if values.get("limit_to_sentence"):
v = v & SentenceContext(0, 0)
return v


class ExcludeModel(AsList[SingleExcludeModel]):
Expand All @@ -204,9 +209,10 @@ class SingleAssignModel(BaseModel):
----------
name: ListOrStr
A name (string)
window: Optional[Window]
window: Optional[Context]
Size of the context to use (in number of words). You can provide the window as:
- A [context string][context-string]
- A positive integer, in this case the used context will be taken **after**
the extraction
- A negative integer, in this case the used context will be taken **before**
Expand All @@ -217,7 +223,7 @@ class SingleAssignModel(BaseModel):
span_getter: Optional[SpanGetterArg]
A span getter to pick the assigned spans from already extracted entities
in the doc.
regex: Optional[Window]
regex: Optional[Context]
A dictionary where keys are labels and values are **Regexes with a single
capturing group**
replace_entity: Optional[bool]
Expand All @@ -235,8 +241,8 @@ class SingleAssignModel(BaseModel):
name: str
regex: Optional[str] = None
span_getter: Optional[SpanGetterArg] = None
window: Optional[Window] = None
limit_to_sentence: Optional[bool] = True
limit_to_sentence: Optional[bool] = None
window: Optional[Context] = None
regex_flags: Optional[Flags] = None
regex_attr: Optional[str] = None
replace_entity: bool = False
Expand All @@ -259,7 +265,20 @@ def check_single_regex_group(cls, pat):

return pat

_normalize_window = validator("window", allow_reuse=True)(normalize_window)
@validator("limit_to_sentence", pre=True, always=True)
def backward_compat_auto_limit_to_sentence(cls, v, values):
if (
isinstance(values.get("window"), (type(None), int, tuple, list))
and v is None
):
v = True
return v

@validator("window", always=True)
def backward_compat_intersect_sentence(cls, v, values):
if values.get("limit_to_sentence"):
v = v & SentenceContext(0, 0)
return v


class AssignModel(AsList[SingleAssignModel]):
Expand Down
151 changes: 151 additions & 0 deletions edsnlp/utils/span_getters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Expand All @@ -11,6 +12,7 @@
Union,
)

import numpy as np
from pydantic import NonNegativeInt
from spacy.tokens import Doc, Span

Expand Down Expand Up @@ -303,3 +305,152 @@ def __call__(self, span: Union[Doc, Span]) -> Union[Span, List[Span]]:
end = min(len(span.doc), max(end, max_end_sent))

return span.doc[start:end]


class ContextMeta(abc.ABCMeta):
pass


class Context(abc.ABC, metaclass=ContextMeta):
@abc.abstractmethod
def __call__(self, span: Span) -> Span:
pass

# logical ops
def __and__(self, other: "Context"):
# fmt: off
return IntersectionContext([
*(self.contexts if isinstance(self, IntersectionContext) else (self,)),
*(other.contexts if isinstance(other, IntersectionContext) else (other,))
])
# fmt: on

def __or__(self, other: "Context"):
# fmt: off
return UnionContext([
*(self.contexts if isinstance(self, UnionContext) else (self,)),
*(other.contexts if isinstance(other, UnionContext) else (other,))
])
# fmt: on

@classmethod
def parse(cls, query):
return eval(
query,
{"__builtins__": None},
{
"words": WordContext,
"sents": SentenceContext,
},
)

@classmethod
def validate(cls, obj, config=None):
if isinstance(obj, str):
return cls.parse(obj)
if isinstance(obj, tuple):
assert len(obj) == 2
return WordContext(*obj)
if isinstance(obj, int):
assert obj != 0, "The provided `window` should not be 0"
return WordContext(obj, 0) if obj < 0 else WordContext(0, obj)
raise ValueError(f"Invalid context: {obj}")

@classmethod
def __get_validators__(cls):
yield cls.validate


class LeafContextMeta(ContextMeta):
def __getitem__(cls, item) -> Span:
assert isinstance(item, slice)
before = item.start
after = item.stop
return cls(before, after)


class LeafContext(Context, metaclass=LeafContextMeta):
pass


class WordContext(LeafContext):
def __init__(
self,
before: Optional[int] = None,
after: Optional[int] = None,
):
self.before = before
self.after = after

def __call__(self, span):
start = span.start + self.before if self.before is not None else 0
end = span.end + self.after if self.after is not None else len(span.doc)
return span.doc[max(0, start) : min(len(span.doc), end)]

def __repr__(self):
return "words[{}:{}]".format(self.before, self.after)


class SentenceContext(LeafContext):
def __init__(
self,
before: Optional[int] = None,
after: Optional[int] = None,
):
self.before = before
self.after = after

def __call__(self, span):
sent_starts = span.doc.to_array("SENT_START") == 1
sent_indices = sent_starts.cumsum()
sent_indices = sent_indices - sent_indices[span.start]

start_idx = end_idx = None
if self.before is not None:
start = sent_starts & (sent_indices == self.before)
x = np.flatnonzero(start)
start_idx = x[-1] if len(x) else 0

if self.after is not None:
end = sent_starts & (sent_indices == self.after + 1)
x = np.flatnonzero(end)
end_idx = x[0] - 1 if len(x) else len(span.doc)

return span.doc[start_idx:end_idx]

def __repr__(self):
return "sents[{}:{}]".format(self.before, self.after)


class UnionContext(Context):
def __init__(
self,
contexts: AsList[Context],
):
self.contexts = contexts

def __call__(self, span):
results = [context(span) for context in self.contexts]
min_word = min([span.start for span in results])
max_word = max([span.end for span in results])
return span.doc[min_word:max_word]

def __repr__(self):
return " | ".join(repr(context) for context in self.contexts)


class IntersectionContext(Context):
def __init__(
self,
contexts: AsList[Context],
):
self.contexts = contexts

def __call__(self, span):
results = [context(span) for context in self.contexts]
min_word = max([span.start for span in results])
max_word = min([span.end for span in results])
return span.doc[min_word:max_word]

def __repr__(self):
return " & ".join(repr(context) for context in self.contexts)
Loading

0 comments on commit 78cc968

Please sign in to comment.