Skip to content

Commit

Permalink
[Go] ai.Model is a separate type from its action
Browse files Browse the repository at this point in the history
This continues the work begun in #402 of making the main ai types
distinct from their underlying actions, instead of aliases.

This allows the types to have methods, unstead of using top-level
functions.

It also clarifies documentation and other output, like panic stack
traces.
  • Loading branch information
jba committed Jun 24, 2024
1 parent 442ca56 commit 329c99f
Show file tree
Hide file tree
Showing 17 changed files with 55 additions and 54 deletions.
21 changes: 11 additions & 10 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import (
"github.com/firebase/genkit/go/internal/atype"
)

// A ModelAction is used to generate content from an AI model.
type ModelAction = core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]
// A Model is used to generate content from an AI model.
type Model core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]

// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *GenerateResponseChunk) error
Expand All @@ -50,7 +50,7 @@ type ModelMetadata struct {

// DefineModel registers the given generate function as an action, and returns a
// [ModelAction] that runs it.
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *ModelAction {
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *Model {
metadataMap := map[string]any{}
if metadata != nil {
if metadata.Label != "" {
Expand All @@ -64,25 +64,26 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
}
metadataMap["supports"] = supports
}
return core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
return (*Model)(core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
"model": metadataMap,
}, generate)
}, generate))
}

// LookupModel looks up a [ModelAction] registered by [DefineModel].
// It returns nil if the model was not defined.
func LookupModel(provider, name string) *ModelAction {
return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name)
func LookupModel(provider, name string) *Model {
return (*Model)(core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name))
}

// Generate applies a [ModelAction] to some input, handling tool requests.
func Generate(ctx context.Context, g *ModelAction, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) {
// Generate applies the [Model] to some input, handling tool requests.
func (m *Model) Generate(ctx context.Context, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) {
if err := conformOutput(req); err != nil {
return nil, err
}

a := (*core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk])(m)
for {
resp, err := g.Run(ctx, req, cb)
resp, err := a.Run(ctx, req, cb)
if err != nil {
return nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions go/plugins/dotprompt/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ type Config struct {
// The prompt variant.
Variant string
// The name of the model for which the prompt is input.
// If this is non-empty, ModelAction should be nil.
Model string
// If this is non-empty, Model should be nil.
ModelName string

// The ModelAction to use.
// The Model to use.
// If this is non-nil, Model should be the empty string.
ModelAction *ai.ModelAction
Model *ai.Model

// TODO(iant): document
Tools []*ai.ToolDefinition
Expand Down Expand Up @@ -224,7 +224,7 @@ func parseFrontmatter(data []byte) (name string, c Config, rest []byte, err erro

ret := Config{
Variant: fy.Variant,
Model: fy.Model,
ModelName: fy.Model,
Tools: fy.Tools,
Candidates: fy.Candidates,
GenerationConfig: fy.Config,
Expand Down Expand Up @@ -289,10 +289,10 @@ func Define(name, templateText string, cfg Config) (*Prompt, error) {
// This may be used for testing or for direct calls not using the
// genkit action and flow mechanisms.
func New(name, templateText string, cfg Config) (*Prompt, error) {
if cfg.Model == "" && cfg.ModelAction == nil {
if cfg.ModelName == "" && cfg.Model == nil {
return nil, errors.New("dotprompt.New: config must specify either Model or ModelAction")
}
if cfg.Model != "" && cfg.ModelAction != nil {
if cfg.ModelName != "" && cfg.Model != nil {
return nil, errors.New("dotprompt.New: config must specify exactly one of Model and ModelAction")
}
hash := fmt.Sprintf("%02x", sha256.Sum256([]byte(templateText)))
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/dotprompt/dotprompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ func TestPrompts(t *testing.T) {
t.Fatal(err)
}

if prompt.Model != test.model {
t.Errorf("got model %q want %q", prompt.Model, test.model)
if prompt.ModelName != test.model {
t.Errorf("got model %q want %q", prompt.ModelName, test.model)
}
if diff := cmpSchema(t, prompt.InputSchema, test.input); diff != "" {
t.Errorf("input schema mismatch (-want, +got):\n%s", diff)
Expand Down
6 changes: 3 additions & 3 deletions go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex
genReq.Context = pr.Context
}

model := p.ModelAction
model := p.Model
if model == nil {
modelName := p.Model
modelName := p.ModelName
if pr.Model != "" {
modelName = pr.Model
}
Expand All @@ -207,7 +207,7 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex
}
}

resp, err := ai.Generate(ctx, model, genReq, cb)
resp, err := model.Generate(ctx, genReq, cb)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context.

func TestExecute(t *testing.T) {
testModel := ai.DefineModel("test", "test", nil, testGenerate)
p, err := New("TestExecute", "TestExecute", Config{ModelAction: testModel})
p, err := New("TestExecute", "TestExecute", Config{Model: testModel})
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func IsKnownModel(name string) bool {
// For known models, it can be nil, or if non-nil it will override the known value.
// It must be supplied for unknown models.
// Use [IsKnownModel] to determine if a model is known.
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, error) {
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
Expand All @@ -110,7 +110,7 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, erro
}

// requires state.mu
func defineModel(name string, caps ai.ModelCapabilities) *ai.ModelAction {
func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
meta := &ai.ModelMetadata{
Label: "Google AI - " + name,
Supports: caps,
Expand Down Expand Up @@ -147,7 +147,7 @@ func defineEmbedder(name string) *ai.EmbedderAction {

// Model returns the [ai.ModelAction] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

Expand Down
6 changes: 3 additions & 3 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestLive(t *testing.T) {
},
}

resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -145,7 +145,7 @@ func TestLive(t *testing.T) {

out := ""
parts := 0
final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
final, err := model.Generate(ctx, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
parts++
out += c.Content[0].Text
return nil
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestLive(t *testing.T) {
Tools: []*ai.ToolDefinition{toolDef},
}

resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ func defineModel(model ModelDefinition, serverAddress string) {
ai.DefineModel(provider, model.Name, meta, g.generate)
}

// Model returns the [ai.ModelAction] with the given name.
// Model returns the [ai.Model] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

Expand Down
6 changes: 3 additions & 3 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func Init(ctx context.Context, projectID, location string) error {
}

// DefineModel defines a model with the given name.
func DefineModel(name string) *ai.ModelAction {
func DefineModel(name string) *ai.Model {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
Expand Down Expand Up @@ -101,9 +101,9 @@ func DefineEmbedder(name string) *ai.EmbedderAction {
})
}

// Model returns the [ai.ModelAction] with the given name.
// Model returns the [ai.Model] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

Expand Down
6 changes: 3 additions & 3 deletions go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestLive(t *testing.T) {
},
}

resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestLive(t *testing.T) {
out := ""
parts := 0
model := vertexai.Model(modelName)
final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
final, err := model.Generate(ctx, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
parts++
for _, p := range c.Content {
out += p.Text
Expand Down Expand Up @@ -162,7 +162,7 @@ func TestLive(t *testing.T) {
Tools: []*ai.ToolDefinition{toolDef},
}

resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func main() {
}
simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: r.Reflect(simpleGreetingInput{}),
OutputFormat: ai.OutputFormatText,
},
Expand Down Expand Up @@ -157,7 +157,7 @@ func main() {

greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", greetingWithHistoryPromptTemplate,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: jsonschema.Reflect(customerTimeAndHistoryInput{}),
OutputFormat: ai.OutputFormatText,
},
Expand Down Expand Up @@ -197,7 +197,7 @@ func main() {

simpleStructuredGreetingPrompt, err := dotprompt.Define("simpleStructuredGreeting", simpleStructuredGreetingPromptTemplate,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: jsonschema.Reflect(simpleGreetingInput{}),
OutputFormat: ai.OutputFormatJSON,
OutputSchema: outputSchema,
Expand Down
6 changes: 3 additions & 3 deletions go/samples/menu/s01.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import (
"github.com/firebase/genkit/go/plugins/dotprompt"
)

func setup01(ctx context.Context, g *ai.ModelAction) error {
func setup01(ctx context.Context, g *ai.Model) error {
_, err := dotprompt.Define("s01_vanillaPrompt",
`You are acting as a helpful AI assistant named "Walt" that can answer
questions about the food available on the menu at Walt's Burgers.
Customer says: ${input.question}`,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: menuQuestionInputSchema,
},
)
Expand Down Expand Up @@ -67,7 +67,7 @@ func setup01(ctx context.Context, g *ai.ModelAction) error {
Question:
{{question}} ?`,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: menuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
},
Expand Down
4 changes: 2 additions & 2 deletions go/samples/menu/s02.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func menu(ctx context.Context, input map[string]any) (map[string]any, error) {
return map[string]any{"menu": s}, nil
}

func setup02(ctx context.Context, m *ai.ModelAction) error {
func setup02(ctx context.Context, m *ai.Model) error {
ai.DefineTool(menuToolDef, nil, menu)

dataMenuPrompt, err := dotprompt.Define("s02_dataMenu",
Expand All @@ -61,7 +61,7 @@ func setup02(ctx context.Context, m *ai.ModelAction) error {
Question:
{{question}} ?`,
dotprompt.Config{
ModelAction: m,
Model: m,
InputSchema: menuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
Tools: []*ai.ToolDefinition{
Expand Down
6 changes: 3 additions & 3 deletions go/samples/menu/s03.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (ch *chatHistoryStore) Retrieve(sessionID string) chatHistory {
return ch.preamble
}

func setup03(ctx context.Context, model *ai.ModelAction) error {
func setup03(ctx context.Context, model *ai.Model) error {
chatPreamblePrompt, err := dotprompt.Define("s03_chatPreamble",
`
{{ role "user" }}
Expand All @@ -72,7 +72,7 @@ func setup03(ctx context.Context, model *ai.ModelAction) error {
{{~/each}}
Do you have any questions about the menu?`,
dotprompt.Config{
ModelAction: model,
Model: model,
InputSchema: dataMenuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
GenerationConfig: &ai.GenerationCommonConfig{
Expand Down Expand Up @@ -115,7 +115,7 @@ func setup03(ctx context.Context, model *ai.ModelAction) error {
req := &ai.GenerateRequest{
Messages: messages,
}
resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/samples/menu/s04.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"github.com/firebase/genkit/go/plugins/localvec"
)

func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.ModelAction) error {
func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.Model) error {
ragDataMenuPrompt, err := dotprompt.Define("s04_ragDataMenu",
`
You are acting as Walt, a helpful AI assistant here at the restaurant.
Expand All @@ -41,7 +41,7 @@ func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever,
Answer this customer's question:
{{question}}?`,
dotprompt.Config{
ModelAction: model,
Model: model,
InputSchema: dataMenuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
GenerationConfig: &ai.GenerationCommonConfig{
Expand Down
6 changes: 3 additions & 3 deletions go/samples/menu/s05.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ type imageURLInput struct {
ImageURL string `json:"imageUrl"`
}

func setup05(ctx context.Context, gen, genVision *ai.ModelAction) error {
func setup05(ctx context.Context, gen, genVision *ai.Model) error {
readMenuPrompt, err := dotprompt.Define("s05_readMenu",
`
Extract _all_ of the text, in order,
from the following image of a restaurant menu.
{{media url=imageUrl}}`,
dotprompt.Config{
ModelAction: genVision,
Model: genVision,
InputSchema: jsonschema.Reflect(imageURLInput{}),
OutputFormat: ai.OutputFormatText,
GenerationConfig: &ai.GenerationCommonConfig{
Expand All @@ -62,7 +62,7 @@ func setup05(ctx context.Context, gen, genVision *ai.ModelAction) error {
{{question}}?
`,
dotprompt.Config{
ModelAction: gen,
Model: gen,
InputSchema: textMenuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
GenerationConfig: &ai.GenerationCommonConfig{
Expand Down
2 changes: 1 addition & 1 deletion go/samples/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func main() {
simpleQaPrompt, err := dotprompt.Define("simpleQaPrompt",
simpleQaPromptTemplate,
dotprompt.Config{
ModelAction: model,
Model: model,
InputSchema: jsonschema.Reflect(simpleQaPromptInput{}),
OutputFormat: ai.OutputFormatText,
},
Expand Down

0 comments on commit 329c99f

Please sign in to comment.