Skip to content

Commit

Permalink
genai: fix tool-using example and test (#163)
Browse files Browse the repository at this point in the history
The model/API is sensitive to our weather example and emit 400s occasionally.

Reuse the movies example from the official docs (https://ai.google.dev/gemini-api/docs/function-calling) instead

Fixes #160
  • Loading branch information
eliben committed Jul 9, 2024
1 parent f043fb2 commit e5d888e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 61 deletions.
37 changes: 19 additions & 18 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
const (
defaultModel = "gemini-1.0-pro"
modelForVision = "gemini-1.5-flash"
modelForTools = "gemini-1.5-pro-latest"
imageFile = "personWorkingOnComputer.jpg"
)

Expand Down Expand Up @@ -332,24 +333,24 @@ func TestLive(t *testing.T) {
})

t.Run("tools", func(t *testing.T) {
weatherChat := func(t *testing.T, s *Schema, fcm FunctionCallingMode) {
weatherTool := &Tool{
movieChat := func(t *testing.T, s *Schema, fcm FunctionCallingMode) {
movieTool := &Tool{
FunctionDeclarations: []*FunctionDeclaration{{
Name: "CurrentWeather",
Description: "Get the current weather in a given location",
Name: "find_theaters",
Description: "find theaters based on location and optionally movie title which is currently playing in theaters",
Parameters: s,
}},
}
model := client.GenerativeModel(defaultModel)
model := client.GenerativeModel(modelForTools)
model.SetTemperature(0)
model.Tools = []*Tool{weatherTool}
model.Tools = []*Tool{movieTool}
model.ToolConfig = &ToolConfig{
FunctionCallingConfig: &FunctionCallingConfig{
Mode: fcm,
},
}
session := model.StartChat()
res, err := session.SendMessage(ctx, Text("What is the weather like in New York?"))
res, err := session.SendMessage(ctx, Text("Which theaters in Mountain View show Barbie movie?"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -364,46 +365,46 @@ func TestLive(t *testing.T) {
t.Fatalf("got %d FunctionCalls, want 1", len(funcalls))
}
funcall := funcalls[0]
if g, w := funcall.Name, weatherTool.FunctionDeclarations[0].Name; g != w {
if g, w := funcall.Name, movieTool.FunctionDeclarations[0].Name; g != w {
t.Errorf("FunctionCall.Name: got %q, want %q", g, w)
}
locArg, ok := funcall.Args["location"].(string)
if !ok {
t.Fatal(`funcall.Args["location"] is not a string`)
}
if c := "New York"; !strings.Contains(locArg, c) {
if c := "Mountain View"; !strings.Contains(locArg, c) {
t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, locArg, c)
}
res, err = session.SendMessage(ctx, FunctionResponse{
Name: weatherTool.FunctionDeclarations[0].Name,
Name: movieTool.FunctionDeclarations[0].Name,
Response: map[string]any{
"weather_there": "cold",
"theater": "AMC16",
},
})
if err != nil {
t.Fatal(err)
}
checkMatch(t, responseString(res), "(it's|it is|weather) .*cold")
checkMatch(t, responseString(res), "AMC")
}
schema := &Schema{
Type: TypeObject,
Properties: map[string]*Schema{
"location": {
Type: TypeString,
Description: "The city and state, e.g. San Francisco, CA",
Description: "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616",
},
"unit": {
Type: TypeString,
Enum: []string{"celsius", "fahrenheit"},
"title": {
Type: TypeString,
Description: "Any movie title",
},
},
Required: []string{"location"},
}
t.Run("direct", func(t *testing.T) {
weatherChat(t, schema, FunctionCallingAuto)
movieChat(t, schema, FunctionCallingAuto)
})
t.Run("none", func(t *testing.T) {
weatherChat(t, schema, FunctionCallingNone)
movieChat(t, schema, FunctionCallingNone)
})
})

Expand Down
62 changes: 19 additions & 43 deletions genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,50 +364,36 @@ func ExampleTool() {
}
defer client.Close()

currentWeather := func(city string) string {
switch city {
case "New York, NY":
return "cold"
case "Miami, FL":
return "warm"
default:
return "unknown"
}
}

// To use functions / tools, we have to first define a schema that describes
// the function to the model. The schema is similar to OpenAPI 3.0.
//
// In this example, we create a single function that provides the model with
// a weather forecast in a given location.
schema := &genai.Schema{
Type: genai.TypeObject,
Properties: map[string]*genai.Schema{
"location": {
Type: genai.TypeString,
Description: "The city and state, for example 'San Francisco, CA'. Both city and state are mandatory",
Description: "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616",
},
"unit": {
Type: genai.TypeString,
Enum: []string{"celsius", "fahrenheit"},
"title": {
Type: genai.TypeString,
Description: "Any movie title",
},
},
Required: []string{"location"},
}

weatherTool := &genai.Tool{
movieTool := &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{{
Name: "CurrentWeather",
Description: "Get the current weather in a given location",
Name: "find_theaters",
Description: "find theaters based on location and optionally movie title which is currently playing in theaters",
Parameters: schema,
}},
}

model := client.GenerativeModel("gemini-1.5-flash-latest")
model := client.GenerativeModel("gemini-1.5-pro-latest")

// Before initiating a conversation, we tell the model which tools it has
// at its disposal.
model.Tools = []*genai.Tool{weatherTool}
model.Tools = []*genai.Tool{movieTool}

// For using tools, the chat mode is useful because it provides the required
// chat context. A model needs to have tools supplied to it in the chat
Expand All @@ -417,45 +403,35 @@ func ExampleTool() {
//
// 1. We send a question to the model
// 2. The model recognizes that it needs to use a tool to answer the question,
// an returns a FunctionCall response asking to use the CurrentWeather
// tool.
// an returns a FunctionCall response asking to use the tool.
// 3. We send a FunctionResponse message, simulating the return value of
// CurrentWeather for the model's query.
// the tool for the model's query.
// 4. The model provides its text answer in response to this message.
session := model.StartChat()

res, err := session.SendMessage(ctx, genai.Text("What is the weather like in New York?"))
res, err := session.SendMessage(ctx, genai.Text("Which theaters in Mountain View show Barbie movie?"))
if err != nil {
log.Fatal(err)
log.Fatalf("session.SendMessage: %v", err)
}

part := res.Candidates[0].Content.Parts[0]
funcall, ok := part.(genai.FunctionCall)
if !ok {
log.Fatalf("expected FunctionCall: %v", part)
}

if funcall.Name != "CurrentWeather" {
log.Fatalf("expected CurrentWeather: %v", funcall.Name)
if !ok || funcall.Name != "find_theaters" {
log.Fatalf("expected FunctionCall to find_theaters: %v", part)
}

// Expect the model to pass a proper string "location" argument to the tool.
locArg, ok := funcall.Args["location"].(string)
if !ok {
if _, ok := funcall.Args["location"].(string); !ok {
log.Fatalf("expected string: %v", funcall.Args["location"])
}

weatherData := currentWeather(locArg)
// Provide the model with a hard-coded reply.
res, err = session.SendMessage(ctx, genai.FunctionResponse{
Name: weatherTool.FunctionDeclarations[0].Name,
Name: movieTool.FunctionDeclarations[0].Name,
Response: map[string]any{
"weather": weatherData,
"theater": "AMC16",
},
})
if err != nil {
log.Fatal(err)
}

printResponse(res)
}

Expand Down

0 comments on commit e5d888e

Please sign in to comment.