Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/common/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
import { Keychain } from "./keychain.js";
import type { Secret } from "./keychain.js";
import levenshtein from "ts-levenshtein";
import type { Similarity } from "./search/vectorSearchEmbeddingsManager.js";

// From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts
const OPTIONS = {
Expand Down Expand Up @@ -186,7 +187,7 @@ export interface UserConfig extends CliOptions {
voyageApiKey: string;
disableEmbeddingsValidation: boolean;
vectorSearchDimensions: number;
vectorSearchSimilarityFunction: "cosine" | "euclidean" | "dotProduct";
vectorSearchSimilarityFunction: Similarity;
}

export const defaultUserConfig: UserConfig = {
Expand Down
97 changes: 78 additions & 19 deletions src/common/search/vectorSearchEmbeddingsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,29 @@ import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-d
import { BSON, type Document } from "bson";
import type { UserConfig } from "../config.js";
import type { ConnectionManager } from "../connectionManager.js";
import z from "zod";

export const similarityEnum = z.enum(["cosine", "euclidean", "dotProduct"]);
export type Similarity = z.infer<typeof similarityEnum>;

export const quantizationEnum = z.enum(["none", "scalar", "binary"]);
export type Quantization = z.infer<typeof quantizationEnum>;

export type VectorFieldIndexDefinition = {
type: "vector";
path: string;
numDimensions: number;
quantization: "none" | "scalar" | "binary";
similarity: "euclidean" | "cosine" | "dotProduct";
quantization: Quantization;
similarity: Similarity;
};

export type VectorFieldValidationError = {
path: string;
expectedNumDimensions: number;
expectedQuantization: Quantization;
actualNumDimensions: number | "unknown";
actualQuantization: Quantization | "unknown";
error: "dimension-mismatch" | "quantization-mismatch" | "not-a-vector" | "not-numeric";
};

export type EmbeddingNamespace = `${string}.${string}`;
Expand Down Expand Up @@ -54,7 +70,7 @@ export class VectorSearchEmbeddingsManager {
const vectorSearchIndexes = allSearchIndexes.filter((index) => index.type === "vectorSearch");
const vectorFields = vectorSearchIndexes
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
.flatMap<Document>((index) => (index.latestDefinition?.fields as Document) ?? [])
.flatMap<Document>((index) => (index.latestDefinition?.fields as Document[]) ?? [])
.filter((field) => this.isVectorFieldIndexDefinition(field));

this.embeddings.set(embeddingDefKey, vectorFields);
Expand All @@ -73,7 +89,7 @@ export class VectorSearchEmbeddingsManager {
collection: string;
},
document: Document
): Promise<VectorFieldIndexDefinition[]> {
): Promise<VectorFieldValidationError[]> {
const provider = await this.assertAtlasSearchIsAvailable();
if (!provider) {
return [];
Expand All @@ -87,15 +103,15 @@ export class VectorSearchEmbeddingsManager {
}

const embeddings = await this.embeddingsForNamespace({ database, collection });
return embeddings.filter((emb) => !this.documentPassesEmbeddingValidation(emb, document));
return embeddings
.map((emb) => this.getValidationErrorForDocument(emb, document))
.filter((e) => e !== undefined);
}

private async assertAtlasSearchIsAvailable(): Promise<NodeDriverServiceProvider | null> {
const connectionState = this.connectionManager.currentConnectionState;
if (connectionState.tag === "connected") {
if (await connectionState.isSearchSupported()) {
return connectionState.serviceProvider;
}
if (connectionState.tag === "connected" && (await connectionState.isSearchSupported())) {
return connectionState.serviceProvider;
}

return null;
Expand All @@ -105,56 +121,99 @@ export class VectorSearchEmbeddingsManager {
return doc["type"] === "vector";
}

private documentPassesEmbeddingValidation(definition: VectorFieldIndexDefinition, document: Document): boolean {
private getValidationErrorForDocument(
definition: VectorFieldIndexDefinition,
document: Document
): VectorFieldValidationError | undefined {
const fieldPath = definition.path.split(".");
let fieldRef: unknown = document;

const constructError = (
details: Partial<Pick<VectorFieldValidationError, "error" | "actualNumDimensions" | "actualQuantization">>
): VectorFieldValidationError => ({
path: definition.path,
expectedNumDimensions: definition.numDimensions,
expectedQuantization: definition.quantization,
actualNumDimensions: details.actualNumDimensions ?? "unknown",
actualQuantization: details.actualQuantization ?? "unknown",
error: details.error ?? "not-a-vector",
});

for (const field of fieldPath) {
if (fieldRef && typeof fieldRef === "object" && field in fieldRef) {
fieldRef = (fieldRef as Record<string, unknown>)[field];
} else {
return true;
return undefined;
}
}

switch (definition.quantization) {
// Because quantization is not defined by the user
// we have to trust them in the format they use.
case "none":
return true;
return undefined;
case "scalar":
case "binary":
if (fieldRef instanceof BSON.Binary) {
try {
const elements = fieldRef.toFloat32Array();
return elements.length === definition.numDimensions;
if (elements.length !== definition.numDimensions) {
return constructError({
actualNumDimensions: elements.length,
actualQuantization: "binary",
error: "dimension-mismatch",
});
}

return undefined;
} catch {
// bits are also supported
try {
const bits = fieldRef.toBits();
return bits.length === definition.numDimensions;
if (bits.length !== definition.numDimensions) {
return constructError({
actualNumDimensions: bits.length,
actualQuantization: "binary",
error: "dimension-mismatch",
});
}

return undefined;
} catch {
return false;
return constructError({
actualQuantization: "binary",
error: "not-a-vector",
});
}
}
} else {
if (!Array.isArray(fieldRef)) {
return false;
return constructError({
error: "not-a-vector",
});
}

if (fieldRef.length !== definition.numDimensions) {
return false;
return constructError({
actualNumDimensions: fieldRef.length,
actualQuantization: "scalar",
error: "dimension-mismatch",
});
}

if (!fieldRef.every((e) => this.isANumber(e))) {
return false;
return constructError({
actualNumDimensions: fieldRef.length,
actualQuantization: "scalar",
error: "not-numeric",
});
}
}

break;
}

return true;
return undefined;
}

private isANumber(value: unknown): boolean {
Expand Down
6 changes: 2 additions & 4 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,13 @@ export class Session extends EventEmitter<SessionEvents> {
}

async assertSearchSupported(): Promise<void> {
const availability = await this.isSearchSupported();
if (!availability) {
const isSearchSupported = await this.isSearchSupported();
if (!isSearchSupported) {
throw new MongoDBError(
ErrorCodes.AtlasSearchNotSupported,
"Atlas Search is not supported in the current cluster."
);
}

return;
}

get serviceProvider(): NodeDriverServiceProvider {
Expand Down
9 changes: 4 additions & 5 deletions src/tools/mongodb/create/createIndex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js";
import type { IndexDirection } from "mongodb";
import { quantizationEnum, similarityEnum } from "../../../common/search/vectorSearchEmbeddingsManager.js";

export class CreateIndexTool extends MongoDBToolBase {
private vectorSearchIndexDefinition = z.object({
Expand Down Expand Up @@ -37,15 +38,12 @@ export class CreateIndexTool extends MongoDBToolBase {
.describe(
"Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time"
),
similarity: z
.enum(["cosine", "euclidean", "dotProduct"])
similarity: similarityEnum
.default(this.config.vectorSearchSimilarityFunction)
.describe(
"Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields."
),
quantization: z
.enum(["none", "scalar", "binary"])
.optional()
quantization: quantizationEnum
.default("none")
.describe(
"Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors."
Expand Down Expand Up @@ -125,6 +123,7 @@ export class CreateIndexTool extends MongoDBToolBase {

responseClarification =
" Since this is a vector search index, it may take a while for the index to build. Use the `list-indexes` tool to check the index status.";

// clean up the embeddings cache so it considers the new index
this.session.vectorSearchEmbeddingsManager.cleanupEmbeddingsForNamespace({ database, collection });
}
Expand Down
4 changes: 3 additions & 1 deletion src/tools/mongodb/create/insertMany.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ export class InsertManyTool extends MongoDBToolBase {
// tell the LLM what happened
const embeddingValidationMessages = [...embeddingValidations].map(
(validation) =>
`- Field ${validation.path} is an embedding with ${validation.numDimensions} dimensions and ${validation.quantization} quantization, and the provided value is not compatible.`
`- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` +
` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` +
`actual quantization: ${validation.actualQuantization}. Error: ${validation.error}`
);

return {
Expand Down
4 changes: 2 additions & 2 deletions src/tools/mongodb/mongodbTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ export abstract class MongoDBToolBase extends ToolBase {
return this.session.serviceProvider;
}

protected async ensureSearchIsSupported(): Promise<void> {
return await this.session.assertSearchSupported();
protected ensureSearchIsSupported(): Promise<void> {
return this.session.assertSearchSupported();
}

public register(server: Server): boolean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ describeWithMongoDB(
expect(content).toContain("There were errors when inserting documents. No document was inserted.");
const untrustedContent = getDataFromUntrustedContent(content);
expect(untrustedContent).toContain(
"- Field embedding is an embedding with 8 dimensions and scalar quantization, and the provided value is not compatible."
"- Field embedding is an embedding with 8 dimensions and scalar quantization, and the provided value is not compatible. Actual dimensions: unknown, actual quantization: unknown. Error: not-a-vector"
);

const oopsieCount = await provider.countDocuments(integration.randomDbName(), "test", {
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { VectorSearchEmbeddingsManager } from "../../../../src/common/search/vec
import type {
EmbeddingNamespace,
VectorFieldIndexDefinition,
VectorFieldValidationError,
} from "../../../../src/common/search/vectorSearchEmbeddingsManager.js";
import { BSON } from "bson";
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
Expand Down Expand Up @@ -275,6 +276,15 @@ describe("VectorSearchEmbeddingsManager", () => {
);

expect(result).toHaveLength(1);
const expectedError: VectorFieldValidationError = {
actualNumDimensions: 3,
actualQuantization: "scalar",
error: "dimension-mismatch",
expectedNumDimensions: 8,
expectedQuantization: "scalar",
path: "embedding_field",
};
expect(result[0]).toEqual(expectedError);
});

it("documents inserting the field with correct dimensions, but wrong type are invalid", async () => {
Expand All @@ -284,6 +294,16 @@ describe("VectorSearchEmbeddingsManager", () => {
);

expect(result).toHaveLength(1);
const expectedError: VectorFieldValidationError = {
actualNumDimensions: 8,
actualQuantization: "scalar",
error: "not-numeric",
expectedNumDimensions: 8,
expectedQuantization: "scalar",
path: "embedding_field",
};

expect(result[0]).toEqual(expectedError);
});

it("documents inserting the field with correct dimensions and quantization in binary are valid", async () => {
Expand Down
Loading