Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Use ChromaDB for our embeddings database and similarity search #460

Merged
merged 11 commits into from
Jan 12, 2024
222 changes: 112 additions & 110 deletions mentat/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,107 @@
import json
import logging
import os
import sqlite3
from pathlib import Path
from timeit import default_timer

import numpy as np
import chromadb
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction

from mentat.code_feature import CodeFeature, count_feature_tokens
from mentat.errors import MentatError
from mentat.llm_api_handler import (
count_tokens,
model_context_size,
model_price_per_1000_tokens,
)
from mentat.llm_api_handler import model_price_per_1000_tokens
from mentat.session_context import SESSION_CONTEXT
from mentat.session_input import ask_yes_no
from mentat.utils import mentat_dir_path, sha256
from mentat.utils import mentat_dir_path

EMBEDDINGS_API_BATCH_SIZE = 2048
client = chromadb.PersistentClient(path=str(mentat_dir_path / "chroma"))


class EmbeddingsDatabase:
def __init__(self, output_dir: Path | None = None):
self.output_dir = output_dir or mentat_dir_path
os.makedirs(self.output_dir, exist_ok=True)
self.path = Path(self.output_dir) / "embeddings.sqlite3"
self._connect()
class Collection:
_collection = None

def _connect(self):
self.conn = sqlite3.connect(self.path)
with self.conn as db:
db.execute(
"CREATE TABLE IF NOT EXISTS embeddings "
"(checksum TEXT PRIMARY KEY, vector BLOB)"
)

def set(self, items: dict[str, list[float]]):
with self.conn as db:
db.executemany(
"INSERT OR REPLACE INTO embeddings (checksum, vector) VALUES (?, ?)",
[
(key, sqlite3.Binary(json.dumps(value).encode("utf-8")))
for key, value in items.items()
],
)

def get(self, keys: list[str]) -> dict[str, list[float]]:
with self.conn as db:
cursor = db.execute(
"SELECT checksum, vector FROM embeddings WHERE checksum IN"
f" ({','.join(['?']*len(keys))})",
keys,
)
return {row[0]: json.loads(row[1]) for row in cursor.fetchall()}

def exists(self, key: str) -> bool:
with self.conn as db:
cursor = db.execute("SELECT 1 FROM embeddings WHERE checksum=?", (key,))
return cursor.fetchone() is not None

def __del__(self):
self.conn.close()


database = EmbeddingsDatabase()
def __init__(self, embedding_model: str):
api_key = os.getenv("OPENAI_API_KEY")
# src: https://cookbook.openai.com/examples/vector_databases/chroma/using_chroma_for_embeddings_search
embedding_function = OpenAIEmbeddingFunction(
api_key=api_key, model_name=embedding_model
)
self._collection = client.get_or_create_collection(
name=f"mentat-{embedding_model}",
embedding_function=embedding_function, # type: ignore
granawkins marked this conversation as resolved.
Show resolved Hide resolved
)
self.migrate_old_db()

def exists(self, id: str) -> bool:
assert self._collection is not None, "Collection not initialized"
return len(self._collection.get(id)["ids"]) > 0

def add(self, checksums: list[str], texts: list[str]) -> None:
assert self._collection is not None, "Collection not initialized"
return self._collection.add( # type: ignore
ids=checksums,
documents=texts,
metadatas=[{"active": False} for _ in checksums],
)

def query(self, prompt: str, checksums: list[str]) -> dict[str, float]:
assert self._collection is not None, "Collection not initialized"

def _cosine_similarity(v1: list[float], v2: list[float]) -> float:
"""Calculate the cosine similarity between two vectors."""
dot_product = np.dot(v1, v2)
norm_v1 = np.linalg.norm(v1)
norm_v2 = np.linalg.norm(v2)
return dot_product / (norm_v1 * norm_v2) # pyright: ignore
self._collection.update( # type: ignore
ids=checksums,
metadatas=[{"active": True} for _ in checksums],
)
results = self._collection.query( # type: ignore
query_texts=[prompt],
where={"active": True},
n_results=len(checksums) + 1,
)
self._collection.update( # type: ignore
ids=checksums,
metadatas=[{"active": False} for _ in checksums],
)
assert results["distances"], "Error calculating distances"
return {c: e for c, e in zip(results["ids"][0], results["distances"][0])}

def migrate_old_db(self):
"""Temporary helper function to migrate sqlite3 to chromadb

Prior to January 2024, embeddings were fetched directly from the OpenAI API in
batches and saved to a db. We're currently using the same embeddings (ada-2) with
ChromaDB, so we might as well save the effort of re-fetching them. One drawback
is that ChromaDB saves the actual text, while our old schema did not, so migrated
records will have an empty documents field. This shouldn't be a problem. If it is,
we can just update the 'exists' method to require a non-empty "document" field.

TODO: erase this method/call after a few months
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nice optimization, but I'm not sure the extra complexity is worth saving users a couple cents to rebuild. It might be unlikely, but some bug from migrated records being slightly different could be a big headache and not worth the risk

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok make sense. I was just looking at the size of my sqlite file but (1) mine is probably bigger than anyone's and (2) it is indeed really fast and cheap to reload.

"""
path = mentat_dir_path / "embeddings.sqlite3"
if not path.exists():
return
import json
import sqlite3

try:
conn = sqlite3.connect(path)
cursor = conn.execute("SELECT checksum, vector FROM embeddings")
results = {row[0]: json.loads(row[1]) for row in cursor.fetchall()}
results = {
k: v
for k, v in results.items()
if not self.exists(k) and len(v) == 1536
}
if results:
ids = list(results.keys())
embeddings = list(results.values())
batches = len(ids) // 1000 + 1
for i in range(batches):
_ids = ids[i * 1000 : (i + 1) * 1000]
_embeddings = embeddings[i * 1000 : (i + 1) * 1000]
self._collection.add( # type: ignore
ids=_ids,
embeddings=_embeddings,
metadatas=[{"active": False} for _ in _ids],
)
path.unlink()
except Exception as e:
logging.debug(f"Error migrating old embeddings database: {e}")


async def get_feature_similarity_scores(
Expand All @@ -80,38 +110,22 @@ async def get_feature_similarity_scores(
loading_multiplier: float = 0.0,
) -> list[float]:
"""Return the similarity scores for a given prompt and list of features."""
global database
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
cost_tracker = session_context.cost_tracker
embedding_model = session_context.config.embedding_model
llm_api_handler = session_context.llm_api_handler

max_model_tokens = model_context_size(embedding_model)
if max_model_tokens is None:
raise MentatError(f"Missing model context size for {embedding_model}.")
prompt_tokens = count_tokens(prompt, embedding_model, False)
if prompt_tokens > max_model_tokens:
stream.send(
f"Warning: Prompt contains {prompt_tokens} tokens, but the model"
f" can only handle {max_model_tokens} tokens. Ignoring embeddings."
)
return [0.0 for _ in features]
# Initialize DB
collection = Collection(embedding_model)

prompt_checksum = sha256(prompt)
# Identify which items need embeddings.
checksums: list[str] = [f.get_checksum() for f in features]
tokens: list[int] = await count_feature_tokens(features, embedding_model)
embed_texts = list[str]()
embed_checksums = list[str]()
embed_tokens = list[int]()
if not database.exists(prompt_checksum):
embed_texts.append(prompt)
embed_checksums.append(prompt_checksum)
embed_tokens.append(prompt_tokens)
for feature, checksum, token in zip(features, checksums, tokens):
if token > max_model_tokens:
continue
if not database.exists(checksum):
if not collection.exists(checksum) and checksum not in embed_checksums:
embed_texts.append("\n".join(feature.get_code_message()))
embed_checksums.append(checksum)
embed_tokens.append(token)
Expand All @@ -134,42 +148,30 @@ async def get_feature_similarity_scores(
stream.send("Ignoring embeddings for now.")
return [0.0 for _ in checksums]

# Fetch embeddings in batches
if len(embed_texts) == 0:
n_batches = 0
else:
n_batches = len(embed_texts) // EMBEDDINGS_API_BATCH_SIZE + 1
for batch in range(n_batches):
# Load embeddings
if embed_texts:
start_time = default_timer()
if loading_multiplier:
stream.send(
f"Fetching embeddings, batch {batch+1}/{n_batches}",
f"Fetching embeddings for {len(embed_texts)} documents",
channel="loading",
progress=(100 / n_batches) * loading_multiplier,
progress=50 * loading_multiplier,
)
start_time = default_timer()
i_start, i_end = (
batch * EMBEDDINGS_API_BATCH_SIZE,
(batch + 1) * EMBEDDINGS_API_BATCH_SIZE,
)
_texts = embed_texts[i_start:i_end]
_checksums = embed_checksums[i_start:i_end]
_tokens = embed_tokens[i_start:i_end]

response = await llm_api_handler.call_embedding_api(_texts, embedding_model)
collection.add(embed_checksums, embed_texts)
cost_tracker.log_api_call_stats(
sum(_tokens),
sum(embed_tokens),
0,
embedding_model,
start_time - default_timer(),
)
database.set({k: v for k, v in zip(_checksums, response)})

# Calculate similarity score for each feature
prompt_embedding = database.get([prompt_checksum])[prompt_checksum]
embeddings = database.get(checksums)
scores = [
_cosine_similarity(prompt_embedding, embeddings[k]) if k in embeddings else 0.0
for k in checksums
]

return scores
# Get similarity scores
if loading_multiplier:
stream.send(
"Matching relevant documents based on embedding similarity",
channel="loading",
progress=(50 if embed_texts else 100) * loading_multiplier,
)
_checksums = list(set(checksums))
scores = collection.query(prompt, _checksums)
return [scores[f.get_checksum()] for f in features]
2 changes: 1 addition & 1 deletion mentat/feature_filters/embedding_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def score(
self.query, features, self.loading_multiplier
)
features_scored = zip(features, sim_scores)
return sorted(features_scored, key=lambda x: x[1], reverse=True)
return sorted(features_scored, key=lambda x: x[1])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were using cosine_similarity, which is larger-is-better, but Chroma's default similarity func is 'l2', which is lower-is-better.


async def filter(
self,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
attrs==23.1.0
backoff==2.2.1
chromadb==0.4.22
fire==0.5.0
gitpython==3.1.37
jinja2==3.1.2
Expand Down
11 changes: 0 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,6 @@ def set_unstreamed_values(value):
return completion_mock


@pytest.fixture(scope="function")
def mock_call_embedding_api(mocker):
embedding_mock = mocker.patch.object(LlmApiHandler, "call_embedding_api")

def set_embedding_values(value):
embedding_mock.return_value = value

embedding_mock.set_embedding_values = set_embedding_values
return embedding_mock


granawkins marked this conversation as resolved.
Show resolved Hide resolved
### Auto-used fixtures


Expand Down
32 changes: 0 additions & 32 deletions tests/embeddings_test.py

This file was deleted.

19 changes: 19 additions & 0 deletions tests/license_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
"tiktoken",
# openai as Apache 2.0; for some reason, after updating to 1.0, pip-licenses thinks it's UNKNOWN
"openai",
# UNKNOWN
"chroma-hnswlib",
"pyinstaller-hooks-contrib",
# "Other/Proprietary License"
"pinecone-client", # part of chromadb - we don't use directly, but potentially problematic
"pygls",
granawkins marked this conversation as resolved.
Show resolved Hide resolved
]
accepted_licenses = [
"BSD License",
Expand All @@ -26,6 +32,19 @@
"GNU General Public License (GPL)",
"Public Domain",
"The Unlicense (Unlicense)",
"Apache License, Version 2.0",
"Apache License v2.0",
"Apache License 2.0",
"GNU Lesser General Public License v2 or later (LGPLv2+)",
granawkins marked this conversation as resolved.
Show resolved Hide resolved
"ISC",
"LGPL-2.1-only",
"GNU General Public License v3 (GPLv3)",
granawkins marked this conversation as resolved.
Show resolved Hide resolved
"Apache-2.0",
"GNU Library or Lesser General Public License (LGPL)",
"GNU General Public License v2 (GPLv2)",
"Apache Software License v2",
"Artistic License",
"GNU General Public License v2 or later (GPLv2+)",
]


Expand Down