Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Qdrant vector search support #2428

Merged
merged 9 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add predict_id in Listener
- Add serve in Model
- Added templates directory with OSS templates
- Qdrant vector search support

#### Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies = [
"python-magic",
"apscheduler",
"bson",
"qdrant-client>=1.10.0,<2"
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions superduper/backends/base/backends.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from superduper.vector_search.atlas import MongoAtlasVectorSearcher
from superduper.vector_search.in_memory import InMemoryVectorSearcher
from superduper.vector_search.lance import LanceVectorSearcher
from superduper.vector_search.qdrant import QdrantVectorSearcher

vector_searcher_implementations = {
"lance": LanceVectorSearcher,
"in_memory": InMemoryVectorSearcher,
"mongodb+srv": MongoAtlasVectorSearcher,
"qdrant": QdrantVectorSearcher,
}
47 changes: 25 additions & 22 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _dataclass_from_dict(data_class: t.Any, data: dict):
for f in data:
if (
f in field_types
and hasattr(field_types[f], '__dataclass_fields__')
and hasattr(field_types[f], "__dataclass_fields__")
and not isinstance(data[f], field_types[f])
):
params[f] = _dataclass_from_dict(field_types[f], data[f])
Expand All @@ -39,10 +39,10 @@ def __call__(self, **kwargs):
"""Update the configuration with the given parameters."""
parameters = self.dict()
for k, v in kwargs.items():
if '__' in k:
parts = k.split('__')
if "__" in k:
parts = k.split("__")
parent = parts[0]
child = '__'.join(parts[1:])
child = "__".join(parts[1:])
parameters[parent] = getattr(self, parent)(**{child: v})
else:
parameters[k] = v
Expand Down Expand Up @@ -89,8 +89,8 @@ class PollingStrategy(CDCStrategy):
"""

auto_increment_field: t.Optional[str] = None
frequency: str = '30'
type: str = 'incremental'
frequency: str = "30"
type: str = "incremental"


@dc.dataclass
Expand All @@ -102,7 +102,7 @@ class LogBasedStrategy(CDCStrategy):
"""

resume_token: t.Optional[t.Dict[str, str]] = None
type: str = 'logbased'
type: str = "logbased"


@dc.dataclass
Expand All @@ -129,7 +129,7 @@ class VectorSearch(BaseConfig):
"""

uri: t.Optional[str] = None # None implies local mode
type: str = 'in_memory' # in_memory|lance
type: str = "in_memory" # in_memory|lance|qdrant
backfill_batch_size: int = 100


Expand Down Expand Up @@ -187,7 +187,7 @@ class Compute(BaseConfig):

uri: t.Optional[str] = None
kwargs: t.Dict = dc.field(default_factory=dict)
backend: str = 'local'
backend: str = "local"


@dc.dataclass
Expand Down Expand Up @@ -223,11 +223,11 @@ class Cluster(BaseConfig):
class LogLevel(str, Enum):
"""Enumerate log severity level # noqa."""

DEBUG = 'DEBUG'
INFO = 'INFO'
DEBUG = "DEBUG"
INFO = "INFO"
SUCCESS = "SUCCESS"
WARN = 'WARN'
ERROR = 'ERROR'
WARN = "WARN"
ERROR = "ERROR"


class LogType(str, Enum):
Expand All @@ -245,8 +245,8 @@ class LogType(str, Enum):
class BytesEncoding(str, Enum):
"""Enumerate the encoding of bytes in the data backend # noqa."""

BYTES = 'Bytes'
BASE64 = 'Str'
BYTES = "Bytes"
BASE64 = "Str"


@dc.dataclass
Expand All @@ -261,7 +261,7 @@ class Downloads(BaseConfig):

folder: t.Optional[str] = None
n_workers: int = 0
headers: t.Dict = dc.field(default_factory=lambda: {'User-Agent': 'me'})
headers: t.Dict = dc.field(default_factory=lambda: {"User-Agent": "me"})
timeout: t.Optional[int] = None


Expand All @@ -286,13 +286,14 @@ class Config(BaseConfig):
If True, the schema will be created if it does not exist.
:param log_colorize: Whether to colorize the logs
:param output_prefix: The prefix for the output table and output field key
:param vector_search_kwargs: The keyword arguments to pass to the vector search
"""

envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None

data_backend: str = 'mongodb://mongodb:27017/test_db'
data_backend: str = "mongodb://mongodb:27017/test_db"

lance_home: str = os.path.join('.superduper', 'vector_indices')
lance_home: str = os.path.join(".superduper", "vector_indices")

artifact_store: t.Optional[str] = None
metadata_store: t.Optional[str] = None
Expand All @@ -309,7 +310,9 @@ class Config(BaseConfig):

bytes_encoding: BytesEncoding = BytesEncoding.BYTES
auto_schema: bool = True
output_prefix: str = '_outputs__'
output_prefix: str = "_outputs__"

vector_search_kwargs: t.Dict = dc.field(default_factory=dict)

def __post_init__(self, envs):
if envs is not None:
Expand All @@ -325,7 +328,7 @@ def hybrid_storage(self):
def comparables(self):
"""A dict of `self` excluding some defined attributes."""
_dict = dc.asdict(self)
list(map(_dict.pop, ('cluster', 'retries', 'downloads')))
list(map(_dict.pop, ("cluster", "retries", "downloads")))
return _dict

def match(self, cfg: t.Dict):
Expand All @@ -350,7 +353,7 @@ def to_yaml(self):
import yaml

def enum_representer(dumper, data):
return dumper.represent_scalar('tag:yaml.org,2002:str', str(data.value))
return dumper.represent_scalar("tag:yaml.org,2002:str", str(data.value))

yaml.SafeDumper.add_representer(BytesEncoding, enum_representer)
yaml.SafeDumper.add_representer(LogLevel, enum_representer)
Expand All @@ -370,7 +373,7 @@ def _diff(r1, r2):
d = _diff_impl(r1, r2)
out = {}
for path, left, right in d:
out['.'.join(path)] = (left, right)
out[".".join(path)] = (left, right)
return out


Expand Down
182 changes: 182 additions & 0 deletions superduper/vector_search/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import typing as t
import uuid
from copy import deepcopy

import numpy as np
from qdrant_client import QdrantClient, models

from superduper import CFG
from superduper.vector_search.base import (
BaseVectorSearcher,
VectorIndexMeasureType,
VectorItem,
)

ID_PAYLOAD_KEY = "_id"


class QdrantVectorSearcher(BaseVectorSearcher):
"""
Implementation of a vector index using [Qdrant](https://qdrant.tech/).

:param identifier: Unique string identifier of index
:param dimensions: Dimension of the vector embeddings
:param h: Seed vectors ``numpy.ndarray``
:param index: list of IDs
:param measure: measure to assess similarity
"""

def __init__(
self,
identifier: str,
dimensions: int,
h: t.Optional[np.ndarray] = None,
index: t.Optional[t.List[str]] = None,
measure: t.Optional[str] = None,
):
super().__init__(identifier, dimensions, h, index, measure)
config_dict = deepcopy(CFG.vector_search_kwargs)
self.vector_name: t.Optional[str] = config_dict.pop("vector_name", None)
# Use an in-memory instance by default
# https://github.com/qdrant/qdrant-client#local-mode
config_dict = config_dict or {"location": ":memory:"}
self.client = QdrantClient(**config_dict)

self.collection_name = identifier
if not self.client.collection_exists(self.collection_name):
measure = (
measure.name if isinstance(measure, VectorIndexMeasureType) else measure
)
distance = self._distance_mapping(measure)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(size=dimensions, distance=distance),
)

self.initialize(identifier)

if h is not None and index is not None:
self.add(
[
VectorItem(
id=_id,
vector=vector,
)
for _id, vector in zip(index, h)
]
)

def __len__(self):
return self.client.get_collection(self.collection_name).vectors_count

def add(self, items: t.Sequence[VectorItem], cache: bool = False) -> None:
"""Add vectors to the index.

:param items: List of vectors to add
:param cache: Cache vectors (not used in Qdrant implementation).
"""
points = [
models.PointStruct(
id=self._convert_id(item.id),
vector={self.vector_name: item.vector.tolist()}
if self.vector_name
else item.vector.tolist(),
payload={ID_PAYLOAD_KEY: item.id},
)
for item in items
]
self.client.upsert(collection_name=self.collection_name, points=points)

def delete(self, ids: t.Sequence[str]) -> None:
"""Delete vectors from the index.

:param ids: List of IDs to delete
"""
self.client.delete(
collection_name=self.collection_name,
points_selector=models.Filter(
must=[
models.FieldCondition(
key=ID_PAYLOAD_KEY, match=models.MatchAny(any=list(ids))
)
]
),
)

def find_nearest_from_id(
self,
_id,
n: int = 100,
within_ids: t.Sequence[str] = (),
) -> t.Tuple[t.List[str], t.List[float]]:
"""Find the nearest vectors to a given ID.

:param _id: ID to search
:param n: Number of results to return
:param within_ids: List of IDs to search within
"""
return self._query_nearest(_id, n, within_ids)

def find_nearest_from_array(
self,
h: np.typing.ArrayLike,
n: int = 100,
within_ids: t.Sequence[str] = (),
) -> t.Tuple[t.List[str], t.List[float]]:
"""Find the nearest vectors to a given vector.

:param h: Vector to search
:param n: Number of results to return
:param within_ids: List of IDs to search within
"""
return self._query_nearest(h, n, within_ids)

def _query_nearest(
self,
query: t.Union[np.typing.ArrayLike, str],
n: int = 100,
within_ids: t.Sequence[str] = (),
) -> t.Tuple[t.List[str], t.List[float]]:
query_filter = None
if within_ids:
query_filter = models.Filter(
must=[
models.FieldCondition(
key=ID_PAYLOAD_KEY, match=models.MatchAny(any=list(within_ids))
)
]
)

search_result = self.client.query_points(
collection_name=self.collection_name,
query=query,
limit=n,
query_filter=query_filter,
with_payload=[ID_PAYLOAD_KEY],
using=self.vector_name,
).points

ids = [hit.payload[ID_PAYLOAD_KEY] for hit in search_result if hit.payload]
scores = [hit.score for hit in search_result]

return ids, scores

def _distance_mapping(self, measure: t.Optional[str] = None) -> models.Distance:
if measure == "cosine":
return models.Distance.COSINE
if measure == "l2":
return models.Distance.EUCLID
if measure == "dot":
return models.Distance.DOT
else:
raise ValueError(f"Unsupported measure: {measure}")

def _convert_id(self, _id: str) -> str:
"""
Converts any string into a UUID string based on a seed.

Qdrant accepts UUID strings and unsigned integers as point ID.
We use a seed to convert each string into a UUID string deterministically.
This allows us to overwrite the same point with the original ID.
"""
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
Loading
Loading