diff --git a/docprompt/tasks/table_extraction/anthropic.py b/docprompt/tasks/table_extraction/anthropic.py index 7c44d7c..07f1dcf 100644 --- a/docprompt/tasks/table_extraction/anthropic.py +++ b/docprompt/tasks/table_extraction/anthropic.py @@ -98,7 +98,7 @@ def parse_response(response: str, **kwargs) -> TableExtractionPageResult: return result -async def _prepare_messages( +def _prepare_messages( document_images: Iterable[bytes], start: Optional[int] = None, stop: Optional[int] = None, @@ -130,7 +130,7 @@ class AnthropicTableExtractionProvider(BaseTableExtractionProvider): async def _ainvoke( self, input: Iterable[bytes], config: Optional[None] = None ) -> List[TableExtractionPageResult]: - messages = await _prepare_messages(input) + messages = _prepare_messages(input) completions = await inference.run_batch_inference_anthropic(messages) diff --git a/tests/tasks/table_extraction/__init__.py b/tests/tasks/table_extraction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tasks/table_extraction/test_anthropic.py b/tests/tasks/table_extraction/test_anthropic.py new file mode 100644 index 0000000..ee77dec --- /dev/null +++ b/tests/tasks/table_extraction/test_anthropic.py @@ -0,0 +1,180 @@ +""" +Test the Anthropic implementation of the table extraction task. +""" + +from unittest.mock import patch + +import pytest +from bs4 import BeautifulSoup + +from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage +from docprompt.tasks.table_extraction.anthropic import ( + AnthropicTableExtractionProvider, + _headers_from_tree, + _prepare_messages, + _rows_from_tree, + _title_from_tree, + parse_response, +) +from docprompt.tasks.table_extraction.schema import ( + TableCell, + TableExtractionPageResult, + TableHeader, + TableRow, +) + + +@pytest.fixture +def mock_image_bytes(): + return b"mock_image_bytes" + + +class TestAnthropicTableExtractionProvider: + @pytest.fixture + def provider(self): + return AnthropicTableExtractionProvider() + + def test_provider_name(self, provider): + assert provider.name == "anthropic" + + @pytest.mark.asyncio + async def test_ainvoke(self, provider, mock_image_bytes): + mock_completions = [ + "Test Table
Col1
Data1
", + "
Col2
Data2
", + ] + + with ( + patch( + "docprompt.tasks.table_extraction.anthropic._prepare_messages" + ) as mock_prepare, + patch( + "docprompt.utils.inference.run_batch_inference_anthropic" + ) as mock_inference, + ): + mock_prepare.return_value = "mock_messages" + mock_inference.return_value = mock_completions + + result = await provider._ainvoke([mock_image_bytes, mock_image_bytes]) + + assert len(result) == 2 + assert all(isinstance(r, TableExtractionPageResult) for r in result) + assert result[0].tables[0].title == "Test Table" + assert result[1].tables[0].title is None + assert all(r.provider_name == "anthropic" for r in result) + + mock_prepare.assert_called_once_with([mock_image_bytes, mock_image_bytes]) + mock_inference.assert_called_once_with("mock_messages") + + +def test_prepare_messages(mock_image_bytes): + messages = _prepare_messages([mock_image_bytes]) + + assert len(messages) == 1 + assert len(messages[0]) == 1 + assert isinstance(messages[0][0], OpenAIMessage) + assert messages[0][0].role == "user" + assert len(messages[0][0].content) == 2 + assert isinstance(messages[0][0].content[0], OpenAIComplexContent) + assert messages[0][0].content[0].type == "image_url" + assert isinstance(messages[0][0].content[0].image_url, OpenAIImageURL) + assert messages[0][0].content[0].image_url.url == mock_image_bytes.decode() + assert isinstance(messages[0][0].content[1], OpenAIComplexContent) + assert messages[0][0].content[1].type == "text" + assert ( + "Identify and extract all tables from the document" + in messages[0][0].content[1].text + ) + + +def test_parse_response(): + response = """ + + Test Table + +
Col1
+
Col2
+
+ + + Data1 + Data2 + + +
+ """ + result = parse_response(response) + + assert isinstance(result, TableExtractionPageResult) + assert len(result.tables) == 1 + assert result.tables[0].title == "Test Table" + assert len(result.tables[0].headers) == 2 + assert result.tables[0].headers[0].text == "Col1" + assert len(result.tables[0].rows) == 1 + assert result.tables[0].rows[0].cells[0].text == "Data1" + + +def test_title_from_tree(): + soup = BeautifulSoup("Test Title
", "xml") + assert _title_from_tree(soup.table) == "Test Title" + + soup = BeautifulSoup("
", "xml") + assert _title_from_tree(soup.table) is None + + +def test_headers_from_tree(): + soup = BeautifulSoup( + "
Col1
Col2
", + "xml", + ) + headers = _headers_from_tree(soup.table) + assert len(headers) == 2 + assert all(isinstance(h, TableHeader) for h in headers) + assert headers[0].text == "Col1" + + soup = BeautifulSoup("
", "xml") + assert _headers_from_tree(soup.table) == [] + + +def test_rows_from_tree(): + soup = BeautifulSoup( + "Data1Data2
", + "xml", + ) + rows = _rows_from_tree(soup.table) + assert len(rows) == 1 + assert isinstance(rows[0], TableRow) + assert len(rows[0].cells) == 2 + assert all(isinstance(c, TableCell) for c in rows[0].cells) + assert rows[0].cells[0].text == "Data1" + + soup = BeautifulSoup("
", "xml") + assert _rows_from_tree(soup.table) == [] + + +@pytest.mark.parametrize( + "input_str,sub_str,expected", + [ + ("abcdef
ghijkl
", "", [3, 24]), + ("notables", "
", []), + ("
", "
", [0, 7, 14]), + ], +) +def test_find_start_indices(input_str, sub_str, expected): + from docprompt.tasks.table_extraction.anthropic import _find_start_indices + + assert _find_start_indices(input_str, sub_str) == expected + + +@pytest.mark.parametrize( + "input_str,sub_str,expected", + [ + ("abc
defghi", "", [11, 22]), + ("notables", "", []), + ("", "", [8, 16, 24]), + ], +) +def test_find_end_indices(input_str, sub_str, expected): + from docprompt.tasks.table_extraction.anthropic import _find_end_indices + + assert _find_end_indices(input_str, sub_str) == expected diff --git a/tests/tasks/table_extraction/test_base.py b/tests/tasks/table_extraction/test_base.py new file mode 100644 index 0000000..5b280e9 --- /dev/null +++ b/tests/tasks/table_extraction/test_base.py @@ -0,0 +1,224 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from docprompt import DocumentNode +from docprompt.schema.layout import NormBBox +from docprompt.schema.pipeline.node.page import PageNode +from docprompt.tasks.table_extraction.base import BaseTableExtractionProvider +from docprompt.tasks.table_extraction.schema import ( + ExtractedTable, + TableCell, + TableExtractionPageResult, + TableHeader, + TableRow, +) + + +class TestExtractedTable: + def test_to_markdown_string_with_title(self): + table = ExtractedTable( + title="Sample Table", + headers=[TableHeader(text="Header 1"), TableHeader(text="Header 2")], + rows=[ + TableRow( + cells=[ + TableCell(text="Row 1, Col 1"), + TableCell(text="Row 1, Col 2"), + ] + ), + TableRow( + cells=[ + TableCell(text="Row 2, Col 1"), + TableCell(text="Row 2, Col 2"), + ] + ), + ], + ) + expected_markdown = ( + "# Sample Table\n\n" + "|Header 1|Header 2|\n" + "|---|---|\n" + "|Row 1, Col 1|Row 1, Col 2|\n" + "|Row 2, Col 1|Row 2, Col 2|" + ) + + assert table.to_markdown_string() == expected_markdown + + def test_to_markdown_string_without_title(self): + table = ExtractedTable( + headers=[TableHeader(text="Header 1"), TableHeader(text="Header 2")], + rows=[ + TableRow( + cells=[ + TableCell(text="Row 1, Col 1"), + TableCell(text="Row 1, Col 2"), + ] + ), + ], + ) + expected_markdown = ( + "|Header 1|Header 2|\n" "|---|---|\n" "|Row 1, Col 1|Row 1, Col 2|" + ) + assert table.to_markdown_string() == expected_markdown + + def test_extracted_table_with_bbox(self): + bbox = NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=0.9) + table = ExtractedTable(bbox=bbox) + assert table.bbox == bbox + + +class TestTableExtractionPageResult: + def test_initialization(self): + table1 = ExtractedTable(title="Table 1") + table2 = ExtractedTable(title="Table 2") + result = TableExtractionPageResult( + tables=[table1, table2], provider_name="test" + ) + assert len(result.tables) == 2 + assert result.tables[0].title == "Table 1" + assert result.tables[1].title == "Table 2" + + def test_task_name(self): + result = TableExtractionPageResult(provider_name="test") + assert result.task_name == "table_extraction" + + def test_empty_tables(self): + result = TableExtractionPageResult(provider_name="test") + assert result.tables == [] + + def test_invalid_table_type(self): + with pytest.raises(ValidationError): + TableExtractionPageResult(tables=["Not a table"], provider_name="test") + + +class TestTableComponents: + def test_table_header(self): + header = TableHeader(text="Header 1") + assert header.text == "Header 1" + assert header.bbox is None + + bbox = NormBBox(x0=0.1, top=0.1, x1=0.2, bottom=0.2) + header_with_bbox = TableHeader(text="Header 2", bbox=bbox) + assert header_with_bbox.text == "Header 2" + assert header_with_bbox.bbox == bbox + + def test_table_cell(self): + cell = TableCell(text="Cell 1") + assert cell.text == "Cell 1" + assert cell.bbox is None + + bbox = NormBBox(x0=0.1, top=0.1, x1=0.2, bottom=0.2) + cell_with_bbox = TableCell(text="Cell 2", bbox=bbox) + assert cell_with_bbox.text == "Cell 2" + assert cell_with_bbox.bbox == bbox + + def test_table_row(self): + cell1 = TableCell(text="Cell 1") + cell2 = TableCell(text="Cell 2") + row = TableRow(cells=[cell1, cell2]) + assert len(row.cells) == 2 + assert row.cells[0].text == "Cell 1" + assert row.cells[1].text == "Cell 2" + assert row.bbox is None + + bbox = NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=0.2) + row_with_bbox = TableRow(cells=[cell1, cell2], bbox=bbox) + assert row_with_bbox.bbox == bbox + + def test_normbbox_validation(self): + with pytest.raises(ValidationError): + NormBBox(x0=1.1, top=0.1, x1=0.9, bottom=0.9) # x0 > 1 + + with pytest.raises(ValidationError): + NormBBox(x0=0.1, top=-0.1, x1=0.9, bottom=0.9) # top < 0 + + with pytest.raises(ValidationError): + NormBBox(x0=0.1, top=0.1, x1=1.1, bottom=0.9) # x1 > 1 + + with pytest.raises(ValidationError): + NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=1.1) # bottom > 1 + + # Valid NormBBox should not raise an exception + NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=0.9) + + +class TestBaseTableExtractionProvider: + @pytest.fixture + def mock_document_node(self): + mock_node = MagicMock(spec=DocumentNode) + mock_node.page_nodes = [MagicMock(spec=PageNode) for _ in range(5)] + for pnode in mock_node.page_nodes: + pnode.rasterizer.rasterize.return_value = b"image" + mock_node.__len__.return_value = len(mock_node.page_nodes) + return mock_node + + @pytest.mark.parametrize( + "start,stop,expected_keys,expected_results", + [ + (2, 4, [2, 3, 4], {2: "TABLE-0", 3: "TABLE-1", 4: "TABLE-2"}), + (3, None, [3, 4, 5], {3: "TABLE-0", 4: "TABLE-1", 5: "TABLE-2"}), + (None, 2, [1, 2], {1: "TABLE-0", 2: "TABLE-1"}), + ( + None, + None, + [1, 2, 3, 4, 5], + { + 1: "TABLE-0", + 2: "TABLE-1", + 3: "TABLE-2", + 4: "TABLE-3", + 5: "TABLE-4", + }, + ), + ], + ) + def test_process_document_node_with_start_stop( + self, mock_document_node, start, stop, expected_keys, expected_results + ): + class TestProvider(BaseTableExtractionProvider): + name = "test" + + def _invoke(self, input, config, **kwargs): + return [ + TableExtractionPageResult( + tables=[ExtractedTable(title=f"TABLE-{i}")], + provider_name="test", + ) + for i in range(len(input)) + ] + + provider = TestProvider() + result = provider.process_document_node( + mock_document_node, start=start, stop=stop + ) + + assert list(result.keys()) == expected_keys + assert all(isinstance(v, TableExtractionPageResult) for v in result.values()) + assert {k: v.tables[0].title for k, v in result.items()} == expected_results + + with patch.object(provider, "_invoke") as mock_invoke: + provider.process_document_node(mock_document_node, start=start, stop=stop) + mock_invoke.assert_called_once() + expected_invoke_length = len(expected_keys) + assert len(mock_invoke.call_args[0][0]) == expected_invoke_length + + def test_process_document_node_rasterization(self, mock_document_node): + class TestProvider(BaseTableExtractionProvider): + name = "test" + + def _invoke(self, input, config, **kwargs): + return [ + TableExtractionPageResult( + tables=[ExtractedTable(title=f"TABLE-{i}")], + provider_name="test", + ) + for i in range(len(input)) + ] + + provider = TestProvider() + provider.process_document_node(mock_document_node) + + for page_node in mock_document_node.page_nodes: + page_node.rasterizer.rasterize.assert_called_once_with("default")