From 5eada9acc3db387dbab98002fc038b2ccd81d76f Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Mon, 15 Jul 2024 09:02:00 +0000 Subject: [PATCH 1/9] feat(firestore): Adding vector search --- firestore/docref.go | 22 ++++++ firestore/from_value.go | 170 +++++++++++++++++++++++++--------------- firestore/query.go | 96 +++++++++++++++++++++++ firestore/to_value.go | 3 + 4 files changed, 229 insertions(+), 62 deletions(-) diff --git a/firestore/docref.go b/firestore/docref.go index 822316316c98..52180900039b 100644 --- a/firestore/docref.go +++ b/firestore/docref.go @@ -638,6 +638,28 @@ func (s sentinel) String() string { } } +// VectorType represpresents a vector +type VectorType interface { + isVectorType() + toProtoValue() (*pb.Value, bool, error) +} + +// Vector represents a vector in the form of a float64 array +type Vector []float64 + +func (_ Vector) isVectorType() {} + +func (vector Vector) toProtoValue() (*pb.Value, bool, error) { + if vector == nil { + return nullValue, false, nil + } + + vectorMap := map[string]interface{}{} + vectorMap["__type__"] = "__vector__" + vectorMap["value"] = []float64(vector) + return mapToProtoValue(reflect.ValueOf(vectorMap)) +} + // An Update describes an update to a value referred to by a path. // An Update should have either a non-empty Path or a non-empty FieldPath, // but not both. diff --git a/firestore/from_value.go b/firestore/from_value.go index df68465a9e57..bdc9ffc0c585 100644 --- a/firestore/from_value.go +++ b/firestore/from_value.go @@ -32,63 +32,67 @@ func setFromProtoValue(x interface{}, vproto *pb.Value, c *Client) error { return setReflectFromProtoValue(v.Elem(), vproto, c) } -// setReflectFromProtoValue sets v from a Firestore Value. -// v must be a settable value. -func setReflectFromProtoValue(v reflect.Value, vproto *pb.Value, c *Client) error { +// setReflectFromProtoValue sets vDest from a Firestore Value. +// vDest must be a settable value. +func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Client) error { typeErr := func() error { - return fmt.Errorf("firestore: cannot set type %s to %s", v.Type(), typeString(vproto)) + return fmt.Errorf("firestore: cannot set type %s to %s", vDest.Type(), typeString(vprotoSrc)) } - val := vproto.ValueType + typeErrWithArgs := func(destType string) error { + return fmt.Errorf("firestore: cannot set type %s to %s", destType, typeString(vprotoSrc)) + } + + valTypeSrc := vprotoSrc.ValueType // A Null value sets anything nullable to nil, and has no effect // on anything else. - if _, ok := val.(*pb.Value_NullValue); ok { - switch v.Kind() { + if _, ok := valTypeSrc.(*pb.Value_NullValue); ok { + switch vDest.Kind() { case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: - v.Set(reflect.Zero(v.Type())) + vDest.Set(reflect.Zero(vDest.Type())) } return nil } // Handle special types first. - switch v.Type() { + switch vDest.Type() { case typeOfByteSlice: - x, ok := val.(*pb.Value_BytesValue) + x, ok := valTypeSrc.(*pb.Value_BytesValue) if !ok { return typeErr() } - v.SetBytes(x.BytesValue) + vDest.SetBytes(x.BytesValue) return nil case typeOfGoTime: - x, ok := val.(*pb.Value_TimestampValue) + x, ok := valTypeSrc.(*pb.Value_TimestampValue) if !ok { return typeErr() } if err := x.TimestampValue.CheckValid(); err != nil { return err } - v.Set(reflect.ValueOf(x.TimestampValue.AsTime())) + vDest.Set(reflect.ValueOf(x.TimestampValue.AsTime())) return nil case typeOfProtoTimestamp: - x, ok := val.(*pb.Value_TimestampValue) + x, ok := valTypeSrc.(*pb.Value_TimestampValue) if !ok { return typeErr() } - v.Set(reflect.ValueOf(x.TimestampValue)) + vDest.Set(reflect.ValueOf(x.TimestampValue)) return nil case typeOfLatLng: - x, ok := val.(*pb.Value_GeoPointValue) + x, ok := valTypeSrc.(*pb.Value_GeoPointValue) if !ok { return typeErr() } - v.Set(reflect.ValueOf(x.GeoPointValue)) + vDest.Set(reflect.ValueOf(x.GeoPointValue)) return nil case typeOfDocumentRef: - x, ok := val.(*pb.Value_ReferenceValue) + x, ok := valTypeSrc.(*pb.Value_ReferenceValue) if !ok { return typeErr() } @@ -96,157 +100,199 @@ func setReflectFromProtoValue(v reflect.Value, vproto *pb.Value, c *Client) erro if err != nil { return err } - v.Set(reflect.ValueOf(dr)) + vDest.Set(reflect.ValueOf(dr)) + return nil + + case typeOfVector: + /* + Vector is stored as: + { + "__type__": "__vector__", + "value": []float64{}, + } + but needs to be returned as firestore.Vector to the user + */ + + // Convert Firestore proto map from Go map + vectorMapDest := map[string]interface{}{} + vectorMapDestVal := reflect.ValueOf(vectorMapDest) + x, ok := valTypeSrc.(*pb.Value_MapValue) + if !ok { + // Vector not stored as map in Firestore + return typeErrWithArgs("Vector") + } + err := populateMap(vectorMapDestVal, x.MapValue.Fields, c) + if err != nil { + return err + } + + // Convert value at "value" key to array of floats + anyArr, isInterfaceArr := vectorMapDest["value"].([]interface{}) + if !isInterfaceArr { + // value at "value" key is not an array + return typeErrWithArgs("Vector") + } + floats := []float64{} + for _, v := range anyArr { + // Convert each element of []interface{} to float64 + floatVal, isFloat := v.(float64) + if isFloat { + floats = append(floats, floatVal) + } + } + + // Set Vector in destination + vDest.Set(reflect.ValueOf(floats)) return nil } - switch v.Kind() { + switch vDest.Kind() { case reflect.Bool: - x, ok := val.(*pb.Value_BooleanValue) + x, ok := valTypeSrc.(*pb.Value_BooleanValue) if !ok { return typeErr() } - v.SetBool(x.BooleanValue) + vDest.SetBool(x.BooleanValue) case reflect.String: - x, ok := val.(*pb.Value_StringValue) + x, ok := valTypeSrc.(*pb.Value_StringValue) if !ok { return typeErr() } - v.SetString(x.StringValue) + vDest.SetString(x.StringValue) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: var i int64 - switch x := val.(type) { + switch x := valTypeSrc.(type) { case *pb.Value_IntegerValue: i = x.IntegerValue case *pb.Value_DoubleValue: f := x.DoubleValue i = int64(f) if float64(i) != f { - return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type()) + return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type()) } default: return typeErr() } - if v.OverflowInt(i) { - return overflowErr(v, i) + if vDest.OverflowInt(i) { + return overflowErr(vDest, i) } - v.SetInt(i) + vDest.SetInt(i) case reflect.Uint8, reflect.Uint16, reflect.Uint32: var u uint64 - switch x := val.(type) { + switch x := valTypeSrc.(type) { case *pb.Value_IntegerValue: u = uint64(x.IntegerValue) case *pb.Value_DoubleValue: f := x.DoubleValue u = uint64(f) if float64(u) != f { - return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type()) + return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type()) } default: return typeErr() } - if v.OverflowUint(u) { - return overflowErr(v, u) + if vDest.OverflowUint(u) { + return overflowErr(vDest, u) } - v.SetUint(u) + vDest.SetUint(u) case reflect.Float32, reflect.Float64: var f float64 - switch x := val.(type) { + switch x := valTypeSrc.(type) { case *pb.Value_DoubleValue: f = x.DoubleValue case *pb.Value_IntegerValue: f = float64(x.IntegerValue) if int64(f) != x.IntegerValue { - return overflowErr(v, x.IntegerValue) + return overflowErr(vDest, x.IntegerValue) } default: return typeErr() } - if v.OverflowFloat(f) { - return overflowErr(v, f) + if vDest.OverflowFloat(f) { + return overflowErr(vDest, f) } - v.SetFloat(f) + vDest.SetFloat(f) case reflect.Slice: - x, ok := val.(*pb.Value_ArrayValue) + x, ok := valTypeSrc.(*pb.Value_ArrayValue) if !ok { return typeErr() } vals := x.ArrayValue.Values - vlen := v.Len() + vlen := vDest.Len() xlen := len(vals) // Make a slice of the right size, avoiding allocation if possible. switch { case vlen < xlen: - v.Set(reflect.MakeSlice(v.Type(), xlen, xlen)) + vDest.Set(reflect.MakeSlice(vDest.Type(), xlen, xlen)) case vlen > xlen: - v.SetLen(xlen) + vDest.SetLen(xlen) } - return populateRepeated(v, vals, xlen, c) + return populateRepeated(vDest, vals, xlen, c) case reflect.Array: - x, ok := val.(*pb.Value_ArrayValue) + x, ok := valTypeSrc.(*pb.Value_ArrayValue) if !ok { return typeErr() } vals := x.ArrayValue.Values xlen := len(vals) - vlen := v.Len() + vlen := vDest.Len() minlen := vlen // Set extra elements to their zero value. if vlen > xlen { - z := reflect.Zero(v.Type().Elem()) + z := reflect.Zero(vDest.Type().Elem()) for i := xlen; i < vlen; i++ { - v.Index(i).Set(z) + vDest.Index(i).Set(z) } minlen = xlen } - return populateRepeated(v, vals, minlen, c) + return populateRepeated(vDest, vals, minlen, c) case reflect.Map: - x, ok := val.(*pb.Value_MapValue) + x, ok := valTypeSrc.(*pb.Value_MapValue) if !ok { return typeErr() } - return populateMap(v, x.MapValue.Fields, c) + return populateMap(vDest, x.MapValue.Fields, c) case reflect.Ptr: // If the pointer is nil, set it to a zero value. - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) + if vDest.IsNil() { + vDest.Set(reflect.New(vDest.Type().Elem())) } - return setReflectFromProtoValue(v.Elem(), vproto, c) + return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c) case reflect.Struct: - x, ok := val.(*pb.Value_MapValue) + x, ok := valTypeSrc.(*pb.Value_MapValue) if !ok { return typeErr() } - return populateStruct(v, x.MapValue.Fields, c) + return populateStruct(vDest, x.MapValue.Fields, c) case reflect.Interface: - if v.NumMethod() == 0 { // empty interface + if vDest.NumMethod() == 0 { // empty interface // If v holds a pointer, set the pointer. - if !v.IsNil() && v.Elem().Kind() == reflect.Ptr { - return setReflectFromProtoValue(v.Elem(), vproto, c) + if !vDest.IsNil() && vDest.Elem().Kind() == reflect.Ptr { + return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c) } // Otherwise, create a fresh value. - x, err := createFromProtoValue(vproto, c) + x, err := createFromProtoValue(vprotoSrc, c) if err != nil { return err } - v.Set(reflect.ValueOf(x)) + vDest.Set(reflect.ValueOf(x)) return nil } // Any other kind of interface is an error. fallthrough default: - return fmt.Errorf("firestore: cannot set type %s", v.Type()) + return fmt.Errorf("firestore: cannot set type %s", vDest.Type()) } return nil } diff --git a/firestore/query.go b/firestore/query.go index 73ef4b4cca01..2c41e539ac91 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -61,6 +61,8 @@ type Query struct { // readOptions specifies constraints for reading results from the query // e.g. read time readSettings *readSettings + + findNearest *pb.StructuredQuery_FindNearest } // DocumentID is the special field name representing the ID of a document @@ -364,6 +366,97 @@ func (q Query) Deserialize(bytes []byte) (Query, error) { return q.fromProto(&runQueryRequest) } +// DistanceMeasure is the distance measure to use when comparing vectors. +type DistanceMeasure int32 + +const ( + // Measures the EUCLIDEAN distance between the vectors. See + // [Euclidean](https://en.wikipedia.org/wiki/Euclidean_distance) to learn + // more + DistanceMeasureEuclidean DistanceMeasure = 1 + + // Compares vectors based on the angle between them, which allows you to + // measure similarity that isn't based on the vectors magnitude. + // We recommend using DOT_PRODUCT with unit normalized vectors instead of + // COSINE distance, which is mathematically equivalent with better + // performance. See [Cosine + // Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) to learn + // more. + DistanceMeasureCosine DistanceMeasure = 2 + + // Similar to cosine but is affected by the magnitude of the vectors. See + // [Dot Product](https://en.wikipedia.org/wiki/Dot_product) to learn more. + DistanceMeasureDotProduct DistanceMeasure = 3 +) + +// FindNearestOpts is options to use while building FindNearest vector query +type FindNearestOpts struct { + Limit int + Measure DistanceMeasure +} + +// VectorQuery represents a vector query +type VectorQuery struct { + Query +} + +// FindNearest returns a query that can perform vector distance (similarity) search with given parameters. +// +// The returned query, when executed, performs a distance (similarity) search on the specified +// 'vectorField' against the given 'queryVector' and returns the top documents that are closest +// to the 'queryVector;. +// +// Only documents whose 'vectorField' field is a Vector of the same dimension as 'queryVector' +// participate in the query, all other documents are ignored. +// +// The 'vectorField' argument can be a single field or a dot-separated sequence of +// fields, and must not contain any of the runes "˜*/[]". +func (q Query) FindNearest(vectorField string, queryVector VectorType, options FindNearestOpts) VectorQuery { + vq := VectorQuery{ + Query: q, + } + + // Validate field path + fieldPath, err := parseDotSeparatedString(vectorField) + if err != nil { + vq.Query.err = err + return vq + } + return q.FindNearestPath(fieldPath, queryVector, options) +} + +// FindNearestPath is similar to FindNearest but accepts field path in the form of array of strings +func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector VectorType, options FindNearestOpts) VectorQuery { + vq := VectorQuery{ + Query: q, + } + // Convert field path field reference + vectorFieldRef, err := fref(vectorFieldPath) + if err != nil { + vq.Query.err = err + return vq + } + + pbVal, sawTransform, err := toProtoValue(reflect.ValueOf(queryVector)) + if err != nil { + vq.Query.err = err + return vq + } + if sawTransform { + vq.Query.err = errors.New("firestore: transforms disallowed in query value") + return vq + } + + vq.Query.findNearest = &pb.StructuredQuery_FindNearest{ + VectorField: vectorFieldRef, + QueryVector: pbVal, + Limit: &wrapperspb.Int32Value{Value: trunc32(options.Limit)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_DistanceMeasure(options.Measure), + } + + return vq +} + // NewAggregationQuery returns an AggregationQuery with this query as its // base query. func (q *Query) NewAggregationQuery() *AggregationQuery { @@ -475,6 +568,8 @@ func (q Query) fromProto(pbQuery *pb.RunQueryRequest) (Query, error) { q.limit = limit } + q.findNearest = pbq.GetFindNearest() + // NOTE: limit to last isn't part of the proto, this is a client-side concept // limitToLast bool return q, q.err @@ -556,6 +651,7 @@ func (q Query) toProto() (*pb.StructuredQuery, error) { return nil, err } p.EndAt = cursor + p.FindNearest = q.findNearest return p, nil } diff --git a/firestore/to_value.go b/firestore/to_value.go index 0921ef9e6c51..fec55753a762 100644 --- a/firestore/to_value.go +++ b/firestore/to_value.go @@ -34,6 +34,7 @@ var ( typeOfLatLng = reflect.TypeOf((*latlng.LatLng)(nil)) typeOfDocumentRef = reflect.TypeOf((*DocumentRef)(nil)) typeOfProtoTimestamp = reflect.TypeOf((*ts.Timestamp)(nil)) + typeOfVector = reflect.TypeOf(Vector{}) ) // toProtoValue converts a Go value to a Firestore Value protobuf. @@ -69,6 +70,8 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) return nullValue, false, nil } return &pb.Value{ValueType: &pb.Value_TimestampValue{TimestampValue: x}}, false, nil + case VectorType: + return x.toProtoValue() case *latlng.LatLng: if x == nil { // gRPC doesn't like nil oneofs. Use NullValue. From 2594a79aae24fd332aa9ba1282aaa2087d815f4b Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Mon, 15 Jul 2024 18:44:01 +0000 Subject: [PATCH 2/9] feat(firestore): refactoring code --- firestore/docref.go | 22 -------- firestore/from_value.go | 42 +------------- firestore/query.go | 42 ++++++++------ firestore/to_value.go | 16 ++++-- firestore/vector.go | 121 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 159 insertions(+), 84 deletions(-) create mode 100644 firestore/vector.go diff --git a/firestore/docref.go b/firestore/docref.go index 52180900039b..822316316c98 100644 --- a/firestore/docref.go +++ b/firestore/docref.go @@ -638,28 +638,6 @@ func (s sentinel) String() string { } } -// VectorType represpresents a vector -type VectorType interface { - isVectorType() - toProtoValue() (*pb.Value, bool, error) -} - -// Vector represents a vector in the form of a float64 array -type Vector []float64 - -func (_ Vector) isVectorType() {} - -func (vector Vector) toProtoValue() (*pb.Value, bool, error) { - if vector == nil { - return nullValue, false, nil - } - - vectorMap := map[string]interface{}{} - vectorMap["__type__"] = "__vector__" - vectorMap["value"] = []float64(vector) - return mapToProtoValue(reflect.ValueOf(vectorMap)) -} - // An Update describes an update to a value referred to by a path. // An Update should have either a non-empty Path or a non-empty FieldPath, // but not both. diff --git a/firestore/from_value.go b/firestore/from_value.go index bdc9ffc0c585..06e6bad72251 100644 --- a/firestore/from_value.go +++ b/firestore/from_value.go @@ -39,10 +39,6 @@ func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Clien return fmt.Errorf("firestore: cannot set type %s to %s", vDest.Type(), typeString(vprotoSrc)) } - typeErrWithArgs := func(destType string) error { - return fmt.Errorf("firestore: cannot set type %s to %s", destType, typeString(vprotoSrc)) - } - valTypeSrc := vprotoSrc.ValueType // A Null value sets anything nullable to nil, and has no effect // on anything else. @@ -104,45 +100,11 @@ func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Clien return nil case typeOfVector: - /* - Vector is stored as: - { - "__type__": "__vector__", - "value": []float64{}, - } - but needs to be returned as firestore.Vector to the user - */ - - // Convert Firestore proto map from Go map - vectorMapDest := map[string]interface{}{} - vectorMapDestVal := reflect.ValueOf(vectorMapDest) - x, ok := valTypeSrc.(*pb.Value_MapValue) - if !ok { - // Vector not stored as map in Firestore - return typeErrWithArgs("Vector") - } - err := populateMap(vectorMapDestVal, x.MapValue.Fields, c) + vector, err := vectorFromProtoValue(vprotoSrc) if err != nil { return err } - - // Convert value at "value" key to array of floats - anyArr, isInterfaceArr := vectorMapDest["value"].([]interface{}) - if !isInterfaceArr { - // value at "value" key is not an array - return typeErrWithArgs("Vector") - } - floats := []float64{} - for _, v := range anyArr { - // Convert each element of []interface{} to float64 - floatVal, isFloat := v.(float64) - if isFloat { - floats = append(floats, floatVal) - } - } - - // Set Vector in destination - vDest.Set(reflect.ValueOf(floats)) + vDest.Set(reflect.ValueOf(vector)) return nil } diff --git a/firestore/query.go b/firestore/query.go index 2c41e539ac91..e1b4c6bcf50d 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -366,27 +366,33 @@ func (q Query) Deserialize(bytes []byte) (Query, error) { return q.fromProto(&runQueryRequest) } -// DistanceMeasure is the distance measure to use when comparing vectors. +// DistanceMeasure is the distance measure to use when comparing vectors with [Query.FindNearest] or [Query.FindNearestPath]. type DistanceMeasure int32 const ( - // Measures the EUCLIDEAN distance between the vectors. See - // [Euclidean](https://en.wikipedia.org/wiki/Euclidean_distance) to learn + // DistanceMeasureEuclidean is used to measures the Euclidean distance between the vectors. See + // [Euclidean] to learn // more - DistanceMeasureEuclidean DistanceMeasure = 1 + // + // [Euclidean]: https://en.wikipedia.org/wiki/Euclidean_distance + DistanceMeasureEuclidean DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_EUCLIDEAN) - // Compares vectors based on the angle between them, which allows you to + // DistanceMeasureEuclidean compares vectors based on the angle between them, which allows you to // measure similarity that isn't based on the vectors magnitude. - // We recommend using DOT_PRODUCT with unit normalized vectors instead of - // COSINE distance, which is mathematically equivalent with better + // We recommend using dot product with unit normalized vectors instead of + // cosine distance, which is mathematically equivalent with better // performance. See [Cosine - // Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) to learn + // Similarity] to learn // more. - DistanceMeasureCosine DistanceMeasure = 2 - - // Similar to cosine but is affected by the magnitude of the vectors. See - // [Dot Product](https://en.wikipedia.org/wiki/Dot_product) to learn more. - DistanceMeasureDotProduct DistanceMeasure = 3 + // + // [Cosine Similarity]: https://en.wikipedia.org/wiki/Cosine_similarity + DistanceMeasureCosine DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_COSINE) + + // DistanceMeasureDotProduct is similar to cosine but is affected by the magnitude of the vectors. See + // [Dot Product] to learn more. + // + // [Dot Product]: https://en.wikipedia.org/wiki/Dot_product) + DistanceMeasureDotProduct DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_DOT_PRODUCT) ) // FindNearestOpts is options to use while building FindNearest vector query @@ -403,13 +409,13 @@ type VectorQuery struct { // FindNearest returns a query that can perform vector distance (similarity) search with given parameters. // // The returned query, when executed, performs a distance (similarity) search on the specified -// 'vectorField' against the given 'queryVector' and returns the top documents that are closest -// to the 'queryVector;. +// vectorField against the given queryVector and returns the top documents that are closest +// to the queryVector;. // -// Only documents whose 'vectorField' field is a Vector of the same dimension as 'queryVector' +// Only documents whose vectorField field is a Vector of the same dimension as queryVector // participate in the query, all other documents are ignored. // -// The 'vectorField' argument can be a single field or a dot-separated sequence of +// The vectorField argument can be a single field or a dot-separated sequence of // fields, and must not contain any of the runes "˜*/[]". func (q Query) FindNearest(vectorField string, queryVector VectorType, options FindNearestOpts) VectorQuery { vq := VectorQuery{ @@ -425,7 +431,7 @@ func (q Query) FindNearest(vectorField string, queryVector VectorType, options F return q.FindNearestPath(fieldPath, queryVector, options) } -// FindNearestPath is similar to FindNearest but accepts field path in the form of array of strings +// FindNearestPath is similar to FindNearest but it accepts [FieldPath] func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector VectorType, options FindNearestOpts) VectorQuery { vq := VectorQuery{ Query: q, diff --git a/firestore/to_value.go b/firestore/to_value.go index fec55753a762..27614c2ca7ea 100644 --- a/firestore/to_value.go +++ b/firestore/to_value.go @@ -70,8 +70,8 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) return nullValue, false, nil } return &pb.Value{ValueType: &pb.Value_TimestampValue{TimestampValue: x}}, false, nil - case VectorType: - return x.toProtoValue() + case Vector: + return vectorToProtoValue(x), false, nil case *latlng.LatLng: if x == nil { // gRPC doesn't like nil oneofs. Use NullValue. @@ -98,9 +98,9 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) case reflect.Uint8, reflect.Uint16, reflect.Uint32: return &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(v.Uint())}}, false, nil case reflect.Float32, reflect.Float64: - return &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: v.Float()}}, false, nil + return floatToProtoValue(v.Float()), false, nil case reflect.String: - return &pb.Value{ValueType: &pb.Value_StringValue{StringValue: v.String()}}, false, nil + return stringToProtoValue(v.String()), false, nil case reflect.Array: return arrayToProtoValue(v) case reflect.Slice: @@ -125,6 +125,14 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) } } +func stringToProtoValue(s string) *pb.Value { + return &pb.Value{ValueType: &pb.Value_StringValue{StringValue: s}} +} + +func floatToProtoValue(f float64) *pb.Value { + return &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: f}} +} + // arrayToProtoValue converts a array to a Firestore Value protobuf and reports // whether a transform was encountered. func arrayToProtoValue(v reflect.Value) (*pb.Value, bool, error) { diff --git a/firestore/vector.go b/firestore/vector.go new file mode 100644 index 000000000000..c21aff9a1ed2 --- /dev/null +++ b/firestore/vector.go @@ -0,0 +1,121 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "fmt" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +const ( + typeKey = "__type__" + typeValVector = "__vector__" + valueKey = "value" +) + +// VectorType represpresents a vector +type VectorType interface { + isVectorType() +} + +// Vector represents a vector in the form of a float64 array +type Vector []float64 + +func (Vector) isVectorType() {} + +// vectorToProtoValue returns a Firestore [pb.Value] representing the Vector. +func vectorToProtoValue(v Vector) *pb.Value { + if v == nil { + return nullValue + } + pbVals := make([]*pb.Value, len(v)) + for i, val := range v { + pbVals[i] = floatToProtoValue(float64(val)) + } + + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: pbVals}, + }, + }, + }, + }, + }, + } +} + +func vectorFromProtoValue(v *pb.Value) (Vector, error) { + /* + Vector is stored as: + { + "__type__": "__vector__", + "value": []float64{}, + } + but needs to be returned as firestore.Vector to the user + */ + if v == nil { + return nil, nil + } + pbMap, ok := v.ValueType.(*pb.Value_MapValue) + if !ok { + return nil, fmt.Errorf("firestore: cannot convert %v to *pb.Value_MapValue", v.ValueType) + } + m := pbMap.MapValue.Fields + var typeVal string + typeVal, err := stringFromProtoValue(m[typeKey]) + if err != nil { + return nil, err + } + if typeVal != typeValVector { + return nil, fmt.Errorf("firestore: value of %v : %v is not %v", typeKey, typeVal, typeValVector) + } + pbVal, ok := m[valueKey] + if !ok { + return nil, fmt.Errorf("firestore: %v not present in %v", valueKey, m) + } + + pbArr, ok := pbVal.ValueType.(*pb.Value_ArrayValue) + if !ok { + return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_ArrayValue", pbVal.ValueType) + } + + pbArrVals := pbArr.ArrayValue.Values + floats := make([]float64, len(pbArrVals)) + for i, fval := range pbArrVals { + dv, ok := fval.ValueType.(*pb.Value_DoubleValue) + if !ok { + return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType) + } + floats[i] = float64(dv.DoubleValue) + } + return Vector(floats), nil +} + +func stringFromProtoValue(v *pb.Value) (string, error) { + if v == nil { + return "", fmt.Errorf("firestore: failed to convert %v to string", v) + } + sv, ok := v.ValueType.(*pb.Value_StringValue) + if !ok { + return "", fmt.Errorf("firestore: failed to convert %v to *pb.Value_StringValue", v.ValueType) + } + return sv.StringValue, nil +} From 687df77332924bd616e9f26602b7a015a8fca8c3 Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Mon, 15 Jul 2024 20:11:58 +0000 Subject: [PATCH 3/9] feat(firestore): Resolving vet failures --- firestore/query.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/firestore/query.go b/firestore/query.go index e1b4c6bcf50d..8804b4a0f13d 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -377,7 +377,7 @@ const ( // [Euclidean]: https://en.wikipedia.org/wiki/Euclidean_distance DistanceMeasureEuclidean DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_EUCLIDEAN) - // DistanceMeasureEuclidean compares vectors based on the angle between them, which allows you to + // DistanceMeasureCosine compares vectors based on the angle between them, which allows you to // measure similarity that isn't based on the vectors magnitude. // We recommend using dot product with unit normalized vectors instead of // cosine distance, which is mathematically equivalent with better @@ -403,7 +403,7 @@ type FindNearestOpts struct { // VectorQuery represents a vector query type VectorQuery struct { - Query + q Query } // FindNearest returns a query that can perform vector distance (similarity) search with given parameters. @@ -419,41 +419,46 @@ type VectorQuery struct { // fields, and must not contain any of the runes "˜*/[]". func (q Query) FindNearest(vectorField string, queryVector VectorType, options FindNearestOpts) VectorQuery { vq := VectorQuery{ - Query: q, + q: q, } // Validate field path fieldPath, err := parseDotSeparatedString(vectorField) if err != nil { - vq.Query.err = err + vq.q.err = err return vq } return q.FindNearestPath(fieldPath, queryVector, options) } +// Documents returns an iterator over the vector query's resulting documents. +func (vq VectorQuery) Documents(ctx context.Context) *DocumentIterator { + return vq.q.Documents(ctx) +} + // FindNearestPath is similar to FindNearest but it accepts [FieldPath] func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector VectorType, options FindNearestOpts) VectorQuery { vq := VectorQuery{ - Query: q, + q: q, } // Convert field path field reference vectorFieldRef, err := fref(vectorFieldPath) if err != nil { - vq.Query.err = err + vq.q.err = err return vq } pbVal, sawTransform, err := toProtoValue(reflect.ValueOf(queryVector)) if err != nil { - vq.Query.err = err + vq.q.err = err return vq } if sawTransform { - vq.Query.err = errors.New("firestore: transforms disallowed in query value") + vq.q.err = errors.New("firestore: transforms disallowed in query value") return vq } - vq.Query.findNearest = &pb.StructuredQuery_FindNearest{ + vq.q.findNearest = &pb.StructuredQuery_FindNearest{ VectorField: vectorFieldRef, QueryVector: pbVal, Limit: &wrapperspb.Int32Value{Value: trunc32(options.Limit)}, From 1f785a617c84b74e7d73ba80cd871189dd64fa7d Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Tue, 16 Jul 2024 18:00:56 +0000 Subject: [PATCH 4/9] feat(firestore): Adding unit and integration tests --- firestore/from_value.go | 10 +- firestore/integration_test.go | 252 ++++++++++++++++++++++++-------- firestore/query.go | 25 +--- firestore/query_test.go | 97 +++++++++++++ firestore/vector.go | 15 +- firestore/vector_test.go | 261 ++++++++++++++++++++++++++++++++++ 6 files changed, 572 insertions(+), 88 deletions(-) create mode 100644 firestore/vector_test.go diff --git a/firestore/from_value.go b/firestore/from_value.go index 06e6bad72251..de0332795187 100644 --- a/firestore/from_value.go +++ b/firestore/from_value.go @@ -397,8 +397,16 @@ func createFromProtoValue(vproto *pb.Value, c *Client) (interface{}, error) { } ret[k] = r } - return ret, nil + typeVal, ok := ret[typeKey] + if !ok || typeVal != typeValVector { + // Map is not a vector. Return the map + return ret, nil + } + + // Special handling for vector + vector, err := vectorFromProtoValue(vproto) + return vector, err default: return nil, fmt.Errorf("firestore: unknown value type %T", v) } diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 36c710378b8e..1ee89c52299a 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -62,7 +62,7 @@ func TestMain(m *testing.M) { if status != 0 { os.Exit(status) } - cleanupIntegrationTest() + // cleanupIntegrationTest() } os.Exit(0) @@ -75,13 +75,13 @@ const ( ) var ( - iClient *Client - iAdminClient *apiv1.FirestoreAdminClient - iColl *CollectionRef - collectionIDs = uid.NewSpace("go-integration-test", nil) - wantDBPath string - indexNames []string - testParams map[string]interface{} + iClient *Client + iAdminClient *apiv1.FirestoreAdminClient + iColl *CollectionRef + collectionIDs = uid.NewSpace("go-integration-test", nil) + wantDBPath string + testParams map[string]interface{} + seededFirstIndex bool ) func initIntegrationTest() { @@ -146,16 +146,68 @@ func initIntegrationTest() { integrationTestStruct.Ref = refDoc } +type vectorIndex struct { + dimension int32 + fieldPath string +} + +func createVectorIndexes(ctx context.Context, dbPath string, vectorModeIndexes []vectorIndex) []string { + indexNames := make([]string, len(vectorModeIndexes)) + indexParent := fmt.Sprintf("%s/collectionGroups/%s", dbPath, iColl.ID) + + var wg sync.WaitGroup + + // create vectore mode indexes + for i, vectorModeIndex := range vectorModeIndexes { + wg.Add(1) + req := &adminpb.CreateIndexRequest{ + Parent: indexParent, + Index: &adminpb.Index{ + QueryScope: adminpb.Index_COLLECTION, + Fields: []*adminpb.Index_IndexField{ + { + FieldPath: vectorModeIndex.fieldPath, + ValueMode: &adminpb.Index_IndexField_VectorConfig_{ + VectorConfig: &adminpb.Index_IndexField_VectorConfig{ + Dimension: vectorModeIndex.dimension, + Type: &adminpb.Index_IndexField_VectorConfig_Flat{ + Flat: &adminpb.Index_IndexField_VectorConfig_FlatIndex{}, + }, + }, + }, + }, + }, + }, + } + fmt.Printf("req: %+v\n", req) + op, createErr := iAdminClient.CreateIndex(ctx, req) + if createErr != nil { + log.Fatalf("CreateIndex vectorindexes: %v", createErr) + } + if i == 0 && !seededFirstIndex { + seededFirstIndex = true + handleCreateIndexResp(ctx, indexNames, &wg, i, op) + } else { + go handleCreateIndexResp(ctx, indexNames, &wg, i, op) + } + } + + wg.Wait() + return indexNames +} + // createIndexes creates composite indexes on provided Firestore database // Indexes are required to run queries with composite filters on multiple fields. // Without indexes, FailedPrecondition rpc error is seen with // desc 'The query requires multiple indexes'. -func createIndexes(ctx context.Context, dbPath string, indexFields [][]string) { - indexNames = make([]string, len(indexFields)) +func createIndexes(ctx context.Context, dbPath string, orderModeindexFields [][]string) []string { + indexNames := make([]string, len(orderModeindexFields)) indexParent := fmt.Sprintf("%s/collectionGroups/%s", dbPath, iColl.ID) var wg sync.WaitGroup - for i, fields := range indexFields { + + // Create order mode indexes + for i, fields := range orderModeindexFields { wg.Add(1) var adminPbIndexFields []*adminpb.Index_IndexField for _, field := range fields { @@ -177,17 +229,21 @@ func createIndexes(ctx context.Context, dbPath string, indexFields [][]string) { if createErr != nil { log.Fatalf("CreateIndex: %v", createErr) } - if i == 0 { + if i == 0 && !seededFirstIndex { + seededFirstIndex = true // Seed first index to prevent FirestoreMetadataWrite.BootstrapDatabase Concurrent access error - handleCreateIndexResp(ctx, &wg, i, op) + handleCreateIndexResp(ctx, indexNames, &wg, i, op) } else { - go handleCreateIndexResp(ctx, &wg, i, op) + go handleCreateIndexResp(ctx, indexNames, &wg, i, op) } } + wg.Wait() + return indexNames } -func handleCreateIndexResp(ctx context.Context, wg *sync.WaitGroup, i int, op *apiv1.CreateIndexOperation) { +// handleCreateIndexResp handles create index response and puts the created index name at index i in the indexNames array +func handleCreateIndexResp(ctx context.Context, indexNames []string, wg *sync.WaitGroup, i int, op *apiv1.CreateIndexOperation) { defer wg.Done() createdIndex, waitErr := op.Wait(ctx) if waitErr != nil { @@ -197,7 +253,7 @@ func handleCreateIndexResp(ctx context.Context, wg *sync.WaitGroup, i int, op *a } // deleteIndexes deletes composite indexes created in createIndexes function -func deleteIndexes(ctx context.Context) { +func deleteIndexes(ctx context.Context, indexNames []string) { for _, indexName := range indexNames { err := iAdminClient.DeleteIndex(ctx, &adminpb.DeleteIndexRequest{ Name: indexName, @@ -293,10 +349,6 @@ func deleteDocument(ctx context.Context, docRef *DocumentRef, bulkwriter *BulkWr func cleanupIntegrationTest() { if iClient != nil { - adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - deleteIndexes(adminCtx) - ctx := context.Background() deleteCollection(ctx, iColl) iClient.Close() @@ -346,44 +398,46 @@ var ( // Use this when writing a doc. integrationTestMap = map[string]interface{}{ - "int": 1, - "int8": int8(2), - "int16": int16(3), - "int32": int32(4), - "int64": int64(5), - "uint8": uint8(6), - "uint16": uint16(7), - "uint32": uint32(8), - "str": "two", - "bool": true, - "float": 3.14, - "null": nil, - "bytes": []byte("bytes"), - "*": map[string]interface{}{"`": 4}, - "time": integrationTime, - "geo": integrationGeo, - "ref": nil, // populated by initIntegrationTest + "int": 1, + "int8": int8(2), + "int16": int16(3), + "int32": int32(4), + "int64": int64(5), + "uint8": uint8(6), + "uint16": uint16(7), + "uint32": uint32(8), + "str": "two", + "bool": true, + "float": 3.14, + "null": nil, + "bytes": []byte("bytes"), + "*": map[string]interface{}{"`": 4}, + "time": integrationTime, + "geo": integrationGeo, + "ref": nil, // populated by initIntegrationTest + "embeddedField": Vector{1.0, 2.0, 3.0}, } // The returned data is slightly different. wantIntegrationTestMap = map[string]interface{}{ - "int": int64(1), - "int8": int64(2), - "int16": int64(3), - "int32": int64(4), - "int64": int64(5), - "uint8": int64(6), - "uint16": int64(7), - "uint32": int64(8), - "str": "two", - "bool": true, - "float": 3.14, - "null": nil, - "bytes": []byte("bytes"), - "*": map[string]interface{}{"`": int64(4)}, - "time": wantIntegrationTime, - "geo": integrationGeo, - "ref": nil, // populated by initIntegrationTest + "int": int64(1), + "int8": int64(2), + "int16": int64(3), + "int32": int64(4), + "int64": int64(5), + "uint8": int64(6), + "uint16": int64(7), + "uint32": int64(8), + "str": "two", + "bool": true, + "float": 3.14, + "null": nil, + "bytes": []byte("bytes"), + "*": map[string]interface{}{"`": int64(4)}, + "time": wantIntegrationTime, + "geo": integrationGeo, + "ref": nil, // populated by initIntegrationTest + "embeddedField": Vector{1.0, 2.0, 3.0}, } integrationTestStruct = integrationTestStructType{ @@ -873,7 +927,8 @@ func TestIntegration_QueryDocuments_WhereEntity(t *testing.T) { {"weight", "height"}} adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) defer cancel() - createIndexes(adminCtx, wantDBPath, indexFields) + indexNames := createIndexes(adminCtx, wantDBPath, indexFields) + defer deleteIndexes(adminCtx, indexNames) h := testHelper{t} nowTime := time.Now() @@ -2462,10 +2517,13 @@ func TestIntegration_AggregationQueries(t *testing.T) { client := integrationClient(t) indexFields := [][]string{ - {"weight", "model"}} + {"weight", "model"}, + {"weight", "height"}, + } adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) defer cancel() - createIndexes(adminCtx, wantDBPath, indexFields) + indexNames := createIndexes(adminCtx, wantDBPath, indexFields) + defer deleteIndexes(adminCtx, indexNames) h := testHelper{t} docs := []map[string]interface{}{ @@ -2782,3 +2840,83 @@ func TestIntegration_ClientReadTime(t *testing.T) { } } } + +func TestIntegration_FindNearest(t *testing.T) { + adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + indexNames := createVectorIndexes(adminCtx, wantDBPath, []vectorIndex{ + { + fieldPath: "EmbeddedField", + dimension: 3, + }, + }) + // defer deleteIndexes(adminCtx, indexNames) + time.Sleep(30 * time.Second) + fmt.Printf("indexNames: %+v\n", indexNames) + type coffeeBean struct { + ID string + EmbeddedField Vector + } + + beans := []coffeeBean{ + { + ID: "Robusta", + EmbeddedField: []float32{1, 2, 3}, + }, + { + ID: "Excelsa", + EmbeddedField: []float32{4, 5, 6}, + }, + { + ID: "Arabica", + EmbeddedField: []float32{100, 200, 300}, // too far from query vector. not within findNearest limit + }, + + { + ID: "Liberica", + EmbeddedField: []float32{1, 2}, // Not enough dimensions as compared to query vector. + }, + } + h := testHelper{t} + coll := integrationColl(t) + ctx := context.Background() + var docRefs []*DocumentRef + + // create documents with vector field + for i := 0; i < len(beans); i++ { + doc := coll.NewDoc() + docRefs = append(docRefs, doc) + h.mustCreate(doc, beans[i]) + } + + // Query documents with a vector field + vectorQuery := iColl.FindNearest("EmbeddedField", []float32{1, 2, 3}, FindNearestOpts{ + Limit: 2, + Measure: DistanceMeasureEuclidean, + }) + + iter := vectorQuery.Documents(ctx) + gotDocs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %+v", err) + } + + if len(gotDocs) != 2 { + t.Fatalf("Expected 2 results, got %d", len(gotDocs)) + } + + for i, doc := range gotDocs { + gotBean := coffeeBean{} + err := doc.DataTo(&gotBean) + if err != nil { + t.Errorf("#%v: DataTo: %+v", doc.Ref.ID, err) + } + if beans[i].ID != gotBean.ID { + t.Errorf("#%v: want: %v, got: %v", i, beans[i].ID, gotBean.ID) + } + } + + // t.Cleanup(func() { + // deleteDocuments(docRefs) + // }) +} diff --git a/firestore/query.go b/firestore/query.go index 8804b4a0f13d..8212aad3014e 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -371,8 +371,7 @@ type DistanceMeasure int32 const ( // DistanceMeasureEuclidean is used to measures the Euclidean distance between the vectors. See - // [Euclidean] to learn - // more + // [Euclidean] to learn more // // [Euclidean]: https://en.wikipedia.org/wiki/Euclidean_distance DistanceMeasureEuclidean DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_EUCLIDEAN) @@ -381,9 +380,7 @@ const ( // measure similarity that isn't based on the vectors magnitude. // We recommend using dot product with unit normalized vectors instead of // cosine distance, which is mathematically equivalent with better - // performance. See [Cosine - // Similarity] to learn - // more. + // performance. See [Cosine Similarity] to learn more. // // [Cosine Similarity]: https://en.wikipedia.org/wiki/Cosine_similarity DistanceMeasureCosine DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_COSINE) @@ -391,7 +388,7 @@ const ( // DistanceMeasureDotProduct is similar to cosine but is affected by the magnitude of the vectors. See // [Dot Product] to learn more. // - // [Dot Product]: https://en.wikipedia.org/wiki/Dot_product) + // [Dot Product]: https://en.wikipedia.org/wiki/Dot_product DistanceMeasureDotProduct DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_DOT_PRODUCT) ) @@ -417,7 +414,7 @@ type VectorQuery struct { // // The vectorField argument can be a single field or a dot-separated sequence of // fields, and must not contain any of the runes "˜*/[]". -func (q Query) FindNearest(vectorField string, queryVector VectorType, options FindNearestOpts) VectorQuery { +func (q Query) FindNearest(vectorField string, queryVector Vector, options FindNearestOpts) VectorQuery { vq := VectorQuery{ q: q, } @@ -437,7 +434,7 @@ func (vq VectorQuery) Documents(ctx context.Context) *DocumentIterator { } // FindNearestPath is similar to FindNearest but it accepts [FieldPath] -func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector VectorType, options FindNearestOpts) VectorQuery { +func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector Vector, options FindNearestOpts) VectorQuery { vq := VectorQuery{ q: q, } @@ -448,19 +445,9 @@ func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector VectorType return vq } - pbVal, sawTransform, err := toProtoValue(reflect.ValueOf(queryVector)) - if err != nil { - vq.q.err = err - return vq - } - if sawTransform { - vq.q.err = errors.New("firestore: transforms disallowed in query value") - return vq - } - vq.q.findNearest = &pb.StructuredQuery_FindNearest{ VectorField: vectorFieldRef, - QueryVector: pbVal, + QueryVector: vectorToProtoValue(queryVector), Limit: &wrapperspb.Int32Value{Value: trunc32(options.Limit)}, DistanceMeasure: pb.StructuredQuery_FindNearest_DistanceMeasure(options.Measure), } diff --git a/firestore/query_test.go b/firestore/query_test.go index 228fb8ef9e29..93a9814bf04f 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -650,6 +650,39 @@ func createTestScenarios(t *testing.T) []toProtoScenario { }, }, }, + { + desc: `q.Where("a", ">", 5).FindNearest`, + in: q.Where("a", ">", 5). + FindNearest("embeddedField", []float32{100, 200, 300}, FindNearestOpts{Limit: 2, Measure: DistanceMeasureEuclidean}).q, + want: &pb.StructuredQuery{ + Where: filtr([]string{"a"}, ">", 5), + FindNearest: &pb.StructuredQuery_FindNearest{ + VectorField: fref1("embeddedField"), + QueryVector: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{ + Values: []*pb.Value{ + {ValueType: &pb.Value_DoubleValue{DoubleValue: 100}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 200}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 300}}, + }, + }, + }, + }, + }, + }, + }, + }, + Limit: &wrapperspb.Int32Value{Value: trunc32(2)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_EUCLIDEAN, + }, + }, + }, } } @@ -1359,3 +1392,67 @@ func TestWithAvgPath(t *testing.T) { } } } + +func TestFindNearest(t *testing.T) { + ctx := context.Background() + c, srv, cleanup := newMock(t) + defer cleanup() + + const dbPath = "projects/projectID/databases/(default)" + mapFields := map[string]*pb.Value{ + typeKey: {ValueType: &pb.Value_StringValue{StringValue: typeValVector}}, + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{ + Values: []*pb.Value{ + {ValueType: &pb.Value_DoubleValue{DoubleValue: 1}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 2}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 2}}, + }, + }, + }, + }, + } + wantPBDocs := []*pb.Document{ + { + Name: dbPath + "/documents/C/a", + CreateTime: aTimestamp, + UpdateTime: aTimestamp, + Fields: map[string]*pb.Value{"EmbeddedField": mapval(mapFields)}, + }, + } + srv.addRPC(nil, []interface{}{ + &pb.RunQueryResponse{Document: wantPBDocs[0]}, + }) + + testcases := []struct { + desc string + path string + wantErr bool + }{ + { + desc: "Invalid path", + path: "path*", + wantErr: true, + }, + { + desc: "Valid path", + path: "path", + wantErr: false, + }, + } + for _, tc := range testcases { + + vQuery := c.Collection("C").FindNearest(tc.path, []float32{5, 6, 7}, FindNearestOpts{ + Limit: 2, + Measure: DistanceMeasureEuclidean, + }) + + _, err := vQuery.Documents(ctx).GetAll() + if err == nil && tc.wantErr { + t.Fatalf("%s: got nil wanted error", tc.desc) + } else if err != nil && !tc.wantErr { + t.Fatalf("%s: got %v, want nil", tc.desc, err) + } + } +} diff --git a/firestore/vector.go b/firestore/vector.go index c21aff9a1ed2..37a445218fe7 100644 --- a/firestore/vector.go +++ b/firestore/vector.go @@ -26,15 +26,8 @@ const ( valueKey = "value" ) -// VectorType represpresents a vector -type VectorType interface { - isVectorType() -} - -// Vector represents a vector in the form of a float64 array -type Vector []float64 - -func (Vector) isVectorType() {} +// Vector represents a vector in the form of a float32 array +type Vector []float32 // vectorToProtoValue returns a Firestore [pb.Value] representing the Vector. func vectorToProtoValue(v Vector) *pb.Value { @@ -98,13 +91,13 @@ func vectorFromProtoValue(v *pb.Value) (Vector, error) { } pbArrVals := pbArr.ArrayValue.Values - floats := make([]float64, len(pbArrVals)) + floats := make([]float32, len(pbArrVals)) for i, fval := range pbArrVals { dv, ok := fval.ValueType.(*pb.Value_DoubleValue) if !ok { return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType) } - floats[i] = float64(dv.DoubleValue) + floats[i] = float32(dv.DoubleValue) } return Vector(floats), nil } diff --git a/firestore/vector_test.go b/firestore/vector_test.go new file mode 100644 index 000000000000..a63ece13fa72 --- /dev/null +++ b/firestore/vector_test.go @@ -0,0 +1,261 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "testing" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestVectorToProtoValue(t *testing.T) { + tests := []struct { + name string + v Vector + want *pb.Value + }{ + { + name: "nil vector", + v: nil, + want: nullValue, + }, + { + name: "empty vector", + v: Vector{}, + want: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{}}, + }, + }, + }, + }, + }, + }, + }, + { + name: "multiple element vector", + v: Vector{1.0, 2.0, 3.0}, + want: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{floatToProtoValue(1.0), floatToProtoValue(2.0), floatToProtoValue(3.0)}}, + }, + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := vectorToProtoValue(tt.v) + if !testEqual(got, tt.want) { + t.Errorf("vectorToProtoValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestVectorFromProtoValue(t *testing.T) { + tests := []struct { + name string + v *pb.Value + want Vector + wantErr bool + }{ + { + name: "nil value", + v: nil, + want: nil, + }, + { + name: "empty vector", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{}}, + }, + }, + }, + }, + }, + }, + want: Vector{}, + }, + { + name: "multiple element vector", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{floatToProtoValue(1.0), floatToProtoValue(2.0), floatToProtoValue(3.0)}}, + }, + }, + }, + }, + }, + }, + want: Vector{1.0, 2.0, 3.0}, + }, + { + name: "invalid type", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue("invalid_type"), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{floatToProtoValue(1.0), floatToProtoValue(2.0), floatToProtoValue(3.0)}}, + }, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "missing type", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{floatToProtoValue(1.0), floatToProtoValue(2.0), floatToProtoValue(3.0)}}, + }, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "missing value", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid value", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_StringValue{ + StringValue: "invalid_value", + }, + }, + }, + }, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := vectorFromProtoValue(tt.v) + if (err != nil) != tt.wantErr { + t.Errorf("vectorFromProtoValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if !cmp.Equal(got, tt.want, cmpopts.EquateEmpty()) { + t.Errorf("vectorFromProtoValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStringFromProtoValue(t *testing.T) { + tests := []struct { + name string + v *pb.Value + want string + wantErr bool + }{ + { + name: "nil value", + v: nil, + wantErr: true, + }, + { + name: "string value", + v: &pb.Value{ + ValueType: &pb.Value_StringValue{ + StringValue: "test_string", + }, + }, + want: "test_string", + }, + { + name: "invalid value", + v: &pb.Value{ + ValueType: &pb.Value_IntegerValue{ + IntegerValue: 123, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := stringFromProtoValue(tt.v) + if (err != nil) != tt.wantErr { + t.Errorf("stringFromProtoValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if got != tt.want { + t.Errorf("stringFromProtoValue() = %v, want %v", got, tt.want) + } + }) + } +} From 9af66eb8abd597f7af43cb6a2293ee85807c11b6 Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Wed, 17 Jul 2024 20:00:12 +0000 Subject: [PATCH 5/9] feat(firestore): Fixing tests and refactoring code --- firestore/integration_test.go | 44 ++++++++++++++++++++--------------- firestore/query.go | 14 +++++------ firestore/query_test.go | 42 ++++++++++++++++++++++++++++----- firestore/vector.go | 21 +++++++++++++---- 4 files changed, 84 insertions(+), 37 deletions(-) diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 1ee89c52299a..110c7fd36572 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -151,9 +151,10 @@ type vectorIndex struct { fieldPath string } -func createVectorIndexes(ctx context.Context, dbPath string, vectorModeIndexes []vectorIndex) []string { +func createVectorIndexes(t *testing.T, ctx context.Context, dbPath string, vectorModeIndexes []vectorIndex) []string { + collRef := integrationColl(t) indexNames := make([]string, len(vectorModeIndexes)) - indexParent := fmt.Sprintf("%s/collectionGroups/%s", dbPath, iColl.ID) + indexParent := fmt.Sprintf("%s/collectionGroups/%s", dbPath, collRef.ID) var wg sync.WaitGroup @@ -2825,7 +2826,11 @@ func TestIntegration_ClientReadTime(t *testing.T) { } tm := time.Now().Add(-time.Minute) + oldReadSettings := *c.readSettings c.WithReadOptions(ReadTime(tm)) + t.Cleanup(func() { + c.readSettings = &oldReadSettings + }) ds, err := c.GetAll(ctx, docs) if err != nil { @@ -2842,17 +2847,22 @@ func TestIntegration_ClientReadTime(t *testing.T) { } func TestIntegration_FindNearest(t *testing.T) { + collRef := integrationColl(t) adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - indexNames := createVectorIndexes(adminCtx, wantDBPath, []vectorIndex{ + t.Cleanup(func() { + cancel() + }) + + indexNames := createVectorIndexes(t, adminCtx, wantDBPath, []vectorIndex{ { fieldPath: "EmbeddedField", dimension: 3, }, }) - // defer deleteIndexes(adminCtx, indexNames) - time.Sleep(30 * time.Second) - fmt.Printf("indexNames: %+v\n", indexNames) + t.Cleanup(func() { + deleteIndexes(adminCtx, indexNames) + }) + type coffeeBean struct { ID string EmbeddedField Vector @@ -2861,26 +2871,29 @@ func TestIntegration_FindNearest(t *testing.T) { beans := []coffeeBean{ { ID: "Robusta", - EmbeddedField: []float32{1, 2, 3}, + EmbeddedField: []float64{1, 2, 3}, }, { ID: "Excelsa", - EmbeddedField: []float32{4, 5, 6}, + EmbeddedField: []float64{4, 5, 6}, }, { ID: "Arabica", - EmbeddedField: []float32{100, 200, 300}, // too far from query vector. not within findNearest limit + EmbeddedField: []float64{100, 200, 300}, // too far from query vector. not within findNearest limit }, { ID: "Liberica", - EmbeddedField: []float32{1, 2}, // Not enough dimensions as compared to query vector. + EmbeddedField: []float64{1, 2}, // Not enough dimensions as compared to query vector. }, } h := testHelper{t} coll := integrationColl(t) ctx := context.Background() var docRefs []*DocumentRef + t.Cleanup(func() { + deleteDocuments(docRefs) + }) // create documents with vector field for i := 0; i < len(beans); i++ { @@ -2890,10 +2903,7 @@ func TestIntegration_FindNearest(t *testing.T) { } // Query documents with a vector field - vectorQuery := iColl.FindNearest("EmbeddedField", []float32{1, 2, 3}, FindNearestOpts{ - Limit: 2, - Measure: DistanceMeasureEuclidean, - }) + vectorQuery := collRef.FindNearest("EmbeddedField", []float64{1, 2, 3}, 2, DistanceMeasureEuclidean, nil) iter := vectorQuery.Documents(ctx) gotDocs, err := iter.GetAll() @@ -2915,8 +2925,4 @@ func TestIntegration_FindNearest(t *testing.T) { t.Errorf("#%v: want: %v, got: %v", i, beans[i].ID, gotBean.ID) } } - - // t.Cleanup(func() { - // deleteDocuments(docRefs) - // }) } diff --git a/firestore/query.go b/firestore/query.go index 8212aad3014e..0df7529b7748 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -394,8 +394,6 @@ const ( // FindNearestOpts is options to use while building FindNearest vector query type FindNearestOpts struct { - Limit int - Measure DistanceMeasure } // VectorQuery represents a vector query @@ -414,7 +412,7 @@ type VectorQuery struct { // // The vectorField argument can be a single field or a dot-separated sequence of // fields, and must not contain any of the runes "˜*/[]". -func (q Query) FindNearest(vectorField string, queryVector Vector, options FindNearestOpts) VectorQuery { +func (q Query) FindNearest(vectorField string, queryVector Vector, limit int, measure DistanceMeasure, options *FindNearestOpts) VectorQuery { vq := VectorQuery{ q: q, } @@ -425,7 +423,7 @@ func (q Query) FindNearest(vectorField string, queryVector Vector, options FindN vq.q.err = err return vq } - return q.FindNearestPath(fieldPath, queryVector, options) + return q.FindNearestPath(fieldPath, queryVector, limit, measure, options) } // Documents returns an iterator over the vector query's resulting documents. @@ -433,8 +431,8 @@ func (vq VectorQuery) Documents(ctx context.Context) *DocumentIterator { return vq.q.Documents(ctx) } -// FindNearestPath is similar to FindNearest but it accepts [FieldPath] -func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector Vector, options FindNearestOpts) VectorQuery { +// FindNearestPath is similar to FindNearest but it accepts a [FieldPath]. +func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector Vector, limit int, measure DistanceMeasure, options *FindNearestOpts) VectorQuery { vq := VectorQuery{ q: q, } @@ -448,8 +446,8 @@ func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector Vector, op vq.q.findNearest = &pb.StructuredQuery_FindNearest{ VectorField: vectorFieldRef, QueryVector: vectorToProtoValue(queryVector), - Limit: &wrapperspb.Int32Value{Value: trunc32(options.Limit)}, - DistanceMeasure: pb.StructuredQuery_FindNearest_DistanceMeasure(options.Measure), + Limit: &wrapperspb.Int32Value{Value: trunc32(limit)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_DistanceMeasure(measure), } return vq diff --git a/firestore/query_test.go b/firestore/query_test.go index 93a9814bf04f..f2f94e9d1782 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -651,9 +651,42 @@ func createTestScenarios(t *testing.T) []toProtoScenario { }, }, { - desc: `q.Where("a", ">", 5).FindNearest`, + desc: `q.Where("a", ">", 5).FindNearest float64 vector`, in: q.Where("a", ">", 5). - FindNearest("embeddedField", []float32{100, 200, 300}, FindNearestOpts{Limit: 2, Measure: DistanceMeasureEuclidean}).q, + FindNearest("embeddedField", []float64{100, 200, 300}, 2, DistanceMeasureEuclidean, nil).q, + want: &pb.StructuredQuery{ + Where: filtr([]string{"a"}, ">", 5), + FindNearest: &pb.StructuredQuery_FindNearest{ + VectorField: fref1("embeddedField"), + QueryVector: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{ + Values: []*pb.Value{ + {ValueType: &pb.Value_DoubleValue{DoubleValue: 100}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 200}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 300}}, + }, + }, + }, + }, + }, + }, + }, + }, + Limit: &wrapperspb.Int32Value{Value: trunc32(2)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_EUCLIDEAN, + }, + }, + }, + { + desc: `q.Where("a", ">", 5).FindNearest float32 vector`, + in: q.Where("a", ">", 5). + FindNearest("embeddedField", ToVector([]float32{100, 200, 300}), 2, DistanceMeasureEuclidean, nil).q, want: &pb.StructuredQuery{ Where: filtr([]string{"a"}, ">", 5), FindNearest: &pb.StructuredQuery_FindNearest{ @@ -1443,10 +1476,7 @@ func TestFindNearest(t *testing.T) { } for _, tc := range testcases { - vQuery := c.Collection("C").FindNearest(tc.path, []float32{5, 6, 7}, FindNearestOpts{ - Limit: 2, - Measure: DistanceMeasureEuclidean, - }) + vQuery := c.Collection("C").FindNearest(tc.path, []float64{5, 6, 7}, 2, DistanceMeasureEuclidean, nil) _, err := vQuery.Documents(ctx).GetAll() if err == nil && tc.wantErr { diff --git a/firestore/vector.go b/firestore/vector.go index 37a445218fe7..e662e12f1b4a 100644 --- a/firestore/vector.go +++ b/firestore/vector.go @@ -26,8 +26,21 @@ const ( valueKey = "value" ) -// Vector represents a vector in the form of a float32 array -type Vector []float32 +// Vector is an embedding vector. +type Vector []float64 + +func ToVector[fType float32 | float64](arr []fType) Vector { + var arrAny interface{} + if arrFloat64, ok := arrAny.([]float64); ok { + return Vector(arrFloat64) // Type assertion, no conversion needed + } + + vec := make(Vector, len(arr)) + for i, val := range arr { + vec[i] = float64(val) + } + return vec +} // vectorToProtoValue returns a Firestore [pb.Value] representing the Vector. func vectorToProtoValue(v Vector) *pb.Value { @@ -91,13 +104,13 @@ func vectorFromProtoValue(v *pb.Value) (Vector, error) { } pbArrVals := pbArr.ArrayValue.Values - floats := make([]float32, len(pbArrVals)) + floats := make([]float64, len(pbArrVals)) for i, fval := range pbArrVals { dv, ok := fval.ValueType.(*pb.Value_DoubleValue) if !ok { return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType) } - floats[i] = float32(dv.DoubleValue) + floats[i] = dv.DoubleValue } return Vector(floats), nil } From e3807b361e06d20c735162446490135ec57f4be8 Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Wed, 17 Jul 2024 20:07:52 +0000 Subject: [PATCH 6/9] feat(firestore): Resolving vet failures --- firestore/integration_test.go | 4 ++-- firestore/vector.go | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 110c7fd36572..e2d5dffaf5e2 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -151,7 +151,7 @@ type vectorIndex struct { fieldPath string } -func createVectorIndexes(t *testing.T, ctx context.Context, dbPath string, vectorModeIndexes []vectorIndex) []string { +func createVectorIndexes(ctx context.Context, t *testing.T, dbPath string, vectorModeIndexes []vectorIndex) []string { collRef := integrationColl(t) indexNames := make([]string, len(vectorModeIndexes)) indexParent := fmt.Sprintf("%s/collectionGroups/%s", dbPath, collRef.ID) @@ -2853,7 +2853,7 @@ func TestIntegration_FindNearest(t *testing.T) { cancel() }) - indexNames := createVectorIndexes(t, adminCtx, wantDBPath, []vectorIndex{ + indexNames := createVectorIndexes(adminCtx, t, wantDBPath, []vectorIndex{ { fieldPath: "EmbeddedField", dimension: 3, diff --git a/firestore/vector.go b/firestore/vector.go index e662e12f1b4a..1ecc80c9eee7 100644 --- a/firestore/vector.go +++ b/firestore/vector.go @@ -29,6 +29,7 @@ const ( // Vector is an embedding vector. type Vector []float64 +// ToVector converts float32 or float64 slice to Firestore embedding vector. func ToVector[fType float32 | float64](arr []fType) Vector { var arrAny interface{} if arrFloat64, ok := arrAny.([]float64); ok { From 695b011e7df3e046b4157bc6ffcca4a236f5d398 Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Fri, 19 Jul 2024 18:09:02 +0000 Subject: [PATCH 7/9] feat(firestore): Refactoring code --- firestore/document.go | 1 + firestore/from_value.go | 16 +++++--- firestore/integration_test.go | 38 +++++++++++------- firestore/query.go | 37 ++++++++++++------ firestore/query_test.go | 39 ++++++++++++------ firestore/to_value.go | 7 +++- firestore/vector.go | 74 +++++++++++++++++++++-------------- firestore/vector_test.go | 12 +++--- 8 files changed, 142 insertions(+), 82 deletions(-) diff --git a/firestore/document.go b/firestore/document.go index 8d63e23edd33..cc384db8409e 100644 --- a/firestore/document.go +++ b/firestore/document.go @@ -97,6 +97,7 @@ func (d *DocumentSnapshot) Data() map[string]interface{} { // Slices are resized to the incoming value's size, while arrays that are too // long have excess elements filled with zero values. If the array is too short, // excess incoming values will be dropped. +// - Vectors convert to []float64 // - Maps convert to map[string]interface{}. When setting a struct field, // maps of key type string and any value type are permitted, and are populated // recursively. diff --git a/firestore/from_value.go b/firestore/from_value.go index de0332795187..8ff05c7410bd 100644 --- a/firestore/from_value.go +++ b/firestore/from_value.go @@ -99,12 +99,19 @@ func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Clien vDest.Set(reflect.ValueOf(dr)) return nil - case typeOfVector: - vector, err := vectorFromProtoValue(vprotoSrc) + case typeOfVector32: + val, err := vector32FromProtoValue(vprotoSrc) if err != nil { return err } - vDest.Set(reflect.ValueOf(vector)) + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfVector64: + val, err := vector64FromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) return nil } @@ -405,8 +412,7 @@ func createFromProtoValue(vproto *pb.Value, c *Client) (interface{}, error) { } // Special handling for vector - vector, err := vectorFromProtoValue(vproto) - return vector, err + return vectorFromProtoValue(vproto) default: return nil, fmt.Errorf("firestore: unknown value type %T", v) } diff --git a/firestore/integration_test.go b/firestore/integration_test.go index e2d5dffaf5e2..a1101a166593 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -180,7 +180,6 @@ func createVectorIndexes(ctx context.Context, t *testing.T, dbPath string, vecto }, }, } - fmt.Printf("req: %+v\n", req) op, createErr := iAdminClient.CreateIndex(ctx, req) if createErr != nil { log.Fatalf("CreateIndex vectorindexes: %v", createErr) @@ -416,7 +415,7 @@ var ( "time": integrationTime, "geo": integrationGeo, "ref": nil, // populated by initIntegrationTest - "embeddedField": Vector{1.0, 2.0, 3.0}, + "embeddedField": Vector64{1.0, 2.0, 3.0}, } // The returned data is slightly different. @@ -438,7 +437,7 @@ var ( "time": wantIntegrationTime, "geo": integrationGeo, "ref": nil, // populated by initIntegrationTest - "embeddedField": Vector{1.0, 2.0, 3.0}, + "embeddedField": Vector64{1.0, 2.0, 3.0}, } integrationTestStruct = integrationTestStructType{ @@ -2863,28 +2862,39 @@ func TestIntegration_FindNearest(t *testing.T) { deleteIndexes(adminCtx, indexNames) }) + queryField := "EmbeddedField64" type coffeeBean struct { - ID string - EmbeddedField Vector + ID string + EmbeddedField64 Vector64 + EmbeddedField32 Vector32 + Float32s []float32 // When querying, saving and retrieving, this should be retrieved as []float32 and not Vector32 } beans := []coffeeBean{ { - ID: "Robusta", - EmbeddedField: []float64{1, 2, 3}, + ID: "Robusta", + EmbeddedField64: []float64{1, 2, 3}, + EmbeddedField32: []float32{1, 2, 3}, + Float32s: []float32{1, 2, 3}, }, { - ID: "Excelsa", - EmbeddedField: []float64{4, 5, 6}, + ID: "Excelsa", + EmbeddedField64: []float64{4, 5, 6}, + EmbeddedField32: []float32{4, 5, 6}, + Float32s: []float32{4, 5, 6}, }, { - ID: "Arabica", - EmbeddedField: []float64{100, 200, 300}, // too far from query vector. not within findNearest limit + ID: "Arabica", + EmbeddedField64: []float64{100, 200, 300}, // too far from query vector. not within findNearest limit + EmbeddedField32: []float32{100, 200, 300}, + Float32s: []float32{100, 200, 300}, }, { - ID: "Liberica", - EmbeddedField: []float64{1, 2}, // Not enough dimensions as compared to query vector. + ID: "Liberica", + EmbeddedField64: []float64{1, 2}, // Not enough dimensions as compared to query vector. + EmbeddedField32: []float32{1, 2}, + Float32s: []float32{1, 2}, }, } h := testHelper{t} @@ -2903,7 +2913,7 @@ func TestIntegration_FindNearest(t *testing.T) { } // Query documents with a vector field - vectorQuery := collRef.FindNearest("EmbeddedField", []float64{1, 2, 3}, 2, DistanceMeasureEuclidean, nil) + vectorQuery := collRef.FindNearest(queryField, []float64{1, 2, 3}, 2, DistanceMeasureEuclidean, nil) iter := vectorQuery.Documents(ctx) gotDocs, err := iter.GetAll() diff --git a/firestore/query.go b/firestore/query.go index 0df7529b7748..a843e57fef19 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -392,8 +392,8 @@ const ( DistanceMeasureDotProduct DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_DOT_PRODUCT) ) -// FindNearestOpts is options to use while building FindNearest vector query -type FindNearestOpts struct { +// FindNearestOptions is options to use while building FindNearest vector query +type FindNearestOptions struct { } // VectorQuery represents a vector query @@ -412,16 +412,14 @@ type VectorQuery struct { // // The vectorField argument can be a single field or a dot-separated sequence of // fields, and must not contain any of the runes "˜*/[]". -func (q Query) FindNearest(vectorField string, queryVector Vector, limit int, measure DistanceMeasure, options *FindNearestOpts) VectorQuery { - vq := VectorQuery{ - q: q, - } - +func (q Query) FindNearest(vectorField string, queryVector interface{}, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery { // Validate field path fieldPath, err := parseDotSeparatedString(vectorField) if err != nil { - vq.q.err = err - return vq + q.err = err + return VectorQuery{ + q: q, + } } return q.FindNearestPath(fieldPath, queryVector, limit, measure, options) } @@ -432,24 +430,37 @@ func (vq VectorQuery) Documents(ctx context.Context) *DocumentIterator { } // FindNearestPath is similar to FindNearest but it accepts a [FieldPath]. -func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector Vector, limit int, measure DistanceMeasure, options *FindNearestOpts) VectorQuery { +func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector interface{}, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery { vq := VectorQuery{ q: q, } - // Convert field path field reference + + // Convert field path to field reference vectorFieldRef, err := fref(vectorFieldPath) if err != nil { vq.q.err = err return vq } + var fnvq *pb.Value + switch v := queryVector.(type) { + case Vector32: + fnvq = vectorToProtoValue([]float32(v)) + case Vector64: + fnvq = vectorToProtoValue([]float64(v)) + default: + vq.q.err = errors.New("firestore: queryVector must be Vector32 or Vector64") + return VectorQuery{ + q: q, + } + } + vq.q.findNearest = &pb.StructuredQuery_FindNearest{ VectorField: vectorFieldRef, - QueryVector: vectorToProtoValue(queryVector), + QueryVector: fnvq, Limit: &wrapperspb.Int32Value{Value: trunc32(limit)}, DistanceMeasure: pb.StructuredQuery_FindNearest_DistanceMeasure(measure), } - return vq } diff --git a/firestore/query_test.go b/firestore/query_test.go index f2f94e9d1782..63e6d568dfdc 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -686,7 +686,7 @@ func createTestScenarios(t *testing.T) []toProtoScenario { { desc: `q.Where("a", ">", 5).FindNearest float32 vector`, in: q.Where("a", ">", 5). - FindNearest("embeddedField", ToVector([]float32{100, 200, 300}), 2, DistanceMeasureEuclidean, nil).q, + FindNearest("embeddedField", []float32{100, 200, 300}, 2, DistanceMeasureEuclidean, nil).q, want: &pb.StructuredQuery{ Where: filtr([]string{"a"}, ">", 5), FindNearest: &pb.StructuredQuery_FindNearest{ @@ -742,7 +742,6 @@ func TestQueryFromProtoRoundTrip(t *testing.T) { if err != nil { t.Fatalf("%s: %v", test.desc, err) } - fmt.Printf("proto: %v\n", proto) got, err := Query{c: c}.Deserialize(proto) if err != nil { t.Fatalf("%s: %v", test.desc, err) @@ -1454,14 +1453,12 @@ func TestFindNearest(t *testing.T) { Fields: map[string]*pb.Value{"EmbeddedField": mapval(mapFields)}, }, } - srv.addRPC(nil, []interface{}{ - &pb.RunQueryResponse{Document: wantPBDocs[0]}, - }) testcases := []struct { - desc string - path string - wantErr bool + desc string + path string + queryVector interface{} + wantErr bool }{ { desc: "Invalid path", @@ -1469,14 +1466,30 @@ func TestFindNearest(t *testing.T) { wantErr: true, }, { - desc: "Valid path", - path: "path", - wantErr: false, + desc: "Valid path", + path: "path", + queryVector: []float64{5, 6, 7}, + wantErr: false, + }, + { + desc: "Invalid vector type", + path: "path", + queryVector: "abcd", + wantErr: false, + }, + { + desc: "Valid vector type", + path: "path", + queryVector: []float32{5, 6, 7}, + wantErr: false, }, } for _, tc := range testcases { - - vQuery := c.Collection("C").FindNearest(tc.path, []float64{5, 6, 7}, 2, DistanceMeasureEuclidean, nil) + srv.reset() + srv.addRPC(nil, []interface{}{ + &pb.RunQueryResponse{Document: wantPBDocs[0]}, + }) + vQuery := c.Collection("C").FindNearest(tc.path, tc.queryVector, 2, DistanceMeasureEuclidean, nil) _, err := vQuery.Documents(ctx).GetAll() if err == nil && tc.wantErr { diff --git a/firestore/to_value.go b/firestore/to_value.go index 27614c2ca7ea..9976b16e7251 100644 --- a/firestore/to_value.go +++ b/firestore/to_value.go @@ -34,7 +34,8 @@ var ( typeOfLatLng = reflect.TypeOf((*latlng.LatLng)(nil)) typeOfDocumentRef = reflect.TypeOf((*DocumentRef)(nil)) typeOfProtoTimestamp = reflect.TypeOf((*ts.Timestamp)(nil)) - typeOfVector = reflect.TypeOf(Vector{}) + typeOfVector64 = reflect.TypeOf(Vector64{}) + typeOfVector32 = reflect.TypeOf(Vector32{}) ) // toProtoValue converts a Go value to a Firestore Value protobuf. @@ -70,7 +71,9 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) return nullValue, false, nil } return &pb.Value{ValueType: &pb.Value_TimestampValue{TimestampValue: x}}, false, nil - case Vector: + case Vector32: + return vectorToProtoValue(x), false, nil + case Vector64: return vectorToProtoValue(x), false, nil case *latlng.LatLng: if x == nil { diff --git a/firestore/vector.go b/firestore/vector.go index 1ecc80c9eee7..2514acffbc03 100644 --- a/firestore/vector.go +++ b/firestore/vector.go @@ -26,25 +26,13 @@ const ( valueKey = "value" ) -// Vector is an embedding vector. -type Vector []float64 - -// ToVector converts float32 or float64 slice to Firestore embedding vector. -func ToVector[fType float32 | float64](arr []fType) Vector { - var arrAny interface{} - if arrFloat64, ok := arrAny.([]float64); ok { - return Vector(arrFloat64) // Type assertion, no conversion needed - } - - vec := make(Vector, len(arr)) - for i, val := range arr { - vec[i] = float64(val) - } - return vec -} +// Vector64 is an embedding vector. +type Vector64 []float64 +type Vector32 []float32 // vectorToProtoValue returns a Firestore [pb.Value] representing the Vector. -func vectorToProtoValue(v Vector) *pb.Value { +// The calling function should check for type safety +func vectorToProtoValue[vType float32 | float64](v []vType) *pb.Value { if v == nil { return nullValue } @@ -69,14 +57,51 @@ func vectorToProtoValue(v Vector) *pb.Value { } } -func vectorFromProtoValue(v *pb.Value) (Vector, error) { +func vectorFromProtoValue(v *pb.Value) (interface{}, error) { + return vector64FromProtoValue(v) +} + +func vector32FromProtoValue(v *pb.Value) (Vector32, error) { + pbArrVals, err := pbValToVectorVals(v) + if err != nil { + return nil, err + } + + floats := make([]float32, len(pbArrVals)) + for i, fval := range pbArrVals { + dv, ok := fval.ValueType.(*pb.Value_DoubleValue) + if !ok { + return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType) + } + floats[i] = float32(dv.DoubleValue) + } + return floats, nil +} + +func vector64FromProtoValue(v *pb.Value) (Vector64, error) { + pbArrVals, err := pbValToVectorVals(v) + if err != nil { + return nil, err + } + + floats := make([]float64, len(pbArrVals)) + for i, fval := range pbArrVals { + dv, ok := fval.ValueType.(*pb.Value_DoubleValue) + if !ok { + return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType) + } + floats[i] = dv.DoubleValue + } + return floats, nil +} + +func pbValToVectorVals(v *pb.Value) ([]*pb.Value, error) { /* Vector is stored as: { "__type__": "__vector__", "value": []float64{}, } - but needs to be returned as firestore.Vector to the user */ if v == nil { return nil, nil @@ -104,16 +129,7 @@ func vectorFromProtoValue(v *pb.Value) (Vector, error) { return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_ArrayValue", pbVal.ValueType) } - pbArrVals := pbArr.ArrayValue.Values - floats := make([]float64, len(pbArrVals)) - for i, fval := range pbArrVals { - dv, ok := fval.ValueType.(*pb.Value_DoubleValue) - if !ok { - return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType) - } - floats[i] = dv.DoubleValue - } - return Vector(floats), nil + return pbArr.ArrayValue.Values, nil } func stringFromProtoValue(v *pb.Value) (string, error) { diff --git a/firestore/vector_test.go b/firestore/vector_test.go index a63ece13fa72..9e1497b7ba0e 100644 --- a/firestore/vector_test.go +++ b/firestore/vector_test.go @@ -25,7 +25,7 @@ import ( func TestVectorToProtoValue(t *testing.T) { tests := []struct { name string - v Vector + v Vector64 want *pb.Value }{ { @@ -35,7 +35,7 @@ func TestVectorToProtoValue(t *testing.T) { }, { name: "empty vector", - v: Vector{}, + v: Vector64{}, want: &pb.Value{ ValueType: &pb.Value_MapValue{ MapValue: &pb.MapValue{ @@ -53,7 +53,7 @@ func TestVectorToProtoValue(t *testing.T) { }, { name: "multiple element vector", - v: Vector{1.0, 2.0, 3.0}, + v: Vector64{1.0, 2.0, 3.0}, want: &pb.Value{ ValueType: &pb.Value_MapValue{ MapValue: &pb.MapValue{ @@ -84,7 +84,7 @@ func TestVectorFromProtoValue(t *testing.T) { tests := []struct { name string v *pb.Value - want Vector + want Vector64 wantErr bool }{ { @@ -108,7 +108,7 @@ func TestVectorFromProtoValue(t *testing.T) { }, }, }, - want: Vector{}, + want: Vector64{}, }, { name: "multiple element vector", @@ -126,7 +126,7 @@ func TestVectorFromProtoValue(t *testing.T) { }, }, }, - want: Vector{1.0, 2.0, 3.0}, + want: Vector64{1.0, 2.0, 3.0}, }, { name: "invalid type", From b6f5f25240656b0af39d5a0cf8cac9c9edf19437 Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Mon, 22 Jul 2024 15:45:24 +0000 Subject: [PATCH 8/9] feat(firestore): Resolving review comments --- firestore/integration_test.go | 5 ++--- firestore/query.go | 14 ++++++++------ firestore/query_test.go | 2 +- firestore/vector.go | 4 +++- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/firestore/integration_test.go b/firestore/integration_test.go index a1101a166593..dfc19692c545 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -2851,10 +2851,10 @@ func TestIntegration_FindNearest(t *testing.T) { t.Cleanup(func() { cancel() }) - + queryField := "EmbeddedField64" indexNames := createVectorIndexes(adminCtx, t, wantDBPath, []vectorIndex{ { - fieldPath: "EmbeddedField", + fieldPath: queryField, dimension: 3, }, }) @@ -2862,7 +2862,6 @@ func TestIntegration_FindNearest(t *testing.T) { deleteIndexes(adminCtx, indexNames) }) - queryField := "EmbeddedField64" type coffeeBean struct { ID string EmbeddedField64 Vector64 diff --git a/firestore/query.go b/firestore/query.go index a843e57fef19..e46de704943e 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -392,7 +392,7 @@ const ( DistanceMeasureDotProduct DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_DOT_PRODUCT) ) -// FindNearestOptions is options to use while building FindNearest vector query +// FindNearestOptions are options for a FindNearest vector query. type FindNearestOptions struct { } @@ -412,7 +412,7 @@ type VectorQuery struct { // // The vectorField argument can be a single field or a dot-separated sequence of // fields, and must not contain any of the runes "˜*/[]". -func (q Query) FindNearest(vectorField string, queryVector interface{}, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery { +func (q Query) FindNearest(vectorField string, queryVector any, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery { // Validate field path fieldPath, err := parseDotSeparatedString(vectorField) if err != nil { @@ -430,7 +430,7 @@ func (vq VectorQuery) Documents(ctx context.Context) *DocumentIterator { } // FindNearestPath is similar to FindNearest but it accepts a [FieldPath]. -func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector interface{}, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery { +func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector any, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery { vq := VectorQuery{ q: q, } @@ -446,13 +446,15 @@ func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector interface{ switch v := queryVector.(type) { case Vector32: fnvq = vectorToProtoValue([]float32(v)) + case []float32: + fnvq = vectorToProtoValue([]float32(v)) case Vector64: fnvq = vectorToProtoValue([]float64(v)) + case []float64: + fnvq = vectorToProtoValue([]float64(v)) default: vq.q.err = errors.New("firestore: queryVector must be Vector32 or Vector64") - return VectorQuery{ - q: q, - } + return vq } vq.q.findNearest = &pb.StructuredQuery_FindNearest{ diff --git a/firestore/query_test.go b/firestore/query_test.go index 63e6d568dfdc..106a1bbe15bd 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -1475,7 +1475,7 @@ func TestFindNearest(t *testing.T) { desc: "Invalid vector type", path: "path", queryVector: "abcd", - wantErr: false, + wantErr: true, }, { desc: "Valid vector type", diff --git a/firestore/vector.go b/firestore/vector.go index 2514acffbc03..3b89d2772573 100644 --- a/firestore/vector.go +++ b/firestore/vector.go @@ -26,8 +26,10 @@ const ( valueKey = "value" ) -// Vector64 is an embedding vector. +// Vector64 is an embedding vector of float64s. type Vector64 []float64 + +// Vector32 is an embedding vector of float32s. type Vector32 []float32 // vectorToProtoValue returns a Firestore [pb.Value] representing the Vector. From 6db828ca8b90357c3a7b410a3b49d903b66fd21d Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Mon, 22 Jul 2024 19:51:42 +0000 Subject: [PATCH 9/9] feat(firestore): Resolving review comments --- firestore/integration_test.go | 2 +- firestore/query.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/firestore/integration_test.go b/firestore/integration_test.go index dfc19692c545..bfa690d5524e 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -62,7 +62,7 @@ func TestMain(m *testing.M) { if status != 0 { os.Exit(status) } - // cleanupIntegrationTest() + cleanupIntegrationTest() } os.Exit(0) diff --git a/firestore/query.go b/firestore/query.go index e46de704943e..4a1254d27306 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -447,11 +447,11 @@ func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector any, limit case Vector32: fnvq = vectorToProtoValue([]float32(v)) case []float32: - fnvq = vectorToProtoValue([]float32(v)) + fnvq = vectorToProtoValue(v) case Vector64: fnvq = vectorToProtoValue([]float64(v)) case []float64: - fnvq = vectorToProtoValue([]float64(v)) + fnvq = vectorToProtoValue(v) default: vq.q.err = errors.New("firestore: queryVector must be Vector32 or Vector64") return vq