Skip to content

Commit

Permalink
Improve search. Update groups
Browse files Browse the repository at this point in the history
  • Loading branch information
PSU3D0 committed May 10, 2024
1 parent 14bac96 commit 1cfb695
Show file tree
Hide file tree
Showing 9 changed files with 1,616 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
sudo apt update && sudo apt install -y ghostscript
python -m pip install --upgrade pip
pip install -U pdm
pdm install -d
pdm install -G:all
- name: Lint with Ruff
run: |
pdm run ruff --output-format=github .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/preview.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
sudo apt update && sudo apt install -y ghostscript
python -m pip install --upgrade pip
pip install pdm
pdm install -d
pdm install -G:all
- name: Run Pytest
run:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pdm
pdm install -G:all
- name: Build wheels and source tarball
run: >-
Expand Down
1 change: 1 addition & 0 deletions docprompt/provenance/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def search(
self,
query: str,
page_number: Optional[int] = None,
*,
refine_to_word: bool = True,
require_exact_match: bool = True,
) -> List[ProvenanceSource]:
Expand Down
13 changes: 7 additions & 6 deletions docprompt/provenance/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ class ProvenanceSource(BaseModel):

document_name: str
page_number: PositiveInt
text_location: PageTextLocation
text_location: Optional[PageTextLocation] = None

@computed_field # type: ignore
@property
def source_block(self) -> Optional[TextBlock]:
if self.text_location.merged_source_block:
return self.text_location.merged_source_block
if self.text_location.source_blocks:
return self.text_location.source_blocks[0]
if self.text_location:
if self.text_location.merged_source_block:
return self.text_location.merged_source_block
if self.text_location.source_blocks:
return self.text_location.source_blocks[0]

return None
return None

@property
def text(self) -> str:
Expand Down
37 changes: 10 additions & 27 deletions docprompt/provenance/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from docprompt.schema.layout import NormBBox, TextBlock
from typing import Any, Iterable, List, Optional
from rapidfuzz import fuzz
from rapidfuzz.utils import default_process

try:
import tantivy
Expand Down Expand Up @@ -37,28 +38,6 @@ def preprocess_query_text(text: str) -> str:
return text


def fuzzify_token(
token: str,
lowercase: bool = True,
strip_punct: bool = True,
strip_whitespace: bool = True,
) -> str:
"""
Convert a token into a form that is suitable for fuzzy matching.
"""
if strip_whitespace:
token = token.strip()

if strip_punct:
token = token.strip(".")
token = token.strip(",")

if lowercase:
token = token.lower()

return token


def word_tokenize(text: str) -> List[str]:
"""
Tokenize a string into words.
Expand Down Expand Up @@ -124,22 +103,26 @@ def refine_block_to_word_level(
tokenized_query = word_tokenize(query)

if len(tokenized_query) == 1:
fuzzified = fuzzify_token(tokenized_query[0])
fuzzified = default_process(tokenized_query[0])
for word_level_block in intersecting_word_level_blocks:
if fuzz.ratio(fuzzified, fuzzify_token(word_level_block.text)) > 87.5:
if fuzz.ratio(fuzzified, default_process(word_level_block.text)) > 87.5:
return word_level_block, [word_level_block]
else:
fuzzified_word_level_texts = [
default_process(word_level_block.text)
for word_level_block in intersecting_word_level_blocks
]

# Populate the block mapping
token_block_mapping = defaultdict(set)

first_word = tokenized_query[0]
last_word = tokenized_query[-1]

for token in tokenized_query:
fuzzified_token = fuzzify_token(token)
fuzzified_token = default_process(token)
for i, word_level_block in enumerate(intersecting_word_level_blocks):
fuzzified_block_text = fuzzify_token(word_level_block.text)
if fuzz.ratio(fuzzified_token, fuzzified_block_text) > 87.5:
if fuzz.ratio(fuzzified_token, fuzzified_word_level_texts[i]) > 87.5:
token_block_mapping[token].add(i)

graph = networkx.DiGraph()
Expand Down
1,583 changes: 1,576 additions & 7 deletions pdm.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@ license = {text = "Apache-2.0"}
classifiers = ["Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12"]

[project.optional-dependencies]
test = ["isort<6.0.0,>=5.12.0", "flake8<7.0.0,>=6.1.0", "flake8-docstrings<2.0.0,>=1.7.0", "mypy<2.0.0,>=1.6.1", "pytest<8.0.0,>=7.4.2", "pytest-cov<5.0.0,>=4.1.0", "ruff<1.0.0,>=0.3.3"]
dev = ["tox<4.0.0,>=3.20.1", "virtualenv<21.0.0,>=20.2.2", "pip<21.0.0,>=20.3.1", "twine<4.0.0,>=3.3.0", "pre-commit<3.0.0,>=2.12.0", "toml<1.0.0,>=0.10.2", "bump2version<2.0.0,>=1.0.1"]
doc = ["mkdocs<2.0.0,>=1.1.2", "mkdocs-include-markdown-plugin<2.0.0,>=1.0.0", "mkdocs-material<7.0.0,>=6.1.7", "mkdocstrings<1.0.0,>=0.15.2", "mkdocs-autorefs<1.0.0,>=0.2.1"]
google = ["google-cloud-documentai>=2.20.0"]
azure = ["azure-ai-formrecognizer>=3.3.0"]
search = ["tantivy<1.0.0,>=0.21.0", "rtree<2.0.0,>=1.2.0", "networkx>=2.8.8"]
search = ["tantivy<1.0.0,>=0.21.0", "rtree<2.0.0,>=1.2.0", "networkx<3.2,>=2.8.8"]

[project.scripts]
docprompt = "docprompt.cli:main"
Expand Down Expand Up @@ -79,6 +76,10 @@ distribution = true
[tool.pdm.build]
includes = ["docprompt", "tests"]

[tool.pdm.dev-dependencies]
test = ["isort<6.0.0,>=5.12.0", "flake8<7.0.0,>=6.1.0", "flake8-docstrings<2.0.0,>=1.7.0", "mypy<2.0.0,>=1.6.1", "pytest<8.0.0,>=7.4.2", "pytest-cov<5.0.0,>=4.1.0", "ruff<1.0.0,>=0.3.3"]
dev = ["tox<4.0.0,>=3.20.1", "virtualenv<21.0.0,>=20.2.2", "pip<21.0.0,>=20.3.1", "twine<4.0.0,>=3.3.0", "pre-commit<3.0.0,>=2.12.0", "toml<1.0.0,>=0.10.2", "bump2version<2.0.0,>=1.0.1"]

[tool.ruff]
target-version = "py38"

Expand Down
14 changes: 14 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def test_search():

assert len(result_multiple_words) == 1

sources = result_multiple_words[0].text_location.source_blocks

assert len(sources) == 2

result_multiple_words = locator.search(
"MMAX2 system", page_number=1, refine_to_word=False
)

assert len(result_multiple_words) == 1

sources = result_multiple_words[0].text_location.source_blocks

assert len(sources) == 1

n_best = locator.search_n_best("and", n=3)

assert len(n_best) == 3
Expand Down

0 comments on commit 1cfb695

Please sign in to comment.