diff --git a/go/go.mod b/go/go.mod index c6133f9d11..439d08af7d 100644 --- a/go/go.mod +++ b/go/go.mod @@ -28,7 +28,7 @@ require ( go.opentelemetry.io/otel/trace v1.26.0 golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 golang.org/x/tools v0.23.0 - google.golang.org/api v0.188.0 + google.golang.org/api v0.189.0 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 @@ -37,10 +37,10 @@ require ( require ( cloud.google.com/go v0.115.0 // indirect cloud.google.com/go/ai v0.8.1-0.20240711230438-265963bd5b91 // indirect - cloud.google.com/go/auth v0.7.0 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.4.0 // indirect - cloud.google.com/go/firestore v1.15.0 // indirect + cloud.google.com/go/auth v0.7.2 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.3 // indirect + cloud.google.com/go/compute/metadata v0.5.0 // indirect + cloud.google.com/go/firestore v1.16.0 // indirect cloud.google.com/go/iam v1.1.10 // indirect cloud.google.com/go/longrunning v0.5.9 // indirect cloud.google.com/go/monitoring v1.20.1 // indirect @@ -54,7 +54,7 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.21.2 // indirect github.com/go-openapi/errors v0.22.0 // indirect @@ -70,7 +70,7 @@ require ( github.com/golang/protobuf v1.5.4 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.5 // indirect + github.com/googleapis/gax-go/v2 v2.13.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect @@ -90,8 +90,8 @@ require ( golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect google.golang.org/appengine/v2 v2.0.2 // indirect - google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240708141625-4ad9e859172b // indirect + google.golang.org/genproto v0.0.0-20240722135656-d784300faade // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240722135656-d784300faade // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade // indirect google.golang.org/grpc v1.65.0 // indirect ) diff --git a/go/go.sum b/go/go.sum index fbd0eb794b..fcef3852af 100644 --- a/go/go.sum +++ b/go/go.sum @@ -7,12 +7,20 @@ cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rB cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME= cloud.google.com/go/auth v0.7.0 h1:kf/x9B3WTbBUHkC+1VS8wwwli9TzhSt0vSTVBmMR8Ts= cloud.google.com/go/auth v0.7.0/go.mod h1:D+WqdrpcjmiCgWrXmLLxOVq1GACoE36chW6KXoEvuIw= +cloud.google.com/go/auth v0.7.2 h1:uiha352VrCDMXg+yoBtaD0tUF4Kv9vrtrWPYXwutnDE= +cloud.google.com/go/auth v0.7.2/go.mod h1:VEc4p5NNxycWQTMQEDQF0bd6aTMb6VgYDXEwiJJQAbs= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= +cloud.google.com/go/auth/oauth2adapt v0.2.3 h1:MlxF+Pd3OmSudg/b1yZ5lJwoXCEaeedAguodky1PcKI= +cloud.google.com/go/auth/oauth2adapt v0.2.3/go.mod h1:tMQXOfZzFuNuUxOypHlQEXgdfX5cuhwU+ffUuXRJE8I= cloud.google.com/go/compute/metadata v0.4.0 h1:vHzJCWaM4g8XIcm8kopr3XmDA4Gy/lblD3EhhSux05c= cloud.google.com/go/compute/metadata v0.4.0/go.mod h1:SIQh1Kkb4ZJ8zJ874fqVkslA29PRXuleyj6vOzlbK7M= +cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= +cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= cloud.google.com/go/firestore v1.15.0 h1:/k8ppuWOtNuDHt2tsRV42yI21uaGnKDEQnRFeBpbFF8= cloud.google.com/go/firestore v1.15.0/go.mod h1:GWOxFXcv8GZUtYpWHw/w6IuYNux/BtmeVTMmjrm4yhk= +cloud.google.com/go/firestore v1.16.0 h1:YwmDHcyrxVRErWcgxunzEaZxtNbc8QoFYA/JOEwDPgc= +cloud.google.com/go/firestore v1.16.0/go.mod h1:+22v/7p+WNBSQwdSwP57vz47aZiY+HrDkrOsJNhk7rg= cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI= cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps= cloud.google.com/go/logging v1.10.0 h1:f+ZXMqyrSJ5vZ5pE/zr0xC8y/M9BLNzQeLBwfeZ+wY4= @@ -73,6 +81,8 @@ github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSw github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-openapi/analysis v0.21.2 h1:hXFrOYFHUAMQdu6zwAiKKJHJQ8kqZs1ux/ru1P1wLJU= @@ -174,6 +184,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= +github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDPT0hH1s= +github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= @@ -388,6 +400,8 @@ golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSm golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= google.golang.org/api v0.188.0 h1:51y8fJ/b1AaaBRJr4yWm96fPcuxSo0JcegXE3DaHQHw= google.golang.org/api v0.188.0/go.mod h1:VR0d+2SIiWOYG3r/jdm7adPW9hI2aRv9ETOSCQ9Beag= +google.golang.org/api v0.189.0 h1:equMo30LypAkdkLMBqfeIqtyAnlyig1JSZArl4XPwdI= +google.golang.org/api v0.189.0/go.mod h1:FLWGJKb0hb+pU2j+rJqwbnsF+ym+fQs73rbJ+KAUgy8= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine/v2 v2.0.2 h1:MSqyWy2shDLwG7chbwBJ5uMyw6SNqJzhJHNDwYB0Akk= @@ -397,10 +411,16 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b h1:dSTjko30weBaMj3eERKc0ZVXW4GudCswM3m+P++ukU0= google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b/go.mod h1:FfBgJBJg9GcpPvKIuHSZ/aE1g2ecGL74upMzGZjiGEY= +google.golang.org/genproto v0.0.0-20240722135656-d784300faade h1:lKFsS7wpngDgSCeFn7MoLy+wBDQZ1UQIJD4UNM1Qvkg= +google.golang.org/genproto v0.0.0-20240722135656-d784300faade/go.mod h1:FfBgJBJg9GcpPvKIuHSZ/aE1g2ecGL74upMzGZjiGEY= google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw= +google.golang.org/genproto/googleapis/api v0.0.0-20240722135656-d784300faade h1:WxZOF2yayUHpHSbUE6NMzumUzBxYc3YGwo0YHnbzsJY= +google.golang.org/genproto/googleapis/api v0.0.0-20240722135656-d784300faade/go.mod h1:mw8MG/Qz5wfgYr6VqVCiZcHe/GJEfI+oGGDCohaVgB0= google.golang.org/genproto/googleapis/rpc v0.0.0-20240708141625-4ad9e859172b h1:04+jVzTs2XBnOZcPsLnmrTGqltqJbZQ1Ey26hjYdQQ0= google.golang.org/genproto/googleapis/rpc v0.0.0-20240708141625-4ad9e859172b/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade h1:oCRSWfwGXQsqlVdErcyTt4A93Y8fo0/9D4b1gnI++qo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= diff --git a/go/plugins/firebase/firebase.go b/go/plugins/firebase/firebase.go index e2f4a85b4d..61645fceff 100644 --- a/go/plugins/firebase/firebase.go +++ b/go/plugins/firebase/firebase.go @@ -16,31 +16,77 @@ package firebase import ( "context" + "fmt" + "log" "sync" firebase "firebase.google.com/go/v4" - "firebase.google.com/go/v4/auth" ) -type FirebaseApp interface { - Auth(ctx context.Context) (*auth.Client, error) +var state struct { + mu sync.Mutex + initted bool + app *firebase.App } -var ( - app *firebase.App - mutex sync.Mutex -) +// FirebasePluginConfig is the configuration for the Firebase plugin. +type FirebasePluginConfig struct { + AuthOverride *map[string]interface{} `json:"databaseAuthVariableOverride"` + DatabaseURL string `json:"databaseURL"` + ProjectID string `json:"projectId"` + ServiceAccountID string `json:"serviceAccountId"` + StorageBucket string `json:"storageBucket"` +} + +// Init initializes the Firebase app with the provided configuration. +// If called more than once, it logs a message and returns nil. +func Init(ctx context.Context, cfg *FirebasePluginConfig) error { + state.mu.Lock() + defer state.mu.Unlock() + + if state.initted { + log.Println("firebase.Init: already called, returning without reinitializing") + return nil + } -// app returns a cached Firebase app. -func App(ctx context.Context) (FirebaseApp, error) { - mutex.Lock() - defer mutex.Unlock() - if app == nil { - newApp, err := firebase.NewApp(ctx, nil) - if err != nil { - return nil, err - } - app = newApp + // Prepare the Firebase config + firebaseConfig := &firebase.Config{ + AuthOverride: cfg.AuthOverride, + DatabaseURL: cfg.DatabaseURL, + ProjectID: cfg.ProjectID, // Allow ProjectID to be empty + ServiceAccountID: cfg.ServiceAccountID, + StorageBucket: cfg.StorageBucket, } - return app, nil + + // Initialize Firebase app with service account key if provided + app, err := firebase.NewApp(ctx, firebaseConfig) + if err != nil { + return fmt.Errorf("firebase.Init: %w", err) + } + + state.app = app + state.initted = true + + return nil +} + +func unInit() { + state.mu.Lock() + defer state.mu.Unlock() + + state.initted = false + state.app = nil +} + +// App returns a cached Firebase app. +// If the app is not initialized, it returns an error. +func App(ctx context.Context) (*firebase.App, error) { + state.mu.Lock() + defer state.mu.Unlock() + + if !state.initted { + return nil, fmt.Errorf("firebase.App: Firebase app not initialized. Call Init first") + } + + return state.app, nil } diff --git a/go/plugins/firebase/firebase_app_test.go b/go/plugins/firebase/firebase_app_test.go new file mode 100644 index 0000000000..2cd4b8fa08 --- /dev/null +++ b/go/plugins/firebase/firebase_app_test.go @@ -0,0 +1,78 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firebase + +import ( + "context" + "testing" +) + +func TestApp(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + tests := []struct { + name string + setup func() error + expectedError string + }{ + { + name: "Get App before initialization", + setup: func() error { + // No initialization setup here, calling App directly should fail + return nil + }, + expectedError: "firebase.App: Firebase app not initialized. Call Init first", + }, + { + name: "Get App after successful initialization", + setup: func() error { + // Properly initialize the app + config := &FirebasePluginConfig{ + ProjectID: *firebaseProjectID, + } + return Init(ctx, config) + }, + expectedError: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer unInit() + // Execute setup + if err := tt.setup(); err != nil { + t.Fatalf("Setup failed: %v", err) + } + + // Now test the App function + app, err := App(ctx) + + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Errorf("Expected error %q, got %v", tt.expectedError, err) + } + if app != nil { + t.Errorf("Expected no app, got %v", app) + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } else if app == nil { + t.Errorf("Expected a valid app instance, got nil") + } + }) + } +} diff --git a/go/plugins/firebase/firebase_init_test.go b/go/plugins/firebase/firebase_init_test.go new file mode 100644 index 0000000000..9b9dd70db9 --- /dev/null +++ b/go/plugins/firebase/firebase_init_test.go @@ -0,0 +1,88 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firebase + +import ( + "context" + "flag" + "testing" +) + +// Define the flag with a default value of "demo-test" +var firebaseProjectID = flag.String("firebase-project-id", "demo-test", "Firebase project ID") + +func TestInit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + tests := []struct { + name string + config *FirebasePluginConfig + expectedError string + setup func() error + }{ + { + name: "Successful initialization", + config: &FirebasePluginConfig{ + ProjectID: *firebaseProjectID, + }, + expectedError: "", + setup: func() error { + return nil // No setup required, first call should succeed + }, + }, + { + name: "Initialization when already initialized", + config: &FirebasePluginConfig{ + ProjectID: *firebaseProjectID, + }, + expectedError: "", + setup: func() error { + return Init(ctx, &FirebasePluginConfig{ProjectID: *firebaseProjectID}) // Initialize once + }, + }, + { + name: "Initialization with missing ProjectID", + config: &FirebasePluginConfig{ + ProjectID: "", + }, + expectedError: "", // No error expected, as ProjectID can be inferred + setup: func() error { + return nil // No setup required + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer unInit() + + if err := tt.setup(); err != nil { + t.Fatalf("Setup failed: %v", err) + } + + err := Init(ctx, tt.config) + + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Errorf("Expected error %q, got %v", tt.expectedError, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} diff --git a/go/plugins/firebase/retriever.go b/go/plugins/firebase/retriever.go new file mode 100644 index 0000000000..5cb92247fe --- /dev/null +++ b/go/plugins/firebase/retriever.go @@ -0,0 +1,125 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firebase + +import ( + "context" + "fmt" + + "cloud.google.com/go/firestore" + "github.com/firebase/genkit/go/ai" +) + +type VectorType int + +const ( + Vector64 VectorType = iota +) + +const provider = "firebase" + +type RetrieverOptions struct { + Name string + Label string + Client *firestore.Client + Collection string + Embedder ai.Embedder + VectorField string + MetadataFields []string + ContentField string + Limit int + DistanceMeasure firestore.DistanceMeasure + VectorType VectorType +} + +func DefineFirestoreRetriever(cfg RetrieverOptions) (ai.Retriever, error) { + if cfg.VectorType != Vector64 { + return nil, fmt.Errorf("DefineFirestoreRetriever: only Vector64 is supported") + } + if cfg.Client == nil { + return nil, fmt.Errorf("DefineFirestoreRetriever: Firestore client is not provided") + } + + Retrieve := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { + + if req.Document == nil { + return nil, fmt.Errorf("DefineFirestoreRetriever: Request document is nil") + } + + // Generate query embedding using the Embedder + embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{req.Document}} + embedResponse, err := cfg.Embedder.Embed(ctx, embedRequest) + if err != nil { + return nil, fmt.Errorf("DefineFirestoreRetriever: Embedding failed: %v", err) + } + + if len(embedResponse.Embeddings) == 0 { + return nil, fmt.Errorf("DefineFirestoreRetriever: No embeddings returned") + } + + queryEmbedding := embedResponse.Embeddings[0].Embedding + if len(queryEmbedding) == 0 { + return nil, fmt.Errorf("DefineFirestoreRetriever: Generated embedding is empty") + } + + // Convert to []float64 + queryEmbedding64 := make([]float64, len(queryEmbedding)) + for i, val := range queryEmbedding { + queryEmbedding64[i] = float64(val) + } + // Perform the FindNearest query + vectorQuery := cfg.Client.Collection(cfg.Collection).FindNearest( + cfg.VectorField, + firestore.Vector64(queryEmbedding64), + cfg.Limit, + cfg.DistanceMeasure, + nil, + ) + iter := vectorQuery.Documents(ctx) + + results, err := iter.GetAll() + if err != nil { + return nil, fmt.Errorf("DefineFirestoreRetriever: FindNearest query failed: %v", err) + } + + // Prepare the documents to return in the response + var documents []*ai.Document + for _, result := range results { + data := result.Data() + + // Ensure content field exists and is of type string + content, ok := data[cfg.ContentField].(string) + if !ok { + fmt.Printf("Content field %s missing or not a string in document %s", cfg.ContentField, result.Ref.ID) + continue + } + + // Extract metadata fields + metadata := make(map[string]interface{}) + for _, field := range cfg.MetadataFields { + if value, ok := data[field]; ok { + metadata[field] = value + } + } + + doc := ai.DocumentFromText(content, metadata) + documents = append(documents, doc) + } + + return &ai.RetrieverResponse{Documents: documents}, nil + } + + return ai.DefineRetriever(provider, cfg.Name, Retrieve), nil +} diff --git a/go/plugins/firebase/retriever_test.go b/go/plugins/firebase/retriever_test.go new file mode 100644 index 0000000000..bef9e52179 --- /dev/null +++ b/go/plugins/firebase/retriever_test.go @@ -0,0 +1,236 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firebase + +import ( + "context" + "flag" + "testing" + + "cloud.google.com/go/firestore" + firebasev4 "firebase.google.com/go/v4" + "github.com/firebase/genkit/go/ai" + "google.golang.org/api/iterator" +) + +var ( + testProjectID = flag.String("test-project-id", "", "GCP Project ID to use for tests") + testCollection = flag.String("test-collection", "testR2", "Firestore collection to use for tests") + testVectorField = flag.String("test-vector-field", "embedding", "Field name for vector embeddings") +) + +// MockEmbedder implements the Embedder interface for testing purposes +type MockEmbedder struct{} + +func (e *MockEmbedder) Name() string { + return "MockEmbedder" +} + +func (e *MockEmbedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + var embeddings []*ai.DocumentEmbedding + for _, doc := range req.Documents { + var embedding []float32 + switch doc.Content[0].Text { + case "This is document one": + // Embedding for document one is the closest to the query + embedding = []float32{0.9, 0.1, 0.0} + case "This is document two": + // Embedding for document two is less close to the query + embedding = []float32{0.7, 0.2, 0.1} + case "This is document three": + // Embedding for document three is even further from the query + embedding = []float32{0.4, 0.3, 0.3} + case "This is input query": + // Embedding for the input query + embedding = []float32{0.9, 0.1, 0.0} + default: + // Default embedding for any other documents + embedding = []float32{0.0, 0.0, 0.0} + } + + embeddings = append(embeddings, &ai.DocumentEmbedding{Embedding: embedding}) + } + return &ai.EmbedResponse{Embeddings: embeddings}, nil +} + +// To run this test you must have a Firestore database initialized in a GCP project, with a vector indexed collection (of dimension 3). +// Warning: This test will delete all documents in the collection in cleanup. + +func TestFirestoreRetriever(t *testing.T) { + + // skip if flags aren't defined + if *testProjectID == "" { + t.Skip("Skipping test due to missing flags") + } + if *testCollection == "" { + t.Skip("Skipping test due to missing flags") + } + if *testVectorField == "" { + t.Skip("Skipping test due to missing flags") + } + + ctx := context.Background() + + // Initialize Firebase app + conf := &firebasev4.Config{ProjectID: *testProjectID} + app, err := firebasev4.NewApp(ctx, conf) + if err != nil { + t.Fatalf("Failed to create Firebase app: %v", err) + } + + // Initialize Firestore client + client, err := app.Firestore(ctx) + if err != nil { + t.Fatalf("Failed to create Firestore client: %v", err) + } + defer client.Close() + + // Clean up the collection before the test + defer deleteCollection(ctx, client, *testCollection, t) + + // Initialize the embedder + embedder := &MockEmbedder{} + + // Insert test documents with embeddings generated by the embedder + testDocs := []struct { + ID string + Text string + Data map[string]interface{} + }{ + {"doc1", "This is document one", map[string]interface{}{"metadata": "meta1"}}, + {"doc2", "This is document two", map[string]interface{}{"metadata": "meta2"}}, + {"doc3", "This is document three", map[string]interface{}{"metadata": "meta3"}}, + } + + // Expected document text content in order of relevance for the query + expectedTexts := []string{ + "This is document one", + "This is document two", + } + + for _, doc := range testDocs { + // Create an ai.Document + aiDoc := ai.DocumentFromText(doc.Text, doc.Data) + + // Generate embedding + embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{aiDoc}} + embedResponse, err := embedder.Embed(ctx, embedRequest) + if err != nil { + t.Fatalf("Failed to generate embedding for document %s: %v", doc.ID, err) + } + + if len(embedResponse.Embeddings) == 0 { + t.Fatalf("No embeddings returned for document %s", doc.ID) + } + + embedding := embedResponse.Embeddings[0].Embedding + if len(embedding) == 0 { + t.Fatalf("Generated embedding is empty for document %s", doc.ID) + } + + // Convert to []float64 + embedding64 := make([]float64, len(embedding)) + for i, val := range embedding { + embedding64[i] = float64(val) + } + + // Store in Firestore + _, err = client.Collection(*testCollection).Doc(doc.ID).Set(ctx, map[string]interface{}{ + "text": doc.Text, + "metadata": doc.Data["metadata"], + *testVectorField: firestore.Vector64(embedding64), + }) + if err != nil { + t.Fatalf("Failed to insert document %s: %v", doc.ID, err) + } + t.Logf("Inserted document: %s with embedding: %v", doc.ID, embedding64) + } + + // Define retriever options + retrieverOptions := RetrieverOptions{ + Name: "test-retriever", + Label: "Test Retriever", + Client: client, + Collection: *testCollection, + Embedder: embedder, + VectorField: *testVectorField, + MetadataFields: []string{"metadata"}, + ContentField: "text", + Limit: 2, + DistanceMeasure: firestore.DistanceMeasureEuclidean, + VectorType: Vector64, + } + + // Define the retriever + retriever, err := DefineFirestoreRetriever(retrieverOptions) + if err != nil { + t.Fatalf("Failed to define retriever: %v", err) + } + + // Create a retriever request with the input document + queryText := "This is input query" + inputDocument := ai.DocumentFromText(queryText, nil) + + req := &ai.RetrieverRequest{ + Document: inputDocument, + } + + // Perform the retrieval + resp, err := retriever.Retrieve(ctx, req) + if err != nil { + t.Fatalf("Retriever failed: %v", err) + } + + // Check the retrieved documents + if len(resp.Documents) == 0 { + t.Fatalf("No documents retrieved") + } + + // Verify the content of all retrieved documents against the expected list + for i, doc := range resp.Documents { + if i >= len(expectedTexts) { + t.Errorf("More documents retrieved than expected. Retrieved: %d, Expected: %d", len(resp.Documents), len(expectedTexts)) + break + } + + if doc.Content[0].Text != expectedTexts[i] { + t.Errorf("Mismatch in document %d content. Expected: '%s', Got: '%s'", i+1, expectedTexts[i], doc.Content[0].Text) + } else { + t.Logf("Retrieved Document %d matches expected content: '%s'", i+1, expectedTexts[i]) + } + } +} + +func deleteCollection(ctx context.Context, client *firestore.Client, collectionName string, t *testing.T) { + // Get all documents in the collection + iter := client.Collection(collectionName).Documents(ctx) + for { + doc, err := iter.Next() + if err == iterator.Done { + break // No more documents + } + if err != nil { + t.Fatalf("Failed to iterate documents for deletion: %v", err) + } + + // Delete each document + _, err = doc.Ref.Delete(ctx) + if err != nil { + t.Errorf("Failed to delete document %s: %v", doc.Ref.ID, err) + } else { + t.Logf("Deleted document: %s", doc.Ref.ID) + } + } +} diff --git a/go/plugins/firebase/test_project/.gitignore b/go/plugins/firebase/test_project/.gitignore new file mode 100644 index 0000000000..dbb58ffbfa --- /dev/null +++ b/go/plugins/firebase/test_project/.gitignore @@ -0,0 +1,66 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +firebase-debug.log* +firebase-debug.*.log* + +# Firebase cache +.firebase/ + +# Firebase config + +# Uncomment this if you'd like others to create their own Firebase project. +# For a team working on the same Firebase project(s), it is recommended to leave +# it commented so all members can deploy to the same project(s) in .firebaserc. +# .firebaserc + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (http://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env diff --git a/go/plugins/firebase/test_project/firebase.json b/go/plugins/firebase/test_project/firebase.json new file mode 100644 index 0000000000..a21361e89d --- /dev/null +++ b/go/plugins/firebase/test_project/firebase.json @@ -0,0 +1,15 @@ +{ + "firestore": { + "rules": "firestore.rules", + "indexes": "firestore.indexes.json" + }, + "emulators": { + "firestore": { + "port": 8080 + }, + "ui": { + "enabled": true + }, + "singleProjectMode": true + } +}