Skip to content

Commit ded8b3d

Browse files
committed
[Go] ai.Model is a separate type from its action
This continues the work begun in #402 of making the main ai types distinct from their underlying actions, instead of aliases. This allows the types to have methods, unstead of using top-level functions. It also clarifies documentation and other output, like panic stack traces.
1 parent d241336 commit ded8b3d

File tree

17 files changed

+55
-54
lines changed

17 files changed

+55
-54
lines changed

go/ai/generate.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ import (
2828
"github.com/firebase/genkit/go/internal/atype"
2929
)
3030

31-
// A ModelAction is used to generate content from an AI model.
32-
type ModelAction = core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]
31+
// A Model is used to generate content from an AI model.
32+
type Model core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]
3333

3434
// ModelStreamingCallback is the type for the streaming callback of a model.
3535
type ModelStreamingCallback = func(context.Context, *GenerateResponseChunk) error
@@ -50,7 +50,7 @@ type ModelMetadata struct {
5050

5151
// DefineModel registers the given generate function as an action, and returns a
5252
// [ModelAction] that runs it.
53-
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *ModelAction {
53+
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *Model {
5454
metadataMap := map[string]any{}
5555
if metadata != nil {
5656
if metadata.Label != "" {
@@ -64,25 +64,26 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
6464
}
6565
metadataMap["supports"] = supports
6666
}
67-
return core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
67+
return (*Model)(core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
6868
"model": metadataMap,
69-
}, generate)
69+
}, generate))
7070
}
7171

7272
// LookupModel looks up a [ModelAction] registered by [DefineModel].
7373
// It returns nil if the model was not defined.
74-
func LookupModel(provider, name string) *ModelAction {
75-
return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name)
74+
func LookupModel(provider, name string) *Model {
75+
return (*Model)(core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name))
7676
}
7777

78-
// Generate applies a [ModelAction] to some input, handling tool requests.
79-
func Generate(ctx context.Context, g *ModelAction, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) {
78+
// Generate applies the [Model] to some input, handling tool requests.
79+
func (m *Model) Generate(ctx context.Context, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) {
8080
if err := conformOutput(req); err != nil {
8181
return nil, err
8282
}
8383

84+
a := (*core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk])(m)
8485
for {
85-
resp, err := g.Run(ctx, req, cb)
86+
resp, err := a.Run(ctx, req, cb)
8687
if err != nil {
8788
return nil, err
8889
}

go/plugins/dotprompt/dotprompt.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ type Config struct {
8181
// The prompt variant.
8282
Variant string
8383
// The name of the model for which the prompt is input.
84-
// If this is non-empty, ModelAction should be nil.
85-
Model string
84+
// If this is non-empty, Model should be nil.
85+
ModelName string
8686

87-
// The ModelAction to use.
87+
// The Model to use.
8888
// If this is non-nil, Model should be the empty string.
89-
ModelAction *ai.ModelAction
89+
Model *ai.Model
9090

9191
// TODO(iant): document
9292
Tools []*ai.ToolDefinition
@@ -224,7 +224,7 @@ func parseFrontmatter(data []byte) (name string, c Config, rest []byte, err erro
224224

225225
ret := Config{
226226
Variant: fy.Variant,
227-
Model: fy.Model,
227+
ModelName: fy.Model,
228228
Tools: fy.Tools,
229229
Candidates: fy.Candidates,
230230
GenerationConfig: fy.Config,
@@ -289,10 +289,10 @@ func Define(name, templateText string, cfg Config) (*Prompt, error) {
289289
// This may be used for testing or for direct calls not using the
290290
// genkit action and flow mechanisms.
291291
func New(name, templateText string, cfg Config) (*Prompt, error) {
292-
if cfg.Model == "" && cfg.ModelAction == nil {
292+
if cfg.ModelName == "" && cfg.Model == nil {
293293
return nil, errors.New("dotprompt.New: config must specify either Model or ModelAction")
294294
}
295-
if cfg.Model != "" && cfg.ModelAction != nil {
295+
if cfg.ModelName != "" && cfg.Model != nil {
296296
return nil, errors.New("dotprompt.New: config must specify exactly one of Model and ModelAction")
297297
}
298298
hash := fmt.Sprintf("%02x", sha256.Sum256([]byte(templateText)))

go/plugins/dotprompt/dotprompt_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ func TestPrompts(t *testing.T) {
110110
t.Fatal(err)
111111
}
112112

113-
if prompt.Model != test.model {
114-
t.Errorf("got model %q want %q", prompt.Model, test.model)
113+
if prompt.ModelName != test.model {
114+
t.Errorf("got model %q want %q", prompt.ModelName, test.model)
115115
}
116116
if diff := cmpSchema(t, prompt.InputSchema, test.input); diff != "" {
117117
t.Errorf("input schema mismatch (-want, +got):\n%s", diff)

go/plugins/dotprompt/genkit.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex
187187
genReq.Context = pr.Context
188188
}
189189

190-
model := p.ModelAction
190+
model := p.Model
191191
if model == nil {
192-
modelName := p.Model
192+
modelName := p.ModelName
193193
if pr.Model != "" {
194194
modelName = pr.Model
195195
}
@@ -207,7 +207,7 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex
207207
}
208208
}
209209

210-
resp, err := ai.Generate(ctx, model, genReq, cb)
210+
resp, err := model.Generate(ctx, genReq, cb)
211211
if err != nil {
212212
return nil, err
213213
}

go/plugins/dotprompt/genkit_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context.
4343

4444
func TestExecute(t *testing.T) {
4545
testModel := ai.DefineModel("test", "test", nil, testGenerate)
46-
p, err := New("TestExecute", "TestExecute", Config{ModelAction: testModel})
46+
p, err := New("TestExecute", "TestExecute", Config{Model: testModel})
4747
if err != nil {
4848
t.Fatal(err)
4949
}

go/plugins/googleai/googleai.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func IsKnownModel(name string) bool {
9090
// For known models, it can be nil, or if non-nil it will override the known value.
9191
// It must be supplied for unknown models.
9292
// Use [IsKnownModel] to determine if a model is known.
93-
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, error) {
93+
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
9494
state.mu.Lock()
9595
defer state.mu.Unlock()
9696
if !state.initted {
@@ -110,7 +110,7 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, erro
110110
}
111111

112112
// requires state.mu
113-
func defineModel(name string, caps ai.ModelCapabilities) *ai.ModelAction {
113+
func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
114114
meta := &ai.ModelMetadata{
115115
Label: "Google AI - " + name,
116116
Supports: caps,
@@ -147,7 +147,7 @@ func defineEmbedder(name string) *ai.EmbedderAction {
147147

148148
// Model returns the [ai.ModelAction] with the given name.
149149
// It returns nil if the model was not configured.
150-
func Model(name string) *ai.ModelAction {
150+
func Model(name string) *ai.Model {
151151
return ai.LookupModel(provider, name)
152152
}
153153

go/plugins/googleai/googleai_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func TestLive(t *testing.T) {
116116
},
117117
}
118118

119-
resp, err := ai.Generate(ctx, model, req, nil)
119+
resp, err := model.Generate(ctx, req, nil)
120120
if err != nil {
121121
t.Fatal(err)
122122
}
@@ -145,7 +145,7 @@ func TestLive(t *testing.T) {
145145

146146
out := ""
147147
parts := 0
148-
final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
148+
final, err := model.Generate(ctx, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
149149
parts++
150150
out += c.Content[0].Text
151151
return nil
@@ -184,7 +184,7 @@ func TestLive(t *testing.T) {
184184
Tools: []*ai.ToolDefinition{toolDef},
185185
}
186186

187-
resp, err := ai.Generate(ctx, model, req, nil)
187+
resp, err := model.Generate(ctx, req, nil)
188188
if err != nil {
189189
t.Fatal(err)
190190
}

go/plugins/ollama/ollama.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ func defineModel(model ModelDefinition, serverAddress string) {
4949
ai.DefineModel(provider, model.Name, meta, g.generate)
5050
}
5151

52-
// Model returns the [ai.ModelAction] with the given name.
52+
// Model returns the [ai.Model] with the given name.
5353
// It returns nil if the model was not configured.
54-
func Model(name string) *ai.ModelAction {
54+
func Model(name string) *ai.Model {
5555
return ai.LookupModel(provider, name)
5656
}
5757

go/plugins/vertexai/vertexai.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func Init(ctx context.Context, projectID, location string) error {
7272
}
7373

7474
// DefineModel defines a model with the given name.
75-
func DefineModel(name string) *ai.ModelAction {
75+
func DefineModel(name string) *ai.Model {
7676
state.mu.Lock()
7777
defer state.mu.Unlock()
7878
if !state.initted {
@@ -101,9 +101,9 @@ func DefineEmbedder(name string) *ai.EmbedderAction {
101101
})
102102
}
103103

104-
// Model returns the [ai.ModelAction] with the given name.
104+
// Model returns the [ai.Model] with the given name.
105105
// It returns nil if the model was not configured.
106-
func Model(name string) *ai.ModelAction {
106+
func Model(name string) *ai.Model {
107107
return ai.LookupModel(provider, name)
108108
}
109109

go/plugins/vertexai/vertexai_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func TestLive(t *testing.T) {
9191
},
9292
}
9393

94-
resp, err := ai.Generate(ctx, model, req, nil)
94+
resp, err := model.Generate(ctx, req, nil)
9595
if err != nil {
9696
t.Fatal(err)
9797
}
@@ -120,7 +120,7 @@ func TestLive(t *testing.T) {
120120
out := ""
121121
parts := 0
122122
model := vertexai.Model(modelName)
123-
final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
123+
final, err := model.Generate(ctx, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
124124
parts++
125125
for _, p := range c.Content {
126126
out += p.Text
@@ -162,7 +162,7 @@ func TestLive(t *testing.T) {
162162
Tools: []*ai.ToolDefinition{toolDef},
163163
}
164164

165-
resp, err := ai.Generate(ctx, model, req, nil)
165+
resp, err := model.Generate(ctx, req, nil)
166166
if err != nil {
167167
t.Fatal(err)
168168
}

0 commit comments

Comments
 (0)