diff --git a/CHANGELOG.md b/CHANGELOG.md index 45972066f..32763e469 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add predict_id in Listener - Add serve in Model - Added templates directory with OSS templates +- Qdrant vector search support #### Bug Fixes diff --git a/pyproject.toml b/pyproject.toml index cb8f03587..9700f044c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "python-magic", "apscheduler", "bson", + "qdrant-client>=1.10.0,<2" ] [project.optional-dependencies] diff --git a/superduper/backends/base/backends.py b/superduper/backends/base/backends.py index 7d6838695..9404ce397 100644 --- a/superduper/backends/base/backends.py +++ b/superduper/backends/base/backends.py @@ -1,9 +1,11 @@ from superduper.vector_search.atlas import MongoAtlasVectorSearcher from superduper.vector_search.in_memory import InMemoryVectorSearcher from superduper.vector_search.lance import LanceVectorSearcher +from superduper.vector_search.qdrant import QdrantVectorSearcher vector_searcher_implementations = { "lance": LanceVectorSearcher, "in_memory": InMemoryVectorSearcher, "mongodb+srv": MongoAtlasVectorSearcher, + "qdrant": QdrantVectorSearcher, } diff --git a/superduper/base/config.py b/superduper/base/config.py index 3a6b3c7a1..d0ede0ea1 100644 --- a/superduper/base/config.py +++ b/superduper/base/config.py @@ -18,7 +18,7 @@ def _dataclass_from_dict(data_class: t.Any, data: dict): for f in data: if ( f in field_types - and hasattr(field_types[f], '__dataclass_fields__') + and hasattr(field_types[f], "__dataclass_fields__") and not isinstance(data[f], field_types[f]) ): params[f] = _dataclass_from_dict(field_types[f], data[f]) @@ -39,10 +39,10 @@ def __call__(self, **kwargs): """Update the configuration with the given parameters.""" parameters = self.dict() for k, v in kwargs.items(): - if '__' in k: - parts = k.split('__') + if "__" in k: + parts = k.split("__") parent = parts[0] - child = '__'.join(parts[1:]) + child = "__".join(parts[1:]) parameters[parent] = getattr(self, parent)(**{child: v}) else: parameters[k] = v @@ -89,8 +89,8 @@ class PollingStrategy(CDCStrategy): """ auto_increment_field: t.Optional[str] = None - frequency: str = '30' - type: str = 'incremental' + frequency: str = "30" + type: str = "incremental" @dc.dataclass @@ -102,7 +102,7 @@ class LogBasedStrategy(CDCStrategy): """ resume_token: t.Optional[t.Dict[str, str]] = None - type: str = 'logbased' + type: str = "logbased" @dc.dataclass @@ -129,7 +129,7 @@ class VectorSearch(BaseConfig): """ uri: t.Optional[str] = None # None implies local mode - type: str = 'in_memory' # in_memory|lance + type: str = "in_memory" # in_memory|lance|qdrant backfill_batch_size: int = 100 @@ -187,7 +187,7 @@ class Compute(BaseConfig): uri: t.Optional[str] = None kwargs: t.Dict = dc.field(default_factory=dict) - backend: str = 'local' + backend: str = "local" @dc.dataclass @@ -223,11 +223,11 @@ class Cluster(BaseConfig): class LogLevel(str, Enum): """Enumerate log severity level # noqa.""" - DEBUG = 'DEBUG' - INFO = 'INFO' + DEBUG = "DEBUG" + INFO = "INFO" SUCCESS = "SUCCESS" - WARN = 'WARN' - ERROR = 'ERROR' + WARN = "WARN" + ERROR = "ERROR" class LogType(str, Enum): @@ -245,8 +245,8 @@ class LogType(str, Enum): class BytesEncoding(str, Enum): """Enumerate the encoding of bytes in the data backend # noqa.""" - BYTES = 'Bytes' - BASE64 = 'Str' + BYTES = "Bytes" + BASE64 = "Str" @dc.dataclass @@ -261,7 +261,7 @@ class Downloads(BaseConfig): folder: t.Optional[str] = None n_workers: int = 0 - headers: t.Dict = dc.field(default_factory=lambda: {'User-Agent': 'me'}) + headers: t.Dict = dc.field(default_factory=lambda: {"User-Agent": "me"}) timeout: t.Optional[int] = None @@ -286,13 +286,14 @@ class Config(BaseConfig): If True, the schema will be created if it does not exist. :param log_colorize: Whether to colorize the logs :param output_prefix: The prefix for the output table and output field key + :param vector_search_kwargs: The keyword arguments to pass to the vector search """ envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None - data_backend: str = 'mongodb://mongodb:27017/test_db' + data_backend: str = "mongodb://mongodb:27017/test_db" - lance_home: str = os.path.join('.superduper', 'vector_indices') + lance_home: str = os.path.join(".superduper", "vector_indices") artifact_store: t.Optional[str] = None metadata_store: t.Optional[str] = None @@ -309,7 +310,9 @@ class Config(BaseConfig): bytes_encoding: BytesEncoding = BytesEncoding.BYTES auto_schema: bool = True - output_prefix: str = '_outputs__' + output_prefix: str = "_outputs__" + + vector_search_kwargs: t.Dict = dc.field(default_factory=dict) def __post_init__(self, envs): if envs is not None: @@ -325,7 +328,7 @@ def hybrid_storage(self): def comparables(self): """A dict of `self` excluding some defined attributes.""" _dict = dc.asdict(self) - list(map(_dict.pop, ('cluster', 'retries', 'downloads'))) + list(map(_dict.pop, ("cluster", "retries", "downloads"))) return _dict def match(self, cfg: t.Dict): @@ -350,7 +353,7 @@ def to_yaml(self): import yaml def enum_representer(dumper, data): - return dumper.represent_scalar('tag:yaml.org,2002:str', str(data.value)) + return dumper.represent_scalar("tag:yaml.org,2002:str", str(data.value)) yaml.SafeDumper.add_representer(BytesEncoding, enum_representer) yaml.SafeDumper.add_representer(LogLevel, enum_representer) @@ -370,7 +373,7 @@ def _diff(r1, r2): d = _diff_impl(r1, r2) out = {} for path, left, right in d: - out['.'.join(path)] = (left, right) + out[".".join(path)] = (left, right) return out diff --git a/superduper/vector_search/qdrant.py b/superduper/vector_search/qdrant.py new file mode 100644 index 000000000..6ebe67553 --- /dev/null +++ b/superduper/vector_search/qdrant.py @@ -0,0 +1,182 @@ +import typing as t +import uuid +from copy import deepcopy + +import numpy as np +from qdrant_client import QdrantClient, models + +from superduper import CFG +from superduper.vector_search.base import ( + BaseVectorSearcher, + VectorIndexMeasureType, + VectorItem, +) + +ID_PAYLOAD_KEY = "_id" + + +class QdrantVectorSearcher(BaseVectorSearcher): + """ + Implementation of a vector index using [Qdrant](https://qdrant.tech/). + + :param identifier: Unique string identifier of index + :param dimensions: Dimension of the vector embeddings + :param h: Seed vectors ``numpy.ndarray`` + :param index: list of IDs + :param measure: measure to assess similarity + """ + + def __init__( + self, + identifier: str, + dimensions: int, + h: t.Optional[np.ndarray] = None, + index: t.Optional[t.List[str]] = None, + measure: t.Optional[str] = None, + ): + super().__init__(identifier, dimensions, h, index, measure) + config_dict = deepcopy(CFG.vector_search_kwargs) + self.vector_name: t.Optional[str] = config_dict.pop("vector_name", None) + # Use an in-memory instance by default + # https://github.com/qdrant/qdrant-client#local-mode + config_dict = config_dict or {"location": ":memory:"} + self.client = QdrantClient(**config_dict) + + self.collection_name = identifier + if not self.client.collection_exists(self.collection_name): + measure = ( + measure.name if isinstance(measure, VectorIndexMeasureType) else measure + ) + distance = self._distance_mapping(measure) + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=models.VectorParams(size=dimensions, distance=distance), + ) + + self.initialize(identifier) + + if h is not None and index is not None: + self.add( + [ + VectorItem( + id=_id, + vector=vector, + ) + for _id, vector in zip(index, h) + ] + ) + + def __len__(self): + return self.client.get_collection(self.collection_name).vectors_count + + def add(self, items: t.Sequence[VectorItem], cache: bool = False) -> None: + """Add vectors to the index. + + :param items: List of vectors to add + :param cache: Cache vectors (not used in Qdrant implementation). + """ + points = [ + models.PointStruct( + id=self._convert_id(item.id), + vector={self.vector_name: item.vector.tolist()} + if self.vector_name + else item.vector.tolist(), + payload={ID_PAYLOAD_KEY: item.id}, + ) + for item in items + ] + self.client.upsert(collection_name=self.collection_name, points=points) + + def delete(self, ids: t.Sequence[str]) -> None: + """Delete vectors from the index. + + :param ids: List of IDs to delete + """ + self.client.delete( + collection_name=self.collection_name, + points_selector=models.Filter( + must=[ + models.FieldCondition( + key=ID_PAYLOAD_KEY, match=models.MatchAny(any=list(ids)) + ) + ] + ), + ) + + def find_nearest_from_id( + self, + _id, + n: int = 100, + within_ids: t.Sequence[str] = (), + ) -> t.Tuple[t.List[str], t.List[float]]: + """Find the nearest vectors to a given ID. + + :param _id: ID to search + :param n: Number of results to return + :param within_ids: List of IDs to search within + """ + return self._query_nearest(_id, n, within_ids) + + def find_nearest_from_array( + self, + h: np.typing.ArrayLike, + n: int = 100, + within_ids: t.Sequence[str] = (), + ) -> t.Tuple[t.List[str], t.List[float]]: + """Find the nearest vectors to a given vector. + + :param h: Vector to search + :param n: Number of results to return + :param within_ids: List of IDs to search within + """ + return self._query_nearest(h, n, within_ids) + + def _query_nearest( + self, + query: t.Union[np.typing.ArrayLike, str], + n: int = 100, + within_ids: t.Sequence[str] = (), + ) -> t.Tuple[t.List[str], t.List[float]]: + query_filter = None + if within_ids: + query_filter = models.Filter( + must=[ + models.FieldCondition( + key=ID_PAYLOAD_KEY, match=models.MatchAny(any=list(within_ids)) + ) + ] + ) + + search_result = self.client.query_points( + collection_name=self.collection_name, + query=query, + limit=n, + query_filter=query_filter, + with_payload=[ID_PAYLOAD_KEY], + using=self.vector_name, + ).points + + ids = [hit.payload[ID_PAYLOAD_KEY] for hit in search_result if hit.payload] + scores = [hit.score for hit in search_result] + + return ids, scores + + def _distance_mapping(self, measure: t.Optional[str] = None) -> models.Distance: + if measure == "cosine": + return models.Distance.COSINE + if measure == "l2": + return models.Distance.EUCLID + if measure == "dot": + return models.Distance.DOT + else: + raise ValueError(f"Unsupported measure: {measure}") + + def _convert_id(self, _id: str) -> str: + """ + Converts any string into a UUID string based on a seed. + + Qdrant accepts UUID strings and unsigned integers as point ID. + We use a seed to convert each string into a UUID string deterministically. + This allows us to overwrite the same point with the original ID. + """ + return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id)) diff --git a/test/unittest/vector_search/test_vector_search.py b/test/unittest/vector_search/test_vector_search.py index 084bff339..8eb891b5c 100644 --- a/test/unittest/vector_search/test_vector_search.py +++ b/test/unittest/vector_search/test_vector_search.py @@ -8,34 +8,36 @@ from superduper.vector_search.base import VectorItem from superduper.vector_search.in_memory import InMemoryVectorSearcher from superduper.vector_search.lance import LanceVectorSearcher +from superduper.vector_search.qdrant import QdrantVectorSearcher @pytest.fixture def index_data(monkeypatch): with tempfile.TemporaryDirectory() as unique_dir: - monkeypatch.setattr(CFG, 'lance_home', str(unique_dir)) + monkeypatch.setattr(CFG, "lance_home", str(unique_dir)) h = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]]) ids = [str(uuid.uuid4()) for _ in range(h.shape[0])] yield h, ids, unique_dir @pytest.mark.parametrize( - "vector_index_cls", [InMemoryVectorSearcher, LanceVectorSearcher] + "vector_index_cls", + [InMemoryVectorSearcher, LanceVectorSearcher, QdrantVectorSearcher], ) -@pytest.mark.parametrize("measure", ['l2', 'dot', 'cosine']) +@pytest.mark.parametrize("measure", ["l2", "dot", "cosine"]) def test_index(index_data, measure, vector_index_cls): h, ids, ud = index_data h = vector_index_cls( - identifier='my-index', h=h, index=ids, measure=measure, dimensions=3 + identifier="my-index", h=h, index=ids, measure=measure, dimensions=3 ) - y = np.array([0, 0.5, 0.5]) + y = np.array([0, 0, 1]) res, _ = h.find_nearest_from_array(y, 1) assert res[0] == ids[0] y = np.array([0.66, 0.66, 0.66]) - h.add([VectorItem(id='new', vector=y)]) + h.add([VectorItem(id="new", vector=y)]) h.post_create() res, _ = h.find_nearest_from_array(y, 1) - assert res[0] == 'new' + assert res[0] == "new"