diff --git a/airbyte-integrations/connectors/destination-milvus/destination_milvus/indexer.py b/airbyte-integrations/connectors/destination-milvus/destination_milvus/indexer.py index 8d059c12a7c1..34c624d79054 100644 --- a/airbyte-integrations/connectors/destination-milvus/destination_milvus/indexer.py +++ b/airbyte-integrations/connectors/destination-milvus/destination_milvus/indexer.py @@ -13,8 +13,7 @@ from airbyte_cdk.models import ConfiguredAirbyteCatalog from airbyte_cdk.models.airbyte_protocol import DestinationSyncMode from destination_milvus.config import MilvusIndexingConfigModel -from pymilvus import Collection, DataType, connections -from pymilvus.exceptions import DescribeCollectionException +from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility CLOUD_DEPLOYMENT_MODE = "cloud" @@ -35,7 +34,7 @@ def _connect(self): token=self.config.auth.token if self.config.auth.mode == "token" else "", ) - def _create_client(self): + def _connect_with_timeout(self): # Run connect in a separate process as it will hang if the token is invalid. proc = Process(target=self._connect) proc.start() @@ -46,12 +45,31 @@ def _create_client(self): proc.join() raise Exception("Connection timed out, check your host and credentials") + def _create_index(self, collection: Collection): + """ + Create an index on the vector field when auto-creating the collection. + + This uses an IVF_FLAT index with 1024 clusters. This is a good default for most use cases. If more control is needed, the index can be created manually (this is also stated in the documentation) + """ + collection.create_index( + field_name=self.config.vector_field, index_params={"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}} + ) + + def _create_client(self): + self._connect_with_timeout() # If the process exited within 5 seconds, it's safe to connect on the main process to execute the command self._connect() + if not utility.has_collection(self.config.collection): + pk = FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True) + vector = FieldSchema(name=self.config.vector_field, dtype=DataType.FLOAT_VECTOR, dim=self.embedder_dimensions) + schema = CollectionSchema(fields=[pk, vector], enable_dynamic_field=True) + collection = Collection(name=self.config.collection, schema=schema) + self._create_index(collection) + self._collection = Collection(self.config.collection) self._collection.load() - self._primary_key = next((field["name"] for field in self._collection.describe()["fields"] if field["is_primary"]), None) + self._primary_key = self._collection.primary_field.name def check(self) -> Optional[str]: deployment_mode = os.environ.get("DEPLOYMENT_MODE", "") @@ -70,8 +88,6 @@ def check(self) -> Optional[str]: return f"Vector field {self.config.vector_field} is not a vector" if vector_field["params"]["dim"] != self.embedder_dimensions: return f"Vector field {self.config.vector_field} is not a {self.embedder_dimensions}-dimensional vector" - except DescribeCollectionException: - return f"Collection {self.config.collection} does not exist" except Exception as e: return format_exception(e) return None diff --git a/airbyte-integrations/connectors/destination-milvus/integration_tests/milvus_integration_test.py b/airbyte-integrations/connectors/destination-milvus/integration_tests/milvus_integration_test.py index ee6e7646670e..731ba7edbe76 100644 --- a/airbyte-integrations/connectors/destination-milvus/integration_tests/milvus_integration_test.py +++ b/airbyte-integrations/connectors/destination-milvus/integration_tests/milvus_integration_test.py @@ -11,7 +11,7 @@ from destination_milvus.destination import DestinationMilvus from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores import Milvus -from pymilvus import Collection, connections +from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility class MilvusIntegrationTest(BaseIntegrationTest): @@ -28,26 +28,37 @@ class MilvusIntegrationTest(BaseIntegrationTest): def _init_milvus(self): connections.connect(alias="test_driver", uri=self.config["indexing"]["host"], token=self.config["indexing"]["auth"]["token"]) - self._collection = Collection(self.config["indexing"]["collection"], using="test_driver") - - def _clean_index(self): - self._init_milvus() - entities = self._collection.query( - expr="pk != 0", - ) - if len(entities) > 0: - id_list_expr = ", ".join([str(entity["pk"]) for entity in entities]) - self._collection.delete(expr=f"pk in [{id_list_expr}]") + if utility.has_collection(self.config["indexing"]["collection"], using="test_driver"): + utility.drop_collection(self.config["indexing"]["collection"], using="test_driver") def setUp(self): with open("secrets/config.json", "r") as f: self.config = json.loads(f.read()) - self._clean_index() + self._init_milvus() def test_check_valid_config(self): outcome = DestinationMilvus().check(logging.getLogger("airbyte"), self.config) assert outcome.status == Status.SUCCEEDED + def _create_collection(self, vector_dimensions=1536): + pk = FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True) + vector = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_dimensions) + schema = CollectionSchema(fields=[pk, vector], enable_dynamic_field=True) + collection = Collection(name=self.config["indexing"]["collection"], schema=schema, using="test_driver") + collection.create_index( + field_name="vector", index_params={"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}} + ) + + def test_check_valid_config_pre_created_collection(self): + self._create_collection() + outcome = DestinationMilvus().check(logging.getLogger("airbyte"), self.config) + assert outcome.status == Status.SUCCEEDED + + def test_check_invalid_config_vector_dimension(self): + self._create_collection(vector_dimensions=666) + outcome = DestinationMilvus().check(logging.getLogger("airbyte"), self.config) + assert outcome.status == Status.FAILED + def test_check_invalid_config(self): outcome = DestinationMilvus().check( logging.getLogger("airbyte"), @@ -77,14 +88,15 @@ def test_write(self): # initial sync destination = DestinationMilvus() list(destination.write(self.config, catalog, [*first_record_chunk, first_state_message])) - self._collection.flush() - assert len(self._collection.query(expr="pk != 0")) == 5 + collection = Collection(self.config["indexing"]["collection"], using="test_driver") + collection.flush() + assert len(collection.query(expr="pk != 0")) == 5 # incrementalally update a doc incremental_catalog = self._get_configured_catalog(DestinationSyncMode.append_dedup) list(destination.write(self.config, incremental_catalog, [self._record("mystream", "Cats are nice", 2), first_state_message])) - self._collection.flush() - result = self._collection.search( + collection.flush() + result = collection.search( anns_field=self.config["indexing"]["vector_field"], param={}, data=[[0] * OPEN_AI_VECTOR_SIZE], diff --git a/airbyte-integrations/connectors/destination-milvus/metadata.yaml b/airbyte-integrations/connectors/destination-milvus/metadata.yaml index f411cfa5a79b..9424f86b7640 100644 --- a/airbyte-integrations/connectors/destination-milvus/metadata.yaml +++ b/airbyte-integrations/connectors/destination-milvus/metadata.yaml @@ -22,7 +22,7 @@ data: connectorSubtype: vectorstore connectorType: destination definitionId: 65de8962-48c9-11ee-be56-0242ac120002 - dockerImageTag: 0.0.7 + dockerImageTag: 0.0.8 dockerRepository: airbyte/destination-milvus githubIssueLabel: destination-milvus icon: milvus.svg diff --git a/airbyte-integrations/connectors/destination-milvus/unit_tests/indexer_test.py b/airbyte-integrations/connectors/destination-milvus/unit_tests/indexer_test.py index 009833703374..f88d064862c3 100644 --- a/airbyte-integrations/connectors/destination-milvus/unit_tests/indexer_test.py +++ b/airbyte-integrations/connectors/destination-milvus/unit_tests/indexer_test.py @@ -4,15 +4,17 @@ import os import unittest -from unittest.mock import Mock, call +from unittest.mock import Mock, call, patch from airbyte_cdk.models.airbyte_protocol import AirbyteStream, DestinationSyncMode, SyncMode from destination_milvus.config import MilvusIndexingConfigModel, NoAuth, TokenAuth from destination_milvus.indexer import MilvusIndexer from pymilvus import DataType -from pymilvus.exceptions import DescribeCollectionException +@patch("destination_milvus.indexer.connections") +@patch("destination_milvus.indexer.utility") +@patch("destination_milvus.indexer.Collection") class TestMilvusIndexer(unittest.TestCase): def setUp(self): self.mock_config = MilvusIndexingConfigModel( @@ -28,11 +30,11 @@ def setUp(self): } ) self.milvus_indexer = MilvusIndexer(self.mock_config, 128) - self.milvus_indexer._create_client = Mock() # This is mocked out because testing separate processes is hard + self.milvus_indexer._connect_with_timeout = Mock() # Mocking this out to avoid testing multiprocessing self.milvus_indexer._collection = Mock() - def test_check_returns_expected_result(self): - self.milvus_indexer._collection.describe.return_value = { + def test_check_returns_expected_result(self, mock_Collection, mock_utility, mock_connections): + mock_Collection.return_value.describe.return_value = { "auto_id": True, "fields": [{"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}], } @@ -41,10 +43,10 @@ def test_check_returns_expected_result(self): self.assertIsNone(result) - self.milvus_indexer._collection.describe.assert_called() + mock_Collection.return_value.describe.assert_called() - def test_check_secure_endpoint(self): - self.milvus_indexer._collection.describe.return_value = { + def test_check_secure_endpoint(self, mock_Collection, mock_utility, mock_connections): + mock_Collection.return_value.describe.return_value = { "auto_id": True, "fields": [{"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}], } @@ -75,49 +77,55 @@ def test_check_secure_endpoint(self): self.assertEqual(result, expected_error_message) - def test_check_handles_failure_conditions(self): - # Test 1: Collection does not exist - self.milvus_indexer._collection.describe.side_effect = DescribeCollectionException("Some error") - - result = self.milvus_indexer.check() - self.assertEqual(result, f"Collection {self.mock_config.collection} does not exist") - - # Test 2: General exception in describe - self.milvus_indexer._collection.describe.side_effect = Exception("Random exception") + def test_check_handles_failure_conditions(self, mock_Collection, mock_utility, mock_connections): + # Test 1: General exception in describe + mock_Collection.return_value.describe.side_effect = Exception("Random exception") result = self.milvus_indexer.check() self.assertTrue("Random exception" in result) # Assuming format_exception includes the exception message - # Test 3: auto_id is not True - self.milvus_indexer._collection.describe.return_value = {"auto_id": False} - self.milvus_indexer._collection.describe.side_effect = None + # Test 2: auto_id is not True + mock_Collection.return_value.describe.return_value = {"auto_id": False} + mock_Collection.return_value.describe.side_effect = None result = self.milvus_indexer.check() self.assertEqual(result, "Only collections with auto_id are supported") - # Test 4: Vector field not found - self.milvus_indexer._collection.describe.return_value = {"auto_id": True, "fields": [{"name": "wrong_vector_field"}]} + # Test 3: Vector field not found + mock_Collection.return_value.describe.return_value = {"auto_id": True, "fields": [{"name": "wrong_vector_field"}]} result = self.milvus_indexer.check() self.assertEqual(result, f"Vector field {self.mock_config.vector_field} not found") - # Test 5: Vector field is not a vector - self.milvus_indexer._collection.describe.return_value = { + # Test 4: Vector field is not a vector + mock_Collection.return_value.describe.return_value = { "auto_id": True, "fields": [{"name": self.mock_config.vector_field, "type": DataType.INT32}], } result = self.milvus_indexer.check() self.assertEqual(result, f"Vector field {self.mock_config.vector_field} is not a vector") - # Test 6: Vector field dimension mismatch - self.milvus_indexer._collection.describe.return_value = { + # Test 5: Vector field dimension mismatch + mock_Collection.return_value.describe.return_value = { "auto_id": True, "fields": [{"name": self.mock_config.vector_field, "type": DataType.FLOAT_VECTOR, "params": {"dim": 64}}], } result = self.milvus_indexer.check() self.assertEqual(result, f"Vector field {self.mock_config.vector_field} is not a 128-dimensional vector") - def test_pre_sync_calls_delete(self): + def test_pre_sync_creates_collection(self, mock_Collection, mock_utility, mock_connections): + self.milvus_indexer.config.collection = "ad_hoc" + self.milvus_indexer.config.vector_field = "my_vector_field" + mock_utility.has_collection.return_value = False + self.milvus_indexer.pre_sync( + Mock(streams=[Mock(destination_sync_mode=DestinationSyncMode.append, stream=Mock(name="some_stream"))]) + ) + mock_Collection.assert_has_calls([call("ad_hoc")]) + mock_Collection.return_value.create_index.assert_has_calls( + [call(field_name="my_vector_field", index_params={"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}})] + ) + + def test_pre_sync_calls_delete(self, mock_Collection, mock_utility, mock_connections): mock_iterator = Mock() mock_iterator.next.side_effect = [[{"id": 1}], []] - self.milvus_indexer._collection.query_iterator.return_value = mock_iterator + mock_Collection.return_value.query_iterator.return_value = mock_iterator self.milvus_indexer.pre_sync( Mock( @@ -130,17 +138,17 @@ def test_pre_sync_calls_delete(self): ) ) - self.milvus_indexer._collection.query_iterator.assert_called_with(expr='_ab_stream == "some_stream"') - self.milvus_indexer._collection.delete.assert_called_with(expr="id in [1]") + mock_Collection.return_value.query_iterator.assert_called_with(expr='_ab_stream == "some_stream"') + mock_Collection.return_value.delete.assert_called_with(expr="id in [1]") - def test_pre_sync_does_not_call_delete(self): + def test_pre_sync_does_not_call_delete(self, mock_Collection, mock_utility, mock_connections): self.milvus_indexer.pre_sync( Mock(streams=[Mock(destination_sync_mode=DestinationSyncMode.append, stream=Mock(name="some_stream"))]) ) - self.milvus_indexer._collection.delete.assert_not_called() + mock_Collection.return_value.delete.assert_not_called() - def test_index_calls_insert(self): + def test_index_calls_insert(self, mock_Collection, mock_utility, mock_connections): self.milvus_indexer._primary_key = "id" self.milvus_indexer.index( [Mock(metadata={"key": "value", "id": 5}, page_content="some content", embedding=[1, 2, 3])], None, "some_stream" @@ -148,7 +156,7 @@ def test_index_calls_insert(self): self.milvus_indexer._collection.insert.assert_called_with([{"key": "value", "vector": [1, 2, 3], "text": "some content", "_id": 5}]) - def test_index_calls_delete(self): + def test_index_calls_delete(self, mock_Collection, mock_utility, mock_connections): mock_iterator = Mock() mock_iterator.next.side_effect = [[{"id": "123"}, {"id": "456"}], [{"id": "789"}], []] self.milvus_indexer._collection.query_iterator.return_value = mock_iterator diff --git a/docs/integrations/destinations/milvus.md b/docs/integrations/destinations/milvus.md index 946abab6c321..7615374b2a8e 100644 --- a/docs/integrations/destinations/milvus.md +++ b/docs/integrations/destinations/milvus.md @@ -57,7 +57,9 @@ For testing purposes, it's also possible to use the [Fake embeddings](https://py ### Indexing -To get started, create a new collection in your Milvus instance. Make sure that +If the specified collection doesn't exist, the connector will create it for you with a primary key field `pk` and the configured vector field matching the embedding configuration. Dynamic fields will be enabled. The vector field will have an L2 IVF_FLAT index with an `nlist` parameter of 1024. + +If you want to change any of these settings, create a new collection in your Milvus instance yourself. Make sure that * The primary key field is set to [auto_id](https://milvus.io/docs/create_collection.md) * There is a vector field with the correct dimensionality (1536 for OpenAI, 1024 for Cohere) and [a configured index](https://milvus.io/docs/build_index.md) @@ -81,7 +83,7 @@ When using a self-hosted Milvus cluster, the collection needs to be created usin ```python from pymilvus import CollectionSchema, FieldSchema, DataType, connections, Collection -connections.connection() # connect to locally running Milvus instance without authentication +connections.connect() # connect to locally running Milvus instance without authentication pk = FieldSchema(name="pk",dtype=DataType.INT64, is_primary=True, auto_id=True) vector = FieldSchema(name="vector",dtype=DataType.FLOAT_VECTOR,dim=1536) @@ -107,6 +109,7 @@ vector_store.similarity_search("test") | Version | Date | Pull Request | Subject | |:--------| :--------- |:--------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------| +| 0.0.8 | 2023-11-08 | [#31563](https://github.com/airbytehq/airbyte/pull/32262) | Auto-create collection if it doesn't exist | | 0.0.7 | 2023-10-23 | [#31563](https://github.com/airbytehq/airbyte/pull/31563) | Add field mapping option | | 0.0.6 | 2023-10-19 | [31599](https://github.com/airbytehq/airbyte/pull/31599) | Base image migration: remove Dockerfile and use the python-connector-base image | | 0.0.5 | 2023-10-15 | [#31329](https://github.com/airbytehq/airbyte/pull/31329) | Add OpenAI-compatible embedder option |