From 28e86926246eeacbac791971ed2add567d499a14 Mon Sep 17 00:00:00 2001 From: Joe Reuter Date: Thu, 30 Nov 2023 09:49:02 +0100 Subject: [PATCH] Vector DB CDK: Add omit_raw_text flag (#32698) Co-authored-by: flash1293 --- .../destinations/vector_db_based/config.py | 58 ++++++++++++++++++- .../vector_db_based/document_processor.py | 2 +- .../destinations/vector_db_based/indexer.py | 2 +- .../destinations/vector_db_based/writer.py | 8 ++- .../vector_db_based/writer_test.py | 16 ++++- 5 files changed, 79 insertions(+), 7 deletions(-) diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py index 369b1ee27cbe..0f42e151653a 100644 --- a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py @@ -2,9 +2,11 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union +import dpath.util from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig +from airbyte_cdk.utils.spec_schema_transformations import resolve_refs from pydantic import BaseModel, Field @@ -217,3 +219,57 @@ class Config(OneOfOptionConfig): title = "Cohere" description = "Use the Cohere API to embed text." discriminator = "mode" + + +class VectorDBConfigModel(BaseModel): + """ + The configuration model for the Vector DB based destinations. This model is used to generate the UI for the destination configuration, + as well as to provide type safety for the configuration passed to the destination. + + The configuration model is composed of four parts: + * Processing configuration + * Embedding configuration + * Indexing configuration + * Advanced configuration + + Processing, embedding and advanced configuration are provided by this base class, while the indexing configuration is provided by the destination connector in the sub class. + """ + + embedding: Union[ + OpenAIEmbeddingConfigModel, + CohereEmbeddingConfigModel, + FakeEmbeddingConfigModel, + AzureOpenAIEmbeddingConfigModel, + OpenAICompatibleEmbeddingConfigModel, + ] = Field(..., title="Embedding", description="Embedding configuration", discriminator="mode", group="embedding", type="object") + processing: ProcessingConfigModel + omit_raw_text: bool = Field( + default=False, + title="Do not store raw text", + group="advanced", + description="Do not store the text that gets embedded along with the vector and the metadata in the destination. If set to true, only the vector and the metadata will be stored - in this case raw text for LLM use cases needs to be retrieved from another source.", + ) + + class Config: + title = "Destination Config" + schema_extra = { + "groups": [ + {"id": "processing", "title": "Processing"}, + {"id": "embedding", "title": "Embedding"}, + {"id": "indexing", "title": "Indexing"}, + {"id": "advanced", "title": "Advanced"}, + ] + } + + @staticmethod + def remove_discriminator(schema: Dict[str, Any]) -> None: + """pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" + dpath.util.delete(schema, "properties/**/discriminator") + + @classmethod + def schema(cls, by_alias: bool = True, ref_template: str = "") -> Dict[str, Any]: + """we're overriding the schema classmethod to enable some post-processing""" + schema: Dict[str, Any] = super().schema() + schema = resolve_refs(schema) + cls.remove_discriminator(schema) + return schema diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py index 3ed3e3511dd1..b5a8a07eda77 100644 --- a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py @@ -24,7 +24,7 @@ @dataclass class Chunk: - page_content: str + page_content: Optional[str] metadata: Dict[str, Any] record: AirbyteRecordMessage embedding: Optional[List[float]] = None diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/indexer.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/indexer.py index 20bccd58f4e2..c49f576a6709 100644 --- a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/indexer.py +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/indexer.py @@ -42,7 +42,7 @@ def index(self, document_chunks: List[Chunk], namespace: str, stream: str) -> No """ Index a list of document chunks. - This method should be used to index the documents in the destination. + This method should be used to index the documents in the destination. If page_content is None, the document should be indexed without the raw text. All chunks belong to the stream and namespace specified in the parameters. """ pass diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/writer.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/writer.py index e54143722a40..e8d58abb4ad6 100644 --- a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/writer.py +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/writer.py @@ -23,13 +23,17 @@ class Writer: The destination connector is responsible to create a writer instance and pass the input messages iterable to the write method. The batch size can be configured by the destination connector to give the freedom of either letting the user configure it or hardcoding it to a sensible value depending on the destination. + The omit_raw_text parameter can be used to omit the raw text from the documents. This can be useful if the raw text is very large and not needed for the destination. """ - def __init__(self, processing_config: ProcessingConfigModel, indexer: Indexer, embedder: Embedder, batch_size: int) -> None: + def __init__( + self, processing_config: ProcessingConfigModel, indexer: Indexer, embedder: Embedder, batch_size: int, omit_raw_text: bool + ) -> None: self.processing_config = processing_config self.indexer = indexer self.embedder = embedder self.batch_size = batch_size + self.omit_raw_text = omit_raw_text self._init_batch() def _init_batch(self) -> None: @@ -45,6 +49,8 @@ def _process_batch(self) -> None: embeddings = self.embedder.embed_chunks(documents) for i, document in enumerate(documents): document.embedding = embeddings[i] + if self.omit_raw_text: + document.page_content = None self.indexer.index(documents, namespace, stream) self._init_batch() diff --git a/airbyte-cdk/python/unit_tests/destinations/vector_db_based/writer_test.py b/airbyte-cdk/python/unit_tests/destinations/vector_db_based/writer_test.py index e97532bf280c..dff570d6e698 100644 --- a/airbyte-cdk/python/unit_tests/destinations/vector_db_based/writer_test.py +++ b/airbyte-cdk/python/unit_tests/destinations/vector_db_based/writer_test.py @@ -5,6 +5,7 @@ from typing import Optional from unittest.mock import ANY, MagicMock, call +import pytest from airbyte_cdk.destinations.vector_db_based import ProcessingConfigModel, Writer from airbyte_cdk.models.airbyte_protocol import ( AirbyteLogMessage, @@ -53,7 +54,8 @@ def generate_mock_embedder(): return mock_embedder -def test_write(): +@pytest.mark.parametrize("omit_raw_text", [True, False]) +def test_write(omit_raw_text: bool): """ Basic test for the write method, batcher and document processor. """ @@ -74,7 +76,7 @@ def test_write(): mock_indexer.post_sync.return_value = [post_sync_log_message] # Create the DestinationLangchain instance - writer = Writer(config_model, mock_indexer, mock_embedder, BATCH_SIZE) + writer = Writer(config_model, mock_indexer, mock_embedder, BATCH_SIZE, omit_raw_text) output_messages = writer.write(configured_catalog, input_messages) output_message = next(output_messages) @@ -88,6 +90,14 @@ def test_write(): assert mock_indexer.delete.call_count == 2 assert mock_embedder.embed_chunks.call_count == 2 + if omit_raw_text: + for call_args in mock_indexer.index.call_args_list: + for chunk in call_args[0][0]: + if omit_raw_text: + assert chunk.page_content is None + else: + assert chunk.page_content is not None + output_message = next(output_messages) assert output_message == post_sync_log_message @@ -138,7 +148,7 @@ def test_write_stream_namespace_split(): mock_indexer.post_sync.return_value = [] # Create the DestinationLangchain instance - writer = Writer(config_model, mock_indexer, mock_embedder, BATCH_SIZE) + writer = Writer(config_model, mock_indexer, mock_embedder, BATCH_SIZE, False) output_messages = writer.write(configured_catalog, input_messages) next(output_messages)