Skip to content

Commit 1e6ee10

Browse files
authored
chore: follow-up on #626 MCP-246 (#660)
1 parent 930b947 commit 1e6ee10

File tree

8 files changed

+112
-33
lines changed

8 files changed

+112
-33
lines changed

src/common/config.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
66
import { Keychain } from "./keychain.js";
77
import type { Secret } from "./keychain.js";
88
import levenshtein from "ts-levenshtein";
9+
import type { Similarity } from "./search/vectorSearchEmbeddingsManager.js";
910

1011
// From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts
1112
const OPTIONS = {
@@ -186,7 +187,7 @@ export interface UserConfig extends CliOptions {
186187
voyageApiKey: string;
187188
disableEmbeddingsValidation: boolean;
188189
vectorSearchDimensions: number;
189-
vectorSearchSimilarityFunction: "cosine" | "euclidean" | "dotProduct";
190+
vectorSearchSimilarityFunction: Similarity;
190191
}
191192

192193
export const defaultUserConfig: UserConfig = {

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,29 @@ import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-d
22
import { BSON, type Document } from "bson";
33
import type { UserConfig } from "../config.js";
44
import type { ConnectionManager } from "../connectionManager.js";
5+
import z from "zod";
6+
7+
export const similarityEnum = z.enum(["cosine", "euclidean", "dotProduct"]);
8+
export type Similarity = z.infer<typeof similarityEnum>;
9+
10+
export const quantizationEnum = z.enum(["none", "scalar", "binary"]);
11+
export type Quantization = z.infer<typeof quantizationEnum>;
512

613
export type VectorFieldIndexDefinition = {
714
type: "vector";
815
path: string;
916
numDimensions: number;
10-
quantization: "none" | "scalar" | "binary";
11-
similarity: "euclidean" | "cosine" | "dotProduct";
17+
quantization: Quantization;
18+
similarity: Similarity;
19+
};
20+
21+
export type VectorFieldValidationError = {
22+
path: string;
23+
expectedNumDimensions: number;
24+
expectedQuantization: Quantization;
25+
actualNumDimensions: number | "unknown";
26+
actualQuantization: Quantization | "unknown";
27+
error: "dimension-mismatch" | "quantization-mismatch" | "not-a-vector" | "not-numeric";
1228
};
1329

1430
export type EmbeddingNamespace = `${string}.${string}`;
@@ -54,7 +70,7 @@ export class VectorSearchEmbeddingsManager {
5470
const vectorSearchIndexes = allSearchIndexes.filter((index) => index.type === "vectorSearch");
5571
const vectorFields = vectorSearchIndexes
5672
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
57-
.flatMap<Document>((index) => (index.latestDefinition?.fields as Document) ?? [])
73+
.flatMap<Document>((index) => (index.latestDefinition?.fields as Document[]) ?? [])
5874
.filter((field) => this.isVectorFieldIndexDefinition(field));
5975

6076
this.embeddings.set(embeddingDefKey, vectorFields);
@@ -73,7 +89,7 @@ export class VectorSearchEmbeddingsManager {
7389
collection: string;
7490
},
7591
document: Document
76-
): Promise<VectorFieldIndexDefinition[]> {
92+
): Promise<VectorFieldValidationError[]> {
7793
const provider = await this.assertAtlasSearchIsAvailable();
7894
if (!provider) {
7995
return [];
@@ -87,15 +103,15 @@ export class VectorSearchEmbeddingsManager {
87103
}
88104

89105
const embeddings = await this.embeddingsForNamespace({ database, collection });
90-
return embeddings.filter((emb) => !this.documentPassesEmbeddingValidation(emb, document));
106+
return embeddings
107+
.map((emb) => this.getValidationErrorForDocument(emb, document))
108+
.filter((e) => e !== undefined);
91109
}
92110

93111
private async assertAtlasSearchIsAvailable(): Promise<NodeDriverServiceProvider | null> {
94112
const connectionState = this.connectionManager.currentConnectionState;
95-
if (connectionState.tag === "connected") {
96-
if (await connectionState.isSearchSupported()) {
97-
return connectionState.serviceProvider;
98-
}
113+
if (connectionState.tag === "connected" && (await connectionState.isSearchSupported())) {
114+
return connectionState.serviceProvider;
99115
}
100116

101117
return null;
@@ -105,56 +121,99 @@ export class VectorSearchEmbeddingsManager {
105121
return doc["type"] === "vector";
106122
}
107123

108-
private documentPassesEmbeddingValidation(definition: VectorFieldIndexDefinition, document: Document): boolean {
124+
private getValidationErrorForDocument(
125+
definition: VectorFieldIndexDefinition,
126+
document: Document
127+
): VectorFieldValidationError | undefined {
109128
const fieldPath = definition.path.split(".");
110129
let fieldRef: unknown = document;
111130

131+
const constructError = (
132+
details: Partial<Pick<VectorFieldValidationError, "error" | "actualNumDimensions" | "actualQuantization">>
133+
): VectorFieldValidationError => ({
134+
path: definition.path,
135+
expectedNumDimensions: definition.numDimensions,
136+
expectedQuantization: definition.quantization,
137+
actualNumDimensions: details.actualNumDimensions ?? "unknown",
138+
actualQuantization: details.actualQuantization ?? "unknown",
139+
error: details.error ?? "not-a-vector",
140+
});
141+
112142
for (const field of fieldPath) {
113143
if (fieldRef && typeof fieldRef === "object" && field in fieldRef) {
114144
fieldRef = (fieldRef as Record<string, unknown>)[field];
115145
} else {
116-
return true;
146+
return undefined;
117147
}
118148
}
119149

120150
switch (definition.quantization) {
121151
// Because quantization is not defined by the user
122152
// we have to trust them in the format they use.
123153
case "none":
124-
return true;
154+
return undefined;
125155
case "scalar":
126156
case "binary":
127157
if (fieldRef instanceof BSON.Binary) {
128158
try {
129159
const elements = fieldRef.toFloat32Array();
130-
return elements.length === definition.numDimensions;
160+
if (elements.length !== definition.numDimensions) {
161+
return constructError({
162+
actualNumDimensions: elements.length,
163+
actualQuantization: "binary",
164+
error: "dimension-mismatch",
165+
});
166+
}
167+
168+
return undefined;
131169
} catch {
132170
// bits are also supported
133171
try {
134172
const bits = fieldRef.toBits();
135-
return bits.length === definition.numDimensions;
173+
if (bits.length !== definition.numDimensions) {
174+
return constructError({
175+
actualNumDimensions: bits.length,
176+
actualQuantization: "binary",
177+
error: "dimension-mismatch",
178+
});
179+
}
180+
181+
return undefined;
136182
} catch {
137-
return false;
183+
return constructError({
184+
actualQuantization: "binary",
185+
error: "not-a-vector",
186+
});
138187
}
139188
}
140189
} else {
141190
if (!Array.isArray(fieldRef)) {
142-
return false;
191+
return constructError({
192+
error: "not-a-vector",
193+
});
143194
}
144195

145196
if (fieldRef.length !== definition.numDimensions) {
146-
return false;
197+
return constructError({
198+
actualNumDimensions: fieldRef.length,
199+
actualQuantization: "scalar",
200+
error: "dimension-mismatch",
201+
});
147202
}
148203

149204
if (!fieldRef.every((e) => this.isANumber(e))) {
150-
return false;
205+
return constructError({
206+
actualNumDimensions: fieldRef.length,
207+
actualQuantization: "scalar",
208+
error: "not-numeric",
209+
});
151210
}
152211
}
153212

154213
break;
155214
}
156215

157-
return true;
216+
return undefined;
158217
}
159218

160219
private isANumber(value: unknown): boolean {

src/common/session.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,13 @@ export class Session extends EventEmitter<SessionEvents> {
156156
}
157157

158158
async assertSearchSupported(): Promise<void> {
159-
const availability = await this.isSearchSupported();
160-
if (!availability) {
159+
const isSearchSupported = await this.isSearchSupported();
160+
if (!isSearchSupported) {
161161
throw new MongoDBError(
162162
ErrorCodes.AtlasSearchNotSupported,
163163
"Atlas Search is not supported in the current cluster."
164164
);
165165
}
166-
167-
return;
168166
}
169167

170168
get serviceProvider(): NodeDriverServiceProvider {

src/tools/mongodb/create/createIndex.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
33
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
44
import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js";
55
import type { IndexDirection } from "mongodb";
6+
import { quantizationEnum, similarityEnum } from "../../../common/search/vectorSearchEmbeddingsManager.js";
67

78
export class CreateIndexTool extends MongoDBToolBase {
89
private vectorSearchIndexDefinition = z.object({
@@ -37,15 +38,12 @@ export class CreateIndexTool extends MongoDBToolBase {
3738
.describe(
3839
"Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time"
3940
),
40-
similarity: z
41-
.enum(["cosine", "euclidean", "dotProduct"])
41+
similarity: similarityEnum
4242
.default(this.config.vectorSearchSimilarityFunction)
4343
.describe(
4444
"Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields."
4545
),
46-
quantization: z
47-
.enum(["none", "scalar", "binary"])
48-
.optional()
46+
quantization: quantizationEnum
4947
.default("none")
5048
.describe(
5149
"Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors."
@@ -125,6 +123,7 @@ export class CreateIndexTool extends MongoDBToolBase {
125123

126124
responseClarification =
127125
" Since this is a vector search index, it may take a while for the index to build. Use the `list-indexes` tool to check the index status.";
126+
128127
// clean up the embeddings cache so it considers the new index
129128
this.session.vectorSearchEmbeddingsManager.cleanupEmbeddingsForNamespace({ database, collection });
130129
}

src/tools/mongodb/create/insertMany.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ export class InsertManyTool extends MongoDBToolBase {
3939
// tell the LLM what happened
4040
const embeddingValidationMessages = [...embeddingValidations].map(
4141
(validation) =>
42-
`- Field ${validation.path} is an embedding with ${validation.numDimensions} dimensions and ${validation.quantization} quantization, and the provided value is not compatible.`
42+
`- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` +
43+
` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` +
44+
`actual quantization: ${validation.actualQuantization}. Error: ${validation.error}`
4345
);
4446

4547
return {

src/tools/mongodb/mongodbTool.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ export abstract class MongoDBToolBase extends ToolBase {
4646
return this.session.serviceProvider;
4747
}
4848

49-
protected async ensureSearchIsSupported(): Promise<void> {
50-
return await this.session.assertSearchSupported();
49+
protected ensureSearchIsSupported(): Promise<void> {
50+
return this.session.assertSearchSupported();
5151
}
5252

5353
public register(server: Server): boolean {

tests/integration/tools/mongodb/create/insertMany.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ describeWithMongoDB(
186186
expect(content).toContain("There were errors when inserting documents. No document was inserted.");
187187
const untrustedContent = getDataFromUntrustedContent(content);
188188
expect(untrustedContent).toContain(
189-
"- Field embedding is an embedding with 8 dimensions and scalar quantization, and the provided value is not compatible."
189+
"- Field embedding is an embedding with 8 dimensions and scalar quantization, and the provided value is not compatible. Actual dimensions: unknown, actual quantization: unknown. Error: not-a-vector"
190190
);
191191

192192
const oopsieCount = await provider.countDocuments(integration.randomDbName(), "test", {

tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { VectorSearchEmbeddingsManager } from "../../../../src/common/search/vec
44
import type {
55
EmbeddingNamespace,
66
VectorFieldIndexDefinition,
7+
VectorFieldValidationError,
78
} from "../../../../src/common/search/vectorSearchEmbeddingsManager.js";
89
import { BSON } from "bson";
910
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
@@ -275,6 +276,15 @@ describe("VectorSearchEmbeddingsManager", () => {
275276
);
276277

277278
expect(result).toHaveLength(1);
279+
const expectedError: VectorFieldValidationError = {
280+
actualNumDimensions: 3,
281+
actualQuantization: "scalar",
282+
error: "dimension-mismatch",
283+
expectedNumDimensions: 8,
284+
expectedQuantization: "scalar",
285+
path: "embedding_field",
286+
};
287+
expect(result[0]).toEqual(expectedError);
278288
});
279289

280290
it("documents inserting the field with correct dimensions, but wrong type are invalid", async () => {
@@ -284,6 +294,16 @@ describe("VectorSearchEmbeddingsManager", () => {
284294
);
285295

286296
expect(result).toHaveLength(1);
297+
const expectedError: VectorFieldValidationError = {
298+
actualNumDimensions: 8,
299+
actualQuantization: "scalar",
300+
error: "not-numeric",
301+
expectedNumDimensions: 8,
302+
expectedQuantization: "scalar",
303+
path: "embedding_field",
304+
};
305+
306+
expect(result[0]).toEqual(expectedError);
287307
});
288308

289309
it("documents inserting the field with correct dimensions and quantization in binary are valid", async () => {

0 commit comments

Comments
 (0)