Skip to content

Commit

Permalink
Milvus: Auto-create collection if it doesn't exist and improve workin…
Browse files Browse the repository at this point in the history
…g with existing collections (#32262)

Co-authored-by: flash1293 <flash1293@users.noreply.github.com>
  • Loading branch information
Joe Reuter and flash1293 authored Nov 9, 2023
1 parent a34e337 commit e7a1972
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from airbyte_cdk.models import ConfiguredAirbyteCatalog
from airbyte_cdk.models.airbyte_protocol import DestinationSyncMode
from destination_milvus.config import MilvusIndexingConfigModel
from pymilvus import Collection, DataType, connections
from pymilvus.exceptions import DescribeCollectionException
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility

CLOUD_DEPLOYMENT_MODE = "cloud"

Expand All @@ -35,7 +34,7 @@ def _connect(self):
token=self.config.auth.token if self.config.auth.mode == "token" else "",
)

def _create_client(self):
def _connect_with_timeout(self):
# Run connect in a separate process as it will hang if the token is invalid.
proc = Process(target=self._connect)
proc.start()
Expand All @@ -46,12 +45,31 @@ def _create_client(self):
proc.join()
raise Exception("Connection timed out, check your host and credentials")

def _create_index(self, collection: Collection):
"""
Create an index on the vector field when auto-creating the collection.
This uses an IVF_FLAT index with 1024 clusters. This is a good default for most use cases. If more control is needed, the index can be created manually (this is also stated in the documentation)
"""
collection.create_index(
field_name=self.config.vector_field, index_params={"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}}
)

def _create_client(self):
self._connect_with_timeout()
# If the process exited within 5 seconds, it's safe to connect on the main process to execute the command
self._connect()

if not utility.has_collection(self.config.collection):
pk = FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True)
vector = FieldSchema(name=self.config.vector_field, dtype=DataType.FLOAT_VECTOR, dim=self.embedder_dimensions)
schema = CollectionSchema(fields=[pk, vector], enable_dynamic_field=True)
collection = Collection(name=self.config.collection, schema=schema)
self._create_index(collection)

self._collection = Collection(self.config.collection)
self._collection.load()
self._primary_key = next((field["name"] for field in self._collection.describe()["fields"] if field["is_primary"]), None)
self._primary_key = self._collection.primary_field.name

def check(self) -> Optional[str]:
deployment_mode = os.environ.get("DEPLOYMENT_MODE", "")
Expand All @@ -70,8 +88,6 @@ def check(self) -> Optional[str]:
return f"Vector field {self.config.vector_field} is not a vector"
if vector_field["params"]["dim"] != self.embedder_dimensions:
return f"Vector field {self.config.vector_field} is not a {self.embedder_dimensions}-dimensional vector"
except DescribeCollectionException:
return f"Collection {self.config.collection} does not exist"
except Exception as e:
return format_exception(e)
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from destination_milvus.destination import DestinationMilvus
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Milvus
from pymilvus import Collection, connections
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility


class MilvusIntegrationTest(BaseIntegrationTest):
Expand All @@ -28,26 +28,37 @@ class MilvusIntegrationTest(BaseIntegrationTest):

def _init_milvus(self):
connections.connect(alias="test_driver", uri=self.config["indexing"]["host"], token=self.config["indexing"]["auth"]["token"])
self._collection = Collection(self.config["indexing"]["collection"], using="test_driver")

def _clean_index(self):
self._init_milvus()
entities = self._collection.query(
expr="pk != 0",
)
if len(entities) > 0:
id_list_expr = ", ".join([str(entity["pk"]) for entity in entities])
self._collection.delete(expr=f"pk in [{id_list_expr}]")
if utility.has_collection(self.config["indexing"]["collection"], using="test_driver"):
utility.drop_collection(self.config["indexing"]["collection"], using="test_driver")

def setUp(self):
with open("secrets/config.json", "r") as f:
self.config = json.loads(f.read())
self._clean_index()
self._init_milvus()

def test_check_valid_config(self):
outcome = DestinationMilvus().check(logging.getLogger("airbyte"), self.config)
assert outcome.status == Status.SUCCEEDED

def _create_collection(self, vector_dimensions=1536):
pk = FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True)
vector = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_dimensions)
schema = CollectionSchema(fields=[pk, vector], enable_dynamic_field=True)
collection = Collection(name=self.config["indexing"]["collection"], schema=schema, using="test_driver")
collection.create_index(
field_name="vector", index_params={"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}}
)

def test_check_valid_config_pre_created_collection(self):
self._create_collection()
outcome = DestinationMilvus().check(logging.getLogger("airbyte"), self.config)
assert outcome.status == Status.SUCCEEDED

def test_check_invalid_config_vector_dimension(self):
self._create_collection(vector_dimensions=666)
outcome = DestinationMilvus().check(logging.getLogger("airbyte"), self.config)
assert outcome.status == Status.FAILED

def test_check_invalid_config(self):
outcome = DestinationMilvus().check(
logging.getLogger("airbyte"),
Expand Down Expand Up @@ -77,14 +88,15 @@ def test_write(self):
# initial sync
destination = DestinationMilvus()
list(destination.write(self.config, catalog, [*first_record_chunk, first_state_message]))
self._collection.flush()
assert len(self._collection.query(expr="pk != 0")) == 5
collection = Collection(self.config["indexing"]["collection"], using="test_driver")
collection.flush()
assert len(collection.query(expr="pk != 0")) == 5

# incrementalally update a doc
incremental_catalog = self._get_configured_catalog(DestinationSyncMode.append_dedup)
list(destination.write(self.config, incremental_catalog, [self._record("mystream", "Cats are nice", 2), first_state_message]))
self._collection.flush()
result = self._collection.search(
collection.flush()
result = collection.search(
anns_field=self.config["indexing"]["vector_field"],
param={},
data=[[0] * OPEN_AI_VECTOR_SIZE],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ data:
connectorSubtype: vectorstore
connectorType: destination
definitionId: 65de8962-48c9-11ee-be56-0242ac120002
dockerImageTag: 0.0.7
dockerImageTag: 0.0.8
dockerRepository: airbyte/destination-milvus
githubIssueLabel: destination-milvus
icon: milvus.svg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

import os
import unittest
from unittest.mock import Mock, call
from unittest.mock import Mock, call, patch

from airbyte_cdk.models.airbyte_protocol import AirbyteStream, DestinationSyncMode, SyncMode
from destination_milvus.config import MilvusIndexingConfigModel, NoAuth, TokenAuth
from destination_milvus.indexer import MilvusIndexer
from pymilvus import DataType
from pymilvus.exceptions import DescribeCollectionException


@patch("destination_milvus.indexer.connections")
@patch("destination_milvus.indexer.utility")
@patch("destination_milvus.indexer.Collection")
class TestMilvusIndexer(unittest.TestCase):
def setUp(self):
self.mock_config = MilvusIndexingConfigModel(
Expand All @@ -28,11 +30,11 @@ def setUp(self):
}
)
self.milvus_indexer = MilvusIndexer(self.mock_config, 128)
self.milvus_indexer._create_client = Mock() # This is mocked out because testing separate processes is hard
self.milvus_indexer._connect_with_timeout = Mock() # Mocking this out to avoid testing multiprocessing
self.milvus_indexer._collection = Mock()

def test_check_returns_expected_result(self):
self.milvus_indexer._collection.describe.return_value = {
def test_check_returns_expected_result(self, mock_Collection, mock_utility, mock_connections):
mock_Collection.return_value.describe.return_value = {
"auto_id": True,
"fields": [{"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}],
}
Expand All @@ -41,10 +43,10 @@ def test_check_returns_expected_result(self):

self.assertIsNone(result)

self.milvus_indexer._collection.describe.assert_called()
mock_Collection.return_value.describe.assert_called()

def test_check_secure_endpoint(self):
self.milvus_indexer._collection.describe.return_value = {
def test_check_secure_endpoint(self, mock_Collection, mock_utility, mock_connections):
mock_Collection.return_value.describe.return_value = {
"auto_id": True,
"fields": [{"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}],
}
Expand Down Expand Up @@ -75,49 +77,55 @@ def test_check_secure_endpoint(self):

self.assertEqual(result, expected_error_message)

def test_check_handles_failure_conditions(self):
# Test 1: Collection does not exist
self.milvus_indexer._collection.describe.side_effect = DescribeCollectionException("Some error")

result = self.milvus_indexer.check()
self.assertEqual(result, f"Collection {self.mock_config.collection} does not exist")

# Test 2: General exception in describe
self.milvus_indexer._collection.describe.side_effect = Exception("Random exception")
def test_check_handles_failure_conditions(self, mock_Collection, mock_utility, mock_connections):
# Test 1: General exception in describe
mock_Collection.return_value.describe.side_effect = Exception("Random exception")
result = self.milvus_indexer.check()
self.assertTrue("Random exception" in result) # Assuming format_exception includes the exception message

# Test 3: auto_id is not True
self.milvus_indexer._collection.describe.return_value = {"auto_id": False}
self.milvus_indexer._collection.describe.side_effect = None
# Test 2: auto_id is not True
mock_Collection.return_value.describe.return_value = {"auto_id": False}
mock_Collection.return_value.describe.side_effect = None
result = self.milvus_indexer.check()
self.assertEqual(result, "Only collections with auto_id are supported")

# Test 4: Vector field not found
self.milvus_indexer._collection.describe.return_value = {"auto_id": True, "fields": [{"name": "wrong_vector_field"}]}
# Test 3: Vector field not found
mock_Collection.return_value.describe.return_value = {"auto_id": True, "fields": [{"name": "wrong_vector_field"}]}
result = self.milvus_indexer.check()
self.assertEqual(result, f"Vector field {self.mock_config.vector_field} not found")

# Test 5: Vector field is not a vector
self.milvus_indexer._collection.describe.return_value = {
# Test 4: Vector field is not a vector
mock_Collection.return_value.describe.return_value = {
"auto_id": True,
"fields": [{"name": self.mock_config.vector_field, "type": DataType.INT32}],
}
result = self.milvus_indexer.check()
self.assertEqual(result, f"Vector field {self.mock_config.vector_field} is not a vector")

# Test 6: Vector field dimension mismatch
self.milvus_indexer._collection.describe.return_value = {
# Test 5: Vector field dimension mismatch
mock_Collection.return_value.describe.return_value = {
"auto_id": True,
"fields": [{"name": self.mock_config.vector_field, "type": DataType.FLOAT_VECTOR, "params": {"dim": 64}}],
}
result = self.milvus_indexer.check()
self.assertEqual(result, f"Vector field {self.mock_config.vector_field} is not a 128-dimensional vector")

def test_pre_sync_calls_delete(self):
def test_pre_sync_creates_collection(self, mock_Collection, mock_utility, mock_connections):
self.milvus_indexer.config.collection = "ad_hoc"
self.milvus_indexer.config.vector_field = "my_vector_field"
mock_utility.has_collection.return_value = False
self.milvus_indexer.pre_sync(
Mock(streams=[Mock(destination_sync_mode=DestinationSyncMode.append, stream=Mock(name="some_stream"))])
)
mock_Collection.assert_has_calls([call("ad_hoc")])
mock_Collection.return_value.create_index.assert_has_calls(
[call(field_name="my_vector_field", index_params={"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}})]
)

def test_pre_sync_calls_delete(self, mock_Collection, mock_utility, mock_connections):
mock_iterator = Mock()
mock_iterator.next.side_effect = [[{"id": 1}], []]
self.milvus_indexer._collection.query_iterator.return_value = mock_iterator
mock_Collection.return_value.query_iterator.return_value = mock_iterator

self.milvus_indexer.pre_sync(
Mock(
Expand All @@ -130,25 +138,25 @@ def test_pre_sync_calls_delete(self):
)
)

self.milvus_indexer._collection.query_iterator.assert_called_with(expr='_ab_stream == "some_stream"')
self.milvus_indexer._collection.delete.assert_called_with(expr="id in [1]")
mock_Collection.return_value.query_iterator.assert_called_with(expr='_ab_stream == "some_stream"')
mock_Collection.return_value.delete.assert_called_with(expr="id in [1]")

def test_pre_sync_does_not_call_delete(self):
def test_pre_sync_does_not_call_delete(self, mock_Collection, mock_utility, mock_connections):
self.milvus_indexer.pre_sync(
Mock(streams=[Mock(destination_sync_mode=DestinationSyncMode.append, stream=Mock(name="some_stream"))])
)

self.milvus_indexer._collection.delete.assert_not_called()
mock_Collection.return_value.delete.assert_not_called()

def test_index_calls_insert(self):
def test_index_calls_insert(self, mock_Collection, mock_utility, mock_connections):
self.milvus_indexer._primary_key = "id"
self.milvus_indexer.index(
[Mock(metadata={"key": "value", "id": 5}, page_content="some content", embedding=[1, 2, 3])], None, "some_stream"
)

self.milvus_indexer._collection.insert.assert_called_with([{"key": "value", "vector": [1, 2, 3], "text": "some content", "_id": 5}])

def test_index_calls_delete(self):
def test_index_calls_delete(self, mock_Collection, mock_utility, mock_connections):
mock_iterator = Mock()
mock_iterator.next.side_effect = [[{"id": "123"}, {"id": "456"}], [{"id": "789"}], []]
self.milvus_indexer._collection.query_iterator.return_value = mock_iterator
Expand Down
7 changes: 5 additions & 2 deletions docs/integrations/destinations/milvus.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ For testing purposes, it's also possible to use the [Fake embeddings](https://py

### Indexing

To get started, create a new collection in your Milvus instance. Make sure that
If the specified collection doesn't exist, the connector will create it for you with a primary key field `pk` and the configured vector field matching the embedding configuration. Dynamic fields will be enabled. The vector field will have an L2 IVF_FLAT index with an `nlist` parameter of 1024.

If you want to change any of these settings, create a new collection in your Milvus instance yourself. Make sure that
* The primary key field is set to [auto_id](https://milvus.io/docs/create_collection.md)
* There is a vector field with the correct dimensionality (1536 for OpenAI, 1024 for Cohere) and [a configured index](https://milvus.io/docs/build_index.md)

Expand All @@ -81,7 +83,7 @@ When using a self-hosted Milvus cluster, the collection needs to be created usin
```python
from pymilvus import CollectionSchema, FieldSchema, DataType, connections, Collection

connections.connection() # connect to locally running Milvus instance without authentication
connections.connect() # connect to locally running Milvus instance without authentication

pk = FieldSchema(name="pk",dtype=DataType.INT64, is_primary=True, auto_id=True)
vector = FieldSchema(name="vector",dtype=DataType.FLOAT_VECTOR,dim=1536)
Expand All @@ -107,6 +109,7 @@ vector_store.similarity_search("test")

| Version | Date | Pull Request | Subject |
|:--------| :--------- |:--------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.0.8 | 2023-11-08 | [#31563](https://github.com/airbytehq/airbyte/pull/32262) | Auto-create collection if it doesn't exist |
| 0.0.7 | 2023-10-23 | [#31563](https://github.com/airbytehq/airbyte/pull/31563) | Add field mapping option |
| 0.0.6 | 2023-10-19 | [31599](https://github.com/airbytehq/airbyte/pull/31599) | Base image migration: remove Dockerfile and use the python-connector-base image |
| 0.0.5 | 2023-10-15 | [#31329](https://github.com/airbytehq/airbyte/pull/31329) | Add OpenAI-compatible embedder option |
Expand Down

0 comments on commit e7a1972

Please sign in to comment.