Skip to content

Commit

Permalink
[Go] add coffee-shop sample (#23)
Browse files Browse the repository at this point in the history
Also add some helper methods to simplify the sample code.
  • Loading branch information
ianlancetaylor authored May 3, 2024
1 parent 5126b42 commit bb9d212
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 0 deletions.
27 changes: 27 additions & 0 deletions go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ package ai

import (
"context"
"errors"
"fmt"
"strings"

"github.com/google/genkit/go/genkit"
)
Expand Down Expand Up @@ -61,3 +63,28 @@ type generatorAction struct {
func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.NoStream) (*GenerateResponse, error) {
return ga.action.Run(ctx, input, cb)
}

// Text returns the contents of the first candidate in a
// [GenerateResponse] as a string. It returns an error if there
// are no candidates or if the candidate has no message.
func (gr *GenerateResponse) Text() (string, error) {
if len(gr.Candidates) == 0 {
return "", errors.New("no candidates returned")
}
msg := gr.Candidates[0].Message
if msg == nil {
return "", errors.New("candidate with no message")
}
if len(msg.Content) == 0 {
return "", errors.New("candidate message has no content")
}
if len(msg.Content) == 1 {
return msg.Content[0].Text(), nil
} else {
var sb strings.Builder
for _, p := range msg.Content {
sb.WriteString(p.Text())
}
return sb.String(), nil
}
}
50 changes: 50 additions & 0 deletions go/genkit/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package dotprompt
import (
"context"
"errors"
"reflect"
"strings"

"github.com/google/genkit/go/ai"
"github.com/google/genkit/go/genkit"
Expand All @@ -37,6 +39,54 @@ type ActionInput struct {
Model string `json:"model,omitempty"`
}

// BuildVariables returns a map for [ActionInput.Variables] based
// on a pointer to a struct value. The struct value should have
// JSON tags that correspond to the Prompt's input schema.
// Only exported fields of the struct will be used.
func (p *Prompt) BuildVariables(input any) (map[string]any, error) {
v := reflect.ValueOf(input).Elem()
if v.Kind() != reflect.Struct {
return nil, errors.New("BuildVariables: not a pointer to a struct")
}
vt := v.Type()

// TODO(ianlancetaylor): Verify the struct with p.Frontmatter.Schema.

m := make(map[string]any)

fieldLoop:
for i := 0; i < vt.NumField(); i++ {
ft := vt.Field(i)
if ft.PkgPath != "" {
continue
}

jsonTag := ft.Tag.Get("json")
jsonName, rest, _ := strings.Cut(jsonTag, ",")
if jsonName == "" {
jsonName = ft.Name
}

vf := v.Field(i)

// If the field is the zero value, and omitempty is set,
// don't pass it as a prompt input variable.
if vf.IsZero() {
for rest != "" {
var key string
key, rest, _ = strings.Cut(rest, ",")
if key == "omitempty" {
continue fieldLoop
}
}
}

m[jsonName] = vf.Interface()
}

return m, nil
}

// buildRequest prepares an [ai.GenerateRequest] based on the prompt,
// using the input variables and other information in the [ActionInput].
func (p *Prompt) buildRequest(input *ActionInput) (*ai.GenerateRequest, error) {
Expand Down
192 changes: 192 additions & 0 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This program can be manually tested like so:
// Start the server listening on port 3100:
//
// go run . &
//
// Tell it to run an action:
//
// curl -d '{"key":"/flow/testAllCoffeeFlows/testAllCoffeeFlows", "input":{"start": {"input":null}}}' http://localhost:3100/api/runAction
package main

import (
"context"
"fmt"
"log"
"os"

"github.com/google/genkit/go/ai"
"github.com/google/genkit/go/genkit"
"github.com/google/genkit/go/genkit/dotprompt"
"github.com/google/genkit/go/plugins/googleai"
"github.com/invopop/jsonschema"
)

const simpleGreetingPromptTemplate = `
You're a barista at a nice coffee shop.
A regular customer named {{customerName}} enters.
Greet the customer in one sentence, and recommend a coffee drink.
`

type simpleGreetingInput struct {
CustomerName string `json:"customerName"`
}

const greetingWithHistoryPromptTemplate = `
{{role "user"}}
Hi, my name is {{customerName}}. The time is {{currentTime}}. Who are you?
{{role "model"}}
I am Barb, a barista at this nice underwater-themed coffee shop called Krabby Kooffee.
I know pretty much everything there is to know about coffee,
and I can cheerfully recommend delicious coffee drinks to you based on whatever you like.
{{role "user"}}
Great. Last time I had {{previousOrder}}.
I want you to greet me in one sentence, and recommend a drink.
`

type customerTimeAndHistoryInput struct {
CustomerName string `json:"customerName"`
CurrentTime string `json:"currentTime"`
PreviousOrder string `json:"previousOrder"`
}

type testAllCoffeeFlowsOutput struct {
Pass bool `json:"pass"`
Replies []string `json:"replies,omitempty"`
Error string `json:"error,omitempty"`
}

func main() {
apiKey := os.Getenv("GEMINI_API_KEY")
if apiKey == "" {
fmt.Fprintln(os.Stderr, "coffee-shop example requires setting GEMINI_API_KEY in the environment.")
fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.")
os.Exit(1)
}

if err := googleai.Init(context.Background(), "gemini-pro", apiKey); err != nil {
log.Fatal(err)
}

simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting",
&dotprompt.Frontmatter{
Name: "simpleGreeting",
Model: "google-genai",
Input: dotprompt.FrontmatterInput{
Schema: jsonschema.Reflect(simpleGreetingInput{}),
},
Output: &ai.GenerateRequestOutput{
Format: ai.OutputFormatText,
},
},
simpleGreetingPromptTemplate,
nil,
)
if err != nil {
log.Fatal(err)
}

simpleGreetingFlow := genkit.DefineFlow("simpleGreeting", func(ctx context.Context, input *simpleGreetingInput, _ genkit.NoStream) (string, error) {
vars, err := simpleGreetingPrompt.BuildVariables(input)
if err != nil {
return "", err
}
ai := &dotprompt.ActionInput{ Variables: vars }
resp, err := simpleGreetingPrompt.Execute(ctx, ai)
if err != nil {
return "", err
}
text, err := resp.Text()
if err != nil {
return "", fmt.Errorf("simpleGreeting: %v", err)
}
return text, nil
})

greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory",
&dotprompt.Frontmatter{
Name: "greetingWithHistory",
Model: "google-genai",
Input: dotprompt.FrontmatterInput{
Schema: jsonschema.Reflect(customerTimeAndHistoryInput{}),
},
Output: &ai.GenerateRequestOutput{
Format: ai.OutputFormatText,
},
},
greetingWithHistoryPromptTemplate,
nil,
)
if err != nil {
log.Fatal(err)
}

greetingWithHistoryFlow := genkit.DefineFlow("greetingWithHistory", func(ctx context.Context, input *customerTimeAndHistoryInput, _ genkit.NoStream) (string, error) {
vars, err := greetingWithHistoryPrompt.BuildVariables(input)
if err != nil {
return "", err
}
ai := &dotprompt.ActionInput{ Variables: vars }
resp, err := greetingWithHistoryPrompt.Execute(ctx, ai)
if err != nil {
return "", err
}
text, err := resp.Text()
if err != nil {
return "", fmt.Errorf("greetingWithHistory: %v", err)
}
return text, nil
})

genkit.DefineFlow("testAllCoffeeFlows", func(ctx context.Context, _ struct{}, _ genkit.NoStream) (*testAllCoffeeFlowsOutput, error) {
test1, err := genkit.RunFlow(ctx, simpleGreetingFlow, &simpleGreetingInput{
CustomerName: "Sam",
})
if err != nil {
out := &testAllCoffeeFlowsOutput{
Pass: false,
Error: err.Error(),
}
return out, nil
}
test2, err := genkit.RunFlow(ctx, greetingWithHistoryFlow, &customerTimeAndHistoryInput{
CustomerName: "Sam",
CurrentTime: "09:45am",
PreviousOrder: "Caramel Macchiato",
})
if err != nil {
out := &testAllCoffeeFlowsOutput{
Pass: false,
Error: err.Error(),
}
return out, nil
}
out := &testAllCoffeeFlowsOutput{
Pass: true,
Replies: []string{
test1,
test2,
},
}
return out, nil
})

if err := genkit.StartDevServer(""); err != nil {
log.Fatal(err)
}
}

0 comments on commit bb9d212

Please sign in to comment.