Skip to content

Commit

Permalink
Snowflake Cortex destination : Bug fixes (#38206)
Browse files Browse the repository at this point in the history
  • Loading branch information
bindipankhudi authored May 15, 2024
1 parent 5ecaef0 commit e19e634
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class PasswordBasedAuthorizationModel(BaseModel):
airbyte_secret=True,
description="Enter the password you want to use to access the database",
examples=["AIRBYTE_PASSWORD"],
order=7,
)

class Config:
Expand All @@ -28,42 +29,42 @@ class SnowflakeCortexIndexingModel(BaseModel):
host: str = Field(
...,
title="Host",
airbyte_secret=True,
order=1,
description="Enter the account name you want to use to access the database. This is usually the identifier before .snowflakecomputing.com",
examples=["AIRBYTE_ACCOUNT"],
)
role: str = Field(
...,
title="Role",
airbyte_secret=True,
order=2,
description="Enter the role that you want to use to access Snowflake",
examples=["AIRBYTE_ROLE", "ACCOUNTADMIN"],
)
warehouse: str = Field(
...,
title="Warehouse",
airbyte_secret=True,
order=3,
description="Enter the name of the warehouse that you want to sync data into",
examples=["AIRBYTE_WAREHOUSE"],
)
database: str = Field(
...,
title="Database",
airbyte_secret=True,
order=4,
description="Enter the name of the database that you want to sync data into",
examples=["AIRBYTE_DATABASE"],
)
default_schema: str = Field(
...,
title="Default Schema",
airbyte_secret=True,
order=5,
description="Enter the name of the default schema",
examples=["AIRBYTE_SCHEMA"],
)
username: str = Field(
...,
title="Username",
airbyte_secret=True,
order=6,
description="Enter the name of the user you want to use to access the database",
examples=["AIRBYTE_USER"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from destination_snowflake_cortex.config import ConfigModel
from destination_snowflake_cortex.indexer import SnowflakeCortexIndexer

BATCH_SIZE = 32
BATCH_SIZE = 150


class DestinationSnowflakeCortex(Destination):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#

import copy
import uuid
from typing import Any, Iterable, Optional

Expand Down Expand Up @@ -85,7 +86,7 @@ def _get_updated_catalog(self) -> ConfiguredAirbyteCatalog:
metadata -> metadata of the record
embedding -> embedding of the document content
"""
updated_catalog = self.catalog
updated_catalog = copy.deepcopy(self.catalog)
# update each stream in the catalog
for stream in updated_catalog.streams:
# TO-DO: Revisit this - Clear existing properties, if anys, since we are not entirely sure what's in the configured catalog.
Expand Down Expand Up @@ -144,7 +145,8 @@ def get_write_strategy(self, stream_name: str) -> WriteStrategy:
for stream in self.catalog.streams:
if stream.stream.name == stream_name:
if stream.destination_sync_mode == DestinationSyncMode.overwrite:
return WriteStrategy.REPLACE
# we will use append here since we will remove the existing records and add new ones.
return WriteStrategy.APPEND
if stream.destination_sync_mode == DestinationSyncMode.append:
return WriteStrategy.APPEND
if stream.destination_sync_mode == DestinationSyncMode.append_dedup:
Expand All @@ -170,10 +172,22 @@ def index(self, document_chunks: Iterable[Any], namespace: str, stream: str):
cortex_processor.process_airbyte_messages(airbyte_messages, self.get_write_strategy(stream))

def delete(self, delete_ids: list[str], namespace: str, stream: str):
# delete is generally used when we use full refresh/overwrite strategy.
# PyAirbyte's sync will take care of overwriting the records. Hence, we don't need to do anything here.
# this delete is specific to vector stores, hence not implemented here
pass

def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None:
"""
Run before the sync starts. This method makes sure that all records in the destination that belong to streams with a destination mode of overwrite are deleted.
"""
table_list = self.default_processor._get_tables_list()
for stream in catalog.streams:
# remove all records for streams with overwrite mode
if stream.destination_sync_mode == DestinationSyncMode.overwrite:
stream_name = stream.stream.name
if stream_name.lower() in [table.lower() for table in table_list]:
self.default_processor._execute_sql(f"DELETE FROM {stream_name}")
pass

def check(self) -> Optional[str]:
self.default_processor._get_tables_list()
# TODO: check to see if vector type is available in snowflake instance
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ def test_write(self):
assert(len(result) == 1)
result[0] == "str_col: Cats are nice"


def test_overwrite_mode_deletes_records(self):
self._delete_table("mystream")
catalog = self._get_configured_catalog(DestinationSyncMode.overwrite)
first_state_message = self._state({"state": "1"})
first_record_chunk = [self._record("mystream", f"Dogs are number {i}", i) for i in range(4)]

# initial sync with replace
destination = DestinationSnowflakeCortex()
list(destination.write(self.config, catalog, [*first_record_chunk, first_state_message]))
assert(self._get_record_count("mystream") == 4)

# following should replace existing records
append_catalog = self._get_configured_catalog(DestinationSyncMode.overwrite)
list(destination.write(self.config, append_catalog, [self._record("mystream", "Cats are nice", 6), first_state_message]))
assert(self._get_record_count("mystream") == 1)

"""
Following tests are not code specific, but are useful to confirm that the Cortex functions are available and behaving as expcected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,42 +319,42 @@
"host": {
"title": "Host",
"description": "Enter the account name you want to use to access the database. This is usually the identifier before .snowflakecomputing.com",
"airbyte_secret": true,
"order": 1,
"examples": ["AIRBYTE_ACCOUNT"],
"type": "string"
},
"role": {
"title": "Role",
"description": "Enter the role that you want to use to access Snowflake",
"airbyte_secret": true,
"order": 2,
"examples": ["AIRBYTE_ROLE", "ACCOUNTADMIN"],
"type": "string"
},
"warehouse": {
"title": "Warehouse",
"description": "Enter the name of the warehouse that you want to sync data into",
"airbyte_secret": true,
"order": 3,
"examples": ["AIRBYTE_WAREHOUSE"],
"type": "string"
},
"database": {
"title": "Database",
"description": "Enter the name of the database that you want to sync data into",
"airbyte_secret": true,
"order": 4,
"examples": ["AIRBYTE_DATABASE"],
"type": "string"
},
"default_schema": {
"title": "Default Schema",
"description": "Enter the name of the default schema",
"airbyte_secret": true,
"order": 5,
"examples": ["AIRBYTE_SCHEMA"],
"type": "string"
},
"username": {
"title": "Username",
"description": "Enter the name of the user you want to use to access the database",
"airbyte_secret": true,
"order": 6,
"examples": ["AIRBYTE_USER"],
"type": "string"
},
Expand All @@ -367,6 +367,7 @@
"description": "Enter the password you want to use to access the database",
"airbyte_secret": true,
"examples": ["AIRBYTE_PASSWORD"],
"order": 7,
"type": "string"
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ data:
connectorSubtype: vectorstore
connectorType: destination
definitionId: d9e5418d-f0f4-4d19-a8b1-5630543638e2
dockerImageTag: 0.1.0
dockerImageTag: 0.1.1
dockerRepository: airbyte/destination-snowflake-cortex
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake-cortex
githubIssueLabel: destination-snowflake-cortex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "airbyte-destination-snowflake-cortex"
version = "0.1.0"
version = "0.1.1"
description = "Airbyte destination implementation for Snowflake cortex."
authors = ["Airbyte <contact@airbyte.io>"]
license = "MIT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ def test_write(self, MockedEmbedder, MockedSnowflakeCortexIndexer, MockedWriter)
destination = DestinationSnowflakeCortex()
list(destination.write(self.config, configured_catalog, input_messages))

MockedWriter.assert_called_once_with(self.config_model.processing, mock_indexer, mock_embedder, batch_size=32, omit_raw_text=False)
MockedWriter.assert_called_once_with(self.config_model.processing, mock_indexer, mock_embedder, batch_size=150, omit_raw_text=False)
mock_writer.write.assert_called_once_with(configured_catalog, input_messages)
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_create_state_message():
def test_get_write_strategy():
indexer = _create_snowflake_cortex_indexer(generate_catalog())
assert(indexer.get_write_strategy('example_stream') == WriteStrategy.MERGE)
assert(indexer.get_write_strategy('example_stream2') == WriteStrategy.REPLACE)
assert(indexer.get_write_strategy('example_stream2') == WriteStrategy.APPEND)
assert(indexer.get_write_strategy('example_stream3') == WriteStrategy.APPEND)

def test_get_document_id():
Expand Down Expand Up @@ -184,6 +184,28 @@ def test_check():
assert result == None


def test_pre_sync_table_does_exist():
indexer = _create_snowflake_cortex_indexer(generate_catalog())
mock_processor = MagicMock()
indexer.default_processor = mock_processor

mock_processor._get_tables_list.return_value = ["table1", "table2"]
mock_processor._execute_query.return_value = None
indexer.pre_sync(generate_catalog())
mock_processor._get_tables_list.assert_called_once()
mock_processor._execute_sql.assert_not_called()

def test_pre_sync_table_exists():
indexer = _create_snowflake_cortex_indexer(generate_catalog())
mock_processor = MagicMock()
indexer.default_processor = mock_processor

mock_processor._get_tables_list.return_value = ["example_stream2", "table2"]
mock_processor._execute_query.return_value = None
indexer.pre_sync(generate_catalog())
mock_processor._get_tables_list.assert_called_once()
mock_processor._execute_sql.assert_called_once()

def generate_catalog():
return ConfiguredAirbyteCatalog.parse_obj(
{
Expand Down
3 changes: 2 additions & 1 deletion docs/integrations/destinations/snowflake-cortex.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Overview

This page guides you through the process of setting up the [Snowflake](https://pinecone.io/) as a vector destination.
This page guides you through the process of setting up the [Snowflake](https://www.snowflake.com/en/) as a vector destination.

There are three parts to this:
* Processing - split up individual records in chunks so they will fit the context window and decide which fields to use as context and which are supplementary metadata.
Expand Down Expand Up @@ -81,4 +81,5 @@ To get started, sign up for [Snowflake](https://www.snowflake.com/en/). Ensure y

| Version | Date | Pull Request | Subject |
|:--------| :--------- |:--------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.1.1 | 2024-05-15 | [#38206](https://github.com/airbytehq/airbyte/pull/38206) | Bug fixes.
| 0.1.0 | 2024-05-13 | [#37333](https://github.com/airbytehq/airbyte/pull/36807) | Add support for Snowflake as a Vector destination.

0 comments on commit e19e634

Please sign in to comment.