From 8bd26be1d8cde140567278da044bd81d7852eeef Mon Sep 17 00:00:00 2001 From: Junlin Zhou Date: Wed, 6 Nov 2024 16:06:01 +0800 Subject: [PATCH] refactor: rewrite ColumnDocCompressor --- src/tablegpt/retriever/compressor.py | 46 +++++++----- tests/retriever/test_compressor.py | 101 +++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 19 deletions(-) create mode 100644 tests/retriever/test_compressor.py diff --git a/src/tablegpt/retriever/compressor.py b/src/tablegpt/retriever/compressor.py index 4d4ade5..853c611 100644 --- a/src/tablegpt/retriever/compressor.py +++ b/src/tablegpt/retriever/compressor.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict from sys import version_info from typing import TYPE_CHECKING @@ -21,6 +22,12 @@ def override(func): class ColumnDocCompressor(BaseDocumentCompressor): + """Compresses documents by regrouping them by column. + + The TableGPT Agent generates documents at the cell level (format: {column_name: cell_value}) to enhance retrieval accuracy. + However, after retrieval, these documents need to be recombined by column before being sent to the LLM for processing. + """ + @override def compress_documents( self, @@ -28,23 +35,24 @@ def compress_documents( query: str, # noqa: ARG002 callbacks: Callbacks | None = None, # noqa: ARG002 ) -> Sequence[Document]: - # column name -> document - # TODO: we can perform a map-reduce here. - cols: dict[str, Document] = {} + if not documents: + return [] + + # Initialize defaultdict to collect documents by column + # Document.page_content cannot be None + cols = defaultdict(lambda: Document(page_content="", metadata={})) + for doc in documents: - key = doc.metadata["file_name"] + ":" + doc.metadata["column"] - if key not in cols: - # TODO: what's the difference between this and doc.copy()? - cols[key] = Document( - page_content=f"column:{doc.metadata['column']}", - metadata={ - "file_name": doc.metadata["file_name"], - "column": doc.metadata["column"], - "dtype": doc.metadata["dtype"], - "n_unique": doc.metadata["n_unique"], - "values": [doc.metadata["value"]], - }, - ) - else: - cols[key].metadata["values"] += [doc.metadata["value"]] - return cols.values() + key = f"{doc.metadata['file_name']}:{doc.metadata['column']}" + + # Initialize if key is encountered first time + if not cols[key].page_content: + cols[key].page_content = f"column: {doc.metadata['column']}" + # Copy all metadata, excluding 'value' (if needed) + cols[key].metadata = {k: v for k, v in doc.metadata.items() if k != "value"} + cols[key].metadata["values"] = [] + + # Append value to the existing document's values list + cols[key].metadata["values"].append(doc.metadata["value"]) + + return list(cols.values()) diff --git a/tests/retriever/test_compressor.py b/tests/retriever/test_compressor.py new file mode 100644 index 0000000..30c971b --- /dev/null +++ b/tests/retriever/test_compressor.py @@ -0,0 +1,101 @@ +import unittest + +from langchain_core.documents import Document +from tablegpt.retriever.compressor import ColumnDocCompressor + + +class TestCompressDocuments(unittest.TestCase): + def setUp(self): + self.processor = ColumnDocCompressor() + + def test_single_column_single_file(self): + documents = [ + Document( + page_content="cell content", + metadata={"file_name": "file1", "column": "A", "dtype": "int", "n_unique": 5, "value": 1}, + ), + Document( + page_content="cell content", + metadata={"file_name": "file1", "column": "A", "dtype": "int", "n_unique": 5, "value": 2}, + ), + ] + + expected_output = [ + Document( + page_content="column: A", + metadata={"file_name": "file1", "column": "A", "dtype": "int", "n_unique": 5, "values": [1, 2]}, + ) + ] + + result = self.processor.compress_documents(documents, query="") + assert result == expected_output + + def test_multiple_columns_single_file(self): + documents = [ + Document( + page_content="A:1", + metadata={"file_name": "file1", "column": "A", "dtype": "int", "n_unique": 5, "value": 1}, + ), + Document( + page_content="B:hello", + metadata={"file_name": "file1", "column": "B", "dtype": "str", "n_unique": 3, "value": "hello"}, + ), + ] + + expected_output = [ + Document( + page_content="column: A", + metadata={"file_name": "file1", "column": "A", "dtype": "int", "n_unique": 5, "values": [1]}, + ), + Document( + page_content="column: B", + metadata={"file_name": "file1", "column": "B", "dtype": "str", "n_unique": 3, "values": ["hello"]}, + ), + ] + + result = self.processor.compress_documents(documents, query="") + assert result == expected_output + + def test_multiple_columns_multiple_files(self): + documents = [ + Document( + page_content="cell content", + metadata={"file_name": "file1", "column": "A", "dtype": "int", "n_unique": 5, "value": 1}, + ), + Document( + page_content="cell content", + metadata={"file_name": "file2", "column": "A", "dtype": "int", "n_unique": 4, "value": 2}, + ), + Document( + page_content="cell content", + metadata={"file_name": "file2", "column": "B", "dtype": "str", "n_unique": 3, "value": "world"}, + ), + ] + + expected_output = [ + Document( + page_content="column: A", + metadata={"file_name": "file1", "column": "A", "dtype": "int", "n_unique": 5, "values": [1]}, + ), + Document( + page_content="column: A", + metadata={"file_name": "file2", "column": "A", "dtype": "int", "n_unique": 4, "values": [2]}, + ), + Document( + page_content="column: B", + metadata={"file_name": "file2", "column": "B", "dtype": "str", "n_unique": 3, "values": ["world"]}, + ), + ] + + result = self.processor.compress_documents(documents, query="") + assert result == expected_output + + def test_empty_input(self): + documents = [] + expected_output = [] + result = self.processor.compress_documents(documents, query="") + assert result == expected_output + + +if __name__ == "__main__": + unittest.main()