Skip to content

Commit

Permalink
Merge pull request #25 from climatepolicyradar/feature/pdct-1112-bugf…
Browse files Browse the repository at this point in the history
…ix-parser-output-to-passage-level-method-not-including

Add pdf page metadata to passage level method output
  • Loading branch information
THOR300 authored May 29, 2024
2 parents f3445f4 + 3d90cb9 commit 49936e3
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 34 deletions.
50 changes: 40 additions & 10 deletions src/cpr_sdk/parser_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import logging
import logging.config
from collections import Counter
from datetime import date
from enum import Enum
import json
from typing import List, Optional, Sequence, Tuple, TypeVar, Union, Any
from typing import Any, Final, List, Optional, Sequence, Tuple, TypeVar, Union

from cpr_sdk.pipeline_general_models import (
CONTENT_TYPE_HTML,
Expand All @@ -18,10 +18,13 @@

_LOGGER = logging.getLogger(__name__)

PARSER_METADATA_KEY = "parser_metadata"
AZURE_API_VERSION_KEY = "azure_api_version"
AZURE_MODEL_ID_KEY = "azure_model_id"
PARSING_DATE_KEY = "parsing_date"
PARSER_METADATA_KEY: Final = "parser_metadata"
AZURE_API_VERSION_KEY: Final = "azure_api_version"
AZURE_MODEL_ID_KEY: Final = "azure_model_id"
PARSING_DATE_KEY: Final = "parsing_date"
PDF_PAGE_METADATA_KEY: Final = "pdf_data_page_metadata"
PDF_DATA_PASSAGE_LEVEL_EXPAND_FIELDS: Final = {"text_blocks", "page_metadata"}
HTML_DATA_PASSAGE_LEVEL_EXPAND_FIELDS: Final = {"text_blocks"}


class VerticalFlipError(Exception):
Expand Down Expand Up @@ -395,22 +398,30 @@ def to_passage_level_json(self) -> list[dict[str, Any]]:
if self.text_blocks is None:
return []

common_fields_dict = json.loads(
fixed_fields_dict = json.loads(
self.model_dump_json(
exclude={
"pdf_data": {"text_blocks", "page_metadata"},
"html_data": {"text_blocks"},
"pdf_data": PDF_DATA_PASSAGE_LEVEL_EXPAND_FIELDS,
"html_data": HTML_DATA_PASSAGE_LEVEL_EXPAND_FIELDS,
}
)
)

passages_array = [
common_fields_dict
fixed_fields_dict
| json.loads(block.model_dump_json(exclude={"text"}))
| {"text": block.to_string(), "block_index": idx}
for idx, block in enumerate(self.text_blocks)
]

for passage in passages_array:
page_number = passage.get("page_number")
passage[PDF_PAGE_METADATA_KEY] = (
self.get_page_metadata_by_page_number(page_number)
if page_number
else None
)

empty_html_text_block_keys: list[str] = list(HTMLTextBlock.model_fields.keys())
empty_pdf_text_block_keys: list[str] = list(PDFTextBlock.model_fields.keys())

Expand All @@ -425,3 +436,22 @@ def to_passage_level_json(self) -> list[dict[str, Any]]:
passages_array_filled.append(passage)

return passages_array_filled

def get_page_metadata_by_page_number(self, page_number: int) -> Optional[dict]:
"""
Retrieve the first element of PDF page metadata where the page number matches the given page number.
The reason we convert from the pydantic BaseModel to a string using the
model_dump_json method and then reloading with json.load is as objects like
Enums and child pydantic objects persist when using the model_dump method.
We don't want these when we push to huggingface.
:param pdf_data: PDFData object containing the metadata.
:param page_number: The page number to match.
:return: The first matching PDFPageMetadata object, or None if no match is found.
"""
if self.pdf_data and self.pdf_data.page_metadata:
for metadata in self.pdf_data.page_metadata:
if metadata.page_number == page_number:
return json.loads(metadata.model_dump_json())
return None
2 changes: 1 addition & 1 deletion src/cpr_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_MAJOR = "1"
_MINOR = "1"
_PATCH = "2"
_PATCH = "3"
_SUFFIX = ""

VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
Expand Down
67 changes: 45 additions & 22 deletions tests/test_parser_models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import pydantic
import pytest

from cpr_sdk.parser_models import (
HTML_DATA_PASSAGE_LEVEL_EXPAND_FIELDS,
PDF_DATA_PASSAGE_LEVEL_EXPAND_FIELDS,
HTMLData,
HTMLTextBlock,
ParserInput,
ParserOutput,
PDFData,
PDFTextBlock,
VerticalFlipError,
HTMLTextBlock,
TextBlock,
VerticalFlipError,
PDF_PAGE_METADATA_KEY,
)
from cpr_sdk.pipeline_general_models import (
CONTENT_TYPE_HTML,
CONTENT_TYPE_PDF,
BackendDocument,
)
from cpr_sdk.pipeline_general_models import CONTENT_TYPE_HTML, CONTENT_TYPE_PDF


def test_parser_input_object(parser_output_json_pdf) -> None:
Expand Down Expand Up @@ -157,6 +165,24 @@ def test_to_passage_level_json_method(
parser_output_json_html: dict,
) -> None:
"""Test that we can successfully create a passage level array from the text blocks."""
expected_top_level_fields = set(
list(TextBlock.model_fields.keys())
+ list(HTMLTextBlock.model_fields.keys())
+ list(PDFTextBlock.model_fields.keys())
+ list(ParserOutput.model_fields.keys())
+ ["block_index", PDF_PAGE_METADATA_KEY]
)

expected_document_metadata_fields = set(BackendDocument.model_fields.keys())

expected_html_data_fields = set(HTMLData.model_fields.keys())
for field in HTML_DATA_PASSAGE_LEVEL_EXPAND_FIELDS:
expected_html_data_fields.remove(field)

expected_pdf_data_fields = set(PDFData.model_fields.keys())
for field in PDF_DATA_PASSAGE_LEVEL_EXPAND_FIELDS:
expected_pdf_data_fields.remove(field)

parser_output_pdf = ParserOutput.model_validate(parser_output_json_pdf)
passage_level_array_pdf = parser_output_pdf.to_passage_level_json()

Expand All @@ -167,25 +193,22 @@ def test_to_passage_level_json_method(
assert len(passage_level_array_html) == len(parser_output_html.text_blocks)

for passage_level_array in [passage_level_array_pdf, passage_level_array_html]:
assert all(isinstance(passage, dict) for passage in passage_level_array)

first_doc_keys = set(passage_level_array[0].keys())
assert all(
set(passage.keys()) == first_doc_keys for passage in passage_level_array
)

expected_model_fields = set(
list(TextBlock.model_fields.keys())
+ list(HTMLTextBlock.model_fields.keys())
+ list(PDFTextBlock.model_fields.keys())
+ list(ParserOutput.model_fields.keys())
+ ["block_index"]
)

assert all(
set(passage.keys()) == expected_model_fields
for passage in passage_level_array
)
for passage in passage_level_array:
assert isinstance(passage, dict)
assert set(passage.keys()) == first_doc_keys
assert set(passage.keys()) == expected_top_level_fields
assert (
set(passage["document_metadata"].keys())
== expected_document_metadata_fields
)

if passage["document_content_type"] == CONTENT_TYPE_PDF:
assert set(passage["pdf_data"].keys()) == expected_pdf_data_fields
elif passage["document_content_type"] == CONTENT_TYPE_HTML:
assert set(passage["html_data"].keys()) == expected_html_data_fields
else:
raise ValueError("Document content type must be either PDF or HTML")

passage_level_array_pdf_first_doc = passage_level_array_pdf[0]
passage_level_array_html_first_doc = passage_level_array_html[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_search_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_vespa_search_adaptor__works(fake_vespa_credentials):
)
@pytest.mark.vespa
def test_vespa_search_adaptor__is_fast_enough(fake_vespa_credentials, params):
MAX_SPEED_MS = 750
MAX_SPEED_MS = 850

avg_ms = profile_search(fake_vespa_credentials, params=params)
assert avg_ms <= MAX_SPEED_MS
Expand Down

0 comments on commit 49936e3

Please sign in to comment.