diff --git a/docprompt/tasks/table_extraction/anthropic.py b/docprompt/tasks/table_extraction/anthropic.py
index 7c44d7c..07f1dcf 100644
--- a/docprompt/tasks/table_extraction/anthropic.py
+++ b/docprompt/tasks/table_extraction/anthropic.py
@@ -98,7 +98,7 @@ def parse_response(response: str, **kwargs) -> TableExtractionPageResult:
return result
-async def _prepare_messages(
+def _prepare_messages(
document_images: Iterable[bytes],
start: Optional[int] = None,
stop: Optional[int] = None,
@@ -130,7 +130,7 @@ class AnthropicTableExtractionProvider(BaseTableExtractionProvider):
async def _ainvoke(
self, input: Iterable[bytes], config: Optional[None] = None
) -> List[TableExtractionPageResult]:
- messages = await _prepare_messages(input)
+ messages = _prepare_messages(input)
completions = await inference.run_batch_inference_anthropic(messages)
diff --git a/tests/tasks/table_extraction/__init__.py b/tests/tasks/table_extraction/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/tasks/table_extraction/test_anthropic.py b/tests/tasks/table_extraction/test_anthropic.py
new file mode 100644
index 0000000..ee77dec
--- /dev/null
+++ b/tests/tasks/table_extraction/test_anthropic.py
@@ -0,0 +1,180 @@
+"""
+Test the Anthropic implementation of the table extraction task.
+"""
+
+from unittest.mock import patch
+
+import pytest
+from bs4 import BeautifulSoup
+
+from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
+from docprompt.tasks.table_extraction.anthropic import (
+ AnthropicTableExtractionProvider,
+ _headers_from_tree,
+ _prepare_messages,
+ _rows_from_tree,
+ _title_from_tree,
+ parse_response,
+)
+from docprompt.tasks.table_extraction.schema import (
+ TableCell,
+ TableExtractionPageResult,
+ TableHeader,
+ TableRow,
+)
+
+
+@pytest.fixture
+def mock_image_bytes():
+ return b"mock_image_bytes"
+
+
+class TestAnthropicTableExtractionProvider:
+ @pytest.fixture
+ def provider(self):
+ return AnthropicTableExtractionProvider()
+
+ def test_provider_name(self, provider):
+ assert provider.name == "anthropic"
+
+ @pytest.mark.asyncio
+ async def test_ainvoke(self, provider, mock_image_bytes):
+ mock_completions = [
+ "
",
+ "",
+ ]
+
+ with (
+ patch(
+ "docprompt.tasks.table_extraction.anthropic._prepare_messages"
+ ) as mock_prepare,
+ patch(
+ "docprompt.utils.inference.run_batch_inference_anthropic"
+ ) as mock_inference,
+ ):
+ mock_prepare.return_value = "mock_messages"
+ mock_inference.return_value = mock_completions
+
+ result = await provider._ainvoke([mock_image_bytes, mock_image_bytes])
+
+ assert len(result) == 2
+ assert all(isinstance(r, TableExtractionPageResult) for r in result)
+ assert result[0].tables[0].title == "Test Table"
+ assert result[1].tables[0].title is None
+ assert all(r.provider_name == "anthropic" for r in result)
+
+ mock_prepare.assert_called_once_with([mock_image_bytes, mock_image_bytes])
+ mock_inference.assert_called_once_with("mock_messages")
+
+
+def test_prepare_messages(mock_image_bytes):
+ messages = _prepare_messages([mock_image_bytes])
+
+ assert len(messages) == 1
+ assert len(messages[0]) == 1
+ assert isinstance(messages[0][0], OpenAIMessage)
+ assert messages[0][0].role == "user"
+ assert len(messages[0][0].content) == 2
+ assert isinstance(messages[0][0].content[0], OpenAIComplexContent)
+ assert messages[0][0].content[0].type == "image_url"
+ assert isinstance(messages[0][0].content[0].image_url, OpenAIImageURL)
+ assert messages[0][0].content[0].image_url.url == mock_image_bytes.decode()
+ assert isinstance(messages[0][0].content[1], OpenAIComplexContent)
+ assert messages[0][0].content[1].type == "text"
+ assert (
+ "Identify and extract all tables from the document"
+ in messages[0][0].content[1].text
+ )
+
+
+def test_parse_response():
+ response = """
+
+ Test Table
+
+
+
+
+
+
+ Data1
+ Data2
+
+
+
+ """
+ result = parse_response(response)
+
+ assert isinstance(result, TableExtractionPageResult)
+ assert len(result.tables) == 1
+ assert result.tables[0].title == "Test Table"
+ assert len(result.tables[0].headers) == 2
+ assert result.tables[0].headers[0].text == "Col1"
+ assert len(result.tables[0].rows) == 1
+ assert result.tables[0].rows[0].cells[0].text == "Data1"
+
+
+def test_title_from_tree():
+ soup = BeautifulSoup("", "xml")
+ assert _title_from_tree(soup.table) == "Test Title"
+
+ soup = BeautifulSoup("", "xml")
+ assert _title_from_tree(soup.table) is None
+
+
+def test_headers_from_tree():
+ soup = BeautifulSoup(
+ "",
+ "xml",
+ )
+ headers = _headers_from_tree(soup.table)
+ assert len(headers) == 2
+ assert all(isinstance(h, TableHeader) for h in headers)
+ assert headers[0].text == "Col1"
+
+ soup = BeautifulSoup("", "xml")
+ assert _headers_from_tree(soup.table) == []
+
+
+def test_rows_from_tree():
+ soup = BeautifulSoup(
+ "",
+ "xml",
+ )
+ rows = _rows_from_tree(soup.table)
+ assert len(rows) == 1
+ assert isinstance(rows[0], TableRow)
+ assert len(rows[0].cells) == 2
+ assert all(isinstance(c, TableCell) for c in rows[0].cells)
+ assert rows[0].cells[0].text == "Data1"
+
+ soup = BeautifulSoup("", "xml")
+ assert _rows_from_tree(soup.table) == []
+
+
+@pytest.mark.parametrize(
+ "input_str,sub_str,expected",
+ [
+ ("abcghi", "", [3, 24]),
+ ("notables", "", []),
+ ("", "", [0, 7, 14]),
+ ],
+)
+def test_find_start_indices(input_str, sub_str, expected):
+ from docprompt.tasks.table_extraction.anthropic import _find_start_indices
+
+ assert _find_start_indices(input_str, sub_str) == expected
+
+
+@pytest.mark.parametrize(
+ "input_str,sub_str,expected",
+ [
+ ("abc
def
ghi", "
", [11, 22]),
+ ("notables", "
", []),
+ ("
", "", [8, 16, 24]),
+ ],
+)
+def test_find_end_indices(input_str, sub_str, expected):
+ from docprompt.tasks.table_extraction.anthropic import _find_end_indices
+
+ assert _find_end_indices(input_str, sub_str) == expected
diff --git a/tests/tasks/table_extraction/test_base.py b/tests/tasks/table_extraction/test_base.py
new file mode 100644
index 0000000..5b280e9
--- /dev/null
+++ b/tests/tasks/table_extraction/test_base.py
@@ -0,0 +1,224 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pydantic import ValidationError
+
+from docprompt import DocumentNode
+from docprompt.schema.layout import NormBBox
+from docprompt.schema.pipeline.node.page import PageNode
+from docprompt.tasks.table_extraction.base import BaseTableExtractionProvider
+from docprompt.tasks.table_extraction.schema import (
+ ExtractedTable,
+ TableCell,
+ TableExtractionPageResult,
+ TableHeader,
+ TableRow,
+)
+
+
+class TestExtractedTable:
+ def test_to_markdown_string_with_title(self):
+ table = ExtractedTable(
+ title="Sample Table",
+ headers=[TableHeader(text="Header 1"), TableHeader(text="Header 2")],
+ rows=[
+ TableRow(
+ cells=[
+ TableCell(text="Row 1, Col 1"),
+ TableCell(text="Row 1, Col 2"),
+ ]
+ ),
+ TableRow(
+ cells=[
+ TableCell(text="Row 2, Col 1"),
+ TableCell(text="Row 2, Col 2"),
+ ]
+ ),
+ ],
+ )
+ expected_markdown = (
+ "# Sample Table\n\n"
+ "|Header 1|Header 2|\n"
+ "|---|---|\n"
+ "|Row 1, Col 1|Row 1, Col 2|\n"
+ "|Row 2, Col 1|Row 2, Col 2|"
+ )
+
+ assert table.to_markdown_string() == expected_markdown
+
+ def test_to_markdown_string_without_title(self):
+ table = ExtractedTable(
+ headers=[TableHeader(text="Header 1"), TableHeader(text="Header 2")],
+ rows=[
+ TableRow(
+ cells=[
+ TableCell(text="Row 1, Col 1"),
+ TableCell(text="Row 1, Col 2"),
+ ]
+ ),
+ ],
+ )
+ expected_markdown = (
+ "|Header 1|Header 2|\n" "|---|---|\n" "|Row 1, Col 1|Row 1, Col 2|"
+ )
+ assert table.to_markdown_string() == expected_markdown
+
+ def test_extracted_table_with_bbox(self):
+ bbox = NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=0.9)
+ table = ExtractedTable(bbox=bbox)
+ assert table.bbox == bbox
+
+
+class TestTableExtractionPageResult:
+ def test_initialization(self):
+ table1 = ExtractedTable(title="Table 1")
+ table2 = ExtractedTable(title="Table 2")
+ result = TableExtractionPageResult(
+ tables=[table1, table2], provider_name="test"
+ )
+ assert len(result.tables) == 2
+ assert result.tables[0].title == "Table 1"
+ assert result.tables[1].title == "Table 2"
+
+ def test_task_name(self):
+ result = TableExtractionPageResult(provider_name="test")
+ assert result.task_name == "table_extraction"
+
+ def test_empty_tables(self):
+ result = TableExtractionPageResult(provider_name="test")
+ assert result.tables == []
+
+ def test_invalid_table_type(self):
+ with pytest.raises(ValidationError):
+ TableExtractionPageResult(tables=["Not a table"], provider_name="test")
+
+
+class TestTableComponents:
+ def test_table_header(self):
+ header = TableHeader(text="Header 1")
+ assert header.text == "Header 1"
+ assert header.bbox is None
+
+ bbox = NormBBox(x0=0.1, top=0.1, x1=0.2, bottom=0.2)
+ header_with_bbox = TableHeader(text="Header 2", bbox=bbox)
+ assert header_with_bbox.text == "Header 2"
+ assert header_with_bbox.bbox == bbox
+
+ def test_table_cell(self):
+ cell = TableCell(text="Cell 1")
+ assert cell.text == "Cell 1"
+ assert cell.bbox is None
+
+ bbox = NormBBox(x0=0.1, top=0.1, x1=0.2, bottom=0.2)
+ cell_with_bbox = TableCell(text="Cell 2", bbox=bbox)
+ assert cell_with_bbox.text == "Cell 2"
+ assert cell_with_bbox.bbox == bbox
+
+ def test_table_row(self):
+ cell1 = TableCell(text="Cell 1")
+ cell2 = TableCell(text="Cell 2")
+ row = TableRow(cells=[cell1, cell2])
+ assert len(row.cells) == 2
+ assert row.cells[0].text == "Cell 1"
+ assert row.cells[1].text == "Cell 2"
+ assert row.bbox is None
+
+ bbox = NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=0.2)
+ row_with_bbox = TableRow(cells=[cell1, cell2], bbox=bbox)
+ assert row_with_bbox.bbox == bbox
+
+ def test_normbbox_validation(self):
+ with pytest.raises(ValidationError):
+ NormBBox(x0=1.1, top=0.1, x1=0.9, bottom=0.9) # x0 > 1
+
+ with pytest.raises(ValidationError):
+ NormBBox(x0=0.1, top=-0.1, x1=0.9, bottom=0.9) # top < 0
+
+ with pytest.raises(ValidationError):
+ NormBBox(x0=0.1, top=0.1, x1=1.1, bottom=0.9) # x1 > 1
+
+ with pytest.raises(ValidationError):
+ NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=1.1) # bottom > 1
+
+ # Valid NormBBox should not raise an exception
+ NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=0.9)
+
+
+class TestBaseTableExtractionProvider:
+ @pytest.fixture
+ def mock_document_node(self):
+ mock_node = MagicMock(spec=DocumentNode)
+ mock_node.page_nodes = [MagicMock(spec=PageNode) for _ in range(5)]
+ for pnode in mock_node.page_nodes:
+ pnode.rasterizer.rasterize.return_value = b"image"
+ mock_node.__len__.return_value = len(mock_node.page_nodes)
+ return mock_node
+
+ @pytest.mark.parametrize(
+ "start,stop,expected_keys,expected_results",
+ [
+ (2, 4, [2, 3, 4], {2: "TABLE-0", 3: "TABLE-1", 4: "TABLE-2"}),
+ (3, None, [3, 4, 5], {3: "TABLE-0", 4: "TABLE-1", 5: "TABLE-2"}),
+ (None, 2, [1, 2], {1: "TABLE-0", 2: "TABLE-1"}),
+ (
+ None,
+ None,
+ [1, 2, 3, 4, 5],
+ {
+ 1: "TABLE-0",
+ 2: "TABLE-1",
+ 3: "TABLE-2",
+ 4: "TABLE-3",
+ 5: "TABLE-4",
+ },
+ ),
+ ],
+ )
+ def test_process_document_node_with_start_stop(
+ self, mock_document_node, start, stop, expected_keys, expected_results
+ ):
+ class TestProvider(BaseTableExtractionProvider):
+ name = "test"
+
+ def _invoke(self, input, config, **kwargs):
+ return [
+ TableExtractionPageResult(
+ tables=[ExtractedTable(title=f"TABLE-{i}")],
+ provider_name="test",
+ )
+ for i in range(len(input))
+ ]
+
+ provider = TestProvider()
+ result = provider.process_document_node(
+ mock_document_node, start=start, stop=stop
+ )
+
+ assert list(result.keys()) == expected_keys
+ assert all(isinstance(v, TableExtractionPageResult) for v in result.values())
+ assert {k: v.tables[0].title for k, v in result.items()} == expected_results
+
+ with patch.object(provider, "_invoke") as mock_invoke:
+ provider.process_document_node(mock_document_node, start=start, stop=stop)
+ mock_invoke.assert_called_once()
+ expected_invoke_length = len(expected_keys)
+ assert len(mock_invoke.call_args[0][0]) == expected_invoke_length
+
+ def test_process_document_node_rasterization(self, mock_document_node):
+ class TestProvider(BaseTableExtractionProvider):
+ name = "test"
+
+ def _invoke(self, input, config, **kwargs):
+ return [
+ TableExtractionPageResult(
+ tables=[ExtractedTable(title=f"TABLE-{i}")],
+ provider_name="test",
+ )
+ for i in range(len(input))
+ ]
+
+ provider = TestProvider()
+ provider.process_document_node(mock_document_node)
+
+ for page_node in mock_document_node.page_nodes:
+ page_node.rasterizer.rasterize.assert_called_once_with("default")