Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification fix and tests running properly. Refactor of DocumentLo… #35

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions extract_thinker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .document_loader.document_loader_spreadsheet import DocumentLoaderSpreadSheet
from .document_loader.document_loader_azure_document_intelligence import DocumentLoaderAzureForm
from .document_loader.document_loader_pypdf import DocumentLoaderPyPdf
from .document_loader.document_loader_text import DocumentLoaderText
from .models import classification, classification_response
from .process import Process, ClassificationStrategy
from .splitter import Splitter
Expand All @@ -24,7 +23,6 @@
'DocumentLoaderSpreadSheet',
'DocumentLoaderAzureForm',
'DocumentLoaderPyPdf',
'DocumentLoaderText',
'classification',
'classification_response',
'Process',
Expand Down
2 changes: 1 addition & 1 deletion extract_thinker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
# extractor.loadfile("C:\\Users\\Lopez\\Desktop\\MagniFinance\\examples\\outputTestOne.pdf").split(classifications)

extractor.load_document_loader(
DocumentLoaderTesseract("C:\\Program Files\\Tesseract-OCR\\tesseract.exe")
DocumentLoaderTesseract(os.getenv("TESSERACT_PATH"))
)
extractor.load_llm("claude-3-haiku-20240307")

Expand Down
29 changes: 28 additions & 1 deletion extract_thinker/document_loader/document_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,33 @@
import concurrent.futures
from typing import Any, Dict, List, Union
from cachetools import TTLCache

import os
from extract_thinker.utils import get_file_extension

class DocumentLoader(ABC):
def __init__(self, content: Any = None, cache_ttl: int = 300):
self.content = content
self.file_path = None
self.cache = TTLCache(maxsize=100, ttl=cache_ttl)

def can_handle(self, source: Union[str, BytesIO]) -> bool:
file_type = None
try:
if isinstance(source, str):
if not os.path.isfile(source):
return False
file_type = get_file_extension(source)
elif isinstance(source, BytesIO):
source.seek(0)
img = Image.open(source)
file_type = img.format.lower()
source.seek(0)
else:
return False
return file_type.lower() in [fmt.lower() for fmt in self.SUPPORTED_FORMATS]
except Exception:
return False

@abstractmethod
def load_content_from_file(self, file_path: str) -> Union[str, object]:
pass
Expand All @@ -22,6 +41,14 @@ def load_content_from_file(self, file_path: str) -> Union[str, object]:
def load_content_from_stream(self, stream: BytesIO) -> Union[str, object]:
pass

def load(self, source: Union[str, BytesIO]) -> Any:
if isinstance(source, str):
return self.load_content_from_file(source)
elif isinstance(source, BytesIO):
return self.load_content_from_stream(source)
else:
raise ValueError("Source must be a file path or a stream.")

def getContent(self) -> Any:
return self.content

Expand Down
4 changes: 3 additions & 1 deletion extract_thinker/document_loader/document_loader_pypdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Any, Dict, List, Union
from PyPDF2 import PdfReader
from extract_thinker.document_loader.document_loader_llm_image import DocumentLoaderLLMImage
from extract_thinker.utils import get_file_extension

SUPPORTED_FORMATS = ['pdf']

class DocumentLoaderPyPdf(DocumentLoaderLLMImage):
def __init__(self, content: Any = None, cache_ttl: int = 300):
Expand Down Expand Up @@ -38,4 +40,4 @@ def extract_data_from_pdf(self, reader: PdfReader) -> Union[str, Dict[str, Any]]
# if image_data:
# document_data["images"].append(image_data)

return document_data
return document_data
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
from cachetools import cachedmethod
from cachetools.keys import hashkey
from extract_thinker.utils import get_file_extension

SUPPORTED_FORMATS = ['xls', 'xlsx', 'xlsm', 'xlsb', 'odf', 'ods', 'odt', 'csv']

class DocumentLoaderSpreadSheet(CachedDocumentLoader):

def __init__(self, content=None, cache_ttl=300):
super().__init__(content, cache_ttl)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,4 @@ def load_content_from_file_list(self, input: List[Union[str, BytesIO]]) -> List[
image, content = output_queue.get()
contents.append({"image": Image.open(image), "content": content})

return contents
return contents
24 changes: 0 additions & 24 deletions extract_thinker/document_loader/document_loader_text.py

This file was deleted.

11 changes: 0 additions & 11 deletions extract_thinker/document_loader/text_extract_loader.py

This file was deleted.

46 changes: 17 additions & 29 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@

from extract_thinker.utils import get_file_extension, encode_image, json_to_formatted_string
import yaml
import litellm

SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "bmp", "tiff"]
SUPPORTED_EXCEL_FORMATS = ['.xls', '.xlsx', '.xlsm', '.xlsb', '.odf', '.ods', '.odt', '.csv']


class Extractor:
def __init__(
self, processor: Optional[DocumentLoader] = None, llm: Optional[LLM] = None
self, document_loader: Optional[DocumentLoader] = None, llm: Optional[LLM] = None
):
self.document_loader: Optional[DocumentLoader] = processor
self.document_loader: Optional[DocumentLoader] = document_loader
self.llm: Optional[LLM] = llm
self.file: Optional[str] = None
self.document_loaders_by_file_type: Dict[str, DocumentLoader] = {}
Expand All @@ -47,10 +42,14 @@ def add_interceptor(
"Interceptor must be an instance of LoaderInterceptor or LlmInterceptor"
)

def set_document_loader_for_file_type(
self, file_type: str, document_loader: DocumentLoader
):
self.document_loaders_by_file_type[file_type] = document_loader
def get_document_loader_for_file(self, source: Union[str, IO]) -> DocumentLoader:
if self.document_loader and self.document_loader.can_handle(source):
return self.document_loader
else:
for loader in self.document_loaders_by_file_type.values():
if loader.can_handle(source):
return loader
raise ValueError("No suitable document loader found for the input.")

def get_document_loader_for_file(self, file: str) -> DocumentLoader:
_, ext = os.path.splitext(file)
Expand Down Expand Up @@ -229,23 +228,12 @@ def classify(self, input: Union[str, IO], classifications: List[Classification],
if image:
return self.classify_from_image(input, classifications)

if isinstance(input, str):
# Check if the input is a valid file path
if os.path.isfile(input):
file_type = get_file_extension(input)
if file_type == 'pdf':
return self.classify_from_path(input, classifications)
elif file_type in SUPPORTED_EXCEL_FORMATS:
return self.classify_from_excel(input, classifications)
else:
raise ValueError(f"Unsupported file type: {input}")
else:
raise ValueError(f"No such file: {input}")
elif hasattr(input, 'read'):
# Check if the input is a stream (like a file object)
return self.classify_from_stream(input, classifications)
else:
raise ValueError("Input must be a file path or a stream.")
document_loader = self.get_document_loader_for_file(input)
if document_loader is None:
raise ValueError("No suitable document loader found for the input.")

content = document_loader.load(input)
return self._classify(content, classifications)

async def classify_async(self, input: Union[str, IO], classifications: List[Classification]):
return await asyncio.to_thread(self.classify, input, classifications)
Expand All @@ -256,7 +244,7 @@ def _extract(self,
response_model,
vision=False,
is_stream=False
):
):
# call all the llm interceptors before calling the llm
for interceptor in self.llm_interceptors:
interceptor.intercept(self.llm)
Expand Down
22 changes: 16 additions & 6 deletions extract_thinker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,19 @@ def is_pdf_stream(stream: Union[BytesIO, str]) -> bool:
# logger.error(f"Error checking if stream is PDF: {e}")
return False

def get_image_type(image_path):
def get_image_type(source):
try:
img = Image.open(image_path)
if isinstance(source, str):
img = Image.open(source)
elif isinstance(source, BytesIO):
source.seek(0)
img = Image.open(source)
source.seek(0)
else:
return None
return img.format.lower()
except IOError as e:
return f"An error occurred: {str(e)}"
return None

def verify_json(json_content: str):
try:
Expand Down Expand Up @@ -134,9 +141,12 @@ def extract_json(text):


def get_file_extension(file_path):
_, ext = os.path.splitext(file_path)
ext = ext[1:] # remove the dot
return ext
if isinstance(file_path, str):
_, ext = os.path.splitext(file_path)
ext = ext[1:] # remove the dot
return ext.lower()
else:
return None


def json_to_formatted_string(data):
Expand Down
2 changes: 1 addition & 1 deletion medium_posts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from utils import remove_json_format

# local path to tesseract
pytesseract.pytesseract.tesseract_cmd = 'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
pytesseract.pytesseract.tesseract_cmd = os.getenv("TESSERACT_PATH")
# docker path to tesseract
#os.environ.get('TESSERACT_PATH', 'tesseract')

Expand Down
77 changes: 39 additions & 38 deletions tests/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,44 @@ def arrange_process_with_extractors():
return process


def setup_process_with_textract_extractor():
"""Sets up and returns a process configured with only the Textract extractor."""
# Initialize the Textract document loader
document_loader = DocumentLoaderAWSTextract()

# Initialize the Textract extractor
textract_extractor = Extractor(document_loader)
textract_extractor.load_llm("gpt-4o")

# Create the process with only the Textract extractor
process = Process()
process.add_classify_extractor([[textract_extractor]])

return process


def setup_process_with_gpt4_extractor():
"""Sets up and returns a process configured with only the GPT-4 extractor."""
tesseract_path = os.getenv("TESSERACT_PATH")
if not tesseract_path:
raise ValueError("TESSERACT_PATH environment variable is not set")
print(f"Tesseract path: {tesseract_path}")
document_loader = DocumentLoaderTesseract(tesseract_path)

# Initialize the GPT-4 extractor
gpt_4_extractor = Extractor(document_loader)
gpt_4_extractor.load_llm("gpt-4o")

# Create the process with only the GPT-4 extractor
process = Process()
process.add_classify_extractor([[gpt_4_extractor]])

return process


def test_classify_feature():
"""Test classification using a single feature."""
extractor = setup_extractors()[1] # Using the second configured extractor
extractor = setup_extractors()[1]
result = extractor.classify(INVOICE_FILE_PATH, COMMON_CLASSIFICATIONS)

assert result is not None
Expand Down Expand Up @@ -100,7 +135,7 @@ def test_classify_higher_order():
def test_classify_both():
"""Test classification using both consensus and higher order strategies with a threshold."""
process = arrange_process_with_extractors()
result = process.classify(INVOICE_FILE_PATH, COMMON_CLASSIFICATIONS, strategy=ClassificationStrategy.BOTH, threshold=9)
result = process.classify(INVOICE_FILE_PATH, COMMON_CLASSIFICATIONS, strategy=ClassificationStrategy.CONSENSUS_WITH_THRESHOLD, threshold=9)

assert result is not None
assert isinstance(result, ClassificationResponse)
Expand All @@ -121,37 +156,6 @@ def test_with_contract():
assert result.name == "Invoice"


def setup_process_with_textract_extractor():
"""Sets up and returns a process configured with only the Textract extractor."""
# Initialize the Textract document loader
document_loader = DocumentLoaderAWSTextract()

# Initialize the Textract extractor
textract_extractor = Extractor(document_loader)
textract_extractor.load_llm("gpt-4o")

# Create the process with only the Textract extractor
process = Process()
process.add_classify_extractor([[textract_extractor]])

return process

def setup_process_with_gpt4_extractor():
"""Sets up and returns a process configured with only the GPT-4 extractor."""
tesseract_path = os.getenv("TESSERACT_PATH")
document_loader = DocumentLoaderTesseract(tesseract_path)

# Initialize the GPT-4 extractor
gpt_4_extractor = Extractor(document_loader)
gpt_4_extractor.load_llm("gpt-4o")

# Create the process with only the GPT-4 extractor
process = Process()
process.add_classify_extractor([[gpt_4_extractor]])

return process


def test_with_image():
"""Test classification using both consensus and higher order strategies with a threshold."""
process = setup_process_with_gpt4_extractor()
Expand All @@ -168,6 +172,7 @@ def test_with_image():
assert isinstance(result, ClassificationResponse)
assert result.name == "Invoice"


def test_with_tree():
"""Test classification using the tree strategy"""
process = setup_process_with_gpt4_extractor()
Expand Down Expand Up @@ -228,8 +233,4 @@ def test_with_tree():
result = process.classify(pdf_path, classification_tree, threshold=0.8)

assert result is not None
assert result.name == "Invoice"


if __name__ == "__main__":
test_classify_feature()
assert result.name == "Invoice"
Loading