From bab2ae883c26b3a2962473fdc05a52aa95f332e9 Mon Sep 17 00:00:00 2001 From: fernandoalonso Date: Tue, 10 Dec 2024 01:21:25 +0000 Subject: [PATCH 1/4] Feat: add context caching to go --- go/plugins/vertexai/cache.go | 252 ++++++++++++++++++++++++++++++++ go/plugins/vertexai/vertexai.go | 19 ++- 2 files changed, 268 insertions(+), 3 deletions(-) create mode 100644 go/plugins/vertexai/cache.go diff --git a/go/plugins/vertexai/cache.go b/go/plugins/vertexai/cache.go new file mode 100644 index 000000000..ee29d54f3 --- /dev/null +++ b/go/plugins/vertexai/cache.go @@ -0,0 +1,252 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vertexai + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "time" + + "cloud.google.com/go/vertexai/genai" + "github.com/firebase/genkit/go/ai" +) + +// CacheConfigDetails holds configuration details for caching. +// Adjust fields as needed for your use case. +type CacheConfigDetails struct { + // TTLSeconds is how long to keep the cached content. + // If zero, defaults to 60 minutes. + TTLSeconds int +} + +var ( + INVALID_ARGUMENT_MESSAGES = struct { + modelVersion string + tools string + }{ + modelVersion: "Invalid modelVersion specified.", + tools: "Tools are not supported with context caching.", + } +) + +// getContentForCache inspects the request and modelVersion, and constructs a +// genai.CachedContent that should be cached. +// This is where you decide what goes into the cache: large documents, system instructions, etc. +func getContentForCache( + request *ai.ModelRequest, + modelVersion string, + cacheConfig *CacheConfigDetails, +) (*genai.CachedContent, error) { + // Example logic: + // 1. Extract the system instruction from the request (if any). + // 2. Include user-provided large content (like PDFs or text) that should be cached repeatedly. + + var systemInstruction string + var userParts []*genai.Content + + // Gather system messages (if any) + for _, m := range request.Messages { + if m.Role == ai.RoleSystem { + // Convert system message parts to text + sysParts := []string{} + for _, p := range m.Content { + if p.IsText() { + sysParts = append(sysParts, p.Text) + } + } + if len(sysParts) > 0 { + systemInstruction = strings.Join(sysParts, "\n") + } + } + } + + // We could also gather large user content from the first user message as an example: + // This is arbitrary logic for demonstration. + if len(request.Messages) > 0 { + // Take the first user message for caching (if any) + for _, m := range request.Messages { + if m.Role == ai.RoleUser { + parts, err := convertParts(m.Content) + if err != nil { + return nil, err + } + userParts = append(userParts, &genai.Content{ + Role: "user", + Parts: parts, + }) + break + } + } + } + + if systemInstruction == "" && len(userParts) == 0 { + // Nothing to cache + return nil, fmt.Errorf("no content to cache") + } + + // Create the cached content with a system instruction and user content (if any). + content := &genai.CachedContent{ + Model: modelVersion, + SystemInstruction: &genai.Content{ + Role: "system", + Parts: []genai.Part{genai.Text(systemInstruction)}, + }, + Contents: userParts, + } + + return content, nil +} + +// generateCacheKey creates a unique key for the cached content based on its contents. +// We can hash the system instruction and model version. +func generateCacheKey(content *genai.CachedContent) string { + hash := sha256.New() + if content.SystemInstruction != nil { + for _, p := range content.SystemInstruction.Parts { + if t, ok := p.(genai.Text); ok { + hash.Write([]byte(t)) + } + } + } + hash.Write([]byte(content.Model)) + + // Also incorporate any user content parts to ensure uniqueness + for _, c := range content.Contents { + for _, p := range c.Parts { + switch v := p.(type) { + case genai.Text: + hash.Write([]byte(v)) + case genai.Blob: + hash.Write([]byte(v.MIMEType)) + hash.Write(v.Data) + } + } + } + + return hex.EncodeToString(hash.Sum(nil)) +} + +// calculateTTL returns the TTL as a time.Duration. +func calculateTTL(cacheConfig *CacheConfigDetails) time.Duration { + if cacheConfig == nil || cacheConfig.TTLSeconds <= 0 { + return 60 * time.Minute + } + return time.Duration(cacheConfig.TTLSeconds) * time.Second +} + +// lookupContextCache attempts to find a cached content by its displayName. +// Currently, the Vertex AI client does not provide a direct way to look up by displayName. +// If you have a known name from a previous run, you could store that externally. +// For this demonstration, we return nil to indicate no cache found. +func lookupContextCache(ctx context.Context, client *genai.Client, cacheKey string) (*genai.CachedContent, error) { + // Since we cannot directly list or get by displayName at this time, + // we will return nil, nil indicating not found. + // In a real implementation, you could store cacheKey->cachedContentName in a database + // and then call client.GetCachedContent with that name. + return nil, nil +} + +// getKeysFrom returns the keys from the given map as a slice of strings, it is using to get the supported models +func getKeysFrom(m map[string]ai.ModelCapabilities) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// contains checks if a slice contains a given string. +func contains(slice []string, target string) bool { + for _, s := range slice { + if s == target { + return true + } + } + return false +} + +func countTokensInMessages(messages []*ai.Message) int { + totalTokens := 0 + for _, msg := range messages { + for _, part := range msg.Content { + if part.IsText() { + words := strings.Fields(part.Text) + totalTokens += len(words) + } + } + } + return totalTokens +} + +// validateContextCacheRequest decides if we should try caching for this request. +// For demonstration, we will cache if there are more than 2 messages or if there's a system prompt. +func validateContextCacheRequest(request *ai.ModelRequest, modelVersion string) error { + models := getKeysFrom(knownCaps) + if modelVersion == "" || !contains(models, modelVersion) { + return fmt.Errorf(INVALID_ARGUMENT_MESSAGES.modelVersion) + } + if len(request.Tools) > 0 { + return fmt.Errorf(INVALID_ARGUMENT_MESSAGES.tools) + } + + tokenCount := countTokensInMessages(request.Messages) + // The minimum input token count for context caching is 32,768, and the maximum is the same as the maximum for the given model. + // https://ai.google.dev/gemini-api/docs/caching?lang=go + const minTokens = 32768 + if tokenCount < minTokens { + return fmt.Errorf("the cached content is of %d tokens. The minimum token count to start caching is %d.", tokenCount, minTokens) + } + + // If we reach here, request is valid for context caching + return nil +} + +// handleCacheIfNeeded checks if caching should be used, attempts to find or create the cache, +// and returns the cached content if applicable. +func handleCacheIfNeeded( + ctx context.Context, + client *genai.Client, + request *ai.ModelRequest, + modelVersion string, + cacheConfig *CacheConfigDetails, +) (*genai.CachedContent, error) { + + if cacheConfig == nil || validateContextCacheRequest(request, modelVersion) != nil { + return nil, nil + } + cachedContent, err := getContentForCache(request, modelVersion, cacheConfig) + if err != nil { + return nil, nil + } + + cachedContent.Model = modelVersion + cacheKey := generateCacheKey(cachedContent) + + existingCache, err := lookupContextCache(ctx, client, cacheKey) + if err == nil && existingCache != nil { + return existingCache, nil + } + + cachedContent.Expiration = genai.ExpireTimeOrTTL{TTL: calculateTTL(cacheConfig)} + newCache, err := client.CreateCachedContent(ctx, cachedContent) + if err != nil { + return nil, fmt.Errorf("failed to create cache: %w", err) + } + + return newCache, nil +} diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index b3215e03e..373eb0a8f 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -39,9 +39,9 @@ const ( var ( knownCaps = map[string]ai.ModelCapabilities{ - "gemini-1.0-pro": gemini.BasicText, - "gemini-1.5-pro": gemini.Multimodal, - "gemini-1.5-flash": gemini.Multimodal, + "gemini-1.0-pro": gemini.BasicText, + "gemini-1.5-pro": gemini.Multimodal, + "gemini-1.5-flash-002": gemini.Multimodal, } knownEmbedders = []string{ @@ -238,10 +238,23 @@ func generate( input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, ) (*ai.ModelResponse, error) { + cacheConfig := &CacheConfigDetails{ + TTLSeconds: 3600, // hardcoded to 1 hour + } + + // Attempt to handle caching before creating the model. + cache, err := handleCacheIfNeeded(ctx, client, input, model, cacheConfig) + if err != nil { + return nil, err + } + gm, err := newModel(client, model, input) if err != nil { return nil, err } + if cache != nil { + gm.CachedContentName = cache.Name + } cs, err := startChat(gm, input) if err != nil { return nil, err From 04fccbcfd95a32501e62970ea71844cbe8e7c69e Mon Sep 17 00:00:00 2001 From: fernandoalonso Date: Tue, 10 Dec 2024 01:28:30 +0000 Subject: [PATCH 2/4] Fix: refactor --- go/plugins/vertexai/cache.go | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/go/plugins/vertexai/cache.go b/go/plugins/vertexai/cache.go index ee29d54f3..64f717255 100644 --- a/go/plugins/vertexai/cache.go +++ b/go/plugins/vertexai/cache.go @@ -52,17 +52,11 @@ func getContentForCache( modelVersion string, cacheConfig *CacheConfigDetails, ) (*genai.CachedContent, error) { - // Example logic: - // 1. Extract the system instruction from the request (if any). - // 2. Include user-provided large content (like PDFs or text) that should be cached repeatedly. - var systemInstruction string var userParts []*genai.Content - // Gather system messages (if any) for _, m := range request.Messages { if m.Role == ai.RoleSystem { - // Convert system message parts to text sysParts := []string{} for _, p := range m.Content { if p.IsText() { @@ -75,10 +69,7 @@ func getContentForCache( } } - // We could also gather large user content from the first user message as an example: - // This is arbitrary logic for demonstration. if len(request.Messages) > 0 { - // Take the first user message for caching (if any) for _, m := range request.Messages { if m.Role == ai.RoleUser { parts, err := convertParts(m.Content) @@ -95,11 +86,9 @@ func getContentForCache( } if systemInstruction == "" && len(userParts) == 0 { - // Nothing to cache return nil, fmt.Errorf("no content to cache") } - // Create the cached content with a system instruction and user content (if any). content := &genai.CachedContent{ Model: modelVersion, SystemInstruction: &genai.Content{ @@ -149,17 +138,7 @@ func calculateTTL(cacheConfig *CacheConfigDetails) time.Duration { return time.Duration(cacheConfig.TTLSeconds) * time.Second } -// lookupContextCache attempts to find a cached content by its displayName. -// Currently, the Vertex AI client does not provide a direct way to look up by displayName. -// If you have a known name from a previous run, you could store that externally. -// For this demonstration, we return nil to indicate no cache found. -func lookupContextCache(ctx context.Context, client *genai.Client, cacheKey string) (*genai.CachedContent, error) { - // Since we cannot directly list or get by displayName at this time, - // we will return nil, nil indicating not found. - // In a real implementation, you could store cacheKey->cachedContentName in a database - // and then call client.GetCachedContent with that name. - return nil, nil -} + // getKeysFrom returns the keys from the given map as a slice of strings, it is using to get the supported models func getKeysFrom(m map[string]ai.ModelCapabilities) []string { @@ -237,11 +216,6 @@ func handleCacheIfNeeded( cachedContent.Model = modelVersion cacheKey := generateCacheKey(cachedContent) - existingCache, err := lookupContextCache(ctx, client, cacheKey) - if err == nil && existingCache != nil { - return existingCache, nil - } - cachedContent.Expiration = genai.ExpireTimeOrTTL{TTL: calculateTTL(cacheConfig)} newCache, err := client.CreateCachedContent(ctx, cachedContent) if err != nil { From 6f196c7609ef51ffc394b87d5314d508c6cf62e5 Mon Sep 17 00:00:00 2001 From: fernandoalonso Date: Tue, 10 Dec 2024 01:41:20 +0000 Subject: [PATCH 3/4] Fix: unit test --- go/plugins/vertexai/cache.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/plugins/vertexai/cache.go b/go/plugins/vertexai/cache.go index 64f717255..c20c148b0 100644 --- a/go/plugins/vertexai/cache.go +++ b/go/plugins/vertexai/cache.go @@ -217,6 +217,7 @@ func handleCacheIfNeeded( cacheKey := generateCacheKey(cachedContent) cachedContent.Expiration = genai.ExpireTimeOrTTL{TTL: calculateTTL(cacheConfig)} + cachedContent.Name = cacheKey newCache, err := client.CreateCachedContent(ctx, cachedContent) if err != nil { return nil, fmt.Errorf("failed to create cache: %w", err) From 7ef509007e35c1dc3a9789b1894bace0f99bd8e2 Mon Sep 17 00:00:00 2001 From: alonsopec89 Date: Tue, 17 Dec 2024 11:56:56 -0600 Subject: [PATCH 4/4] Fix: attempt to fix spaces to 4 --- go/plugins/vertexai/cache.go | 2 - go/plugins/vertexai/vertexai.go | 898 ++++++++++++++++---------------- 2 files changed, 449 insertions(+), 451 deletions(-) diff --git a/go/plugins/vertexai/cache.go b/go/plugins/vertexai/cache.go index c20c148b0..cdae71efd 100644 --- a/go/plugins/vertexai/cache.go +++ b/go/plugins/vertexai/cache.go @@ -138,8 +138,6 @@ func calculateTTL(cacheConfig *CacheConfigDetails) time.Duration { return time.Duration(cacheConfig.TTLSeconds) * time.Second } - - // getKeysFrom returns the keys from the given map as a slice of strings, it is using to get the supported models func getKeysFrom(m map[string]ai.ModelCapabilities) []string { keys := make([]string, 0, len(m)) diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 373eb0a8f..eb8dbed8b 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -15,121 +15,121 @@ package vertexai import ( - "context" - "fmt" - "os" - "runtime" - "strings" - "sync" - - aiplatform "cloud.google.com/go/aiplatform/apiv1" - "cloud.google.com/go/vertexai/genai" - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/internal" - "github.com/firebase/genkit/go/plugins/internal/gemini" - "github.com/firebase/genkit/go/plugins/internal/uri" - "google.golang.org/api/iterator" - "google.golang.org/api/option" + "context" + "fmt" + "os" + "runtime" + "strings" + "sync" + + aiplatform "cloud.google.com/go/aiplatform/apiv1" + "cloud.google.com/go/vertexai/genai" + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/plugins/internal/gemini" + "github.com/firebase/genkit/go/plugins/internal/uri" + "google.golang.org/api/iterator" + "google.golang.org/api/option" ) const ( - provider = "vertexai" - labelPrefix = "Vertex AI" + provider = "vertexai" + labelPrefix = "Vertex AI" ) var ( - knownCaps = map[string]ai.ModelCapabilities{ - "gemini-1.0-pro": gemini.BasicText, - "gemini-1.5-pro": gemini.Multimodal, - "gemini-1.5-flash-002": gemini.Multimodal, - } - - knownEmbedders = []string{ - "textembedding-gecko@003", - "textembedding-gecko@002", - "textembedding-gecko@001", - "text-embedding-004", - "textembedding-gecko-multilingual@001", - "text-multilingual-embedding-002", - "multimodalembedding", - } + knownCaps = map[string]ai.ModelCapabilities{ + "gemini-1.0-pro": gemini.BasicText, + "gemini-1.5-pro": gemini.Multimodal, + "gemini-1.5-flash": gemini.Multimodal, + } + + knownEmbedders = []string{ + "textembedding-gecko@003", + "textembedding-gecko@002", + "textembedding-gecko@001", + "text-embedding-004", + "textembedding-gecko-multilingual@001", + "text-multilingual-embedding-002", + "multimodalembedding", + } ) var state struct { - mu sync.Mutex - initted bool - projectID string - location string - gclient *genai.Client - pclient *aiplatform.PredictionClient + mu sync.Mutex + initted bool + projectID string + location string + gclient *genai.Client + pclient *aiplatform.PredictionClient } // Config is the configuration for the plugin. type Config struct { - // The cloud project to use for Vertex AI. - // If empty, the values of the environment variables GCLOUD_PROJECT - // and GOOGLE_CLOUD_PROJECT will be consulted, in that order. - ProjectID string - // The location of the Vertex AI service. The default is "us-central1". - Location string - // Options to the Vertex AI client. - ClientOptions []option.ClientOption + // The cloud project to use for Vertex AI. + // If empty, the values of the environment variables GCLOUD_PROJECT + // and GOOGLE_CLOUD_PROJECT will be consulted, in that order. + ProjectID string + // The location of the Vertex AI service. The default is "us-central1". + Location string + // Options to the Vertex AI client. + ClientOptions []option.ClientOption } // Init initializes the plugin and all known models and embedders. // After calling Init, you may call [DefineModel] and [DefineEmbedder] to create // and register any additional generative models and embedders func Init(ctx context.Context, cfg *Config) error { - if cfg == nil { - cfg = &Config{} - } - state.mu.Lock() - defer state.mu.Unlock() - if state.initted { - panic("vertexai.Init already called") - } - - state.projectID = cfg.ProjectID - if state.projectID == "" { - state.projectID = os.Getenv("GCLOUD_PROJECT") - } - if state.projectID == "" { - state.projectID = os.Getenv("GOOGLE_CLOUD_PROJECT") - } - if state.projectID == "" { - return fmt.Errorf("vertexai.Init: Vertex AI requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment") - } - - state.location = cfg.Location - if state.location == "" { - state.location = "us-central1" - } - var err error - // Client for Gemini SDK. - opts := append([]option.ClientOption{genai.WithClientInfo("genkit-go", internal.Version)}, cfg.ClientOptions...) - state.gclient, err = genai.NewClient(ctx, state.projectID, state.location, opts...) - if err != nil { - return err - } - endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", state.location) - numConns := max(runtime.GOMAXPROCS(0), 4) - o := []option.ClientOption{ - option.WithEndpoint(endpoint), - option.WithGRPCConnectionPool(numConns), - } - - state.pclient, err = aiplatform.NewPredictionClient(ctx, o...) - if err != nil { - return err - } - state.initted = true - for model, caps := range knownCaps { - defineModel(model, caps) - } - for _, e := range knownEmbedders { - defineEmbedder(e) - } - return nil + if cfg == nil { + cfg = &Config{} + } + state.mu.Lock() + defer state.mu.Unlock() + if state.initted { + panic("vertexai.Init already called") + } + + state.projectID = cfg.ProjectID + if state.projectID == "" { + state.projectID = os.Getenv("GCLOUD_PROJECT") + } + if state.projectID == "" { + state.projectID = os.Getenv("GOOGLE_CLOUD_PROJECT") + } + if state.projectID == "" { + return fmt.Errorf("vertexai.Init: Vertex AI requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment") + } + + state.location = cfg.Location + if state.location == "" { + state.location = "us-central1" + } + var err error + // Client for Gemini SDK. + opts := append([]option.ClientOption{genai.WithClientInfo("genkit-go", internal.Version)}, cfg.ClientOptions...) + state.gclient, err = genai.NewClient(ctx, state.projectID, state.location, opts...) + if err != nil { + return err + } + endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", state.location) + numConns := max(runtime.GOMAXPROCS(0), 4) + o := []option.ClientOption{ + option.WithEndpoint(endpoint), + option.WithGRPCConnectionPool(numConns), + } + + state.pclient, err = aiplatform.NewPredictionClient(ctx, o...) + if err != nil { + return err + } + state.initted = true + for model, caps := range knownCaps { + defineModel(model, caps) + } + for _, e := range knownEmbedders { + defineEmbedder(e) + } + return nil } //copy:sink defineModel from ../googleai/googleai.go @@ -140,42 +140,42 @@ func Init(ctx context.Context, cfg *Config) error { // Use [IsDefinedModel] to determine if a model is already defined. // After [Init] is called, only the known models are defined. func DefineModel(name string, caps *ai.ModelCapabilities) (ai.Model, error) { - state.mu.Lock() - defer state.mu.Unlock() - if !state.initted { - panic(provider + ".Init not called") - } - var mc ai.ModelCapabilities - if caps == nil { - var ok bool - mc, ok = knownCaps[name] - if !ok { - return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelCapabilities", provider, name) - } - } else { - mc = *caps - } - return defineModel(name, mc), nil + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic(provider + ".Init not called") + } + var mc ai.ModelCapabilities + if caps == nil { + var ok bool + mc, ok = knownCaps[name] + if !ok { + return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelCapabilities", provider, name) + } + } else { + mc = *caps + } + return defineModel(name, mc), nil } // requires state.mu func defineModel(name string, caps ai.ModelCapabilities) ai.Model { - meta := &ai.ModelMetadata{ - Label: labelPrefix + " - " + name, - Supports: caps, - } - return ai.DefineModel(provider, name, meta, func( - ctx context.Context, - input *ai.ModelRequest, - cb func(context.Context, *ai.ModelResponseChunk) error, - ) (*ai.ModelResponse, error) { - return generate(ctx, state.gclient, name, input, cb) - }) + meta := &ai.ModelMetadata{ + Label: labelPrefix + " - " + name, + Supports: caps, + } + return ai.DefineModel(provider, name, meta, func( + ctx context.Context, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, + ) (*ai.ModelResponse, error) { + return generate(ctx, state.gclient, name, input, cb) + }) } // IsDefinedModel reports whether the named [Model] is defined by this plugin. func IsDefinedModel(name string) bool { - return ai.IsDefinedModel(provider, name) + return ai.IsDefinedModel(provider, name) } // DO NOT MODIFY above ^^^^ @@ -186,17 +186,17 @@ func IsDefinedModel(name string) bool { // DefineEmbedder defines an embedder with a given name. func DefineEmbedder(name string) ai.Embedder { - state.mu.Lock() - defer state.mu.Unlock() - if !state.initted { - panic(provider + ".Init not called") - } - return defineEmbedder(name) + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic(provider + ".Init not called") + } + return defineEmbedder(name) } // IsDefinedEmbedder reports whether the named [Embedder] is defined by this plugin. func IsDefinedEmbedder(name string) bool { - return ai.IsDefinedEmbedder(provider, name) + return ai.IsDefinedEmbedder(provider, name) } // DO NOT MODIFY above ^^^^ @@ -204,10 +204,10 @@ func IsDefinedEmbedder(name string) bool { // requires state.mu func defineEmbedder(name string) ai.Embedder { - fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name) - return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { - return embed(ctx, fullName, state.pclient, req) - }) + fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name) + return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + return embed(ctx, fullName, state.pclient, req) + }) } //copy:sink lookups from ../googleai/googleai.go @@ -216,13 +216,13 @@ func defineEmbedder(name string) ai.Embedder { // Model returns the [ai.Model] with the given name. // It returns nil if the model was not defined. func Model(name string) ai.Model { - return ai.LookupModel(provider, name) + return ai.LookupModel(provider, name) } // Embedder returns the [ai.Embedder] with the given name. // It returns nil if the embedder was not defined. func Embedder(name string) ai.Embedder { - return ai.LookupEmbedder(provider, name) + return ai.LookupEmbedder(provider, name) } // DO NOT MODIFY above ^^^^ @@ -232,256 +232,256 @@ func Embedder(name string) ai.Embedder { // DO NOT MODIFY below vvvv func generate( - ctx context.Context, - client *genai.Client, - model string, - input *ai.ModelRequest, - cb func(context.Context, *ai.ModelResponseChunk) error, + ctx context.Context, + client *genai.Client, + model string, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, ) (*ai.ModelResponse, error) { - cacheConfig := &CacheConfigDetails{ - TTLSeconds: 3600, // hardcoded to 1 hour - } - - // Attempt to handle caching before creating the model. - cache, err := handleCacheIfNeeded(ctx, client, input, model, cacheConfig) - if err != nil { - return nil, err - } - - gm, err := newModel(client, model, input) - if err != nil { - return nil, err - } - if cache != nil { - gm.CachedContentName = cache.Name - } - cs, err := startChat(gm, input) - if err != nil { - return nil, err - } - // The last message gets added to the parts slice. - var parts []genai.Part - if len(input.Messages) > 0 { - last := input.Messages[len(input.Messages)-1] - var err error - parts, err = convertParts(last.Content) - if err != nil { - return nil, err - } - } - - gm.Tools, err = convertTools(input.Tools) - if err != nil { - return nil, err - } - // Convert input.Tools and append to gm.Tools - - // TODO: gm.ToolConfig? - - // Send out the actual request. - if cb == nil { - resp, err := cs.SendMessage(ctx, parts...) - if err != nil { - return nil, err - } - r := translateResponse(resp) - r.Request = input - return r, nil - } - - // Streaming version. - iter := cs.SendMessageStream(ctx, parts...) - var r *ai.ModelResponse - for { - chunk, err := iter.Next() - if err == iterator.Done { - r = translateResponse(iter.MergedResponse()) - break - } - if err != nil { - return nil, err - } - // Send candidates to the callback. - for _, c := range chunk.Candidates { - tc := translateCandidate(c) - err := cb(ctx, &ai.ModelResponseChunk{ - Content: tc.Message.Content, - }) - if err != nil { - return nil, err - } - } - } - if r == nil { - // No candidates were returned. Probably rare, but it might avoid a NPE - // to return an empty instead of nil result. - r = &ai.ModelResponse{} - } - r.Request = input - return r, nil + cacheConfig := &CacheConfigDetails{ + TTLSeconds: 3600, // hardcoded to 1 hour + } + + // Attempt to handle caching before creating the model. + cache, err := handleCacheIfNeeded(ctx, client, input, model, cacheConfig) + if err != nil { + return nil, err + } + + gm, err := newModel(client, model, input) + if err != nil { + return nil, err + } + if cache != nil { + gm.CachedContentName = cache.Name + } + cs, err := startChat(gm, input) + if err != nil { + return nil, err + } + // The last message gets added to the parts slice. + var parts []genai.Part + if len(input.Messages) > 0 { + last := input.Messages[len(input.Messages)-1] + var err error + parts, err = convertParts(last.Content) + if err != nil { + return nil, err + } + } + + gm.Tools, err = convertTools(input.Tools) + if err != nil { + return nil, err + } + // Convert input.Tools and append to gm.Tools + + // TODO: gm.ToolConfig? + + // Send out the actual request. + if cb == nil { + resp, err := cs.SendMessage(ctx, parts...) + if err != nil { + return nil, err + } + r := translateResponse(resp) + r.Request = input + return r, nil + } + + // Streaming version. + iter := cs.SendMessageStream(ctx, parts...) + var r *ai.ModelResponse + for { + chunk, err := iter.Next() + if err == iterator.Done { + r = translateResponse(iter.MergedResponse()) + break + } + if err != nil { + return nil, err + } + // Send candidates to the callback. + for _, c := range chunk.Candidates { + tc := translateCandidate(c) + err := cb(ctx, &ai.ModelResponseChunk{ + Content: tc.Message.Content, + }) + if err != nil { + return nil, err + } + } + } + if r == nil { + // No candidates were returned. Probably rare, but it might avoid a NPE + // to return an empty instead of nil result. + r = &ai.ModelResponse{} + } + r.Request = input + return r, nil } func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*genai.GenerativeModel, error) { - gm := client.GenerativeModel(model) - gm.SetCandidateCount(1) - if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil { - if c.MaxOutputTokens != 0 { - gm.SetMaxOutputTokens(int32(c.MaxOutputTokens)) - } - if len(c.StopSequences) > 0 { - gm.StopSequences = c.StopSequences - } - if c.Temperature != 0 { - gm.SetTemperature(float32(c.Temperature)) - } - if c.TopK != 0 { - gm.SetTopK(int32(c.TopK)) - } - if c.TopP != 0 { - gm.SetTopP(float32(c.TopP)) - } - } - for _, m := range input.Messages { - systemParts, err := convertParts(m.Content) - if err != nil { - return nil, err - - } - // system prompts go into GenerativeModel.SystemInstruction field. - if m.Role == ai.RoleSystem { - gm.SystemInstruction = &genai.Content{ - Parts: systemParts, - Role: string(m.Role), - } - } - } - return gm, nil + gm := client.GenerativeModel(model) + gm.SetCandidateCount(1) + if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil { + if c.MaxOutputTokens != 0 { + gm.SetMaxOutputTokens(int32(c.MaxOutputTokens)) + } + if len(c.StopSequences) > 0 { + gm.StopSequences = c.StopSequences + } + if c.Temperature != 0 { + gm.SetTemperature(float32(c.Temperature)) + } + if c.TopK != 0 { + gm.SetTopK(int32(c.TopK)) + } + if c.TopP != 0 { + gm.SetTopP(float32(c.TopP)) + } + } + for _, m := range input.Messages { + systemParts, err := convertParts(m.Content) + if err != nil { + return nil, err + + } + // system prompts go into GenerativeModel.SystemInstruction field. + if m.Role == ai.RoleSystem { + gm.SystemInstruction = &genai.Content{ + Parts: systemParts, + Role: string(m.Role), + } + } + } + return gm, nil } // startChat starts a chat session and configures it with the input messages. func startChat(gm *genai.GenerativeModel, input *ai.ModelRequest) (*genai.ChatSession, error) { - cs := gm.StartChat() - - // All but the last message goes in the history field. - messages := input.Messages - for len(messages) > 1 { - m := messages[0] - messages = messages[1:] - - // skip system prompt message, it's handled separately. - if m.Role == ai.RoleSystem { - continue - } - - parts, err := convertParts(m.Content) - if err != nil { - return nil, err - } - cs.History = append(cs.History, &genai.Content{ - Parts: parts, - Role: string(m.Role), - }) - } - return cs, nil + cs := gm.StartChat() + + // All but the last message goes in the history field. + messages := input.Messages + for len(messages) > 1 { + m := messages[0] + messages = messages[1:] + + // skip system prompt message, it's handled separately. + if m.Role == ai.RoleSystem { + continue + } + + parts, err := convertParts(m.Content) + if err != nil { + return nil, err + } + cs.History = append(cs.History, &genai.Content{ + Parts: parts, + Role: string(m.Role), + }) + } + return cs, nil } func convertTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { - var outTools []*genai.Tool - for _, t := range inTools { - inputSchema, err := convertSchema(t.InputSchema, t.InputSchema) - if err != err { - return nil, err - } - fd := &genai.FunctionDeclaration{ - Name: t.Name, - Parameters: inputSchema, - Description: t.Description, - } - outTools = append(outTools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}}) - } - return outTools, nil + var outTools []*genai.Tool + for _, t := range inTools { + inputSchema, err := convertSchema(t.InputSchema, t.InputSchema) + if err != err { + return nil, err + } + fd := &genai.FunctionDeclaration{ + Name: t.Name, + Parameters: inputSchema, + Description: t.Description, + } + outTools = append(outTools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}}) + } + return outTools, nil } func convertSchema(originalSchema map[string]any, genkitSchema map[string]any) (*genai.Schema, error) { - // this covers genkitSchema == nil and {} - // genkitSchema will be {} if it's any - if len(genkitSchema) == 0 { - return nil, nil - } - if v, ok := genkitSchema["$ref"]; ok { - ref := v.(string) - return convertSchema(originalSchema, resolveRef(originalSchema, ref)) - } - schema := &genai.Schema{} - - switch genkitSchema["type"].(string) { - case "string": - schema.Type = genai.TypeString - case "float64": - schema.Type = genai.TypeNumber - case "number": - schema.Type = genai.TypeNumber - case "int": - schema.Type = genai.TypeInteger - case "bool": - schema.Type = genai.TypeBoolean - case "object": - schema.Type = genai.TypeObject - case "array": - schema.Type = genai.TypeArray - default: - return nil, fmt.Errorf("schema type %q not allowed", genkitSchema["type"]) - } - if v, ok := genkitSchema["required"]; ok { - schema.Required = castToStringArray(v.([]any)) - } - if v, ok := genkitSchema["description"]; ok { - schema.Description = v.(string) - } - if v, ok := genkitSchema["format"]; ok { - schema.Format = v.(string) - } - if v, ok := genkitSchema["enum"]; ok { - schema.Enum = castToStringArray(v.([]any)) - } - if v, ok := genkitSchema["items"]; ok { - items, err := convertSchema(originalSchema, v.(map[string]any)) - if err != nil { - return nil, err - } - schema.Items = items - } - if val, ok := genkitSchema["properties"]; ok { - props := map[string]*genai.Schema{} - for k, v := range val.(map[string]any) { - p, err := convertSchema(originalSchema, v.(map[string]any)) - if err != nil { - return nil, err - } - props[k] = p - } - schema.Properties = props - } - // Nullable -- not supported in jsonschema.Schema - - return schema, nil + // this covers genkitSchema == nil and {} + // genkitSchema will be {} if it's any + if len(genkitSchema) == 0 { + return nil, nil + } + if v, ok := genkitSchema["$ref"]; ok { + ref := v.(string) + return convertSchema(originalSchema, resolveRef(originalSchema, ref)) + } + schema := &genai.Schema{} + + switch genkitSchema["type"].(string) { + case "string": + schema.Type = genai.TypeString + case "float64": + schema.Type = genai.TypeNumber + case "number": + schema.Type = genai.TypeNumber + case "int": + schema.Type = genai.TypeInteger + case "bool": + schema.Type = genai.TypeBoolean + case "object": + schema.Type = genai.TypeObject + case "array": + schema.Type = genai.TypeArray + default: + return nil, fmt.Errorf("schema type %q not allowed", genkitSchema["type"]) + } + if v, ok := genkitSchema["required"]; ok { + schema.Required = castToStringArray(v.([]any)) + } + if v, ok := genkitSchema["description"]; ok { + schema.Description = v.(string) + } + if v, ok := genkitSchema["format"]; ok { + schema.Format = v.(string) + } + if v, ok := genkitSchema["enum"]; ok { + schema.Enum = castToStringArray(v.([]any)) + } + if v, ok := genkitSchema["items"]; ok { + items, err := convertSchema(originalSchema, v.(map[string]any)) + if err != nil { + return nil, err + } + schema.Items = items + } + if val, ok := genkitSchema["properties"]; ok { + props := map[string]*genai.Schema{} + for k, v := range val.(map[string]any) { + p, err := convertSchema(originalSchema, v.(map[string]any)) + if err != nil { + return nil, err + } + props[k] = p + } + schema.Properties = props + } + // Nullable -- not supported in jsonschema.Schema + + return schema, nil } func resolveRef(originalSchema map[string]any, ref string) map[string]any { - tkns := strings.Split(ref, "/") - // refs look like: $/ref/foo -- we need the foo part - name := tkns[len(tkns)-1] - defs := originalSchema["$defs"].(map[string]any) - return defs[name].(map[string]any) + tkns := strings.Split(ref, "/") + // refs look like: $/ref/foo -- we need the foo part + name := tkns[len(tkns)-1] + defs := originalSchema["$defs"].(map[string]any) + return defs[name].(map[string]any) } func castToStringArray(i []any) []string { - // Is there a better way to do this?? - var r []string - for _, v := range i { - r = append(r, v.(string)) - } - return r + // Is there a better way to do this?? + var r []string + for _, v := range i { + r = append(r, v.(string)) + } + return r } // DO NOT MODIFY above ^^^^ @@ -492,42 +492,42 @@ func castToStringArray(i []any) []string { // translateCandidate translates from a genai.GenerateContentResponse to an ai.ModelResponse. func translateCandidate(cand *genai.Candidate) *ai.ModelResponse { - m := &ai.ModelResponse{} - switch cand.FinishReason { - case genai.FinishReasonStop: - m.FinishReason = ai.FinishReasonStop - case genai.FinishReasonMaxTokens: - m.FinishReason = ai.FinishReasonLength - case genai.FinishReasonSafety: - m.FinishReason = ai.FinishReasonBlocked - case genai.FinishReasonRecitation: - m.FinishReason = ai.FinishReasonBlocked - case genai.FinishReasonOther: - m.FinishReason = ai.FinishReasonOther - default: // Unspecified - m.FinishReason = ai.FinishReasonUnknown - } - msg := &ai.Message{} - msg.Role = ai.Role(cand.Content.Role) - for _, part := range cand.Content.Parts { - var p *ai.Part - switch part := part.(type) { - case genai.Text: - p = ai.NewTextPart(string(part)) - case genai.Blob: - p = ai.NewMediaPart(part.MIMEType, string(part.Data)) - case genai.FunctionCall: - p = ai.NewToolRequestPart(&ai.ToolRequest{ - Name: part.Name, - Input: part.Args, - }) - default: - panic(fmt.Sprintf("unknown part %#v", part)) - } - msg.Content = append(msg.Content, p) - } - m.Message = msg - return m + m := &ai.ModelResponse{} + switch cand.FinishReason { + case genai.FinishReasonStop: + m.FinishReason = ai.FinishReasonStop + case genai.FinishReasonMaxTokens: + m.FinishReason = ai.FinishReasonLength + case genai.FinishReasonSafety: + m.FinishReason = ai.FinishReasonBlocked + case genai.FinishReasonRecitation: + m.FinishReason = ai.FinishReasonBlocked + case genai.FinishReasonOther: + m.FinishReason = ai.FinishReasonOther + default: // Unspecified + m.FinishReason = ai.FinishReasonUnknown + } + msg := &ai.Message{} + msg.Role = ai.Role(cand.Content.Role) + for _, part := range cand.Content.Parts { + var p *ai.Part + switch part := part.(type) { + case genai.Text: + p = ai.NewTextPart(string(part)) + case genai.Blob: + p = ai.NewMediaPart(part.MIMEType, string(part.Data)) + case genai.FunctionCall: + p = ai.NewToolRequestPart(&ai.ToolRequest{ + Name: part.Name, + Input: part.Args, + }) + default: + panic(fmt.Sprintf("unknown part %#v", part)) + } + msg.Content = append(msg.Content, p) + } + m.Message = msg + return m } // DO NOT MODIFY above ^^^^ @@ -538,15 +538,15 @@ func translateCandidate(cand *genai.Candidate) *ai.ModelResponse { // Translate from a genai.GenerateContentResponse to a ai.ModelResponse. func translateResponse(resp *genai.GenerateContentResponse) *ai.ModelResponse { - r := translateCandidate(resp.Candidates[0]) - - r.Usage = &ai.GenerationUsage{} - if u := resp.UsageMetadata; u != nil { - r.Usage.InputTokens = int(u.PromptTokenCount) - r.Usage.OutputTokens = int(u.CandidatesTokenCount) - r.Usage.TotalTokens = int(u.TotalTokenCount) - } - return r + r := translateCandidate(resp.Candidates[0]) + + r.Usage = &ai.GenerationUsage{} + if u := resp.UsageMetadata; u != nil { + r.Usage.InputTokens = int(u.PromptTokenCount) + r.Usage.OutputTokens = int(u.CandidatesTokenCount) + r.Usage.TotalTokens = int(u.TotalTokenCount) + } + return r } // DO NOT MODIFY above ^^^^ @@ -557,47 +557,47 @@ func translateResponse(resp *genai.GenerateContentResponse) *ai.ModelResponse { // convertParts converts a slice of *ai.Part to a slice of genai.Part. func convertParts(parts []*ai.Part) ([]genai.Part, error) { - res := make([]genai.Part, 0, len(parts)) - for _, p := range parts { - part, err := convertPart(p) - if err != nil { - return nil, err - } - res = append(res, part) - } - return res, nil + res := make([]genai.Part, 0, len(parts)) + for _, p := range parts { + part, err := convertPart(p) + if err != nil { + return nil, err + } + res = append(res, part) + } + return res, nil } // convertPart converts a *ai.Part to a genai.Part. func convertPart(p *ai.Part) (genai.Part, error) { - switch { - case p.IsText(): - return genai.Text(p.Text), nil - case p.IsMedia(): - contentType, data, err := uri.Data(p) - if err != nil { - return nil, err - } - return genai.Blob{MIMEType: contentType, Data: data}, nil - case p.IsData(): - panic(fmt.Sprintf("%s does not support Data parts", provider)) - case p.IsToolResponse(): - toolResp := p.ToolResponse - fr := genai.FunctionResponse{ - Name: toolResp.Name, - Response: toolResp.Output, - } - return fr, nil - case p.IsToolRequest(): - toolReq := p.ToolRequest - fc := genai.FunctionCall{ - Name: toolReq.Name, - Args: toolReq.Input, - } - return fc, nil - default: - panic("unknown part type in a request") - } + switch { + case p.IsText(): + return genai.Text(p.Text), nil + case p.IsMedia(): + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + return genai.Blob{MIMEType: contentType, Data: data}, nil + case p.IsData(): + panic(fmt.Sprintf("%s does not support Data parts", provider)) + case p.IsToolResponse(): + toolResp := p.ToolResponse + fr := genai.FunctionResponse{ + Name: toolResp.Name, + Response: toolResp.Output, + } + return fr, nil + case p.IsToolRequest(): + toolReq := p.ToolRequest + fc := genai.FunctionCall{ + Name: toolReq.Name, + Args: toolReq.Input, + } + return fc, nil + default: + panic("unknown part type in a request") + } } // DO NOT MODIFY above ^^^^