Skip to content

Commit

Permalink
[Go] add tools support to vertexai plugin (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianlancetaylor authored May 14, 2024
1 parent 69b0a7c commit 7910210
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 6 deletions.
63 changes: 57 additions & 6 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package vertexai

import (
"context"
"fmt"

"cloud.google.com/go/vertexai/genai"
"github.com/google/genkit/go/ai"
Expand All @@ -39,7 +40,7 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb

// Translate from a ai.GenerateRequest to a genai request.
gm.SetCandidateCount(int32(input.Candidates))
if c, ok := input.Config.(*ai.GenerationCommonConfig); ok {
if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil {
gm.SetMaxOutputTokens(int32(c.MaxOutputTokens))
gm.StopSequences = c.StopSequences
gm.SetTemperature(float32(c.Temperature))
Expand All @@ -65,7 +66,40 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb
if len(messages) > 0 {
parts = convertParts(messages[0].Content)
}
//TODO: convert input.Tools and append to gm.Tools

// Convert input.Tools and append to gm.Tools.
for _, t := range input.Tools {
schema := &genai.Schema{
Type: genai.TypeObject,
Properties: make(map[string]*genai.Schema),
}
for k, v := range t.InputSchema {
typ := genai.TypeUnspecified
switch v {
case "string":
typ = genai.TypeString
case "float64":
typ = genai.TypeNumber
case "int":
typ = genai.TypeInteger
case "bool":
typ = genai.TypeBoolean
default:
return nil, fmt.Errorf("schema value %q not supported", v)
}
schema.Properties[k] = &genai.Schema{Type: typ}
}

fd := &genai.FunctionDeclaration{
Name: t.Name,
Parameters: schema,
}

gm.Tools = append(gm.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{fd},
})
}
// TODO: gm.ToolConfig?

// Send out the actual request.
resp, err := cs.SendMessage(ctx, parts...)
Expand Down Expand Up @@ -103,10 +137,13 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb
p = ai.NewTextPart(string(part))
case genai.Blob:
p = ai.NewBlobPart(part.MIMEType, string(part.Data))
case genai.FunctionResponse:
p = ai.NewBlobPart("TODO", string(part.Name))
case genai.FunctionCall:
p = ai.NewToolRequestPart(&ai.ToolRequest{
Name: part.Name,
Input: part.Args,
})
default:
panic("unknown part type")
panic(fmt.Sprintf("unknown part #%v", part))
}
m.Content = append(m.Content, p)
}
Expand Down Expand Up @@ -159,7 +196,21 @@ func convertPart(p *ai.Part) genai.Part {
switch {
case p.IsText():
return genai.Text(p.Text())
default:
case p.IsBlob():
return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())}
case p.IsToolResponse():
toolResp := p.ToolResponse()
return genai.FunctionResponse{
Name: toolResp.Name,
Response: toolResp.Output,
}
case p.IsToolRequest():
toolReq := p.ToolRequest()
return genai.FunctionCall{
Name: toolReq.Name,
Args: toolReq.Input,
}
default:
panic("unknown part type in a request")
}
}
76 changes: 76 additions & 0 deletions go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package vertexai_test
import (
"context"
"flag"
"math"
"strings"
"testing"

Expand Down Expand Up @@ -57,4 +58,79 @@ func TestGenerator(t *testing.T) {
if !strings.Contains(out, "France") {
t.Errorf("got \"%s\", expecting it would contain \"France\"", out)
}
if resp.Request != req {
t.Error("Request field not set properly")
}
}

func TestGeneratorTool(t *testing.T) {
if *projectID == "" {
t.Skip("no -projectid provided")
}
ctx := context.Background()
g, err := vertexai.NewGenerator(ctx, "gemini-1.0-pro", *projectID, *location)
if err != nil {
t.Fatal(err)
}
req := &ai.GenerateRequest{
Candidates: 1,
Messages: []*ai.Message{
&ai.Message{
Content: []*ai.Part{ai.NewTextPart("what is 3.5 squared? Use the tool provided.")},
Role: ai.RoleUser,
},
},
Tools: []*ai.ToolDefinition{
&ai.ToolDefinition{
Name: "exponentiation",
InputSchema: map[string]any{"base": "float64", "exponent": "int"},
OutputSchema: map[string]any{"output": "float64"},
},
},
}

resp, err := g.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}
p := resp.Candidates[0].Message.Content[0]
if !p.IsToolRequest() {
t.Fatalf("tool not requested")
}
toolReq := p.ToolRequest()
if toolReq.Name != "exponentiation" {
t.Errorf("tool name is %q, want \"exponentiation\"", toolReq.Name)
}
if toolReq.Input["base"] != 3.5 {
t.Errorf("base is %f, want 3.5", toolReq.Input["base"])
}
if toolReq.Input["exponent"] != 2 && toolReq.Input["exponent"] != 2.0 {
// Note: 2.0 is wrong given the schema, but Gemini returns a float anyway.
t.Errorf("exponent is %f, want 2", toolReq.Input["exponent"])
}

// Update our conversation with the tool request the model made and our tool response.
// (Our "tool" is just math.Pow.)
req.Messages = append(req.Messages,
resp.Candidates[0].Message,
&ai.Message{
Content: []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{
Name: "exponentiation",
Output: map[string]any{"output": math.Pow(3.5, 2)},
})},
Role: ai.RoleTool,
},
)

// Issue our request again.
resp, err = g.Generate(ctx, req, nil)
if err != nil {
t.Fatal(err)
}

// Check final response.
out := resp.Candidates[0].Message.Content[0].Text()
if !strings.Contains(out, "12.25") {
t.Errorf("got %s, expecting it to contain \"12.25\"", out)
}
}

0 comments on commit 7910210

Please sign in to comment.