Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Return computed distance and set distance thresholds on VectorQueries #2090

Merged
merged 18 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 96 additions & 72 deletions api-report/firestore.api.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dev/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ export type {AggregateQuery} from './reference/aggregate-query';
export type {AggregateQuerySnapshot} from './reference/aggregate-query-snapshot';
export type {VectorQuery} from './reference/vector-query';
export type {VectorQuerySnapshot} from './reference/vector-query-snapshot';
export type {VectorQueryOptions} from './reference/vector-query-options';
export {BulkWriter} from './bulk-writer';
export type {BulkWriterError} from './bulk-writer';
export type {BundleBuilder} from './bundle';
Expand Down
80 changes: 69 additions & 11 deletions dev/src/reference/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,9 @@ export class Query<
* @param options - Options control the vector query. `limit` specifies the upper bound of documents to return, must
* be a positive integer with a maximum value of 1000. `distanceMeasure` specifies what type of distance is calculated
* when performing the query.
*
* @deprecated Use the new {@link findNearest} implementation
* accepting a single `options` param.
*/
findNearest(
vectorField: string | firestore.FieldPath,
Expand All @@ -637,30 +640,85 @@ export class Query<
limit: number;
distanceMeasure: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT';
}
): VectorQuery<AppModelType, DbModelType>;

/**
* Returns a query that can perform vector distance (similarity) search with given parameters.
*
* The returned query, when executed, performs a distance (similarity) search on the specified
* `vectorField` against the given `queryVector` and returns the top documents that are closest
* to the `queryVector`.
*
* Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as `queryVector`
* participate in the query, all other documents are ignored.
*
* @example
* ```
* // Returns the closest 10 documents whose Euclidean distance from their 'embedding' fields are closed to [41, 42].
* const vectorQuery = col.findNearest({
* vectorField: 'embedding',
* queryVector: [41, 42],
* limit: 10,
* distanceMeasure: 'EUCLIDEAN',
* distanceResultField: 'distance',
* distanceThreshold: 0.125
* });
*
* const querySnapshot = await aggregateQuery.get();
* querySnapshot.forEach(...);
* ```
* @param options - An argument specifying the behavior of the {@link VectorQuery} returned by this function.
* See {@link VectorQueryOptions}.
*/
findNearest(
options: VectorQueryOptions
): VectorQuery<AppModelType, DbModelType>;

findNearest(
vectorFieldOrOptions: string | firestore.FieldPath | VectorQueryOptions,
queryVector?: firestore.VectorValue | Array<number>,
options?: {
limit?: number;
distanceMeasure?: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT';
}
): VectorQuery<AppModelType, DbModelType> {
validateFieldPath('vectorField', vectorField);
if (
typeof vectorFieldOrOptions === 'string' ||
vectorFieldOrOptions instanceof FieldPath
) {
const fnOptions: VectorQueryOptions = {
distanceMeasure: options!.distanceMeasure!,
limit: options!.limit!,
queryVector: queryVector!,
vectorField: vectorFieldOrOptions,
};
return this._findNearest(fnOptions);
} else {
return this._findNearest(vectorFieldOrOptions as VectorQueryOptions);
}
}

_findNearest(
options: VectorQueryOptions
): VectorQuery<AppModelType, DbModelType> {
validateFieldPath('vectorField', options.vectorField);

if (options.limit <= 0) {
throw invalidArgumentMessage('options.limit', 'positive limit number');
throw invalidArgumentMessage('limit', 'positive limit number');
}

if (
(Array.isArray(queryVector)
? queryVector.length
: queryVector.toArray().length) === 0
(Array.isArray(options.queryVector)
? options.queryVector.length
: options.queryVector.toArray().length) === 0
) {
throw invalidArgumentMessage(
'queryVector',
'vector size must be larger than 0'
);
}

return new VectorQuery<AppModelType, DbModelType>(
this,
vectorField,
queryVector,
new VectorQueryOptions(options.limit, options.distanceMeasure)
);
return new VectorQuery<AppModelType, DbModelType>(this, options);
}

/**
Expand Down
59 changes: 42 additions & 17 deletions dev/src/reference/vector-query-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,48 @@
* limitations under the License.
*/

export class VectorQueryOptions {
constructor(
readonly limit: number,
readonly distanceMeasure: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT'
) {}
import * as firestore from '@google-cloud/firestore';

isEqual(other: VectorQueryOptions): boolean {
if (this === other) {
return true;
}
if (!(other instanceof VectorQueryOptions)) {
return false;
}
/**
* Specifies the behavior of the {@link VectorQuery} generated by a call to {@link Query.findNearest}.
*/
export interface VectorQueryOptions {
/**
* A string or {@link FieldPath} specifying the vector field to search on.
*/
vectorField: string | firestore.FieldPath;

/**
* The {@link VectorValue} used to measure the distance from `vectorField` values in the documents.
*/
queryVector: firestore.VectorValue | Array<number>;

/**
* Specifies the upper bound of documents to return, must be a positive integer with a maximum value of 1000.
*/
limit: number;

/**
* Specifies what type of distance is calculated when performing the query.
*/
distanceMeasure: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT';

/**
* Optionally specifies the name of a field that will be set on each returned DocumentSnapshot,
* which will contain the computed distance for the document.
*/
distanceResultField?: string | firestore.FieldPath;

return (
this.limit === other.limit &&
this.distanceMeasure === other.distanceMeasure
);
}
/**
* Specifies a threshold for which no less similar documents will be returned. The behavior
* of the specified `distanceMeasure` will affect the meaning of the distance threshold.
*
* - For `distanceMeasure: "EUCLIDEAN"`, the meaning of `distanceThreshold` is:
* SELECT docs WHERE euclidean_distance <= distanceThreshold
* - For `distanceMeasure: "COSINE"`, the meaning of `distanceThreshold` is:
* SELECT docs WHERE cosine_distance <= distanceThreshold
* - For `distanceMeasure: "DOT_PRODUCT"`, the meaning of `distanceThreshold` is:
* SELECT docs WHERE dot_product_distance >= distanceThreshold
*/
distanceThreshold?: number;
}
55 changes: 38 additions & 17 deletions dev/src/reference/vector-query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ export class VectorQuery<
*/
constructor(
private readonly _query: Query<AppModelType, DbModelType>,
private readonly vectorField: string | firestore.FieldPath,
private readonly queryVector: firestore.VectorValue | Array<number>,
private readonly options: VectorQueryOptions
private readonly _options: VectorQueryOptions
) {
this._queryUtil = new QueryUtil<
AppModelType,
Expand All @@ -79,19 +77,31 @@ export class VectorQuery<
* @internal
*/
private get _rawVectorField(): string {
return typeof this.vectorField === 'string'
? this.vectorField
: this.vectorField.toString();
return typeof this._options.vectorField === 'string'
? this._options.vectorField
: this._options.vectorField.toString();
}

/**
* @private
* @internal
*/
private get _rawDistanceResultField(): string | undefined {
if (typeof this._options.distanceResultField === 'undefined') return;

return typeof this._options.distanceResultField === 'string'
? this._options.distanceResultField
: this._options.distanceResultField.toString();
}

/**
* @private
* @internal
*/
private get _rawQueryVector(): Array<number> {
return Array.isArray(this.queryVector)
? this.queryVector
: this.queryVector.toArray();
return Array.isArray(this._options.queryVector)
? this._options.queryVector
: this._options.queryVector.toArray();
}

/**
Expand Down Expand Up @@ -157,7 +167,7 @@ export class VectorQuery<
}

/**
* Internal method for serializing a query to its RunAggregationQuery proto
* Internal method for serializing a query to its proto
* representation with an optional transaction id.
*
* @private
Expand All @@ -170,17 +180,25 @@ export class VectorQuery<
): api.IRunQueryRequest {
const queryProto = this._query.toProto(transactionOrReadTime);

const queryVector = Array.isArray(this.queryVector)
? new VectorValue(this.queryVector)
: (this.queryVector as VectorValue);
const queryVector = Array.isArray(this._options.queryVector)
? new VectorValue(this._options.queryVector)
: (this._options.queryVector as VectorValue);

queryProto.structuredQuery!.findNearest = {
limit: {value: this.options.limit},
distanceMeasure: this.options.distanceMeasure,
limit: {value: this._options.limit},
distanceMeasure: this._options.distanceMeasure,
vectorField: {
fieldPath: FieldPath.fromArgument(this.vectorField).formattedName,
fieldPath: FieldPath.fromArgument(this._options.vectorField)
.formattedName,
},
queryVector: queryVector._toProto(this._query._serializer),
distanceResultField: this._options?.distanceResultField
? FieldPath.fromArgument(this._options.distanceResultField!)
.formattedName
: undefined,
distanceThreshold: this._options?.distanceThreshold
? {value: this._options?.distanceThreshold}
: undefined,
};

if (explainOptions) {
Expand Down Expand Up @@ -253,7 +271,10 @@ export class VectorQuery<
return (
this._rawVectorField === other._rawVectorField &&
isPrimitiveArrayEqual(this._rawQueryVector, other._rawQueryVector) &&
this.options.isEqual(other.options)
this._options.limit === other._options.limit &&
this._options.distanceMeasure === other._options.distanceMeasure &&
this._options.distanceThreshold === other._options.distanceThreshold &&
this._rawDistanceResultField === other._rawDistanceResultField
);
}
}
Loading
Loading