From e56da8d2d706a86183c3857f9a7ab4e6e5885c93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Almeida?= Date: Thu, 15 Aug 2024 14:35:07 +0100 Subject: [PATCH] AWS textract working with testing --- .../document_loader_aws_textract.py | 203 ++++++++++++++++++ poetry.lock | 109 +++++++++- pyproject.toml | 1 + tests/document_loader_aws_textract.py | 47 ++++ 4 files changed, 355 insertions(+), 5 deletions(-) create mode 100644 extract_thinker/document_loader/document_loader_aws_textract.py create mode 100644 tests/document_loader_aws_textract.py diff --git a/extract_thinker/document_loader/document_loader_aws_textract.py b/extract_thinker/document_loader/document_loader_aws_textract.py new file mode 100644 index 0000000..537c032 --- /dev/null +++ b/extract_thinker/document_loader/document_loader_aws_textract.py @@ -0,0 +1,203 @@ +import asyncio +from io import BytesIO +from operator import attrgetter +import os +import threading +from typing import Any, List, Union +from PIL import Image +import boto3 +import pdfium + +from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader +from extract_thinker.utils import get_image_type + +from cachetools import cachedmethod +from cachetools.keys import hashkey +from queue import Queue + +SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "pdf"] + +class DocumentLoaderAWSTextract(CachedDocumentLoader): + def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, region_name=None, textract_client=None, content=None, cache_ttl=300): + super().__init__(content, cache_ttl) + if textract_client: + self.textract_client = textract_client + elif aws_access_key_id and aws_secret_access_key and region_name: + self.textract_client = boto3.client( + 'textract', + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name + ) + else: + raise ValueError("Either provide a textract_client or aws credentials (access key, secret key, and region).") + + @classmethod + def from_client(cls, textract_client, content=None, cache_ttl=300): + return cls(textract_client=textract_client, content=content, cache_ttl=cache_ttl) + + @cachedmethod(cache=attrgetter('cache'), key=lambda self, file_path: hashkey(file_path)) + def load_content_from_file(self, file_path: str) -> Union[dict, object]: + try: + file_type = get_image_type(file_path) + if file_type in SUPPORTED_IMAGE_FORMATS: + with open(file_path, 'rb') as file: + file_bytes = file.read() + if file_type == 'pdf': + return self.process_pdf(file_bytes) + else: + return self.process_image(file_bytes) + else: + raise Exception(f"Unsupported file type: {file_path}") + except Exception as e: + raise Exception(f"Error processing file: {e}") from e + + @cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream))) + def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[dict, object]: + try: + file_type = get_image_type(stream) + if file_type in SUPPORTED_IMAGE_FORMATS: + file_bytes = stream.getvalue() if isinstance(stream, BytesIO) else stream + if file_type == 'pdf': + return self.process_pdf(file_bytes) + else: + return self.process_image(file_bytes) + else: + raise Exception(f"Unsupported stream type: {stream}") + except Exception as e: + raise Exception(f"Error processing stream: {e}") from e + + def process_image(self, image_bytes: bytes) -> dict: + for attempt in range(3): + try: + response = self.textract_client.analyze_document( + Document={'Bytes': image_bytes}, + FeatureTypes=['TABLES', 'FORMS', 'LAYOUT'] + ) + return self._parse_analyze_document_response(response) + except Exception as e: + if attempt == 2: + raise Exception(f"Failed to process image after 3 attempts: {e}") + return {} + + def process_pdf(self, pdf_bytes: bytes) -> dict: + pdf = pdfium.PdfDocument(pdf_bytes) + result = { + "pages": [], + "tables": [], + "forms": [], + "layout": {} + } + for page_number in range(len(pdf)): + page = pdf.get_page(page_number) + pil_image = page.render().to_pil() + img_byte_arr = BytesIO() + pil_image.save(img_byte_arr, format='PNG') + img_byte_arr = img_byte_arr.getvalue() + page_result = self.process_image(img_byte_arr) + result["pages"].extend(page_result["pages"]) + result["tables"].extend(page_result["tables"]) + result["forms"].extend(page_result["forms"]) + for key, value in page_result["layout"].items(): + if key not in result["layout"]: + result["layout"][key] = [] + result["layout"][key].extend(value) + return result + + def _parse_analyze_document_response(self, response: dict) -> dict: + result = { + "pages": [], + "tables": [], + "forms": [], + "layout": {} + } + + current_page = {"paragraphs": [], "lines": [], "words": []} + + for block in response['Blocks']: + if block['BlockType'] == 'PAGE': + if current_page["paragraphs"] or current_page["lines"] or current_page["words"]: + result["pages"].append(current_page) + current_page = {"paragraphs": [], "lines": [], "words": []} + elif block['BlockType'] == 'LINE': + current_page["lines"].append(block['Text']) + elif block['BlockType'] == 'WORD': + current_page["words"].append(block['Text']) + elif block['BlockType'] == 'TABLE': + result["tables"].append(self._parse_table(block, response['Blocks'])) + elif block['BlockType'] == 'KEY_VALUE_SET': + if 'KEY' in block['EntityTypes']: + key = block['Text'] + value = self._find_value_for_key(block, response['Blocks']) + result["forms"].append({"key": key, "value": value}) + elif block['BlockType'] in ['CELL', 'SELECTION_ELEMENT']: + self._add_to_layout(result["layout"], block) + + if current_page["paragraphs"] or current_page["lines"] or current_page["words"]: + result["pages"].append(current_page) + + return result + + def _parse_table(self, table_block, blocks): + cells = [block for block in blocks if block['BlockType'] == 'CELL' and block['Id'] in table_block['Relationships'][0]['Ids']] + rows = max(cell['RowIndex'] for cell in cells) + cols = max(cell['ColumnIndex'] for cell in cells) + + table = [['' for _ in range(cols)] for _ in range(rows)] + + for cell in cells: + row = cell['RowIndex'] - 1 + col = cell['ColumnIndex'] - 1 + if 'Relationships' in cell: + words = [block['Text'] for block in blocks if block['Id'] in cell['Relationships'][0]['Ids']] + table[row][col] = ' '.join(words) + + return table + + def _find_value_for_key(self, key_block, blocks): + for relationship in key_block['Relationships']: + if relationship['Type'] == 'VALUE': + value_block = next(block for block in blocks if block['Id'] == relationship['Ids'][0]) + if 'Relationships' in value_block: + words = [block['Text'] for block in blocks if block['Id'] in value_block['Relationships'][0]['Ids']] + return ' '.join(words) + return '' + + def _add_to_layout(self, layout, block): + block_type = block['BlockType'] + if block_type not in layout: + layout[block_type] = [] + + layout_item = { + 'id': block['Id'], + 'text': block.get('Text', ''), + 'confidence': block['Confidence'], + 'geometry': block['Geometry'] + } + + if 'RowIndex' in block: + layout_item['row_index'] = block['RowIndex'] + if 'ColumnIndex' in block: + layout_item['column_index'] = block['ColumnIndex'] + if 'SelectionStatus' in block: + layout_item['selection_status'] = block['SelectionStatus'] + + layout[block_type].append(layout_item) + + def load_content_from_stream_list(self, stream: BytesIO) -> List[Any]: + images = self.convert_to_images(stream) + return self._process_images(images) + + def load_content_from_file_list(self, input: List[Union[str, BytesIO]]) -> List[Any]: + images = self.convert_to_images(input) + return self._process_images(images) + + async def _process_images(self, images: dict) -> List[Any]: + tasks = [self.process_image(img) for img in images.values()] + results = await asyncio.gather(*tasks) + + contents = [] + for (image_name, image), content in zip(images.items(), results): + contents.append({"image": Image.open(BytesIO(image)) if isinstance(image, bytes) else image, "content": content}) + + return contents \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 52fa16c..ab877a3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -238,8 +238,8 @@ pathspec = ">=0.9.0,<1" platformdirs = ">=2" tomli = ">=0.2.6,<2.0.0" typing-extensions = [ - {version = ">=3.10.0.0,<3.10.0.1 || >3.10.0.1", markers = "python_version >= \"3.10\""}, {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}, + {version = ">=3.10.0.0,<3.10.0.1 || >3.10.0.1", markers = "python_version >= \"3.10\""}, ] [package.extras] @@ -249,6 +249,47 @@ jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] python2 = ["typed-ast (>=1.4.3)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "boto3" +version = "1.34.161" +description = "The AWS SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.161-py3-none-any.whl", hash = "sha256:4ef285334a0edc3047e27a04caf00f7742e32c0f03a361101e768014ac5709dd"}, + {file = "boto3-1.34.161.tar.gz", hash = "sha256:a872d8fdb3203c1eb0b12fa9e9d879e6f7fd02983a485f02189e6d5914ccd834"}, +] + +[package.dependencies] +botocore = ">=1.34.161,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.161" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.161-py3-none-any.whl", hash = "sha256:6c606d2da6f62fde06880aff1190566af208875c29938b6b68741e607817975a"}, + {file = "botocore-1.34.161.tar.gz", hash = "sha256:16381bfb786142099abf170ce734b95a402a3a7f8e4016358712ac333c5568b2"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.21.2)"] + [[package]] name = "cachetools" version = "5.3.3" @@ -614,12 +655,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" @@ -1025,6 +1066,17 @@ files = [ {file = "jiter-0.4.2.tar.gz", hash = "sha256:29b9d44f23f0c05f46d482f4ebf03213ee290d77999525d0975a17f875bf1eea"}, ] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "litellm" version = "1.40.8" @@ -1935,6 +1987,20 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "python-docx" version = "1.1.2" @@ -2183,6 +2249,23 @@ files = [ [package.dependencies] pyasn1 = ">=0.1.3" +[[package]] +name = "s3transfer" +version = "0.10.2" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">=3.8" +files = [ + {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"}, + {file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "shellingham" version = "1.5.4" @@ -2459,6 +2542,22 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "urllib3" +version = "1.26.19" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"}, + {file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.1" @@ -2613,4 +2712,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "5fd0f5a498a864d7a694f02f8e87a5ebdbec85d7c00ce054f81d33984abe88f8" +content-hash = "80be233c997072eb67a8bba585fe04e4d97c968504bf7e165e58fd0a62f1a224" diff --git a/pyproject.toml b/pyproject.toml index b015f8b..9e0215f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ openpyxl = "^3.1.2" pypdf2 = "^3.0.1" azure-ai-formrecognizer = "^3.3.3" google-cloud-documentai = "^2.29.1" +boto3 = "^1.34.161" [tool.poetry.dev-dependencies] flake8 = "^3.9.2" diff --git a/tests/document_loader_aws_textract.py b/tests/document_loader_aws_textract.py new file mode 100644 index 0000000..4ce11ed --- /dev/null +++ b/tests/document_loader_aws_textract.py @@ -0,0 +1,47 @@ +import os +import pytest +from moto import mock_textract +import boto3 +from dotenv import load_dotenv + +from extract_thinker.document_loader.document_loader_aws_textract import DocumentLoaderAWSTextract + +load_dotenv() + +@pytest.fixture +def aws_credentials(): + """Mocked AWS Credentials for moto.""" + os.environ['AWS_ACCESS_KEY_ID'] = 'testing' + os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' + os.environ['AWS_SECURITY_TOKEN'] = 'testing' + os.environ['AWS_SESSION_TOKEN'] = 'testing' + os.environ['AWS_DEFAULT_REGION'] = 'us-east-1' + +@pytest.fixture +def aws_credentials(): + """Mocked AWS Credentials for moto.""" + return { + 'aws_access_key_id': os.getenv('AWS_ACCESS_KEY_ID'), + 'aws_secret_access_key': os.getenv('AWS_SECRET_ACCESS_KEY'), + 'region_name': os.getenv('AWS_DEFAULT_REGION') + } + + +def test_load_content_from_pdf(textract_client): + # Arrange + loader = DocumentLoaderAWSTextract.from_client(textract_client) + current_dir = os.path.dirname(os.path.abspath(__file__)) + pdf_path = os.path.join(current_dir, 'test_files', 'sample.pdf') + + # Act + result = loader.load_content_from_file(pdf_path) + + # Assert + assert isinstance(result, dict) + assert "pages" in result + assert "tables" in result + assert "forms" in result + assert "layout" in result + assert len(result["pages"]) > 0 + + # You may want to add more specific \ No newline at end of file