diff --git a/src/common/config.ts b/src/common/config.ts index c9505fd90..68c6ebc17 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -6,6 +6,7 @@ import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; import { Keychain } from "./keychain.js"; import type { Secret } from "./keychain.js"; import levenshtein from "ts-levenshtein"; +import type { Similarity } from "./search/vectorSearchEmbeddingsManager.js"; // From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts const OPTIONS = { @@ -186,7 +187,7 @@ export interface UserConfig extends CliOptions { voyageApiKey: string; disableEmbeddingsValidation: boolean; vectorSearchDimensions: number; - vectorSearchSimilarityFunction: "cosine" | "euclidean" | "dotProduct"; + vectorSearchSimilarityFunction: Similarity; } export const defaultUserConfig: UserConfig = { diff --git a/src/common/search/vectorSearchEmbeddingsManager.ts b/src/common/search/vectorSearchEmbeddingsManager.ts index 65ab0cd77..b6c06e485 100644 --- a/src/common/search/vectorSearchEmbeddingsManager.ts +++ b/src/common/search/vectorSearchEmbeddingsManager.ts @@ -2,13 +2,29 @@ import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-d import { BSON, type Document } from "bson"; import type { UserConfig } from "../config.js"; import type { ConnectionManager } from "../connectionManager.js"; +import z from "zod"; + +export const similarityEnum = z.enum(["cosine", "euclidean", "dotProduct"]); +export type Similarity = z.infer; + +export const quantizationEnum = z.enum(["none", "scalar", "binary"]); +export type Quantization = z.infer; export type VectorFieldIndexDefinition = { type: "vector"; path: string; numDimensions: number; - quantization: "none" | "scalar" | "binary"; - similarity: "euclidean" | "cosine" | "dotProduct"; + quantization: Quantization; + similarity: Similarity; +}; + +export type VectorFieldValidationError = { + path: string; + expectedNumDimensions: number; + expectedQuantization: Quantization; + actualNumDimensions: number | "unknown"; + actualQuantization: Quantization | "unknown"; + error: "dimension-mismatch" | "quantization-mismatch" | "not-a-vector" | "not-numeric"; }; export type EmbeddingNamespace = `${string}.${string}`; @@ -54,7 +70,7 @@ export class VectorSearchEmbeddingsManager { const vectorSearchIndexes = allSearchIndexes.filter((index) => index.type === "vectorSearch"); const vectorFields = vectorSearchIndexes // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - .flatMap((index) => (index.latestDefinition?.fields as Document) ?? []) + .flatMap((index) => (index.latestDefinition?.fields as Document[]) ?? []) .filter((field) => this.isVectorFieldIndexDefinition(field)); this.embeddings.set(embeddingDefKey, vectorFields); @@ -73,7 +89,7 @@ export class VectorSearchEmbeddingsManager { collection: string; }, document: Document - ): Promise { + ): Promise { const provider = await this.assertAtlasSearchIsAvailable(); if (!provider) { return []; @@ -87,15 +103,15 @@ export class VectorSearchEmbeddingsManager { } const embeddings = await this.embeddingsForNamespace({ database, collection }); - return embeddings.filter((emb) => !this.documentPassesEmbeddingValidation(emb, document)); + return embeddings + .map((emb) => this.getValidationErrorForDocument(emb, document)) + .filter((e) => e !== undefined); } private async assertAtlasSearchIsAvailable(): Promise { const connectionState = this.connectionManager.currentConnectionState; - if (connectionState.tag === "connected") { - if (await connectionState.isSearchSupported()) { - return connectionState.serviceProvider; - } + if (connectionState.tag === "connected" && (await connectionState.isSearchSupported())) { + return connectionState.serviceProvider; } return null; @@ -105,15 +121,29 @@ export class VectorSearchEmbeddingsManager { return doc["type"] === "vector"; } - private documentPassesEmbeddingValidation(definition: VectorFieldIndexDefinition, document: Document): boolean { + private getValidationErrorForDocument( + definition: VectorFieldIndexDefinition, + document: Document + ): VectorFieldValidationError | undefined { const fieldPath = definition.path.split("."); let fieldRef: unknown = document; + const constructError = ( + details: Partial> + ): VectorFieldValidationError => ({ + path: definition.path, + expectedNumDimensions: definition.numDimensions, + expectedQuantization: definition.quantization, + actualNumDimensions: details.actualNumDimensions ?? "unknown", + actualQuantization: details.actualQuantization ?? "unknown", + error: details.error ?? "not-a-vector", + }); + for (const field of fieldPath) { if (fieldRef && typeof fieldRef === "object" && field in fieldRef) { fieldRef = (fieldRef as Record)[field]; } else { - return true; + return undefined; } } @@ -121,40 +151,69 @@ export class VectorSearchEmbeddingsManager { // Because quantization is not defined by the user // we have to trust them in the format they use. case "none": - return true; + return undefined; case "scalar": case "binary": if (fieldRef instanceof BSON.Binary) { try { const elements = fieldRef.toFloat32Array(); - return elements.length === definition.numDimensions; + if (elements.length !== definition.numDimensions) { + return constructError({ + actualNumDimensions: elements.length, + actualQuantization: "binary", + error: "dimension-mismatch", + }); + } + + return undefined; } catch { // bits are also supported try { const bits = fieldRef.toBits(); - return bits.length === definition.numDimensions; + if (bits.length !== definition.numDimensions) { + return constructError({ + actualNumDimensions: bits.length, + actualQuantization: "binary", + error: "dimension-mismatch", + }); + } + + return undefined; } catch { - return false; + return constructError({ + actualQuantization: "binary", + error: "not-a-vector", + }); } } } else { if (!Array.isArray(fieldRef)) { - return false; + return constructError({ + error: "not-a-vector", + }); } if (fieldRef.length !== definition.numDimensions) { - return false; + return constructError({ + actualNumDimensions: fieldRef.length, + actualQuantization: "scalar", + error: "dimension-mismatch", + }); } if (!fieldRef.every((e) => this.isANumber(e))) { - return false; + return constructError({ + actualNumDimensions: fieldRef.length, + actualQuantization: "scalar", + error: "not-numeric", + }); } } break; } - return true; + return undefined; } private isANumber(value: unknown): boolean { diff --git a/src/common/session.ts b/src/common/session.ts index b53e3bec9..958c28355 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -156,15 +156,13 @@ export class Session extends EventEmitter { } async assertSearchSupported(): Promise { - const availability = await this.isSearchSupported(); - if (!availability) { + const isSearchSupported = await this.isSearchSupported(); + if (!isSearchSupported) { throw new MongoDBError( ErrorCodes.AtlasSearchNotSupported, "Atlas Search is not supported in the current cluster." ); } - - return; } get serviceProvider(): NodeDriverServiceProvider { diff --git a/src/tools/mongodb/create/createIndex.ts b/src/tools/mongodb/create/createIndex.ts index 9a8997aa1..e535f4fe3 100644 --- a/src/tools/mongodb/create/createIndex.ts +++ b/src/tools/mongodb/create/createIndex.ts @@ -3,6 +3,7 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js"; import type { IndexDirection } from "mongodb"; +import { quantizationEnum, similarityEnum } from "../../../common/search/vectorSearchEmbeddingsManager.js"; export class CreateIndexTool extends MongoDBToolBase { private vectorSearchIndexDefinition = z.object({ @@ -37,15 +38,12 @@ export class CreateIndexTool extends MongoDBToolBase { .describe( "Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time" ), - similarity: z - .enum(["cosine", "euclidean", "dotProduct"]) + similarity: similarityEnum .default(this.config.vectorSearchSimilarityFunction) .describe( "Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields." ), - quantization: z - .enum(["none", "scalar", "binary"]) - .optional() + quantization: quantizationEnum .default("none") .describe( "Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors." @@ -125,6 +123,7 @@ export class CreateIndexTool extends MongoDBToolBase { responseClarification = " Since this is a vector search index, it may take a while for the index to build. Use the `list-indexes` tool to check the index status."; + // clean up the embeddings cache so it considers the new index this.session.vectorSearchEmbeddingsManager.cleanupEmbeddingsForNamespace({ database, collection }); } diff --git a/src/tools/mongodb/create/insertMany.ts b/src/tools/mongodb/create/insertMany.ts index fbf1556a7..fa3fc3651 100644 --- a/src/tools/mongodb/create/insertMany.ts +++ b/src/tools/mongodb/create/insertMany.ts @@ -39,7 +39,9 @@ export class InsertManyTool extends MongoDBToolBase { // tell the LLM what happened const embeddingValidationMessages = [...embeddingValidations].map( (validation) => - `- Field ${validation.path} is an embedding with ${validation.numDimensions} dimensions and ${validation.quantization} quantization, and the provided value is not compatible.` + `- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` + + ` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` + + `actual quantization: ${validation.actualQuantization}. Error: ${validation.error}` ); return { diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index dc1345082..ce4ce6042 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -46,8 +46,8 @@ export abstract class MongoDBToolBase extends ToolBase { return this.session.serviceProvider; } - protected async ensureSearchIsSupported(): Promise { - return await this.session.assertSearchSupported(); + protected ensureSearchIsSupported(): Promise { + return this.session.assertSearchSupported(); } public register(server: Server): boolean { diff --git a/tests/integration/tools/mongodb/create/insertMany.test.ts b/tests/integration/tools/mongodb/create/insertMany.test.ts index d426a791f..54baa8869 100644 --- a/tests/integration/tools/mongodb/create/insertMany.test.ts +++ b/tests/integration/tools/mongodb/create/insertMany.test.ts @@ -186,7 +186,7 @@ describeWithMongoDB( expect(content).toContain("There were errors when inserting documents. No document was inserted."); const untrustedContent = getDataFromUntrustedContent(content); expect(untrustedContent).toContain( - "- Field embedding is an embedding with 8 dimensions and scalar quantization, and the provided value is not compatible." + "- Field embedding is an embedding with 8 dimensions and scalar quantization, and the provided value is not compatible. Actual dimensions: unknown, actual quantization: unknown. Error: not-a-vector" ); const oopsieCount = await provider.countDocuments(integration.randomDbName(), "test", { diff --git a/tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts b/tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts index e9becac04..ad6949668 100644 --- a/tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts +++ b/tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts @@ -4,6 +4,7 @@ import { VectorSearchEmbeddingsManager } from "../../../../src/common/search/vec import type { EmbeddingNamespace, VectorFieldIndexDefinition, + VectorFieldValidationError, } from "../../../../src/common/search/vectorSearchEmbeddingsManager.js"; import { BSON } from "bson"; import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; @@ -275,6 +276,15 @@ describe("VectorSearchEmbeddingsManager", () => { ); expect(result).toHaveLength(1); + const expectedError: VectorFieldValidationError = { + actualNumDimensions: 3, + actualQuantization: "scalar", + error: "dimension-mismatch", + expectedNumDimensions: 8, + expectedQuantization: "scalar", + path: "embedding_field", + }; + expect(result[0]).toEqual(expectedError); }); it("documents inserting the field with correct dimensions, but wrong type are invalid", async () => { @@ -284,6 +294,16 @@ describe("VectorSearchEmbeddingsManager", () => { ); expect(result).toHaveLength(1); + const expectedError: VectorFieldValidationError = { + actualNumDimensions: 8, + actualQuantization: "scalar", + error: "not-numeric", + expectedNumDimensions: 8, + expectedQuantization: "scalar", + path: "embedding_field", + }; + + expect(result[0]).toEqual(expectedError); }); it("documents inserting the field with correct dimensions and quantization in binary are valid", async () => {