Skip to content

Commit

Permalink
Feat: Add support for ChromaDB (run-llama#310)
Browse files Browse the repository at this point in the history
Co-authored-by: Aarav Navani <38411399+oofmeister27@users.noreply.github.com>
  • Loading branch information
thucpn and aaravnavani authored Jan 12, 2024
1 parent bb46afe commit 648482b
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 104 deletions.
5 changes: 5 additions & 0 deletions .changeset/purple-camels-walk.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

Feat: Add support for Chroma DB as a vector store
2 changes: 1 addition & 1 deletion examples/astradb/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const collectionName = "movie_reviews";
async function main() {
try {
const reader = new PapaCSVReader(false);
const docs = await reader.loadData("astradb/data/movie_reviews.csv");
const docs = await reader.loadData("../data/movie_reviews.csv");

const astraVS = new AstraDBVectorStore();
await astraVS.create(collectionName, {
Expand Down
13 changes: 13 additions & 0 deletions examples/chromadb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Chroma Vector Store Example

How to run `examples/chromadb/test.ts`:

Export your OpenAI API Key using `export OPEN_API_KEY=insert your api key here`

If you haven't installed chromadb, run `pip install chromadb`. Start the server using `chroma run`.

Now, open a new terminal window and inside `examples`, run `pnpx ts-node chromadb/test.ts`.

Here's the output for the input query `Tell me about Godfrey Cheshire's rating of La Sapienza.`:

`Godfrey Cheshire gave La Sapienza a rating of 4 out of 4, describing it as fresh and the most astonishing and important movie to emerge from France in quite some time.`
40 changes: 40 additions & 0 deletions examples/chromadb/test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import {
ChromaVectorStore,
PapaCSVReader,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";

const collectionName = "movie_reviews";

async function main() {
const sourceFile: string = "./data/movie_reviews.csv";

try {
console.log(`Loading data from ${sourceFile}`);
const reader = new PapaCSVReader(false, ", ", "\n", {
header: true,
});
const docs = await reader.loadData(sourceFile);

console.log("Creating ChromaDB vector store");
const chromaVS = new ChromaVectorStore({ collectionName });
const ctx = await storageContextFromDefaults({ vectorStore: chromaVS });

console.log("Embedding documents and adding to index");
const index = await VectorStoreIndex.fromDocuments(docs, {
storageContext: ctx,
});

console.log("Querying index");
const queryEngine = index.asQueryEngine();
const response = await queryEngine.query(
"Tell me about Godfrey Cheshire's rating of La Sapienza.",
);
console.log(response.toString());
} catch (e) {
console.error(e);
}
}

main();

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
"private": true,
"name": "examples",
"dependencies": {
"@notionhq/client": "^2.2.13",
"@datastax/astra-db-ts": "^0.1.2",
"@notionhq/client": "^2.2.13",
"@pinecone-database/pinecone": "^1.1.2",
"chromadb": "^1.7.3",
"commander": "^11.1.0",
"llamaindex": "latest",
"dotenv": "^16.3.1",
"llamaindex": "latest",
"mongodb": "^6.2.0"
},
"devDependencies": {
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"@pinecone-database/pinecone": "^1.1.2",
"@xenova/transformers": "^2.10.0",
"assemblyai": "^4.0.0",
"chromadb": "^1.7.3",
"file-type": "^18.7.0",
"js-tiktoken": "^1.0.8",
"lodash": "^4.17.21",
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/storage/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export * from "./indexStore/types";
export { SimpleKVStore } from "./kvStore/SimpleKVStore";
export * from "./kvStore/types";
export { AstraDBVectorStore } from "./vectorStore/AstraDBVectorStore";
export { ChromaVectorStore } from "./vectorStore/ChromaVectorStore";
export { MongoDBAtlasVectorSearch } from "./vectorStore/MongoDBAtlasVectorStore";
export { PGVectorStore } from "./vectorStore/PGVectorStore";
export { PineconeVectorStore } from "./vectorStore/PineconeVectorStore";
Expand Down
148 changes: 148 additions & 0 deletions packages/core/src/storage/vectorStore/ChromaVectorStore.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import {
AddParams,
ChromaClient,
ChromaClientParams,
Collection,
IncludeEnum,
QueryResponse,
Where,
WhereDocument,
} from "chromadb";
import { BaseNode, MetadataMode } from "../../Node";
import {
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
} from "./types";
import { metadataDictToNode, nodeToMetadata } from "./utils";

type ChromaDeleteOptions = {
where?: Where;
whereDocument?: WhereDocument;
};

type ChromaQueryOptions = {
whereDocument?: WhereDocument;
};

const DEFAULT_TEXT_KEY = "text";

export class ChromaVectorStore implements VectorStore {
storesText: boolean = true;
flatMetadata: boolean = true;
textKey: string;
private chromaClient: ChromaClient;
private collection: Collection | null = null;
private collectionName: string;

constructor(init: {
collectionName: string;
textKey?: string;
chromaClientParams?: ChromaClientParams;
}) {
this.collectionName = init.collectionName;
this.chromaClient = new ChromaClient(init.chromaClientParams);
this.textKey = init.textKey ?? DEFAULT_TEXT_KEY;
}

client(): ChromaClient {
return this.chromaClient;
}

async getCollection(): Promise<Collection> {
if (!this.collection) {
const coll = await this.chromaClient.createCollection({
name: this.collectionName,
});
this.collection = coll;
}
return this.collection;
}

private getDataToInsert(nodes: BaseNode[]): AddParams {
const metadatas = nodes.map((node) =>
nodeToMetadata(node, true, this.textKey, this.flatMetadata),
);
return {
embeddings: nodes.map((node) => node.getEmbedding()),
ids: nodes.map((node) => node.id_),
metadatas,
documents: nodes.map((node) => node.getContent(MetadataMode.NONE)),
};
}

async add(nodes: BaseNode[]): Promise<string[]> {
if (!nodes || nodes.length === 0) {
return [];
}

const dataToInsert = this.getDataToInsert(nodes);
const collection = await this.getCollection();
await collection.add(dataToInsert);
return nodes.map((node) => node.id_);
}

async delete(
refDocId: string,
deleteOptions?: ChromaDeleteOptions,
): Promise<void> {
const collection = await this.getCollection();
await collection.delete({
ids: [refDocId],
where: deleteOptions?.where,
whereDocument: deleteOptions?.whereDocument,
});
}

async query(
query: VectorStoreQuery,
options?: ChromaQueryOptions,
): Promise<VectorStoreQueryResult> {
if (query.docIds) {
throw new Error("ChromaDB does not support querying by docIDs");
}
if (query.mode != VectorStoreQueryMode.DEFAULT) {
throw new Error("ChromaDB does not support querying by mode");
}

const chromaWhere: { [x: string]: string | number | boolean } = {};
if (query.filters) {
query.filters.filters.map((filter) => {
const filterKey = filter.key;
const filterValue = filter.value;
chromaWhere[filterKey] = filterValue;
});
}

const collection = await this.getCollection();
const queryResponse: QueryResponse = await collection.query({
queryEmbeddings: query.queryEmbedding ?? undefined,
queryTexts: query.queryStr ?? undefined,
nResults: query.similarityTopK,
where: Object.keys(chromaWhere).length ? chromaWhere : undefined,
whereDocument: options?.whereDocument,
//ChromaDB doesn't return the result embeddings by default so we need to include them
include: [
IncludeEnum.Distances,
IncludeEnum.Metadatas,
IncludeEnum.Documents,
IncludeEnum.Embeddings,
],
});
const vectorStoreQueryResult: VectorStoreQueryResult = {
nodes: queryResponse.ids[0].map((id, index) => {
const text = (queryResponse.documents as string[][])[0][index];
const metaData = queryResponse.metadatas[0][index] ?? {};
const node = metadataDictToNode(metaData);
node.setContent(text);
return node;
}),
similarities: (queryResponse.distances as number[][])[0].map(
(distance) => 1 - distance,
),
ids: queryResponse.ids[0],
};
return vectorStoreQueryResult;
}
}
42 changes: 41 additions & 1 deletion pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 648482b

Please sign in to comment.