Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/vector index/on load #2396

Merged
merged 4 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Optimize the logic of ready_ids in trigger_ids.
- Move all plugins superduperdb/ext/* to /plugins
- Optimize the logic for file saving and retrieval in the artifact_store.
- Add backfill on load of vector index

#### New Features & Functionality

Expand Down Expand Up @@ -52,6 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Change default encoding to sqlvector
- Fix some links in documentation
- Change `__dataclass_params__` to `_dataclass_params`
- Make component reload after caching in apply

## [0.3.0](https://github.com/superduper-io/superduper/compare/0.3.0...0.2.0]) (2024-Jun-21)

Expand Down
22 changes: 15 additions & 7 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,18 @@ def cdc(self, cdc):
self._cdc = cdc

def initialize_vector_searcher(
self, identifier, searcher_type: t.Optional[str] = None
self, vi, searcher_type: t.Optional[str] = None
) -> t.Optional[BaseVectorSearcher]:
"""
Initialize vector searcher.

:param identifier: Identifying string to component.
:param vi: Identifying string to component.
:param searcher_type: Searcher type (in_memory|native).
"""
searcher_type = searcher_type or s.CFG.cluster.vector_search.type

vi = self.vector_indices.force_load(identifier)
if isinstance(vi, str):
vi = self.vector_indices.force_load(vi)
from superduper import VectorIndex

assert isinstance(vi, VectorIndex)
Expand Down Expand Up @@ -704,9 +705,10 @@ def _apply(
dependencies = [*deps, *dependencies] # type: ignore[list-item]

object.post_create(self)
self._add_component_to_cache(object)
these_jobs = object.schedule_jobs(self, dependencies=dependencies)
jobs.extend(these_jobs)

self._add_component_to_cache(object)
return jobs

def _change_component_reference_prefix(self, serialized):
Expand Down Expand Up @@ -898,8 +900,11 @@ def _add_component_to_cache(self, component: Component):
"""
type_id = component.type_id
if cm := self.type_id_to_cache_mapping.get(type_id):
# NOTE: We need to reload the object since in `schedule_jobs`
# of the object, `db.replace` might be performed.
# e.g model prediction object is replace with updated datatype.
self.load(type_id, component.identifier)
kartik4949 marked this conversation as resolved.
Show resolved Hide resolved
getattr(self, cm)[component.identifier] = component
component.on_load(self)

def infer_schema(
self, data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None
Expand Down Expand Up @@ -938,16 +943,19 @@ class LoadDict(dict):
field: t.Optional[str] = None
callable: t.Optional[t.Callable] = None

def __missing__(self, key: str):
def __missing__(self, key: t.Union[str, Component]):
if self.field is not None:
key = key.identifier if isinstance(key, Component) else key
value = self[key] = self.database.load(
self.field,
key,
)
else:
msg = f'callable is ``None`` for {key}'
assert self.callable is not None, msg
value = self[key] = self.callable(key)
value = self.callable(key)
key = key.identifier if isinstance(key, Component) else key
self[key] = value
return value

def force_load(self, key: str):
Expand Down
94 changes: 92 additions & 2 deletions superduper/components/vector_index.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import dataclasses as dc
import itertools
import typing as t

import numpy as np
import tqdm
from overrides import override

from superduper import CFG, logging
from superduper.backends.base.query import Query
from superduper.base.datalayer import Datalayer, DBEvent
from superduper.base.document import Document
Expand All @@ -15,11 +18,83 @@
from superduper.jobs.job import FunctionJob
from superduper.misc.annotations import component
from superduper.misc.special_dicts import MongoStyleDict
from superduper.vector_search.base import VectorIndexMeasureType
from superduper.vector_search.base import VectorIndexMeasureType, VectorItem
from superduper.vector_search.update_tasks import copy_vectors, delete_vectors

KeyType = t.Union[str, t.List, t.Dict]

T = t.TypeVar('T')


def ibatch(iterable: t.Iterable[T], batch_size: int) -> t.Iterator[t.List[T]]:
"""Batch an iterable into chunks of size `batch_size`.

:param iterable: the iterable to batch
:param batch_size: the number of groups to write
"""
iterator = iter(iterable)
while True:
batch = list(itertools.islice(iterator, batch_size))
if not batch:
break
yield batch


def backfill_vector_search(db, vi, searcher):
"""
Backfill vector search from model outputs of a given vector index.

:param db: Datalayer instance.
:param vi: Identifier of vector index.
:param searcher: FastVectorSearch instance to load model outputs as vectors.
"""
from superduper.components.datatype import _BaseEncodable

logging.info(f"Loading vectors of vector-index: '{vi.identifier}'")

if vi.indexing_listener.select is None:
raise ValueError('.select must be set')

outputs_key = vi.indexing_listener.outputs
query = db[outputs_key].select()

logging.info(str(query))
id_field = '_source'

progress = tqdm.tqdm(desc='Loading vectors into vector-table...')
notfound = 0
found = 0
for record_batch in ibatch(
db.execute(query),
CFG.cluster.vector_search.backfill_batch_size,
):
items = []
for record in record_batch:
id = record[id_field]
assert not isinstance(vi.indexing_listener.model, str)
try:
h = record[outputs_key]
except KeyError:
notfound += 1
continue
else:
found += 1
if isinstance(h, _BaseEncodable):
h = h.unpack()
items.append(VectorItem.create(id=str(id), vector=h))
if items:
searcher.add(items)
progress.update(len(items))

if notfound:
logging.warn(
f'{notfound} document/rows were missing outputs ',
'key hence skipping vector loading for those.',
)

searcher.post_create()
logging.info(f'Loaded {found} vectors into vector index succesfully')


class VectorIndex(Component):
"""
Expand Down Expand Up @@ -57,6 +132,22 @@ def on_load(self, db: Datalayer) -> None:
Listener, db.load('listener', self.compatible_listener)
)

# Backfill vectors into vi
searcher = db.fast_vector_searchers[self]
if not searcher.is_initialized():
backfill_vector_search(db, self, searcher=searcher.searcher)
searcher.initialize()

def __hash__(self):
return hash((self.type_id, self.identifier))

def __eq__(self, other: Component):
if isinstance(other, Component):
return (
self.identifier == other.identifier and self.type_id and other.type_id
)
return False

def get_vector(
self,
like: Document,
Expand Down Expand Up @@ -155,7 +246,6 @@ def get_nearest(
)[0]

searcher = db.fast_vector_searchers[self.identifier]

return searcher.find_nearest_from_array(h, within_ids=within_ids, n=n)

def cleanup(self, db: Datalayer):
Expand Down
2 changes: 1 addition & 1 deletion superduper/jobs/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, uri: t.Optional[str] = None):
super().__init__(uri=uri)
self.consumer = self.build_consumer()

def build_consumer(self):
def build_consumer(self, **kwargs):
"""Build consumer client."""
return LocalQueueConsumer()

Expand Down
2 changes: 2 additions & 0 deletions superduper/misc/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def load_plugin(name: str):
"""
if name == 'local':
return importlib.import_module('superduper.backends.local')
if name == 'ray':
return importlib.import_module('superduper_services.compute.ray.compute')
logging.info(f"Loading plugin: {name}")
plugin = importlib.import_module(f'superduper_{name}')
return plugin
11 changes: 10 additions & 1 deletion superduper/vector_search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import typing as t
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass

import numpy
Expand Down Expand Up @@ -31,7 +32,15 @@ def __init__(
index: t.Optional[t.List[str]] = None,
measure: t.Optional[str] = None,
):
pass
self._init_vi: t.Dict = defaultdict(lambda: False)

def initialize(self, identifier):
"""Initialize vector index."""
self._init_vi[identifier] = True

def is_initialized(self, identifier):
"""Check if vector index initialized."""
return self._init_vi[identifier]

@classmethod
def from_component(cls, vi: 'VectorIndex'):
Expand Down
3 changes: 3 additions & 0 deletions superduper/vector_search/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(
self.lookup = None

self.identifier = identifier
super().__init__(
identifier=identifier, dimensions=dimensions, h=h, measure=measure
)

def __len__(self):
if self.h is not None:
Expand Down
28 changes: 27 additions & 1 deletion superduper/vector_search/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,33 @@ def __init__(self, db: 'Datalayer', vector_searcher, vector_index: str):
self.searcher = vector_searcher
self.vector_index = vector_index

def initialize(self):
"""Initialize vector index."""
if CFG.cluster.vector_search.uri is not None:
request_server(
service='vector_search',
endpoint='initialize',
args={
'vector_index': self.vector_index,
},
)
else:
self.searcher.initialize(self.vector_index)

def is_initialized(self):
"""Check if vector index initialized."""
if CFG.cluster.vector_search.uri is not None:
response = request_server(
service='vector_search',
endpoint='is_initialized',
args={
'vector_index': self.vector_index,
},
)
return response['status']
else:
return self.searcher.is_initialized(self.vector_index)

@staticmethod
def drop_remote(index):
"""Drop a vector index from the remote.
Expand All @@ -41,7 +68,6 @@ def drop(self):
"""Drop the vector index from the remote."""
if CFG.cluster.vector_search.uri is not None:
self.drop_remote(self.vector_index)
request_server

def __len__(self):
return len(self.searcher)
Expand Down
23 changes: 23 additions & 0 deletions test/unittest/component/test_vector_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
def test_vector_index_recovery(db):
from test.utils.usecase.vector_search import build_vector_index

build_vector_index(db)

table = db["documents"]
primary_id = table.primary_id
vector_index = "vector_index"
sample_data = list(table.select().execute())[50]

# Simulate restart
del db.fast_vector_searchers[vector_index]

db.load('vector_index', vector_index)

out = (
table.like({"x": sample_data["x"]}, vector_index=vector_index, n=10)
.select()
.execute()
)

ids = [o[primary_id] for o in list(out)]
assert len(ids) == 10
Loading