Skip to content

Commit

Permalink
Classification complete. Tesseract now can receive PDF
Browse files Browse the repository at this point in the history
  • Loading branch information
enoch3712 committed Sep 24, 2024
1 parent 96ac4e3 commit 32815c2
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 40 deletions.
17 changes: 8 additions & 9 deletions extract_thinker/document_loader/document_loader_aws_textract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pypdfium2 as pdfium

from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
from extract_thinker.utils import get_file_extension, get_image_type
from extract_thinker.utils import get_file_extension, get_image_type, is_pdf_stream

from cachetools import cachedmethod
from cachetools.keys import hashkey
Expand All @@ -35,19 +35,18 @@ def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, region_na
@classmethod
def from_client(cls, textract_client, content=None, cache_ttl=300):
return cls(textract_client=textract_client, content=content, cache_ttl=cache_ttl)

@cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream)))
def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[dict, object]:
try:
file_type = get_image_type(stream)
if file_type in SUPPORTED_IMAGE_FORMATS:
if is_pdf_stream(stream):
file_bytes = stream.getvalue() if isinstance(stream, BytesIO) else stream
return self.process_pdf(file_bytes)
elif get_image_type(stream) in SUPPORTED_IMAGE_FORMATS:
file_bytes = stream.getvalue() if isinstance(stream, BytesIO) else stream
if file_type == 'pdf':
return self.process_pdf(file_bytes)
else:
return self.process_image(file_bytes)
return self.process_image(file_bytes)
else:
raise Exception(f"Unsupported stream type: {stream}")
raise Exception(f"Unsupported stream type: {get_file_extension(stream) if isinstance(stream, str) else 'unknown'}")
except Exception as e:
raise Exception(f"Error processing stream: {e}") from e

Expand Down
38 changes: 37 additions & 1 deletion extract_thinker/document_loader/document_loader_tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytesseract

from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
from extract_thinker.utils import get_image_type
from extract_thinker.utils import get_file_extension, get_image_type, is_pdf_stream

from cachetools import cachedmethod
from cachetools.keys import hashkey
Expand All @@ -29,6 +29,11 @@ def __init__(self, tesseract_cmd, isContainer=False, content=None, cache_ttl=300
@cachedmethod(cache=attrgetter('cache'), key=lambda self, file_path: hashkey(file_path))
def load_content_from_file(self, file_path: str) -> Union[str, object]:
try:
file_type = get_file_extension(file_path)

if is_pdf_stream(file_path):
with open(file_path, 'rb') as file:
return self.process_pdf(file)
file_type = get_image_type(file_path)
if file_type in SUPPORTED_IMAGE_FORMATS:
image = Image.open(file_path)
Expand All @@ -43,6 +48,8 @@ def load_content_from_file(self, file_path: str) -> Union[str, object]:
@cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream)))
def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[str, object]:
try:
if is_pdf_stream(stream):
return self.process_pdf(stream)
file_type = get_image_type(stream)
if file_type in SUPPORTED_IMAGE_FORMATS:
image = Image.open(stream)
Expand All @@ -53,6 +60,35 @@ def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[str, ob
raise Exception(f"Unsupported stream type: {stream}")
except Exception as e:
raise Exception(f"Error processing stream: {e}") from e

def process_pdf(self, stream: BytesIO) -> str:
"""
Processes a PDF by converting its pages to images and extracting text from each image.
Args:
stream (BytesIO): The PDF file as a BytesIO stream.
Returns:
str: The extracted text from all pages.
"""
try:
# Reset stream position
stream.seek(0)
# Can you give me a file: Union[str, io.BytesIO]
file = BytesIO(stream.read())
images = self.convert_to_images(file)
extracted_text = []

for page_number, image_bytes in images.items():
image = BytesIO(image_bytes[0])
text = self.process_image(image)
extracted_text.append(text)

# Combine text from all pages
self.content = "\n".join(extracted_text)
return self.content
except Exception as e:
raise Exception(f"Error processing PDF: {e}") from e

def process_image(self, image: BytesIO) -> str:
for attempt in range(3):
Expand Down
2 changes: 1 addition & 1 deletion extract_thinker/models/classification_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class ClassificationResponse(BaseModel):
name: str
confidence: Optional[int] = Field("From 1 to 10. 10 being the highest confidence. Always integer", ge=1, le=10)
confidence: int = Field("From 1 to 10. 10 being the highest confidence. Always integer", ge=1, le=10)

def __hash__(self):
return hash((self.name))
69 changes: 43 additions & 26 deletions extract_thinker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,34 +101,51 @@ async def _classify_tree_async(
threshold: float,
image: bool
) -> Optional[Classification]:
async def classify_node(node: ClassificationNode) -> tuple[Classification, float]:
classifications = [node.classification] + [child.classification for child in node.children]
results = await asyncio.gather(*(
self._classify_async(self.extractor_groups[0][0], file, [classification], image)
for classification in classifications
))

best_result = max(results, key=lambda x: x.confidence)

if not node.children or best_result.name == node.classification.name:
return best_result, best_result.confidence

for child in node.children:
if child.classification.name == best_result.name:
return await classify_node(child)

raise ValueError("Inconsistent classification tree structure")

"""
Perform classification in a hierarchical, level-by-level approach.
"""
best_classification = None
best_confidence = -1

for root_node in classification_tree.nodes:
classification, confidence = await classify_node(root_node)
if confidence > best_confidence:
best_classification = classification
best_confidence = confidence
current_nodes = classification_tree.nodes

while current_nodes:
# Get the list of classifications at the current level
classifications = [node.classification for node in current_nodes]

# Classify among the current level's classifications
classification = await self._classify_async(
extractor=self.extractor_groups[0][0],
file=file,
classifications=classifications,
image=image
)

if classification.confidence < threshold:
raise ValueError(
f"Classification confidence {classification.confidence} "
f"for '{classification.classification}' is below the threshold of {threshold}."
)

best_classification = classification

matching_node = next(
(
node for node in current_nodes
if node.classification.name == best_classification.name
),
None
)

if matching_node is None:
raise ValueError(
f"No matching node found for classification '{classification.classification}'."
)

if matching_node.children:
current_nodes = matching_node.children
else:
break

return best_classification if best_confidence >= threshold else None
return best_classification

async def classify_extractor(self, session, extractor, file):
return await session.run(extractor.classify, file)
Expand Down
41 changes: 39 additions & 2 deletions extract_thinker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,50 @@
from pydantic import BaseModel
import typing
import os

from io import BytesIO
from typing import Union

def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

def is_pdf_stream(stream: Union[BytesIO, str]) -> bool:
"""
Checks if the provided stream is a PDF.
Args:
stream (Union[BytesIO, str]): The stream to check. It can be a BytesIO object or a file path as a string.
Returns:
bool: True if the stream is a PDF, False otherwise.
"""
try:
if isinstance(stream, BytesIO):
# Save the current position
current_position = stream.tell()
# Move to the start of the stream
stream.seek(0)
# Read the first 5 bytes to check the PDF signature
header = stream.read(5)
# Restore the original position
stream.seek(current_position)
elif isinstance(stream, str):
if os.path.isfile(stream):
with open(stream, 'rb') as file:
header = file.read(5)
else:
# If it's not a file path, assume it's a bytes string
header = stream.encode()[:5] if isinstance(stream, str) else b''
else:
# Unsupported type
return False

# PDF files start with '%PDF-'
return header == b'%PDF-'
except Exception as e:
# Optional: Log the exception if logging is set up
# logger.error(f"Error checking if stream is PDF: {e}")
return False

def get_image_type(image_path):
try:
Expand All @@ -21,7 +59,6 @@ def get_image_type(image_path):
except IOError as e:
return f"An error occurred: {str(e)}"


def verify_json(json_content: str):
try:
data = json.loads(json_content)
Expand Down
1 change: 0 additions & 1 deletion tests/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def test_with_tree():
result = process.classify(pdf_path, classification_tree, threshold=0.8)

assert result is not None
assert isinstance(result, Classification)
assert result.name == "Invoice"

if __name__ == "__main__":
Expand Down

0 comments on commit 32815c2

Please sign in to comment.