Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added chroma support #11

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,5 @@ docs/build
.pypirc
*.tar.gz
*.whl
*.db
*.db
.idea
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ certifi==2023.11.17
cffi==1.16.0
charset-normalizer==3.3.2
cryptography==41.0.7
chromadb==0.4.17
docutils==0.20.1
idna==3.6
imagesize==1.4.1
Expand All @@ -17,7 +18,7 @@ markdown-it-py==3.0.0
MarkupSafe==2.1.3
mdurl==0.1.2
more-itertools==10.2.0
nh3==0.2.15
opentelemetry-api~=1.12
packaging==23.2
pkginfo==1.9.6
pycparser==2.21
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"annotated-types==0.6.0",
"anyio==4.2.0",
"certifi==2023.11.17",
"chromadb==0.4.17",
"charset-normalizer==3.3.2",
"distro==1.9.0",
"exceptiongroup==1.2.0",
Expand Down
180 changes: 180 additions & 0 deletions src/brdata_rag_tools/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import faiss
import numpy as np
from pgvector.sqlalchemy import Vector

import chromadb
from chromadb.api.types import Where

from sqlalchemy import String, text, BLOB
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session, Mapped, mapped_column
Expand Down Expand Up @@ -140,6 +144,7 @@ def write_rows(self, rows: List[Type[BaseClass]], create_embeddings: bool = True
if expunge:
session.expunge_all()


@dataclass
class IndexWrapper:
embeddings: faiss.METRIC_INNER_PRODUCT
Expand Down Expand Up @@ -334,3 +339,178 @@ def retrieve_embedding(self, row_id: str, table: BaseClass) -> np.array:
table.id == row_id).first()

return embedding


class Chroma(Database):
"""
This class represents a locally run ChromaDB.

:param user: The username to connect to the database, if run as a different service.
:type user: str
:param database: The name of the database to connect to, if run as a different service.
:type database: str
:param password: The password to connect to the database, if run as a different service. If not provided, it will use the value of the "DATABASE_PASSWORD" environment variable.
:type password: str
:param host: The host address of the database, if run as a different service. Default is "localhost".
:type host: str
:param port: The port number of the database, if run as a different service. Default is 8000.
:type port: int
:param verbose: Whether to enable verbose output. Default is False.
:type verbose: bool
"""
def __init__(self, user: str = None, database: str = None, password: str = None,
host: str = None, port: int = None, verbose: bool = False):
super().__init__(user, database, password, host, port, verbose, vector_type=Vector)

def _create_engine(self):
if self.database:
# run productively
return chromadb.PersistentClient(path=self.database)
else:
# run for test purposes without persistence
return chromadb.Client()

def write_rows(self, rows: List[Type[BaseClass]], create_embeddings: bool = True, expunge=False, expire=True):
"""
Write rows to the database and optionally create embeddings for the rows.

:param rows: A list of rows to be written to the database. Rows must be instances of BaseClass or its subclasses.
:param create_embeddings: A boolean value indicating whether embeddings should be created for the rows.
Default value is True.
:return: None
"""
table = type(rows[0])
collection = self.engine.get_or_create_collection(name=table.__name__, metadata={"hnsw:space": "cosine"}) # https://docs.trychroma.com/usage-guide#changing-the-distance-function

rows_with_embedding = self.get_existing_row_ids(table)
rows_wo_embedding = [x for x in rows if x.id not in rows_with_embedding]

embedder = rows[0].embedding_type.model
custom_embedder = type(embedder).__name__ != 'ChromaEmbedder'

if custom_embedder:
newly_embedded = embedder.create_embedding_bulk(rows_wo_embedding)
else:
newly_embedded = rows_wo_embedding

for i, row in enumerate(newly_embedded):
# kill unneeded metadata from class
# not very elegant, i guess ;-)
metadata = row.__dict__.copy()
del metadata['id']
del metadata['embedding_source']
if custom_embedder:
del metadata['embedding']
internal_keys = [k for k in list(metadata.keys()) if k.startswith("_")]
for ik in internal_keys:
del metadata[ik]

# ChromaDB does not accept some kind of metadata, so change to str
dt = [k for k, v in metadata.items() if type(v) not in [str, float, int, bool]]
for d in dt:
metadata[d] = str(metadata[d])

if custom_embedder:
collection.add(documents=str(row.embedding_source),
embeddings=list(row.embedding),
metadatas=metadata,
ids=row.id
)
else:
# use ChromaDB's own embedding
collection.add(documents=str(row.embedding_source),
metadatas=metadata,
ids=row.id
)

def update_rows(self, entries: List, update_metadatas: bool = False):
"""
Update Entries

:return: None
"""
table = type(entries[0])

collection = self.engine.get_collection(name=table.__name__)

ids = []
metadatas = []
for e in entries:
# kill unneeded metadata from class
# XXX not very elegant, i guess ;-)
ids.append(e.id)
metadata = e.__dict__.copy()
del metadata['id']
internal_keys = [k for k in list(metadata.keys()) if k.startswith("_")]
for ik in internal_keys:
del metadata[ik]

# ChromaDB does not accept some kind of metadata, so change to str
dt = [k for k, v in metadata.items() if type(v) not in [str, float, int, bool]]
for d in dt:
metadata[d] = str(metadata[d])
metadatas.append(metadata)

collection.update(ids=ids,
metadatas=metadatas,
)

def retrieve_similar_content(self, prompt, table: Type[BaseClass],
embedding_type: EmbeddingConfig = None,
limit: int = 50, max_dist: float = 100, where: Where = {}) -> List:
"""
Retrieve similar content based on a prompt. The function creates an embedding with the specified embedding type
and queries the associated database for the most similar matches.

:param prompt: The prompt for which similar content needs to be found.
:param table: The table in which the content is stored.
:param embedding_type: The type of embedding to be used. (default: None, stored in table class)
:param limit: The maximum number of similar content to be retrieved (default: 50).
:param max_dist: The maximum cosine distance between embedding vectors (default: 100)
:param: where: query metadata parameters, see chromadb docs for details
:return: A list of results containing similar content.
"""
if embedding_type:
embedder = embedding_type.model
else:
embedder = table.embedding_type.model

custom_embedder = type(embedder).__name__ != 'ChromaEmbedder'
collection = self.engine.get_or_create_collection(name=table.__name__)

if custom_embedder:
prompt_embedding = list(embedder.create_embedding(prompt))
query = collection.query(query_embeddings=[prompt_embedding], n_results=limit, where=where)
else:
# Use ChromaDB´s own embedder
query = collection.query(query_texts=[prompt], n_results=limit, where=where)

results = []
if len(query) > 0:
documents = query['documents'][0]
metadatas = query['metadatas'][0]
distances = query['distances'][0]
for i, id in enumerate(query['ids'][0]):
if distances[i] > max_dist:
break
entry = table()
entry.__dict__.update(metadatas[i])
entry.__dict__.update(id=id, embedding_source=documents[i], cosine_dist = distances[i])
results.append(entry)

return results

def retrieve_embedding(self, row_id: str, table: BaseClass) -> np.array:
collection = self.engine.get_collection(name=table.__name__)
all = collection.get(ids=[row_id], include=['embeddings'])

return np.array(all['embeddings'][0])

def create_tables(self):
return NotImplementedError("create_tables is not implemented")

def get_existing_row_ids(self, table: BaseClass):
collection = self.engine.get_collection(name=table.__name__)
all = collection.get(include=[])

return all['ids']
16 changes: 16 additions & 0 deletions src/brdata_rag_tools/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
"sentence_transformers": {
"dimension": 1024
},
"ChromaEmbedder": {
"dimension": 1024 # not sure... XXX
},
}

user_models = {}
Expand Down Expand Up @@ -49,6 +52,7 @@ class EmbeddingConfig(Enum):
"""
SENTENCE_TRANSFORMERS = "sentence_transformers"
TF_IDF = "tfidf"
CHROMAEMBEDDER = "ChromaEmbedder"

@property
def dimension(self):
Expand All @@ -69,6 +73,8 @@ def model(self):
return SentenceTransformer()
elif self == self.TF_IDF:
raise NotImplementedError()
elif self == self.CHROMAEMBEDDER:
return ChromaEmbedder()
else:
try:
return user_models[self.value]()
Expand Down Expand Up @@ -213,4 +219,14 @@ def create_embedding_bulk(self, rows: List[Type[BaseClass]]) -> List[

return rows

class ChromaEmbedder(Embedder):
# Chroma's own embedder, no need for implementation
def __init__(self, endpoint: str = None, auth_token: Optional[str] = None):
self.endpoint = ""
self.auth_token = ""

def create_embedding_bulk(self, rows: List[Type[BaseClass]]):
raise NotImplementedError()

def create_embedding(self, text: str) -> np.array:
raise NotImplementedError()
37 changes: 36 additions & 1 deletion test/test_databases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from src.brdata_rag_tools.databases import PGVector, FAISS
from src.brdata_rag_tools.databases import PGVector, FAISS, Chroma
from src.brdata_rag_tools.embeddings import EmbeddingConfig, Embedder, register

from sqlalchemy.orm import Mapped, mapped_column
Expand Down Expand Up @@ -126,3 +126,38 @@ class Podcast(abstract_table):
assert len(simcont) == 1
assert isinstance(simcont[0], dict)
assert simcont[0]["cosine_dist"] == 0

def test_chroma():
database = Chroma()
assert type(database) == Chroma

abstract_table = database.create_abstract_embedding_table(EmbeddingConfig.CHROMAEMBEDDER)
assert len(set(abstract_table.__annotations__.keys()) & set(["id", "embedding_source", "embedding"])) == 3

class Podcast(abstract_table):
__tablename__ = "testchroma"
title: Mapped[str] = mapped_column(String)
url: Mapped[str] = mapped_column(String)

podcasts = []

for i in range(3):
podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht",
id=str(i), # ChromaDb only accepts strings as ID
url="example.com",
embedding_source="test")
)

podcasts.append(Podcast(title="TRUE CRIME - Unter Verdacht",
id="4",
url="example.com",
embedding_source="Different Vector")
)

database.write_rows(podcasts, create_embeddings=True)

simcont = database.retrieve_similar_content(prompt="Hallo Test.", table=Podcast, max_dist=0.5)

assert len(simcont) == 3
assert isinstance(simcont[0], Podcast)
assert simcont[0].cosine_dist < 0.5