diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 6bd4b3eaef..8b192542a4 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -91,12 +91,10 @@ func Init(ctx context.Context, apiKey string) (err error) { state.client = client state.initted = true for model, caps := range knownCaps { - if _, err := DefineModel(model, &caps); err != nil { - return fmt.Errorf("googleai.Init: failed to define known model %q: %w", model, err) - } + defineModel(model, caps) } for _, e := range knownEmbedders { - DefineEmbedder(e) + defineEmbedder(e) } return nil } @@ -111,8 +109,8 @@ func IsKnownModel(name string) bool { // The second argument describes the capability of the model. // Use [IsKnownModel] to determine if a model is known. func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { - // state.mu.Lock() - // defer state.mu.Unlock() + state.mu.Lock() + defer state.mu.Unlock() if !state.initted { panic("googleai.Init not called") } @@ -152,8 +150,8 @@ func defineModel(name string, caps ai.ModelCapabilities) *ai.Model { // DefineEmbedder defines an embedder with a given name. func DefineEmbedder(name string) *ai.Embedder { - // state.mu.Lock() - // defer state.mu.Unlock() + state.mu.Lock() + defer state.mu.Unlock() if !state.initted { panic("googleai.Init not called") } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 2b5b6f6058..a822189a59 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -114,12 +114,10 @@ func Init(ctx context.Context, projectID, location string) error { } state.initted = true for model, caps := range knownCaps { - if _, err := DefineModel(model, &caps); err != nil { - return fmt.Errorf("vertexai.Init: failed to define known model %q: %w", model, err) - } + defineModel(model, caps) } for _, e := range knownEmbedders { - DefineEmbedder(e) + defineEmbedder(e) } return nil } @@ -128,8 +126,8 @@ func Init(ctx context.Context, projectID, location string) error { // The second argument describes the capability of the model. // Use [IsKnownModel] to determine if a model is known. func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { - // state.mu.Lock() - // defer state.mu.Unlock() + state.mu.Lock() + defer state.mu.Unlock() if !state.initted { panic("vertexai.Init not called") } @@ -143,13 +141,17 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { } else { mc = *caps } + return defineModel(name, mc), nil +} +// requires state.mu +func defineModel(name string, mc ai.ModelCapabilities) *ai.Model { meta := &ai.ModelMetadata{ Label: "Vertex AI - " + name, Supports: mc, } g := &generator{model: name, client: state.gclient} - return ai.DefineModel(provider, name, meta, g.generate), nil + return ai.DefineModel(provider, name, meta, g.generate) } // IsKnownModel reports whether a model is known to this plugin. @@ -169,13 +171,18 @@ func KnownModels() []string { return keys } -// DefineModel defines an embedder with the given name. +// DefineEmbedder defines an embedder with the given name. func DefineEmbedder(name string) *ai.Embedder { - // state.mu.Lock() - // defer state.mu.Unlock() + state.mu.Lock() + defer state.mu.Unlock() if !state.initted { panic("vertexai.Init not called") } + return defineEmbedder(name) +} + +// requires state.mu +func defineEmbedder(name string) *ai.Embedder { fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name) return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { return embed(ctx, fullName, state.pclient, req)