Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ func Init(ctx context.Context, apiKey string) (err error) {
state.client = client
state.initted = true
for model, caps := range knownCaps {
if _, err := DefineModel(model, &caps); err != nil {
return fmt.Errorf("googleai.Init: failed to define known model %q: %w", model, err)
}
defineModel(model, caps)
}
for _, e := range knownEmbedders {
DefineEmbedder(e)
defineEmbedder(e)
}
return nil
}
Expand All @@ -111,8 +109,8 @@ func IsKnownModel(name string) bool {
// The second argument describes the capability of the model.
// Use [IsKnownModel] to determine if a model is known.
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
// state.mu.Lock()
// defer state.mu.Unlock()
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("googleai.Init not called")
}
Expand Down Expand Up @@ -152,8 +150,8 @@ func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {

// DefineEmbedder defines an embedder with a given name.
func DefineEmbedder(name string) *ai.Embedder {
// state.mu.Lock()
// defer state.mu.Unlock()
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("googleai.Init not called")
}
Expand Down
27 changes: 17 additions & 10 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,10 @@ func Init(ctx context.Context, projectID, location string) error {
}
state.initted = true
for model, caps := range knownCaps {
if _, err := DefineModel(model, &caps); err != nil {
return fmt.Errorf("vertexai.Init: failed to define known model %q: %w", model, err)
}
defineModel(model, caps)
}
for _, e := range knownEmbedders {
DefineEmbedder(e)
defineEmbedder(e)
}
return nil
}
Expand All @@ -128,8 +126,8 @@ func Init(ctx context.Context, projectID, location string) error {
// The second argument describes the capability of the model.
// Use [IsKnownModel] to determine if a model is known.
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
// state.mu.Lock()
// defer state.mu.Unlock()
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("vertexai.Init not called")
}
Expand All @@ -143,13 +141,17 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
} else {
mc = *caps
}
return defineModel(name, mc), nil
}

// requires state.mu
func defineModel(name string, mc ai.ModelCapabilities) *ai.Model {
meta := &ai.ModelMetadata{
Label: "Vertex AI - " + name,
Supports: mc,
}
g := &generator{model: name, client: state.gclient}
return ai.DefineModel(provider, name, meta, g.generate), nil
return ai.DefineModel(provider, name, meta, g.generate)
}

// IsKnownModel reports whether a model is known to this plugin.
Expand All @@ -169,13 +171,18 @@ func KnownModels() []string {
return keys
}

// DefineModel defines an embedder with the given name.
// DefineEmbedder defines an embedder with the given name.
func DefineEmbedder(name string) *ai.Embedder {
// state.mu.Lock()
// defer state.mu.Unlock()
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("vertexai.Init not called")
}
return defineEmbedder(name)
}

// requires state.mu
func defineEmbedder(name string) *ai.Embedder {
fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name)
return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
return embed(ctx, fullName, state.pclient, req)
Expand Down