Skip to content

Commit

Permalink
Fix of the textract document loader. Before refactor the classificati…
Browse files Browse the repository at this point in the history
…on tree
  • Loading branch information
enoch3712 committed Sep 23, 2024
1 parent 01aa385 commit 96ac4e3
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 101 deletions.
78 changes: 34 additions & 44 deletions extract_thinker/document_loader/document_loader_aws_textract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import Any, List, Union
from PIL import Image
import boto3
import pdfium
import pypdfium2 as pdfium

from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
from extract_thinker.utils import get_image_type
from extract_thinker.utils import get_file_extension, get_image_type

from cachetools import cachedmethod
from cachetools.keys import hashkey
Expand All @@ -35,23 +35,7 @@ def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, region_na
@classmethod
def from_client(cls, textract_client, content=None, cache_ttl=300):
return cls(textract_client=textract_client, content=content, cache_ttl=cache_ttl)

@cachedmethod(cache=attrgetter('cache'), key=lambda self, file_path: hashkey(file_path))
def load_content_from_file(self, file_path: str) -> Union[dict, object]:
try:
file_type = get_image_type(file_path)
if file_type in SUPPORTED_IMAGE_FORMATS:
with open(file_path, 'rb') as file:
file_bytes = file.read()
if file_type == 'pdf':
return self.process_pdf(file_bytes)
else:
return self.process_image(file_bytes)
else:
raise Exception(f"Unsupported file type: {file_path}")
except Exception as e:
raise Exception(f"Error processing file: {e}") from e


@cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream)))
def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[dict, object]:
try:
Expand All @@ -67,43 +51,49 @@ def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[dict, o
except Exception as e:
raise Exception(f"Error processing stream: {e}") from e

@cachedmethod(cache=attrgetter('cache'), key=lambda self, file_path: hashkey(file_path))
def load_content_from_file(self, file_path: str) -> Union[dict, object]:
try:
file_type = get_file_extension(file_path)
if file_type == 'pdf':
with open(file_path, 'rb') as file:
file_bytes = file.read()
return self.process_pdf(file_bytes)
elif file_type in SUPPORTED_IMAGE_FORMATS:
with open(file_path, 'rb') as file:
file_bytes = file.read()
return self.process_image(file_bytes)
else:
raise Exception(f"Unsupported file type: {file_path}")
except Exception as e:
raise Exception(f"Error processing file: {e}") from e

def process_pdf(self, pdf_bytes: bytes) -> dict:
for attempt in range(3):
try:
response = self.textract_client.analyze_document(
Document={'Bytes': pdf_bytes},
FeatureTypes=['TABLES']
)
return self._parse_analyze_document_response(response)
except Exception as e:
if attempt == 2:
raise Exception(f"Failed to process PDF after 3 attempts: {e}")
return {}

def process_image(self, image_bytes: bytes) -> dict:
for attempt in range(3):
try:
response = self.textract_client.analyze_document(
Document={'Bytes': image_bytes},
FeatureTypes=['TABLES', 'FORMS', 'LAYOUT']
FeatureTypes=['TABLES'] # Only extract tables
)
return self._parse_analyze_document_response(response)
except Exception as e:
if attempt == 2:
raise Exception(f"Failed to process image after 3 attempts: {e}")
return {}

def process_pdf(self, pdf_bytes: bytes) -> dict:
pdf = pdfium.PdfDocument(pdf_bytes)
result = {
"pages": [],
"tables": [],
"forms": [],
"layout": {}
}
for page_number in range(len(pdf)):
page = pdf.get_page(page_number)
pil_image = page.render().to_pil()
img_byte_arr = BytesIO()
pil_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
page_result = self.process_image(img_byte_arr)
result["pages"].extend(page_result["pages"])
result["tables"].extend(page_result["tables"])
result["forms"].extend(page_result["forms"])
for key, value in page_result["layout"].items():
if key not in result["layout"]:
result["layout"][key] = []
result["layout"][key].extend(value)
return result

def _parse_analyze_document_response(self, response: dict) -> dict:
result = {
"pages": [],
Expand Down
2 changes: 1 addition & 1 deletion extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def classify(self, input: Union[str, IO], classifications: List[Classification],
# Check if the input is a valid file path
if os.path.isfile(input):
file_type = get_file_extension(input)
if file_type in SUPPORTED_IMAGE_FORMATS:
if file_type == 'pdf':
return self.classify_from_path(input, classifications)
elif file_type in SUPPORTED_EXCEL_FORMATS:
return self.classify_from_excel(input, classifications)
Expand Down
17 changes: 13 additions & 4 deletions extract_thinker/models/classification.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from typing import Any, Optional
from pydantic import BaseModel
from typing import Any, Optional, Type
from pydantic import BaseModel, field_validator
from extract_thinker.models.contract import Contract
import os

class Classification(BaseModel):
name: str
description: str
contract: Optional[Contract] = None
image: Optional[str] = None # Path to the image file
contract: Optional[Type] = None
image: Optional[str] = None
extractor: Optional[Any] = None

@field_validator('contract', mode='before')
def validate_contract(cls, v):
if v is not None:
if not isinstance(v, type):
raise ValueError('contract must be a type')
if not issubclass(v, Contract):
raise ValueError('contract must be a subclass of Contract')
return v

def set_image(self, image_path: str):
if os.path.isfile(image_path):
self.image = image_path
Expand Down
3 changes: 2 additions & 1 deletion extract_thinker/models/classification_node.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List
from typing import List, Optional
from pydantic import BaseModel, Field
from extract_thinker.models.classification import Classification

class ClassificationNode(BaseModel):
name: str
classification: Classification
children: List['ClassificationNode'] = Field(default_factory=list)

Expand Down
83 changes: 61 additions & 22 deletions tests/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import asyncio
from dotenv import load_dotenv

from extract_thinker.document_loader.document_loader_aws_textract import DocumentLoaderAWSTextract
from extract_thinker.extractor import Extractor
from extract_thinker.models.classification_node import ClassificationNode
from extract_thinker.models.classification_tree import ClassificationTree
from extract_thinker.process import Process, ClassificationStrategy
from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract
from extract_thinker.models.classification import Classification
from extract_thinker.models.classification_response import ClassificationResponse
from tests.models.invoice import InvoiceContract
from tests.models.driver_license import DriverLicense
from tests.models.invoice import CreditNoteContract, FinancialContract, InvoiceContract
from tests.models.driver_license import DriverLicense, IdentificationContract

# Setup environment and common paths
load_dotenv()
Expand Down Expand Up @@ -120,6 +121,21 @@ def test_with_contract():
assert result.name == "Invoice"


def setup_process_with_textract_extractor():
"""Sets up and returns a process configured with only the Textract extractor."""
# Initialize the Textract document loader
document_loader = DocumentLoaderAWSTextract()

# Initialize the Textract extractor
textract_extractor = Extractor(document_loader)
textract_extractor.load_llm("gpt-4o")

# Create the process with only the Textract extractor
process = Process()
process.add_classify_extractor([[textract_extractor]])

return process

def setup_process_with_gpt4_extractor():
"""Sets up and returns a process configured with only the GPT-4 extractor."""
tesseract_path = os.getenv("TESSERACT_PATH")
Expand Down Expand Up @@ -156,37 +172,60 @@ def test_with_tree():
"""Test classification using the tree strategy"""
process = setup_process_with_gpt4_extractor()

# Create Classification Nodes
financial_docs = ClassificationNode(
classification=Classification(name="Financial Documents", description="Documents related to finances")
)
invoice = ClassificationNode(
name="Financial Documents",
classification=Classification(
name="Invoice",
description="This is an invoice",
contract=InvoiceContract,
image=INVOICE_FILE_PATH
)
)
credit_note = ClassificationNode(
classification=Classification(name="Credit Note", description="This is a credit note")
name="Financial Documents",
description="This is a financial document",
contract=FinancialContract,
),
children=[
ClassificationNode(
name="Invoice",
classification=Classification(
name="Invoice",
description="This is an invoice",
contract=InvoiceContract,
)
),
ClassificationNode(
name="Credit Note",
classification=Classification(
name="Credit Note",
description="This is a credit note",
contract=CreditNoteContract,
)
)
]
)
financial_docs.children = [invoice, credit_note]

legal_docs = ClassificationNode(
classification=Classification(name="Legal Documents", description="Documents related to legal matters")
)
contract = ClassificationNode(
classification=Classification(name="Contract", description="This is a contract")
name="Identity Documents",
classification=Classification(
name="Identity Documents",
description="This is an identity document",
contract=IdentificationContract,
),
children=[
ClassificationNode(
name="Driver License",
classification=Classification(
name="Driver License",
description="This is a driver license",
contract=DriverLicense,
)
)
]
)
legal_docs.children = [contract]

# Create the classification tree
classification_tree = ClassificationTree(
nodes=[financial_docs, legal_docs]
)

result = process.classify(INVOICE_FILE_PATH, classification_tree, threshold=0.8)
current_dir = os.path.dirname(os.path.abspath(__file__))
pdf_path = os.path.join(current_dir, 'files','invoice.pdf')

result = process.classify(pdf_path, classification_tree, threshold=0.8)

assert result is not None
assert isinstance(result, Classification)
Expand Down
37 changes: 11 additions & 26 deletions tests/document_loader_aws_textract.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,21 @@
import os
import pytest
from moto import mock_textract
import boto3
from dotenv import load_dotenv

from extract_thinker.document_loader.document_loader_aws_textract import DocumentLoaderAWSTextract

load_dotenv()

@pytest.fixture
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'
os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'

@pytest.fixture
def aws_credentials():
"""Mocked AWS Credentials for moto."""
return {
'aws_access_key_id': os.getenv('AWS_ACCESS_KEY_ID'),
'aws_secret_access_key': os.getenv('AWS_SECRET_ACCESS_KEY'),
'region_name': os.getenv('AWS_DEFAULT_REGION')
}


def test_load_content_from_pdf(textract_client):
def test_load_content_from_pdf():
# Arrange
loader = DocumentLoaderAWSTextract.from_client(textract_client)
loader = DocumentLoaderAWSTextract(
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
region_name=os.getenv('AWS_DEFAULT_REGION')
)

current_dir = os.path.dirname(os.path.abspath(__file__))
pdf_path = os.path.join(current_dir, 'test_files', 'sample.pdf')
pdf_path = os.path.join(current_dir, 'files','invoice.pdf')

# Act
result = loader.load_content_from_file(pdf_path)
Expand All @@ -43,5 +27,6 @@ def test_load_content_from_pdf(textract_client):
assert "forms" in result
assert "layout" in result
assert len(result["pages"]) > 0

# You may want to add more specific

if __name__ == "__main__":
test_load_content_from_pdf()
Binary file added tests/files/invoice.pdf
Binary file not shown.
6 changes: 5 additions & 1 deletion tests/models/driver_license.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from extract_thinker.models.contract import Contract

class IdentificationContract(Contract):
name: str
age: int
id_number: str

class DriverLicense(Contract):
name: str
age: int
license_number: str
license_number: str
13 changes: 11 additions & 2 deletions tests/models/invoice.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
from typing import List
from extract_thinker.models.contract import Contract


class LinesContract(Contract):
description: str
quantity: int
unit_price: int
amount: int


class InvoiceContract(Contract):
invoice_number: str
invoice_date: str
lines: List[LinesContract]
total_amount: int

class CreditNoteContract(Contract):
credit_note_number: str
credit_note_date: str
lines: List[LinesContract]
total_amount: int

class FinancialContract(Contract):
total_amount: int
document_number: str
document_date: str

0 comments on commit 96ac4e3

Please sign in to comment.