Skip to content

Commit 70bb3a8

Browse files
committed
[Go] plugins: fix locking
Ensure that the global state is protected during Init, DefineModel and DefineEmbedder for the googleai and vertexai plugins.
1 parent 959c7f2 commit 70bb3a8

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

go/plugins/googleai/googleai.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,10 @@ func Init(ctx context.Context, apiKey string) (err error) {
9191
state.client = client
9292
state.initted = true
9393
for model, caps := range knownCaps {
94-
if _, err := DefineModel(model, &caps); err != nil {
95-
return fmt.Errorf("googleai.Init: failed to define known model %q: %w", model, err)
96-
}
94+
defineModel(model, caps)
9795
}
9896
for _, e := range knownEmbedders {
99-
DefineEmbedder(e)
97+
defineEmbedder(e)
10098
}
10199
return nil
102100
}
@@ -111,8 +109,8 @@ func IsKnownModel(name string) bool {
111109
// The second argument describes the capability of the model.
112110
// Use [IsKnownModel] to determine if a model is known.
113111
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
114-
// state.mu.Lock()
115-
// defer state.mu.Unlock()
112+
state.mu.Lock()
113+
defer state.mu.Unlock()
116114
if !state.initted {
117115
panic("googleai.Init not called")
118116
}
@@ -152,8 +150,8 @@ func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
152150

153151
// DefineEmbedder defines an embedder with a given name.
154152
func DefineEmbedder(name string) *ai.Embedder {
155-
// state.mu.Lock()
156-
// defer state.mu.Unlock()
153+
state.mu.Lock()
154+
defer state.mu.Unlock()
157155
if !state.initted {
158156
panic("googleai.Init not called")
159157
}

go/plugins/vertexai/vertexai.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,10 @@ func Init(ctx context.Context, projectID, location string) error {
114114
}
115115
state.initted = true
116116
for model, caps := range knownCaps {
117-
if _, err := DefineModel(model, &caps); err != nil {
118-
return fmt.Errorf("vertexai.Init: failed to define known model %q: %w", model, err)
119-
}
117+
defineModel(model, caps)
120118
}
121119
for _, e := range knownEmbedders {
122-
DefineEmbedder(e)
120+
defineEmbedder(e)
123121
}
124122
return nil
125123
}
@@ -128,8 +126,8 @@ func Init(ctx context.Context, projectID, location string) error {
128126
// The second argument describes the capability of the model.
129127
// Use [IsKnownModel] to determine if a model is known.
130128
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
131-
// state.mu.Lock()
132-
// defer state.mu.Unlock()
129+
state.mu.Lock()
130+
defer state.mu.Unlock()
133131
if !state.initted {
134132
panic("vertexai.Init not called")
135133
}
@@ -143,13 +141,17 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
143141
} else {
144142
mc = *caps
145143
}
144+
return defineModel(name, mc), nil
145+
}
146146

147+
// requires state.mu
148+
func defineModel(name string, mc ai.ModelCapabilities) *ai.Model {
147149
meta := &ai.ModelMetadata{
148150
Label: "Vertex AI - " + name,
149151
Supports: mc,
150152
}
151153
g := &generator{model: name, client: state.gclient}
152-
return ai.DefineModel(provider, name, meta, g.generate), nil
154+
return ai.DefineModel(provider, name, meta, g.generate)
153155
}
154156

155157
// IsKnownModel reports whether a model is known to this plugin.
@@ -169,13 +171,18 @@ func KnownModels() []string {
169171
return keys
170172
}
171173

172-
// DefineModel defines an embedder with the given name.
174+
// DefineEmbedder defines an embedder with the given name.
173175
func DefineEmbedder(name string) *ai.Embedder {
174-
// state.mu.Lock()
175-
// defer state.mu.Unlock()
176+
state.mu.Lock()
177+
defer state.mu.Unlock()
176178
if !state.initted {
177179
panic("vertexai.Init not called")
178180
}
181+
return defineEmbedder(name)
182+
}
183+
184+
// requires state.mu
185+
func defineEmbedder(name string) *ai.Embedder {
179186
fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name)
180187
return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
181188
return embed(ctx, fullName, state.pclient, req)

0 commit comments

Comments
 (0)