Skip to content

Add feature : add chromadb support as a vector database #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flask4modelcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,16 @@ def response_hitquery(cache_resp):
# redis_config = configparser.ConfigParser()
# redis_config.read('modelcache/config/redis_config.ini')

# chromadb_config = configparser.ConfigParser()
# chromadb_config.read('modelcache/config/chromadb_config.ini')

data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))


# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
# VectorBase("chromadb", dimension=data2vec.dimension, chromadb_config=chromadb_config))

# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config))

Expand Down
2 changes: 2 additions & 0 deletions modelcache/config/chromadb_config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[chromadb]
persist_directory=''
92 changes: 92 additions & 0 deletions modelcache/manager/vector_data/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import List

import numpy as np
import logging
from modelcache.manager.vector_data.base import VectorBase, VectorData
from modelcache.utils import import_chromadb, import_torch

import_torch()
import_chromadb()

import chromadb


class Chromadb(VectorBase):

def __init__(
self,
persist_directory="./chromadb",
top_k: int = 1,
):
self.collection_name = "modelcache"
self.top_k = top_k

self._client = chromadb.PersistentClient(path=persist_directory)
self._collection = None

def mul_add(self, datas: List[VectorData], model=None):
collection_name_model = self.collection_name + '_' + model
self._collection = self._client.get_or_create_collection(name=collection_name_model)

data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
self._collection.add(embeddings=data_array, ids=id_array)

def search(self, data: np.ndarray, top_k: int = -1, model=None):
collection_name_model = self.collection_name + '_' + model
self._collection = self._client.get_or_create_collection(name=collection_name_model)

if self._collection.count() == 0:
return []
if top_k == -1:
top_k = self.top_k
results = self._collection.query(
query_embeddings=[data.tolist()],
n_results=top_k,
include=["distances"],
)
return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))

def rebuild(self, ids=None):
pass

def delete(self, ids, model=None):
try:
collection_name_model = self.collection_name + '_' + model
self._collection = self._client.get_or_create_collection(name=collection_name_model)
# 查询集合中实际存在的 ID
ids_str = [str(x) for x in ids]
existing_ids = set(self._collection.get(ids=ids_str).ids)

# 删除存在的 ID
if existing_ids:
self._collection.delete(list(existing_ids))

# 返回实际删除的条目数量
return len(existing_ids)

except Exception as e:
logging.error('Error during deletion: {}'.format(e))
raise ValueError(str(e))

def rebuild_col(self, model):
collection_name_model = self.collection_name + '_' + model

# 检查集合是否存在,如果存在则删除
collections = self._client.list_collections()
if any(col.name == collection_name_model for col in collections):
self._client.delete_collection(collection_name_model)
else:
return 'model collection not found, please check!'

try:
self._client.create_collection(collection_name_model)
except Exception as e:
logging.info(f'rebuild_collection: {e}')
raise ValueError(str(e))

def flush(self):
# chroma无flush方法
pass

def close(self):
pass
8 changes: 3 additions & 5 deletions modelcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,11 @@ def get(name, **kwargs):
elif name == "chromadb":
from modelcache.manager.vector_data.chroma import Chromadb

client_settings = kwargs.get("client_settings", None)
persist_directory = kwargs.get("persist_directory", None)
collection_name = kwargs.get("collection_name", COLLECTION_NAME)
chromadb_config = kwargs.get("chromadb_config", None)
persist_directory = chromadb_config.get('chromadb','persist_directory')

vector_base = Chromadb(
client_settings=client_settings,
persist_directory=persist_directory,
collection_name=collection_name,
top_k=top_k,
)
elif name == "hnswlib":
Expand Down
4 changes: 4 additions & 0 deletions modelcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ def import_pillow():

def import_redis():
_check_library("redis")


def import_chromadb():
_check_library("chromadb", package="chromadb")
2 changes: 2 additions & 0 deletions modelcache_mm/config/chromadb_config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[chromadb]
persist_directory=./chromadb
99 changes: 99 additions & 0 deletions modelcache_mm/manager/vector_data/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import List

import numpy as np
import logging
from modelcache_mm.manager.vector_data.base import VectorBase, VectorData
from modelcache_mm.utils import import_chromadb, import_torch
from modelcache_mm.utils.index_util import get_mm_index_name

import_torch()
import_chromadb()

import chromadb


class Chromadb(VectorBase):

def __init__(
self,
persist_directory="./chromadb",
top_k: int = 1,
):
self.top_k = top_k

self._client = chromadb.PersistentClient(path=persist_directory)
self._collection = None

def create(self, model=None, mm_type=None):
try:
collection_name_model = get_mm_index_name(model, mm_type)
# collection_name_model = self.collection_name + '_' + model
self._client.get_or_create_collection(name=collection_name_model)
except Exception as e:
raise ValueError(str(e))

def add(self, datas: List[VectorData], model=None, mm_type=None):
collection_name_model = get_mm_index_name(model, mm_type)
self._collection = self._client.get_or_create_collection(name=collection_name_model)

data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
self._collection.add(embeddings=data_array, ids=id_array)

def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type='mm'):
collection_name_model = get_mm_index_name(model, mm_type)
self._collection = self._client.get_or_create_collection(name=collection_name_model)

if self._collection.count() == 0:
return []
if top_k == -1:
top_k = self.top_k
results = self._collection.query(
query_embeddings=[data.tolist()],
n_results=top_k,
include=["distances"],
)
return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))

def delete(self, ids, model=None, mm_type=None):
try:
collection_name_model = get_mm_index_name(model, mm_type)
self._collection = self._client.get_or_create_collection(name=collection_name_model)
# 查询集合中实际存在的 ID
ids_str = [str(x) for x in ids]
existing_ids = set(self._collection.get(ids=ids_str).ids)

# 删除存在的 ID
if existing_ids:
self._collection.delete(list(existing_ids))

# 返回实际删除的条目数量
return len(existing_ids)

except Exception as e:
logging.error('Error during deletion: {}'.format(e))
raise ValueError(str(e))

def rebuild_idx(self, model, mm_type=None):
collection_name_model = get_mm_index_name(model, mm_type)

# 检查集合是否存在,如果存在则删除
collections = self._client.list_collections()
if any(col.name == collection_name_model for col in collections):
self._client.delete_collection(collection_name_model)
else:
return 'model collection not found, please check!'

try:
self._client.create_collection(collection_name_model)
except Exception as e:
logging.info(f'rebuild_collection: {e}')
raise ValueError(str(e))

def rebuild(self, ids=None):
pass

def flush(self):
pass

def close(self):
pass
9 changes: 9 additions & 0 deletions modelcache_mm/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def get(name, **kwargs):
dimension=dimension,
top_k=top_k
)
elif name == "chromadb":
from modelcache_mm.manager.vector_data.chroma import Chromadb

chromadb_config = kwargs.get("chromadb_config", None)
persist_directory = chromadb_config.get('chromadb', 'persist_directory')
vector_base = Chromadb(
persist_directory=persist_directory,
top_k=top_k,
)
else:
raise NotFoundError("vector store", name)
return vector_base
4 changes: 4 additions & 0 deletions modelcache_mm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ def import_pillow():

def import_redis():
_check_library("redis")


def import_chromadb():
_check_library("chromadb", package="chromadb")
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ faiss-cpu==1.7.4
redis==5.0.1
modelscope==1.14.0
fastapi==0.115.5
uvicorn==0.32.0
uvicorn==0.32.0
chromadb==0.5.23