Skip to content

Commit

Permalink
Add support for raster contrib from GCP
Browse files Browse the repository at this point in the history
  • Loading branch information
PSU3D0 committed Oct 7, 2024
1 parent a98bc09 commit 27601d1
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docprompt/schema/pipeline/rasterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def cache_proportion(self, name: str) -> float:
"""Calculate the proportion of the document that is cached."""
lookup_key = f"{name}/" if not name.endswith("/") else name

return len(self.cache.list_prefix(lookup_key)) / len(self.document)
return min(len(self.cache.list_prefix(lookup_key)) / len(self.document), 1.0)

def fully_cached(self, name: str) -> bool:
"""Check if the entire document is cached."""
Expand Down
13 changes: 11 additions & 2 deletions docprompt/tasks/ocr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@ class BaseOCRProvider(
AbstractPageTaskProvider[Union[PdfDocument, ImageBytes], None, OcrPageResult]
):
def _populate_ocr_results(
self, document_node: "DocumentNode", results: Dict[int, OcrPageResult]
self,
document_node: "DocumentNode",
results: Dict[int, OcrPageResult],
add_images_to_raster_cache: bool = False,
raster_cache_key: str = "default",
) -> None:
for page_number, result in results.items():
result.contribute_to_document_node(document_node, page_number=page_number)
result.contribute_to_document_node(
document_node,
page_number=page_number,
add_images_to_raster_cache=add_images_to_raster_cache,
raster_cache_key=raster_cache_key,
)

@abstractmethod
def process_document_node(
Expand Down
16 changes: 15 additions & 1 deletion docprompt/tasks/ocr/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ class GoogleOcrProvider(BaseOCRProvider):
max_workers: int = Field(multiprocessing.cpu_count() * 2)
exclude_bounding_poly: bool = Field(False)
return_images: bool = Field(False)
add_images_to_raster_cache: bool = Field(False)
image_raster_cache_key: str = "default"
return_image_quality_scores: bool = Field(False)

_documentai: "documentai.DocumentProcessorServiceClient" = PrivateAttr()
Expand All @@ -456,6 +458,9 @@ def __init__(
processor_id: str,
service_account_info: Optional[Dict[str, str]] = None,
service_account_file: Optional[str] = None,
return_images: bool = False,
add_images_to_raster_cache: bool = False,
image_raster_cache_key: str = "default",
**kwargs,
):
super().__init__(project_id=project_id, processor_id=processor_id, **kwargs)
Expand All @@ -467,6 +472,10 @@ def __init__(
"service_account_file", service_account_file
)

self.return_images = return_images
self.add_images_to_raster_cache = add_images_to_raster_cache
self.image_raster_cache_key = image_raster_cache_key

try:
from google.cloud import documentai

Expand Down Expand Up @@ -649,6 +658,11 @@ def process_document_node(
result = self.invoke([document_node.document], start=start, stop=stop, **kwargs)

# For OCR, we also need to populate the ocr_results for powered search
self._populate_ocr_results(document_node, result)
self._populate_ocr_results(
document_node,
result,
add_images_to_raster_cache=self.add_images_to_raster_cache,
raster_cache_key=self.image_raster_cache_key,
)

return result
18 changes: 17 additions & 1 deletion docprompt/tasks/ocr/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class OcrPageResult(BasePageResult):

@property
def pil_image(self):
if not self.raster_image:
return None
from PIL import Image

return Image.open(BytesIO(self.raster_image))
Expand All @@ -56,7 +58,12 @@ def blocks(self):
return self.block_level_blocks

def contribute_to_document_node(
self, document_node: DocumentNode, page_number: Optional[int] = None, **kwargs
self,
document_node: DocumentNode,
page_number: Optional[int] = None,
add_images_to_raster_cache: bool = False,
raster_cache_key: str = "default",
**kwargs,
) -> None:
if not page_number:
raise ValueError("Page number must be provided for page level results")
Expand All @@ -66,3 +73,12 @@ def contribute_to_document_node(
page_node.metadata.ocr_results = self
else:
super().contribute_to_document_node(document_node, page_number=page_number)

if self.raster_image is not None and add_images_to_raster_cache:
document_node.rasterizer.cache.set_image_for_page(
key=raster_cache_key,
page_number=page_number,
image_bytes=self.raster_image,
)

self.raster_image = None # We need to clear this for memory reasons

0 comments on commit 27601d1

Please sign in to comment.