Skip to content

Commit

Permalink
[Go] refactor googleai generate and copy to vertexai (#650)
Browse files Browse the repository at this point in the history
- Break the googleai generation code up into separate functions
  for readability.

- Copy the code into vertexai.

- Gate the header test behind a flag. It can't be run along
  with the live test, because both call genkit.Init.
  • Loading branch information
jba authored Jul 22, 2024
1 parent 5ed5129 commit fdf3523
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 151 deletions.
171 changes: 102 additions & 69 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,13 @@ func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
Label: labelPrefix + " - " + name,
Supports: caps,
}
g := generator{model: name, client: state.gclient}
return ai.DefineModel(provider, name, meta, g.generate)
return ai.DefineModel(provider, name, meta, func(
ctx context.Context,
input *ai.GenerateRequest,
cb func(context.Context, *ai.GenerateResponseChunk) error,
) (*ai.GenerateResponse, error) {
return generate(ctx, state.gclient, name, input, cb)
})
}

// IsDefinedModel reports whether the named [Model] is defined by this plugin.
Expand All @@ -157,6 +162,8 @@ func IsDefinedModel(name string) bool {

//copy:stop

//copy:start vertexai.go defineEmbedder

// DefineEmbedder defines an embedder with a given name.
func DefineEmbedder(name string) *ai.Embedder {
state.mu.Lock()
Expand All @@ -172,6 +179,8 @@ func IsDefinedEmbedder(name string) bool {
return ai.IsDefinedEmbedder(provider, name)
}

//copy:stop

// requires state.mu
func defineEmbedder(name string) *ai.Embedder {
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
Expand Down Expand Up @@ -213,16 +222,85 @@ func Embedder(name string) *ai.Embedder {

//copy:stop

type generator struct {
model string
client *genai.Client
//session *genai.ChatSession // non-nil if we're in the middle of a chat
}
//copy:start vertexai.go generate

func generate(
ctx context.Context,
client *genai.Client,
model string,
input *ai.GenerateRequest,
cb func(context.Context, *ai.GenerateResponseChunk) error,
) (*ai.GenerateResponse, error) {
gm := newModel(client, model, input)
cs, err := startChat(gm, input)
if err != nil {
return nil, err
}
// The last message gets added to the parts slice.
var parts []genai.Part
if len(input.Messages) > 0 {
last := input.Messages[len(input.Messages)-1]
var err error
parts, err = convertParts(last.Content)
if err != nil {
return nil, err
}
}

gm.Tools, err = convertTools(input.Tools)
if err != nil {
return nil, err
}
// Convert input.Tools and append to gm.Tools

// TODO: gm.ToolConfig?

// Send out the actual request.
if cb == nil {
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
}
r := translateResponse(resp)
r.Request = input
return r, nil
}

func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
gm := g.client.GenerativeModel(g.model)
// Streaming version.
iter := cs.SendMessageStream(ctx, parts...)
var r *ai.GenerateResponse
for {
chunk, err := iter.Next()
if err == iterator.Done {
r = translateResponse(iter.MergedResponse())
break
}
if err != nil {
return nil, err
}
// Send candidates to the callback.
for _, c := range chunk.Candidates {
tc := translateCandidate(c)
err := cb(ctx, &ai.GenerateResponseChunk{
Content: tc.Message.Content,
Index: tc.Index,
})
if err != nil {
return nil, err
}
}
}
if r == nil {
// No candidates were returned. Probably rare, but it might avoid a NPE
// to return an empty instead of nil result.
r = &ai.GenerateResponse{}
}
r.Request = input
return r, nil
}

// Translate from a ai.GenerateRequest to a genai request.
func newModel(client *genai.Client, model string, input *ai.GenerateRequest) *genai.GenerativeModel {
gm := client.GenerativeModel(model)
gm.SetCandidateCount(int32(input.Candidates))
if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil {
if c.MaxOutputTokens != 0 {
Expand All @@ -241,8 +319,11 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
gm.SetTopP(float32(c.TopP))
}
}
return gm
}

// Start a "chat".
// startChat starts a chat session and configures it with the input messages.
func startChat(gm *genai.GenerativeModel, input *ai.GenerateRequest) (*genai.ChatSession, error) {
cs := gm.StartChat()

// All but the last message goes in the history field.
Expand All @@ -259,18 +340,11 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
Role: string(m.Role),
})
}
// The last message gets added to the parts slice.
var parts []genai.Part
if len(messages) > 0 {
var err error
parts, err = convertParts(messages[0].Content)
if err != nil {
return nil, err
}
}

// Convert input.Tools and append to gm.Tools
for _, t := range input.Tools {
return cs, nil
}
func convertTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) {
var outTools []*genai.Tool
for _, t := range inTools {
schema := &genai.Schema{}
schema.Type = genai.TypeObject
schema.Properties = map[string]*genai.Schema{}
Expand All @@ -286,7 +360,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
case "bool":
typ = genai.TypeBoolean
default:
return nil, fmt.Errorf("schema value \"%s\" not allowed", v)
return nil, fmt.Errorf("schema value %q not allowed", v)
}
schema.Properties[k] = &genai.Schema{Type: typ}
}
Expand All @@ -295,54 +369,13 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
Parameters: schema,
Description: t.Description,
}
gm.Tools = append(gm.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}})
}
// TODO: gm.ToolConfig?

// Send out the actual request.
if cb == nil {
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
}
r := translateResponse(resp)
r.Request = input
return r, nil
}

// Streaming version.
iter := cs.SendMessageStream(ctx, parts...)
var r *ai.GenerateResponse
for {
chunk, err := iter.Next()
if err == iterator.Done {
r = translateResponse(iter.MergedResponse())
break
}
if err != nil {
return nil, err
}
// Send candidates to the callback.
for _, c := range chunk.Candidates {
tc := translateCandidate(c)
err := cb(ctx, &ai.GenerateResponseChunk{
Content: tc.Message.Content,
Index: tc.Index,
})
if err != nil {
return nil, err
}
}
outTools = append(outTools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}})
}
if r == nil {
// No candidates were returned. Probably rare, but it might avoid a NPE
// to return an empty instead of nil result.
r = &ai.GenerateResponse{}
}
r.Request = input
return r, nil
return outTools, nil
}

//copy:stop

//copy:start vertexai.go translateCandidate

// translateCandidate translates from a genai.GenerateContentResponse to an ai.GenerateResponse.
Expand Down
5 changes: 5 additions & 0 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import (
// The tests here only work with an API key set to a valid value.
var apiKey = flag.String("key", "", "Gemini API key")

var header = flag.Bool("header", false, "run test for x-goog-client-api header")

// We can't test the DefineAll functions along with the other tests because
// we get duplicate definitions of models.
var testAll = flag.Bool("all", false, "test DefineAllXXX functions")
Expand Down Expand Up @@ -203,6 +205,9 @@ func TestLive(t *testing.T) {
}

func TestHeader(t *testing.T) {
if !*header {
t.Skip("skipped; to run, pass -header and don't run the live test")
}
ctx := context.Background()
var header http.Header
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading

0 comments on commit fdf3523

Please sign in to comment.