Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(firestore): Adding vector search #10548

Merged
merged 12 commits into from
Jul 22, 2024
1 change: 1 addition & 0 deletions firestore/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func (d *DocumentSnapshot) Data() map[string]interface{} {
// Slices are resized to the incoming value's size, while arrays that are too
// long have excess elements filled with zero values. If the array is too short,
// excess incoming values will be dropped.
// - Vectors convert to []float64
// - Maps convert to map[string]interface{}. When setting a struct field,
// maps of key type string and any value type are permitted, and are populated
// recursively.
Expand Down
148 changes: 85 additions & 63 deletions firestore/from_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,221 +32,236 @@ func setFromProtoValue(x interface{}, vproto *pb.Value, c *Client) error {
return setReflectFromProtoValue(v.Elem(), vproto, c)
}

// setReflectFromProtoValue sets v from a Firestore Value.
// v must be a settable value.
func setReflectFromProtoValue(v reflect.Value, vproto *pb.Value, c *Client) error {
// setReflectFromProtoValue sets vDest from a Firestore Value.
// vDest must be a settable value.
func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Client) error {
typeErr := func() error {
return fmt.Errorf("firestore: cannot set type %s to %s", v.Type(), typeString(vproto))
return fmt.Errorf("firestore: cannot set type %s to %s", vDest.Type(), typeString(vprotoSrc))
}

val := vproto.ValueType
valTypeSrc := vprotoSrc.ValueType
// A Null value sets anything nullable to nil, and has no effect
// on anything else.
if _, ok := val.(*pb.Value_NullValue); ok {
switch v.Kind() {
if _, ok := valTypeSrc.(*pb.Value_NullValue); ok {
switch vDest.Kind() {
case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice:
v.Set(reflect.Zero(v.Type()))
vDest.Set(reflect.Zero(vDest.Type()))
}
return nil
}

// Handle special types first.
switch v.Type() {
switch vDest.Type() {
case typeOfByteSlice:
x, ok := val.(*pb.Value_BytesValue)
x, ok := valTypeSrc.(*pb.Value_BytesValue)
if !ok {
return typeErr()
}
v.SetBytes(x.BytesValue)
vDest.SetBytes(x.BytesValue)
return nil

case typeOfGoTime:
x, ok := val.(*pb.Value_TimestampValue)
x, ok := valTypeSrc.(*pb.Value_TimestampValue)
if !ok {
return typeErr()
}
if err := x.TimestampValue.CheckValid(); err != nil {
return err
}
v.Set(reflect.ValueOf(x.TimestampValue.AsTime()))
vDest.Set(reflect.ValueOf(x.TimestampValue.AsTime()))
return nil

case typeOfProtoTimestamp:
x, ok := val.(*pb.Value_TimestampValue)
x, ok := valTypeSrc.(*pb.Value_TimestampValue)
if !ok {
return typeErr()
}
v.Set(reflect.ValueOf(x.TimestampValue))
vDest.Set(reflect.ValueOf(x.TimestampValue))
return nil

case typeOfLatLng:
x, ok := val.(*pb.Value_GeoPointValue)
x, ok := valTypeSrc.(*pb.Value_GeoPointValue)
if !ok {
return typeErr()
}
v.Set(reflect.ValueOf(x.GeoPointValue))
vDest.Set(reflect.ValueOf(x.GeoPointValue))
return nil

case typeOfDocumentRef:
x, ok := val.(*pb.Value_ReferenceValue)
x, ok := valTypeSrc.(*pb.Value_ReferenceValue)
if !ok {
return typeErr()
}
dr, err := pathToDoc(x.ReferenceValue, c)
if err != nil {
return err
}
v.Set(reflect.ValueOf(dr))
vDest.Set(reflect.ValueOf(dr))
return nil

case typeOfVector32:
val, err := vector32FromProtoValue(vprotoSrc)
if err != nil {
return err
}
vDest.Set(reflect.ValueOf(val))
return nil
case typeOfVector64:
val, err := vector64FromProtoValue(vprotoSrc)
if err != nil {
return err
}
vDest.Set(reflect.ValueOf(val))
return nil
}

switch v.Kind() {
switch vDest.Kind() {
case reflect.Bool:
x, ok := val.(*pb.Value_BooleanValue)
x, ok := valTypeSrc.(*pb.Value_BooleanValue)
if !ok {
return typeErr()
}
v.SetBool(x.BooleanValue)
vDest.SetBool(x.BooleanValue)

case reflect.String:
x, ok := val.(*pb.Value_StringValue)
x, ok := valTypeSrc.(*pb.Value_StringValue)
if !ok {
return typeErr()
}
v.SetString(x.StringValue)
vDest.SetString(x.StringValue)

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var i int64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_IntegerValue:
i = x.IntegerValue
case *pb.Value_DoubleValue:
f := x.DoubleValue
i = int64(f)
if float64(i) != f {
return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type())
return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type())
}
default:
return typeErr()
}
if v.OverflowInt(i) {
return overflowErr(v, i)
if vDest.OverflowInt(i) {
return overflowErr(vDest, i)
}
v.SetInt(i)
vDest.SetInt(i)

case reflect.Uint8, reflect.Uint16, reflect.Uint32:
var u uint64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_IntegerValue:
u = uint64(x.IntegerValue)
case *pb.Value_DoubleValue:
f := x.DoubleValue
u = uint64(f)
if float64(u) != f {
return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type())
return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type())
}
default:
return typeErr()
}
if v.OverflowUint(u) {
return overflowErr(v, u)
if vDest.OverflowUint(u) {
return overflowErr(vDest, u)
}
v.SetUint(u)
vDest.SetUint(u)

case reflect.Float32, reflect.Float64:
var f float64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_DoubleValue:
f = x.DoubleValue
case *pb.Value_IntegerValue:
f = float64(x.IntegerValue)
if int64(f) != x.IntegerValue {
return overflowErr(v, x.IntegerValue)
return overflowErr(vDest, x.IntegerValue)
}
default:
return typeErr()
}
if v.OverflowFloat(f) {
return overflowErr(v, f)
if vDest.OverflowFloat(f) {
return overflowErr(vDest, f)
}
v.SetFloat(f)
vDest.SetFloat(f)

case reflect.Slice:
x, ok := val.(*pb.Value_ArrayValue)
x, ok := valTypeSrc.(*pb.Value_ArrayValue)
if !ok {
return typeErr()
}
vals := x.ArrayValue.Values
vlen := v.Len()
vlen := vDest.Len()
xlen := len(vals)
// Make a slice of the right size, avoiding allocation if possible.
switch {
case vlen < xlen:
v.Set(reflect.MakeSlice(v.Type(), xlen, xlen))
vDest.Set(reflect.MakeSlice(vDest.Type(), xlen, xlen))
case vlen > xlen:
v.SetLen(xlen)
vDest.SetLen(xlen)
}
return populateRepeated(v, vals, xlen, c)
return populateRepeated(vDest, vals, xlen, c)

case reflect.Array:
x, ok := val.(*pb.Value_ArrayValue)
x, ok := valTypeSrc.(*pb.Value_ArrayValue)
if !ok {
return typeErr()
}
vals := x.ArrayValue.Values
xlen := len(vals)
vlen := v.Len()
vlen := vDest.Len()
minlen := vlen
// Set extra elements to their zero value.
if vlen > xlen {
z := reflect.Zero(v.Type().Elem())
z := reflect.Zero(vDest.Type().Elem())
for i := xlen; i < vlen; i++ {
v.Index(i).Set(z)
vDest.Index(i).Set(z)
}
minlen = xlen
}
return populateRepeated(v, vals, minlen, c)
return populateRepeated(vDest, vals, minlen, c)

case reflect.Map:
x, ok := val.(*pb.Value_MapValue)
x, ok := valTypeSrc.(*pb.Value_MapValue)
if !ok {
return typeErr()
}
return populateMap(v, x.MapValue.Fields, c)
return populateMap(vDest, x.MapValue.Fields, c)

case reflect.Ptr:
// If the pointer is nil, set it to a zero value.
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
if vDest.IsNil() {
vDest.Set(reflect.New(vDest.Type().Elem()))
}
return setReflectFromProtoValue(v.Elem(), vproto, c)
return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c)

case reflect.Struct:
x, ok := val.(*pb.Value_MapValue)
x, ok := valTypeSrc.(*pb.Value_MapValue)
if !ok {
return typeErr()
}
return populateStruct(v, x.MapValue.Fields, c)
return populateStruct(vDest, x.MapValue.Fields, c)

case reflect.Interface:
if v.NumMethod() == 0 { // empty interface
if vDest.NumMethod() == 0 { // empty interface
// If v holds a pointer, set the pointer.
if !v.IsNil() && v.Elem().Kind() == reflect.Ptr {
return setReflectFromProtoValue(v.Elem(), vproto, c)
if !vDest.IsNil() && vDest.Elem().Kind() == reflect.Ptr {
return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c)
}
// Otherwise, create a fresh value.
x, err := createFromProtoValue(vproto, c)
x, err := createFromProtoValue(vprotoSrc, c)
if err != nil {
return err
}
v.Set(reflect.ValueOf(x))
vDest.Set(reflect.ValueOf(x))
return nil
}
// Any other kind of interface is an error.
fallthrough

default:
return fmt.Errorf("firestore: cannot set type %s", v.Type())
return fmt.Errorf("firestore: cannot set type %s", vDest.Type())
}
return nil
}
Expand Down Expand Up @@ -389,8 +404,15 @@ func createFromProtoValue(vproto *pb.Value, c *Client) (interface{}, error) {
}
ret[k] = r
}
return ret, nil

typeVal, ok := ret[typeKey]
if !ok || typeVal != typeValVector {
// Map is not a vector. Return the map
return ret, nil
}

// Special handling for vector
return vectorFromProtoValue(vproto)
default:
return nil, fmt.Errorf("firestore: unknown value type %T", v)
}
Expand Down
Loading
Loading