diff --git a/go/ai/generate.go b/go/ai/generate.go index 7d7c958fc..0f2136f5c 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -28,8 +28,8 @@ import ( "github.com/firebase/genkit/go/internal/atype" ) -// A ModelAction is used to generate content from an AI model. -type ModelAction = core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk] +// A Model is used to generate content from an AI model. +type Model core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk] // ModelStreamingCallback is the type for the streaming callback of a model. type ModelStreamingCallback = func(context.Context, *GenerateResponseChunk) error @@ -50,7 +50,7 @@ type ModelMetadata struct { // DefineModel registers the given generate function as an action, and returns a // [ModelAction] that runs it. -func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *ModelAction { +func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *Model { metadataMap := map[string]any{} if metadata != nil { if metadata.Label != "" { @@ -64,25 +64,26 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c } metadataMap["supports"] = supports } - return core.DefineStreamingAction(provider, name, atype.Model, map[string]any{ + return (*Model)(core.DefineStreamingAction(provider, name, atype.Model, map[string]any{ "model": metadataMap, - }, generate) + }, generate)) } // LookupModel looks up a [ModelAction] registered by [DefineModel]. // It returns nil if the model was not defined. -func LookupModel(provider, name string) *ModelAction { - return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name) +func LookupModel(provider, name string) *Model { + return (*Model)(core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name)) } -// Generate applies a [ModelAction] to some input, handling tool requests. -func Generate(ctx context.Context, g *ModelAction, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) { +// Generate applies the [Model] to some input, handling tool requests. +func (m *Model) Generate(ctx context.Context, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) { if err := conformOutput(req); err != nil { return nil, err } + a := (*core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk])(m) for { - resp, err := g.Run(ctx, req, cb) + resp, err := a.Run(ctx, req, cb) if err != nil { return nil, err } diff --git a/go/plugins/dotprompt/dotprompt.go b/go/plugins/dotprompt/dotprompt.go index 8c1b5385f..179c0fea0 100644 --- a/go/plugins/dotprompt/dotprompt.go +++ b/go/plugins/dotprompt/dotprompt.go @@ -81,12 +81,12 @@ type Config struct { // The prompt variant. Variant string // The name of the model for which the prompt is input. - // If this is non-empty, ModelAction should be nil. - Model string + // If this is non-empty, Model should be nil. + ModelName string - // The ModelAction to use. + // The Model to use. // If this is non-nil, Model should be the empty string. - ModelAction *ai.ModelAction + Model *ai.Model // TODO(iant): document Tools []*ai.ToolDefinition @@ -224,7 +224,7 @@ func parseFrontmatter(data []byte) (name string, c Config, rest []byte, err erro ret := Config{ Variant: fy.Variant, - Model: fy.Model, + ModelName: fy.Model, Tools: fy.Tools, Candidates: fy.Candidates, GenerationConfig: fy.Config, @@ -289,10 +289,10 @@ func Define(name, templateText string, cfg Config) (*Prompt, error) { // This may be used for testing or for direct calls not using the // genkit action and flow mechanisms. func New(name, templateText string, cfg Config) (*Prompt, error) { - if cfg.Model == "" && cfg.ModelAction == nil { + if cfg.ModelName == "" && cfg.Model == nil { return nil, errors.New("dotprompt.New: config must specify either Model or ModelAction") } - if cfg.Model != "" && cfg.ModelAction != nil { + if cfg.ModelName != "" && cfg.Model != nil { return nil, errors.New("dotprompt.New: config must specify exactly one of Model and ModelAction") } hash := fmt.Sprintf("%02x", sha256.Sum256([]byte(templateText))) diff --git a/go/plugins/dotprompt/dotprompt_test.go b/go/plugins/dotprompt/dotprompt_test.go index e25654d33..481ce3443 100644 --- a/go/plugins/dotprompt/dotprompt_test.go +++ b/go/plugins/dotprompt/dotprompt_test.go @@ -110,8 +110,8 @@ func TestPrompts(t *testing.T) { t.Fatal(err) } - if prompt.Model != test.model { - t.Errorf("got model %q want %q", prompt.Model, test.model) + if prompt.ModelName != test.model { + t.Errorf("got model %q want %q", prompt.ModelName, test.model) } if diff := cmpSchema(t, prompt.InputSchema, test.input); diff != "" { t.Errorf("input schema mismatch (-want, +got):\n%s", diff) diff --git a/go/plugins/dotprompt/genkit.go b/go/plugins/dotprompt/genkit.go index 6d20c67c4..2a924eddb 100644 --- a/go/plugins/dotprompt/genkit.go +++ b/go/plugins/dotprompt/genkit.go @@ -187,9 +187,9 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex genReq.Context = pr.Context } - model := p.ModelAction + model := p.Model if model == nil { - modelName := p.Model + modelName := p.ModelName if pr.Model != "" { modelName = pr.Model } @@ -207,7 +207,7 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex } } - resp, err := ai.Generate(ctx, model, genReq, cb) + resp, err := model.Generate(ctx, genReq, cb) if err != nil { return nil, err } diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index 0890d8d50..71bb3bb22 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -43,7 +43,7 @@ func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context. func TestExecute(t *testing.T) { testModel := ai.DefineModel("test", "test", nil, testGenerate) - p, err := New("TestExecute", "TestExecute", Config{ModelAction: testModel}) + p, err := New("TestExecute", "TestExecute", Config{Model: testModel}) if err != nil { t.Fatal(err) } diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 07c634fa7..04b5b30c2 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -90,7 +90,7 @@ func IsKnownModel(name string) bool { // For known models, it can be nil, or if non-nil it will override the known value. // It must be supplied for unknown models. // Use [IsKnownModel] to determine if a model is known. -func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, error) { +func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { state.mu.Lock() defer state.mu.Unlock() if !state.initted { @@ -110,7 +110,7 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, erro } // requires state.mu -func defineModel(name string, caps ai.ModelCapabilities) *ai.ModelAction { +func defineModel(name string, caps ai.ModelCapabilities) *ai.Model { meta := &ai.ModelMetadata{ Label: "Google AI - " + name, Supports: caps, @@ -147,7 +147,7 @@ func defineEmbedder(name string) *ai.EmbedderAction { // Model returns the [ai.ModelAction] with the given name. // It returns nil if the model was not configured. -func Model(name string) *ai.ModelAction { +func Model(name string) *ai.Model { return ai.LookupModel(provider, name) } diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 4e8faf72f..5998bd5cb 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -116,7 +116,7 @@ func TestLive(t *testing.T) { }, } - resp, err := ai.Generate(ctx, model, req, nil) + resp, err := model.Generate(ctx, req, nil) if err != nil { t.Fatal(err) } @@ -145,7 +145,7 @@ func TestLive(t *testing.T) { out := "" parts := 0 - final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error { + final, err := model.Generate(ctx, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error { parts++ out += c.Content[0].Text return nil @@ -184,7 +184,7 @@ func TestLive(t *testing.T) { Tools: []*ai.ToolDefinition{toolDef}, } - resp, err := ai.Generate(ctx, model, req, nil) + resp, err := model.Generate(ctx, req, nil) if err != nil { t.Fatal(err) } diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index 5cb4c5afb..77a9ce291 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -49,9 +49,9 @@ func defineModel(model ModelDefinition, serverAddress string) { ai.DefineModel(provider, model.Name, meta, g.generate) } -// Model returns the [ai.ModelAction] with the given name. +// Model returns the [ai.Model] with the given name. // It returns nil if the model was not configured. -func Model(name string) *ai.ModelAction { +func Model(name string) *ai.Model { return ai.LookupModel(provider, name) } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index b0436734e..d9e9026f4 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -72,7 +72,7 @@ func Init(ctx context.Context, projectID, location string) error { } // DefineModel defines a model with the given name. -func DefineModel(name string) *ai.ModelAction { +func DefineModel(name string) *ai.Model { state.mu.Lock() defer state.mu.Unlock() if !state.initted { @@ -101,9 +101,9 @@ func DefineEmbedder(name string) *ai.EmbedderAction { }) } -// Model returns the [ai.ModelAction] with the given name. +// Model returns the [ai.Model] with the given name. // It returns nil if the model was not configured. -func Model(name string) *ai.ModelAction { +func Model(name string) *ai.Model { return ai.LookupModel(provider, name) } diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index b9126bb95..c26f10606 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -91,7 +91,7 @@ func TestLive(t *testing.T) { }, } - resp, err := ai.Generate(ctx, model, req, nil) + resp, err := model.Generate(ctx, req, nil) if err != nil { t.Fatal(err) } @@ -120,7 +120,7 @@ func TestLive(t *testing.T) { out := "" parts := 0 model := vertexai.Model(modelName) - final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error { + final, err := model.Generate(ctx, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error { parts++ for _, p := range c.Content { out += p.Text @@ -162,7 +162,7 @@ func TestLive(t *testing.T) { Tools: []*ai.ToolDefinition{toolDef}, } - resp, err := ai.Generate(ctx, model, req, nil) + resp, err := model.Generate(ctx, req, nil) if err != nil { t.Fatal(err) } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 30dffd63c..4fc0d2151 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -119,7 +119,7 @@ func main() { } simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate, dotprompt.Config{ - ModelAction: g, + Model: g, InputSchema: r.Reflect(simpleGreetingInput{}), OutputFormat: ai.OutputFormatText, }, @@ -157,7 +157,7 @@ func main() { greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", greetingWithHistoryPromptTemplate, dotprompt.Config{ - ModelAction: g, + Model: g, InputSchema: jsonschema.Reflect(customerTimeAndHistoryInput{}), OutputFormat: ai.OutputFormatText, }, @@ -197,7 +197,7 @@ func main() { simpleStructuredGreetingPrompt, err := dotprompt.Define("simpleStructuredGreeting", simpleStructuredGreetingPromptTemplate, dotprompt.Config{ - ModelAction: g, + Model: g, InputSchema: jsonschema.Reflect(simpleGreetingInput{}), OutputFormat: ai.OutputFormatJSON, OutputSchema: outputSchema, diff --git a/go/samples/menu/s01.go b/go/samples/menu/s01.go index f86894e4a..3c31360cf 100644 --- a/go/samples/menu/s01.go +++ b/go/samples/menu/s01.go @@ -21,13 +21,13 @@ import ( "github.com/firebase/genkit/go/plugins/dotprompt" ) -func setup01(ctx context.Context, g *ai.ModelAction) error { +func setup01(ctx context.Context, g *ai.Model) error { _, err := dotprompt.Define("s01_vanillaPrompt", `You are acting as a helpful AI assistant named "Walt" that can answer questions about the food available on the menu at Walt's Burgers. Customer says: ${input.question}`, dotprompt.Config{ - ModelAction: g, + Model: g, InputSchema: menuQuestionInputSchema, }, ) @@ -67,7 +67,7 @@ func setup01(ctx context.Context, g *ai.ModelAction) error { Question: {{question}} ?`, dotprompt.Config{ - ModelAction: g, + Model: g, InputSchema: menuQuestionInputSchema, OutputFormat: ai.OutputFormatText, }, diff --git a/go/samples/menu/s02.go b/go/samples/menu/s02.go index eaa389f79..b62e5f970 100644 --- a/go/samples/menu/s02.go +++ b/go/samples/menu/s02.go @@ -46,7 +46,7 @@ func menu(ctx context.Context, input map[string]any) (map[string]any, error) { return map[string]any{"menu": s}, nil } -func setup02(ctx context.Context, m *ai.ModelAction) error { +func setup02(ctx context.Context, m *ai.Model) error { ai.DefineTool(menuToolDef, nil, menu) dataMenuPrompt, err := dotprompt.Define("s02_dataMenu", @@ -61,7 +61,7 @@ func setup02(ctx context.Context, m *ai.ModelAction) error { Question: {{question}} ?`, dotprompt.Config{ - ModelAction: m, + Model: m, InputSchema: menuQuestionInputSchema, OutputFormat: ai.OutputFormatText, Tools: []*ai.ToolDefinition{ diff --git a/go/samples/menu/s03.go b/go/samples/menu/s03.go index d8b95bc58..9057105a9 100644 --- a/go/samples/menu/s03.go +++ b/go/samples/menu/s03.go @@ -55,7 +55,7 @@ func (ch *chatHistoryStore) Retrieve(sessionID string) chatHistory { return ch.preamble } -func setup03(ctx context.Context, model *ai.ModelAction) error { +func setup03(ctx context.Context, model *ai.Model) error { chatPreamblePrompt, err := dotprompt.Define("s03_chatPreamble", ` {{ role "user" }} @@ -72,7 +72,7 @@ func setup03(ctx context.Context, model *ai.ModelAction) error { {{~/each}} Do you have any questions about the menu?`, dotprompt.Config{ - ModelAction: model, + Model: model, InputSchema: dataMenuQuestionInputSchema, OutputFormat: ai.OutputFormatText, GenerationConfig: &ai.GenerationCommonConfig{ @@ -115,7 +115,7 @@ func setup03(ctx context.Context, model *ai.ModelAction) error { req := &ai.GenerateRequest{ Messages: messages, } - resp, err := ai.Generate(ctx, model, req, nil) + resp, err := model.Generate(ctx, req, nil) if err != nil { return nil, err } diff --git a/go/samples/menu/s04.go b/go/samples/menu/s04.go index 6b3beff8a..f860b0f49 100644 --- a/go/samples/menu/s04.go +++ b/go/samples/menu/s04.go @@ -24,7 +24,7 @@ import ( "github.com/firebase/genkit/go/plugins/localvec" ) -func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.ModelAction) error { +func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.Model) error { ragDataMenuPrompt, err := dotprompt.Define("s04_ragDataMenu", ` You are acting as Walt, a helpful AI assistant here at the restaurant. @@ -41,7 +41,7 @@ func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, Answer this customer's question: {{question}}?`, dotprompt.Config{ - ModelAction: model, + Model: model, InputSchema: dataMenuQuestionInputSchema, OutputFormat: ai.OutputFormatText, GenerationConfig: &ai.GenerationCommonConfig{ diff --git a/go/samples/menu/s05.go b/go/samples/menu/s05.go index 10601b98b..d124f44ea 100644 --- a/go/samples/menu/s05.go +++ b/go/samples/menu/s05.go @@ -29,7 +29,7 @@ type imageURLInput struct { ImageURL string `json:"imageUrl"` } -func setup05(ctx context.Context, gen, genVision *ai.ModelAction) error { +func setup05(ctx context.Context, gen, genVision *ai.Model) error { readMenuPrompt, err := dotprompt.Define("s05_readMenu", ` Extract _all_ of the text, in order, @@ -37,7 +37,7 @@ func setup05(ctx context.Context, gen, genVision *ai.ModelAction) error { {{media url=imageUrl}}`, dotprompt.Config{ - ModelAction: genVision, + Model: genVision, InputSchema: jsonschema.Reflect(imageURLInput{}), OutputFormat: ai.OutputFormatText, GenerationConfig: &ai.GenerationCommonConfig{ @@ -62,7 +62,7 @@ func setup05(ctx context.Context, gen, genVision *ai.ModelAction) error { {{question}}? `, dotprompt.Config{ - ModelAction: gen, + Model: gen, InputSchema: textMenuQuestionInputSchema, OutputFormat: ai.OutputFormatText, GenerationConfig: &ai.GenerationCommonConfig{ diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index e833e1e9b..0d18bb529 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -95,7 +95,7 @@ func main() { simpleQaPrompt, err := dotprompt.Define("simpleQaPrompt", simpleQaPromptTemplate, dotprompt.Config{ - ModelAction: model, + Model: model, InputSchema: jsonschema.Reflect(simpleQaPromptInput{}), OutputFormat: ai.OutputFormatText, },