Skip to content

Commit c545ca5

Browse files
hugoaguirrehendrixmar
authored andcommitted
fix(go/plugins/googlegenai): show config schema in model runner view (#3484)
1 parent de9ac7b commit c545ca5

File tree

4 files changed

+107
-51
lines changed

4 files changed

+107
-51
lines changed

go/plugins/googlegenai/gemini.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,13 @@ func configToMap(config any) map[string]any {
8989
r := jsonschema.Reflector{
9090
DoNotReference: true, // Prevent $ref usage
9191
ExpandedStruct: true, // Include all fields directly
92-
// Prevent stack overflow panic due type traversal recursion (circular references)
93-
// [genai.Schema] should not be used at this point since Schema is provided later
9492
// NOTE: keep track of updated fields in [genai.GenerateContentConfig] since
9593
// they could create runtime panics when parsing fields with type recursion
96-
IgnoredTypes: []any{genai.Schema{}},
94+
IgnoredTypes: []any{
95+
genai.Schema{},
96+
},
9797
}
98+
9899
schema := r.Reflect(config)
99100
result := base.SchemaAsMap(schema)
100101
return result
@@ -141,9 +142,8 @@ func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model
141142

142143
var config any
143144
config = &genai.GenerateContentConfig{}
144-
if imageOpts, found := supportedImagenModels[name]; found {
145+
if strings.Contains(name, "imagen") {
145146
config = &genai.GenerateImagesConfig{}
146-
opts = imageOpts
147147
}
148148
meta := &ai.ModelOptions{
149149
Label: opts.Label,
@@ -712,7 +712,6 @@ func toGeminiParts(parts []*ai.Part) ([]*genai.Part, error) {
712712

713713
// toGeminiPart converts a [ai.Part] to a [genai.Part].
714714
func toGeminiPart(p *ai.Part) (*genai.Part, error) {
715-
716715
switch {
717716
case p.IsReasoning():
718717
// TODO: go-genai does not support genai.NewPartFromThought()

go/plugins/googlegenai/googlegenai.go

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -324,51 +324,81 @@ func (ga *GoogleAI) ListActions(ctx context.Context) []api.ActionDesc {
324324
"systemRole": true,
325325
"tools": true,
326326
"toolChoice": true,
327-
"constrained": true,
327+
"constrained": "no-tools",
328328
},
329-
"versions": []string{},
330-
"stage": string(ai.ModelStageStable),
329+
"versions": []string{},
330+
"stage": string(ai.ModelStageStable),
331+
"customOptions": configToMap(&genai.GenerateContentConfig{}),
331332
},
332333
}
333334
metadata["label"] = fmt.Sprintf("%s - %s", googleAILabelPrefix, name)
334335

335336
actions = append(actions, api.ActionDesc{
336337
Type: api.ActionTypeModel,
337-
Name: fmt.Sprintf("%s/%s", googleAIProvider, name),
338-
Key: fmt.Sprintf("/%s/%s/%s", api.ActionTypeModel, googleAIProvider, name),
338+
Name: api.NewName(googleAIProvider, name),
339+
Key: api.NewKey(api.ActionTypeModel, googleAIProvider, name),
340+
Metadata: metadata,
341+
})
342+
}
343+
344+
for _, name := range models.imagen {
345+
metadata := map[string]any{
346+
"model": map[string]any{
347+
"supports": map[string]any{
348+
"media": true,
349+
"multiturn": true,
350+
"systemRole": false,
351+
"tools": false,
352+
"toolChoice": false,
353+
"constrained": "no-tools",
354+
},
355+
"versions": []string{},
356+
"stage": string(ai.ModelStageStable),
357+
"customOptions": configToMap(&genai.GenerateImagesConfig{}),
358+
},
359+
}
360+
metadata["label"] = fmt.Sprintf("%s - %s", googleAILabelPrefix, name)
361+
362+
actions = append(actions, api.ActionDesc{
363+
Type: api.ActionTypeModel,
364+
Name: api.NewName(googleAIProvider, name),
365+
Key: api.NewKey(api.ActionTypeModel, googleAIProvider, name),
339366
Metadata: metadata,
340367
})
341368
}
342369

343370
for _, e := range models.embedders {
344371
actions = append(actions, api.ActionDesc{
345372
Type: api.ActionTypeEmbedder,
346-
Name: fmt.Sprintf("%s/%s", googleAIProvider, e),
347-
Key: fmt.Sprintf("/%s/%s/%s", api.ActionTypeEmbedder, googleAIProvider, e),
373+
Name: api.NewName(googleAIProvider, e),
374+
Key: api.NewKey(api.ActionTypeEmbedder, googleAIProvider, e),
348375
})
349376
}
350377

351378
return actions
352379
}
353380

354381
func (ga *GoogleAI) ResolveAction(atype api.ActionType, name string) api.Action {
382+
var config any
355383
switch atype {
356384
case api.ActionTypeEmbedder:
357385
return newEmbedder(ga.gclient, name, &ai.EmbedderOptions{}).(api.Action)
358386
case api.ActionTypeModel:
359-
var supports *ai.ModelSupports
360-
if strings.Contains(name, "gemini") || strings.Contains(name, "gemma") {
361-
supports = &Multimodal
387+
supports := &Multimodal
388+
config = &genai.GenerateContentConfig{}
389+
if strings.Contains(name, "imagen") {
390+
supports = &Media
391+
config = &genai.GenerateImagesConfig{}
362392
}
363393

364394
return newModel(ga.gclient, name, ai.ModelOptions{
365-
Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name),
366-
Stage: ai.ModelStageStable,
367-
Versions: []string{},
368-
Supports: supports,
395+
Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name),
396+
Stage: ai.ModelStageStable,
397+
Versions: []string{},
398+
Supports: supports,
399+
ConfigSchema: configToMap(config),
369400
}).(api.Action)
370401
}
371-
372402
return nil
373403
}
374404

@@ -388,47 +418,77 @@ func (v *VertexAI) ListActions(ctx context.Context) []api.ActionDesc {
388418
"systemRole": true,
389419
"tools": true,
390420
"toolChoice": true,
391-
"constrained": true,
421+
"constrained": "no-tools",
422+
},
423+
"versions": []string{},
424+
"stage": string(ai.ModelStageStable),
425+
"customOptions": configToMap(&genai.GenerateContentConfig{}),
426+
},
427+
}
428+
metadata["label"] = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name)
429+
actions = append(actions, api.ActionDesc{
430+
Type: api.ActionTypeModel,
431+
Name: api.NewName(vertexAIProvider, name),
432+
Key: api.NewKey(api.ActionTypeModel, vertexAIProvider, name),
433+
Metadata: metadata,
434+
})
435+
}
436+
437+
for _, name := range models.imagen {
438+
metadata := map[string]any{
439+
"model": map[string]any{
440+
"supports": map[string]any{
441+
"media": true,
442+
"multiturn": true,
443+
"systemRole": false,
444+
"tools": false,
445+
"toolChoice": false,
446+
"constrained": "no-tools",
392447
},
393-
"versions": []string{},
394-
"stage": string(ai.ModelStageStable),
448+
"versions": []string{},
449+
"stage": string(ai.ModelStageStable),
450+
"customOptions": configToMap(&genai.GenerateImagesConfig{}),
395451
},
396452
}
397453
metadata["label"] = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name)
398454
actions = append(actions, api.ActionDesc{
399455
Type: api.ActionTypeModel,
400-
Name: fmt.Sprintf("%s/%s", vertexAIProvider, name),
401-
Key: fmt.Sprintf("/%s/%s/%s", api.ActionTypeModel, vertexAIProvider, name),
456+
Name: api.NewName(vertexAIProvider, name),
457+
Key: api.NewKey(api.ActionTypeModel, vertexAIProvider, name),
402458
Metadata: metadata,
403459
})
404460
}
405461

406462
for _, e := range models.embedders {
407463
actions = append(actions, api.ActionDesc{
408464
Type: api.ActionTypeEmbedder,
409-
Name: fmt.Sprintf("%s/%s", vertexAIProvider, e),
410-
Key: fmt.Sprintf("/%s/%s/%s", api.ActionTypeEmbedder, vertexAIProvider, e),
465+
Name: api.NewName(vertexAIProvider, e),
466+
Key: api.NewKey(api.ActionTypeEmbedder, vertexAIProvider, e),
411467
})
412468
}
413469

414470
return actions
415471
}
416472

417473
func (v *VertexAI) ResolveAction(atype api.ActionType, name string) api.Action {
474+
var config any
418475
switch atype {
419476
case api.ActionTypeEmbedder:
420477
return newEmbedder(v.gclient, name, &ai.EmbedderOptions{}).(api.Action)
421478
case api.ActionTypeModel:
422-
var supports *ai.ModelSupports
423-
if strings.Contains(name, "gemini") {
424-
supports = &Multimodal
479+
supports := &Multimodal
480+
config = &genai.GenerateContentConfig{}
481+
if strings.Contains(name, "imagen") {
482+
supports = &Media
483+
config = &genai.GenerateImagesConfig{}
425484
}
426485

427486
return newModel(v.gclient, name, ai.ModelOptions{
428-
Label: fmt.Sprintf("%s - %s", vertexAILabelPrefix, name),
429-
Stage: ai.ModelStageStable,
430-
Versions: []string{},
431-
Supports: supports,
487+
Label: fmt.Sprintf("%s - %s", vertexAILabelPrefix, name),
488+
Stage: ai.ModelStageStable,
489+
Versions: []string{},
490+
Supports: supports,
491+
ConfigSchema: configToMap(config),
432492
}).(api.Action)
433493
}
434494
return nil

go/plugins/googlegenai/models.go

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,6 @@ type genaiModels struct {
406406
func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, error) {
407407
models := genaiModels{}
408408
allowedModels := []string{"gemini", "gemma"}
409-
allowedImagenModels := []string{"imagen"}
410409

411410
for item, err := range client.Models.All(ctx) {
412411
var name string
@@ -428,22 +427,20 @@ func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, er
428427
continue
429428
}
430429

431-
found := slices.ContainsFunc(allowedModels, func(s string) bool {
432-
return strings.Contains(name, s)
433-
})
434-
// filter out: Aqa, Text-bison, Chat, learnlm
435-
if found {
436-
models.gemini = append(models.gemini, name)
430+
if slices.Contains(item.SupportedActions, "predict") && strings.Contains(name, "imagen") {
431+
models.imagen = append(models.imagen, name)
437432
continue
438433
}
439434

440-
found = slices.ContainsFunc(allowedImagenModels, func(s string) bool {
441-
return strings.Contains(name, s)
442-
})
443-
// filter out: Aqa, Text-bison, Chat, learnlm
444-
if found {
445-
models.imagen = append(models.imagen, name)
446-
continue
435+
if slices.Contains(item.SupportedActions, "generateContent") {
436+
found := slices.ContainsFunc(allowedModels, func(s string) bool {
437+
return strings.Contains(name, s)
438+
})
439+
// filter out: Aqa, Text-bison, Chat, learnlm
440+
if found {
441+
models.gemini = append(models.gemini, name)
442+
continue
443+
}
447444
}
448445
}
449446

go/samples/prompts/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func SimplePrompt(ctx context.Context, g *genkit.Genkit) {
6262
// Define prompt with default model and system text.
6363
helloPrompt := genkit.DefinePrompt(
6464
g, "SimplePrompt",
65-
ai.WithModelName("vertexai/gemini-2.0-flash-lite"), // Override the default model.
65+
ai.WithModelName("vertexai/gemini-2.5-pro"), // Override the default model.
6666
ai.WithSystem("You are a helpful AI assistant named Walt. Greet the user."),
6767
ai.WithPrompt("Hello, who are you?"),
6868
)
@@ -272,7 +272,7 @@ func PromptWithExecuteOverrides(ctx context.Context, g *genkit.Genkit) {
272272

273273
// Call the model and add additional messages from the user.
274274
resp, err := helloPrompt.Execute(ctx,
275-
ai.WithModel(googlegenai.VertexAIModel(g, "gemini-2.0-flash-lite")),
275+
ai.WithModel(googlegenai.VertexAIModel(g, "gemini-2.5-pro")),
276276
ai.WithMessages(ai.NewUserTextMessage("And I like turtles.")),
277277
)
278278
if err != nil {

0 commit comments

Comments
 (0)