Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] ai.Model is a separate type from its action #458

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import (
"github.com/firebase/genkit/go/internal/atype"
)

// A ModelAction is used to generate content from an AI model.
type ModelAction = core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]
// A Model is used to generate content from an AI model.
type Model core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]

// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *GenerateResponseChunk) error
Expand All @@ -50,7 +50,7 @@ type ModelMetadata struct {

// DefineModel registers the given generate function as an action, and returns a
// [ModelAction] that runs it.
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *ModelAction {
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *Model {
metadataMap := map[string]any{}
if metadata != nil {
if metadata.Label != "" {
Expand All @@ -64,25 +64,26 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
}
metadataMap["supports"] = supports
}
return core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
return (*Model)(core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
"model": metadataMap,
}, generate)
}, generate))
}

// LookupModel looks up a [ModelAction] registered by [DefineModel].
// It returns nil if the model was not defined.
func LookupModel(provider, name string) *ModelAction {
return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name)
func LookupModel(provider, name string) *Model {
return (*Model)(core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name))
}

// Generate applies a [ModelAction] to some input, handling tool requests.
func Generate(ctx context.Context, g *ModelAction, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) {
// Generate applies the [Model] to some input, handling tool requests.
func (m *Model) Generate(ctx context.Context, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) {
if err := conformOutput(req); err != nil {
return nil, err
}

a := (*core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk])(m)
for {
resp, err := g.Run(ctx, req, cb)
resp, err := a.Run(ctx, req, cb)
if err != nil {
return nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions go/plugins/dotprompt/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ type Config struct {
// The prompt variant.
Variant string
// The name of the model for which the prompt is input.
// If this is non-empty, ModelAction should be nil.
Model string
// If this is non-empty, Model should be nil.
ModelName string

// The ModelAction to use.
// The Model to use.
// If this is non-nil, Model should be the empty string.
ModelAction *ai.ModelAction
Model *ai.Model

// TODO(iant): document
Tools []*ai.ToolDefinition
Expand Down Expand Up @@ -224,7 +224,7 @@ func parseFrontmatter(data []byte) (name string, c Config, rest []byte, err erro

ret := Config{
Variant: fy.Variant,
Model: fy.Model,
ModelName: fy.Model,
Tools: fy.Tools,
Candidates: fy.Candidates,
GenerationConfig: fy.Config,
Expand Down Expand Up @@ -289,10 +289,10 @@ func Define(name, templateText string, cfg Config) (*Prompt, error) {
// This may be used for testing or for direct calls not using the
// genkit action and flow mechanisms.
func New(name, templateText string, cfg Config) (*Prompt, error) {
if cfg.Model == "" && cfg.ModelAction == nil {
if cfg.ModelName == "" && cfg.Model == nil {
return nil, errors.New("dotprompt.New: config must specify either Model or ModelAction")
}
if cfg.Model != "" && cfg.ModelAction != nil {
if cfg.ModelName != "" && cfg.Model != nil {
return nil, errors.New("dotprompt.New: config must specify exactly one of Model and ModelAction")
}
hash := fmt.Sprintf("%02x", sha256.Sum256([]byte(templateText)))
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/dotprompt/dotprompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ func TestPrompts(t *testing.T) {
t.Fatal(err)
}

if prompt.Model != test.model {
t.Errorf("got model %q want %q", prompt.Model, test.model)
if prompt.ModelName != test.model {
t.Errorf("got model %q want %q", prompt.ModelName, test.model)
}
if diff := cmpSchema(t, prompt.InputSchema, test.input); diff != "" {
t.Errorf("input schema mismatch (-want, +got):\n%s", diff)
Expand Down
6 changes: 3 additions & 3 deletions go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex
genReq.Context = pr.Context
}

model := p.ModelAction
model := p.Model
if model == nil {
modelName := p.Model
modelName := p.ModelName
if pr.Model != "" {
modelName = pr.Model
}
Expand All @@ -207,7 +207,7 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex
}
}

resp, err := ai.Generate(ctx, model, genReq, cb)
resp, err := model.Generate(ctx, genReq, cb)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context.

func TestExecute(t *testing.T) {
testModel := ai.DefineModel("test", "test", nil, testGenerate)
p, err := New("TestExecute", "TestExecute", Config{ModelAction: testModel})
p, err := New("TestExecute", "TestExecute", Config{Model: testModel})
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func IsKnownModel(name string) bool {
// For known models, it can be nil, or if non-nil it will override the known value.
// It must be supplied for unknown models.
// Use [IsKnownModel] to determine if a model is known.
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, error) {
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
Expand All @@ -110,7 +110,7 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, erro
}

// requires state.mu
func defineModel(name string, caps ai.ModelCapabilities) *ai.ModelAction {
func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
meta := &ai.ModelMetadata{
Label: "Google AI - " + name,
Supports: caps,
Expand Down Expand Up @@ -147,7 +147,7 @@ func defineEmbedder(name string) *ai.EmbedderAction {

// Model returns the [ai.ModelAction] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

Expand Down
6 changes: 3 additions & 3 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestLive(t *testing.T) {
},
}

resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -145,7 +145,7 @@ func TestLive(t *testing.T) {

out := ""
parts := 0
final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
final, err := model.Generate(ctx, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
parts++
out += c.Content[0].Text
return nil
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestLive(t *testing.T) {
Tools: []*ai.ToolDefinition{toolDef},
}

resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ func defineModel(model ModelDefinition, serverAddress string) {
ai.DefineModel(provider, model.Name, meta, g.generate)
}

// Model returns the [ai.ModelAction] with the given name.
// Model returns the [ai.Model] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

Expand Down
6 changes: 3 additions & 3 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func Init(ctx context.Context, projectID, location string) error {
}

// DefineModel defines a model with the given name.
func DefineModel(name string) *ai.ModelAction {
func DefineModel(name string) *ai.Model {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
Expand Down Expand Up @@ -101,9 +101,9 @@ func DefineEmbedder(name string) *ai.EmbedderAction {
})
}

// Model returns the [ai.ModelAction] with the given name.
// Model returns the [ai.Model] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

Expand Down
6 changes: 3 additions & 3 deletions go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestLive(t *testing.T) {
},
}

resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestLive(t *testing.T) {
out := ""
parts := 0
model := vertexai.Model(modelName)
final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
final, err := model.Generate(ctx, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
parts++
for _, p := range c.Content {
out += p.Text
Expand Down Expand Up @@ -162,7 +162,7 @@ func TestLive(t *testing.T) {
Tools: []*ai.ToolDefinition{toolDef},
}

resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func main() {
}
simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: r.Reflect(simpleGreetingInput{}),
OutputFormat: ai.OutputFormatText,
},
Expand Down Expand Up @@ -157,7 +157,7 @@ func main() {

greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", greetingWithHistoryPromptTemplate,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: jsonschema.Reflect(customerTimeAndHistoryInput{}),
OutputFormat: ai.OutputFormatText,
},
Expand Down Expand Up @@ -197,7 +197,7 @@ func main() {

simpleStructuredGreetingPrompt, err := dotprompt.Define("simpleStructuredGreeting", simpleStructuredGreetingPromptTemplate,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: jsonschema.Reflect(simpleGreetingInput{}),
OutputFormat: ai.OutputFormatJSON,
OutputSchema: outputSchema,
Expand Down
6 changes: 3 additions & 3 deletions go/samples/menu/s01.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import (
"github.com/firebase/genkit/go/plugins/dotprompt"
)

func setup01(ctx context.Context, g *ai.ModelAction) error {
func setup01(ctx context.Context, g *ai.Model) error {
_, err := dotprompt.Define("s01_vanillaPrompt",
`You are acting as a helpful AI assistant named "Walt" that can answer
questions about the food available on the menu at Walt's Burgers.
Customer says: ${input.question}`,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: menuQuestionInputSchema,
},
)
Expand Down Expand Up @@ -67,7 +67,7 @@ func setup01(ctx context.Context, g *ai.ModelAction) error {
Question:
{{question}} ?`,
dotprompt.Config{
ModelAction: g,
Model: g,
InputSchema: menuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
},
Expand Down
4 changes: 2 additions & 2 deletions go/samples/menu/s02.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func menu(ctx context.Context, input map[string]any) (map[string]any, error) {
return map[string]any{"menu": s}, nil
}

func setup02(ctx context.Context, m *ai.ModelAction) error {
func setup02(ctx context.Context, m *ai.Model) error {
ai.DefineTool(menuToolDef, nil, menu)

dataMenuPrompt, err := dotprompt.Define("s02_dataMenu",
Expand All @@ -61,7 +61,7 @@ func setup02(ctx context.Context, m *ai.ModelAction) error {
Question:
{{question}} ?`,
dotprompt.Config{
ModelAction: m,
Model: m,
InputSchema: menuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
Tools: []*ai.ToolDefinition{
Expand Down
6 changes: 3 additions & 3 deletions go/samples/menu/s03.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (ch *chatHistoryStore) Retrieve(sessionID string) chatHistory {
return ch.preamble
}

func setup03(ctx context.Context, model *ai.ModelAction) error {
func setup03(ctx context.Context, model *ai.Model) error {
chatPreamblePrompt, err := dotprompt.Define("s03_chatPreamble",
`
{{ role "user" }}
Expand All @@ -72,7 +72,7 @@ func setup03(ctx context.Context, model *ai.ModelAction) error {
{{~/each}}
Do you have any questions about the menu?`,
dotprompt.Config{
ModelAction: model,
Model: model,
InputSchema: dataMenuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
GenerationConfig: &ai.GenerationCommonConfig{
Expand Down Expand Up @@ -115,7 +115,7 @@ func setup03(ctx context.Context, model *ai.ModelAction) error {
req := &ai.GenerateRequest{
Messages: messages,
}
resp, err := ai.Generate(ctx, model, req, nil)
resp, err := model.Generate(ctx, req, nil)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/samples/menu/s04.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"github.com/firebase/genkit/go/plugins/localvec"
)

func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.ModelAction) error {
func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.Model) error {
ragDataMenuPrompt, err := dotprompt.Define("s04_ragDataMenu",
`
You are acting as Walt, a helpful AI assistant here at the restaurant.
Expand All @@ -41,7 +41,7 @@ func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever,
Answer this customer's question:
{{question}}?`,
dotprompt.Config{
ModelAction: model,
Model: model,
InputSchema: dataMenuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
GenerationConfig: &ai.GenerationCommonConfig{
Expand Down
6 changes: 3 additions & 3 deletions go/samples/menu/s05.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ type imageURLInput struct {
ImageURL string `json:"imageUrl"`
}

func setup05(ctx context.Context, gen, genVision *ai.ModelAction) error {
func setup05(ctx context.Context, gen, genVision *ai.Model) error {
readMenuPrompt, err := dotprompt.Define("s05_readMenu",
`
Extract _all_ of the text, in order,
from the following image of a restaurant menu.

{{media url=imageUrl}}`,
dotprompt.Config{
ModelAction: genVision,
Model: genVision,
InputSchema: jsonschema.Reflect(imageURLInput{}),
OutputFormat: ai.OutputFormatText,
GenerationConfig: &ai.GenerationCommonConfig{
Expand All @@ -62,7 +62,7 @@ func setup05(ctx context.Context, gen, genVision *ai.ModelAction) error {
{{question}}?
`,
dotprompt.Config{
ModelAction: gen,
Model: gen,
InputSchema: textMenuQuestionInputSchema,
OutputFormat: ai.OutputFormatText,
GenerationConfig: &ai.GenerationCommonConfig{
Expand Down
2 changes: 1 addition & 1 deletion go/samples/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func main() {
simpleQaPrompt, err := dotprompt.Define("simpleQaPrompt",
simpleQaPromptTemplate,
dotprompt.Config{
ModelAction: model,
Model: model,
InputSchema: jsonschema.Reflect(simpleQaPromptInput{}),
OutputFormat: ai.OutputFormatText,
},
Expand Down
Loading