Skip to content

Commit

Permalink
Vector DB CDK: Add omit_raw_text flag (#32698)
Browse files Browse the repository at this point in the history
Co-authored-by: flash1293 <flash1293@users.noreply.github.com>
  • Loading branch information
Joe Reuter and flash1293 authored Nov 30, 2023
1 parent 8f7abc2 commit 28e8692
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

import dpath.util
from airbyte_cdk.utils.oneof_option_config import OneOfOptionConfig
from airbyte_cdk.utils.spec_schema_transformations import resolve_refs
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -217,3 +219,57 @@ class Config(OneOfOptionConfig):
title = "Cohere"
description = "Use the Cohere API to embed text."
discriminator = "mode"


class VectorDBConfigModel(BaseModel):
"""
The configuration model for the Vector DB based destinations. This model is used to generate the UI for the destination configuration,
as well as to provide type safety for the configuration passed to the destination.
The configuration model is composed of four parts:
* Processing configuration
* Embedding configuration
* Indexing configuration
* Advanced configuration
Processing, embedding and advanced configuration are provided by this base class, while the indexing configuration is provided by the destination connector in the sub class.
"""

embedding: Union[
OpenAIEmbeddingConfigModel,
CohereEmbeddingConfigModel,
FakeEmbeddingConfigModel,
AzureOpenAIEmbeddingConfigModel,
OpenAICompatibleEmbeddingConfigModel,
] = Field(..., title="Embedding", description="Embedding configuration", discriminator="mode", group="embedding", type="object")
processing: ProcessingConfigModel
omit_raw_text: bool = Field(
default=False,
title="Do not store raw text",
group="advanced",
description="Do not store the text that gets embedded along with the vector and the metadata in the destination. If set to true, only the vector and the metadata will be stored - in this case raw text for LLM use cases needs to be retrieved from another source.",
)

class Config:
title = "Destination Config"
schema_extra = {
"groups": [
{"id": "processing", "title": "Processing"},
{"id": "embedding", "title": "Embedding"},
{"id": "indexing", "title": "Indexing"},
{"id": "advanced", "title": "Advanced"},
]
}

@staticmethod
def remove_discriminator(schema: Dict[str, Any]) -> None:
"""pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references"""
dpath.util.delete(schema, "properties/**/discriminator")

@classmethod
def schema(cls, by_alias: bool = True, ref_template: str = "") -> Dict[str, Any]:
"""we're overriding the schema classmethod to enable some post-processing"""
schema: Dict[str, Any] = super().schema()
schema = resolve_refs(schema)
cls.remove_discriminator(schema)
return schema
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@dataclass
class Chunk:
page_content: str
page_content: Optional[str]
metadata: Dict[str, Any]
record: AirbyteRecordMessage
embedding: Optional[List[float]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def index(self, document_chunks: List[Chunk], namespace: str, stream: str) -> No
"""
Index a list of document chunks.
This method should be used to index the documents in the destination.
This method should be used to index the documents in the destination. If page_content is None, the document should be indexed without the raw text.
All chunks belong to the stream and namespace specified in the parameters.
"""
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ class Writer:
The destination connector is responsible to create a writer instance and pass the input messages iterable to the write method.
The batch size can be configured by the destination connector to give the freedom of either letting the user configure it or hardcoding it to a sensible value depending on the destination.
The omit_raw_text parameter can be used to omit the raw text from the documents. This can be useful if the raw text is very large and not needed for the destination.
"""

def __init__(self, processing_config: ProcessingConfigModel, indexer: Indexer, embedder: Embedder, batch_size: int) -> None:
def __init__(
self, processing_config: ProcessingConfigModel, indexer: Indexer, embedder: Embedder, batch_size: int, omit_raw_text: bool
) -> None:
self.processing_config = processing_config
self.indexer = indexer
self.embedder = embedder
self.batch_size = batch_size
self.omit_raw_text = omit_raw_text
self._init_batch()

def _init_batch(self) -> None:
Expand All @@ -45,6 +49,8 @@ def _process_batch(self) -> None:
embeddings = self.embedder.embed_chunks(documents)
for i, document in enumerate(documents):
document.embedding = embeddings[i]
if self.omit_raw_text:
document.page_content = None
self.indexer.index(documents, namespace, stream)

self._init_batch()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional
from unittest.mock import ANY, MagicMock, call

import pytest
from airbyte_cdk.destinations.vector_db_based import ProcessingConfigModel, Writer
from airbyte_cdk.models.airbyte_protocol import (
AirbyteLogMessage,
Expand Down Expand Up @@ -53,7 +54,8 @@ def generate_mock_embedder():
return mock_embedder


def test_write():
@pytest.mark.parametrize("omit_raw_text", [True, False])
def test_write(omit_raw_text: bool):
"""
Basic test for the write method, batcher and document processor.
"""
Expand All @@ -74,7 +76,7 @@ def test_write():
mock_indexer.post_sync.return_value = [post_sync_log_message]

# Create the DestinationLangchain instance
writer = Writer(config_model, mock_indexer, mock_embedder, BATCH_SIZE)
writer = Writer(config_model, mock_indexer, mock_embedder, BATCH_SIZE, omit_raw_text)

output_messages = writer.write(configured_catalog, input_messages)
output_message = next(output_messages)
Expand All @@ -88,6 +90,14 @@ def test_write():
assert mock_indexer.delete.call_count == 2
assert mock_embedder.embed_chunks.call_count == 2

if omit_raw_text:
for call_args in mock_indexer.index.call_args_list:
for chunk in call_args[0][0]:
if omit_raw_text:
assert chunk.page_content is None
else:
assert chunk.page_content is not None

output_message = next(output_messages)
assert output_message == post_sync_log_message

Expand Down Expand Up @@ -138,7 +148,7 @@ def test_write_stream_namespace_split():
mock_indexer.post_sync.return_value = []

# Create the DestinationLangchain instance
writer = Writer(config_model, mock_indexer, mock_embedder, BATCH_SIZE)
writer = Writer(config_model, mock_indexer, mock_embedder, BATCH_SIZE, False)

output_messages = writer.write(configured_catalog, input_messages)
next(output_messages)
Expand Down

0 comments on commit 28e8692

Please sign in to comment.