Skip to content

Commit

Permalink
implement vertexai streaming
Browse files Browse the repository at this point in the history
Fixes #344
  • Loading branch information
randall77 committed Jun 10, 2024
1 parent 9d5bc47 commit 3309d31
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 29 deletions.
8 changes: 4 additions & 4 deletions go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ module github.com/firebase/genkit/go
go 1.22.0

require (
cloud.google.com/go/aiplatform v1.66.0
cloud.google.com/go/aiplatform v1.67.0
cloud.google.com/go/logging v1.9.0
cloud.google.com/go/vertexai v0.7.1
cloud.google.com/go/vertexai v0.10.0
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0
github.com/aymerick/raymond v2.0.2+incompatible
Expand All @@ -22,15 +22,15 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.26.0
go.opentelemetry.io/otel/trace v1.26.0
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81
google.golang.org/api v0.178.0
google.golang.org/api v0.180.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1
)

require (
cloud.google.com/go v0.113.0 // indirect
cloud.google.com/go/ai v0.5.0 // indirect
cloud.google.com/go/auth v0.4.0 // indirect
cloud.google.com/go/auth v0.4.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
cloud.google.com/go/iam v1.1.7 // indirect
Expand Down
8 changes: 8 additions & 0 deletions go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ cloud.google.com/go/ai v0.5.0 h1:x8s4rDn5t9OVZvBCgtr5bZTH5X0O7JdE6zYo+O+MpRw=
cloud.google.com/go/ai v0.5.0/go.mod h1:96VBphk70e0zdXZrbtgPuKYRZsQ3UktSUXhuojwiKA8=
cloud.google.com/go/aiplatform v1.66.0 h1:bbFYY4JInclG10czRFUYj2rjD+obhh3Gi9zVlyoMgEc=
cloud.google.com/go/aiplatform v1.66.0/go.mod h1:bPQS0UjaXaTAq57UgP3XWDCtYFOIbXXpkMsl6uP4JAc=
cloud.google.com/go/aiplatform v1.67.0 h1:YWeqD4BjYwrmY4fa+isGcw0P81lJ3dKVxbWxdBchoiU=
cloud.google.com/go/aiplatform v1.67.0/go.mod h1:s/sJ6btBEr6bKnrNWdK9ZgHCvwbZNdP90b3DDtxxw+Y=
cloud.google.com/go/auth v0.4.0 h1:vcJWEguhY8KuiHoSs/udg1JtIRYm3YAWPBE1moF1m3U=
cloud.google.com/go/auth v0.4.0/go.mod h1:tO/chJN3obc5AbRYFQDsuFbL4wW5y8LfbPtDCfgwOVE=
cloud.google.com/go/auth v0.4.1 h1:Z7YNIhlWRtrnKlZke7z3GMqzvuYzdc2z98F9D1NV5Hg=
cloud.google.com/go/auth v0.4.1/go.mod h1:QVBuVEKpCn4Zp58hzRGvL0tjRGU0YqdRTdCHM1IHnro=
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
Expand All @@ -23,6 +27,8 @@ cloud.google.com/go/trace v1.10.6 h1:XF0Ejdw0NpRfAvuZUeQe3ClAG4R/9w5JYICo7l2weaw
cloud.google.com/go/trace v1.10.6/go.mod h1:EABXagUjxGuKcZMy4pXyz0fJpE5Ghog3jzTxcEsVJS4=
cloud.google.com/go/vertexai v0.7.1 h1:CSdqsEwjklLIlI1e5SrsnkwG/I+CeJekkBbMTzeYhVg=
cloud.google.com/go/vertexai v0.7.1/go.mod h1:HfnfYR9aPS+qF2436S6Hzuw0Fp+PORjzK3ggqymdzSU=
cloud.google.com/go/vertexai v0.10.0 h1:k157bLrtyajGtAAZnqdEn8lwFlUTG3BgHc7kvWbP/3s=
cloud.google.com/go/vertexai v0.10.0/go.mod h1:w/Zb22QvOVvxx5CGM4fPzH3WA6gwUkId9juA7pigzFI=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 h1:n3T26hyfDl9RdgcOjWvOFMh1lCBNuZ0JQ/3DM5pou8Y=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0/go.mod h1:3S7qK2nHOO2cLID3xk6H8f55D38XswhVFzKEk0nqIbY=
Expand Down Expand Up @@ -186,6 +192,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.178.0 h1:yoW/QMI4bRVCHF+NWOTa4cL8MoWL3Jnuc7FlcFF91Ok=
google.golang.org/api v0.178.0/go.mod h1:84/k2v8DFpDRebpGcooklv/lais3MEfqpaBLA12gl2U=
google.golang.org/api v0.180.0 h1:M2D87Yo0rGBPWpo1orwfCLehUUL6E7/TYe5gvMQWDh4=
google.golang.org/api v0.180.0/go.mod h1:51AiyoEg1MJPSZ9zvklA8VnRILPXxn1iVen9v25XHAE=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
Expand Down
115 changes: 90 additions & 25 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"fmt"
"reflect"
"runtime"

aiplatform "cloud.google.com/go/aiplatform/apiv1"
Expand Down Expand Up @@ -125,9 +126,6 @@ type generator struct {
}

func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
if cb != nil {
panic("streaming not supported yet") // TODO: streaming
}
gm := g.client.GenerativeModel(g.model)

// Translate from a ai.GenerateRequest to a genai request.
Expand All @@ -143,7 +141,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
gm.SetTemperature(float32(c.Temperature))
}
if c.TopK != 0 {
gm.SetTopK(float32(c.TopK))
gm.SetTopK(int32(c.TopK))
}
if c.TopP != 0 {
gm.SetTopP(float32(c.TopP))
Expand Down Expand Up @@ -213,13 +211,77 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
// TODO: gm.ToolConfig?

// Send out the actual request.
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
if cb == nil {
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
}

r := translateResponse(resp)
r.Request = input
return r, nil
}

r := translateResponse(resp)
r.Request = input
// Streaming version.
iter := cs.SendMessageStream(ctx, parts...)
r := &ai.GenerateResponse{Request: input, Candidates: make([]*ai.Candidate, input.Candidates)}
for {
chunk, err := iter.Next()
if err != nil {
if err.Error() == "no more items in iterator" {
break
}
return nil, err
}

// Process each candidate.
for _, c := range chunk.Candidates {
tc := translateCandidate(c)

// Call callback with the candidate info.
err := cb(ctx, &ai.GenerateResponseChunk{
Content: tc.Message.Content,
Index: tc.Index,
})
if err != nil {
return nil, err
}
// Save candidate in full response structure.
if old := r.Candidates[tc.Index]; old == nil {
r.Candidates[tc.Index] = tc
} else {
// Need to merge two "parts" of a candidate.
// Currently, we:
// - append the Message content
// - merge the FinishReason
// - assert everything else is unchanged
// (We do that 3rd step first.)
c1 := *r.Candidates[tc.Index]
c2 := *tc
m1 := *c1.Message
m2 := *c2.Message
c1.Message = &m1
c2.Message = &m2
m1.Content = nil
m2.Content = nil
c1.FinishReason = ai.FinishReasonUnknown
c2.FinishReason = ai.FinishReasonUnknown
if !reflect.DeepEqual(&c1, &c2) {
return nil, fmt.Errorf("some candidate fields unexpectedly changed")
}

// Append the Parts to the final candidate.
old.Message.Content = append(old.Message.Content, tc.Message.Content...)
// Merge the FinishReasons.
if old.FinishReason == ai.FinishReasonUnknown {
old.FinishReason = tc.FinishReason
} else if old.FinishReason != tc.FinishReason {
return nil, fmt.Errorf("invalid finish reason transition: %s to %s", old.FinishReason, tc.FinishReason)
}
}
}
// TODO: use chunk.PromptFeedback, chunk.UsageMetadata
}
return r, nil
}

Expand All @@ -242,30 +304,33 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate {
c.FinishReason = ai.FinishReasonUnknown
}
m := &ai.Message{}
m.Role = ai.Role(cand.Content.Role)
for _, part := range cand.Content.Parts {
var p *ai.Part
switch part := part.(type) {
case genai.Text:
p = ai.NewTextPart(string(part))
case genai.Blob:
p = ai.NewMediaPart(part.MIMEType, string(part.Data))
case genai.FunctionCall:
p = ai.NewToolRequestPart(&ai.ToolRequest{
Name: part.Name,
Input: part.Args,
})
default:
panic(fmt.Sprintf("unknown part %#v", part))
if cand.Content != nil {
m.Role = ai.Role(cand.Content.Role)
for _, part := range cand.Content.Parts {
var p *ai.Part
switch part := part.(type) {
case genai.Text:
p = ai.NewTextPart(string(part))
case genai.Blob:
p = ai.NewMediaPart(part.MIMEType, string(part.Data))
case genai.FunctionCall:
p = ai.NewToolRequestPart(&ai.ToolRequest{
Name: part.Name,
Input: part.Args,
})
default:
panic(fmt.Sprintf("unknown part %#v", part))
}
m.Content = append(m.Content, p)
}
m.Content = append(m.Content, p)
}
c.Message = m
return c
}

// Translate from a genai.GenerateContentResponse to a ai.GenerateResponse.
func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse {
// Note: this path doesn't get used when streaming.
r := &ai.GenerateResponse{}
for _, c := range resp.Candidates {
r.Candidates = append(r.Candidates, translateCandidate(c))
Expand Down
40 changes: 40 additions & 0 deletions go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,46 @@ func TestLive(t *testing.T) {
t.Error("Request field not set properly")
}
})
t.Run("streaming", func(t *testing.T) {
req := &ai.GenerateRequest{
Candidates: 1,
Messages: []*ai.Message{
&ai.Message{
Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")},
Role: ai.RoleUser,
},
},
}

out := ""
parts := 0
model := vertexai.Model(modelName)
final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
parts++
for _, p := range c.Content {
out += p.Text
}
return nil
})
if err != nil {
t.Fatal(err)
}
out2 := ""
for _, p := range final.Candidates[0].Message.Content {
out2 += p.Text
}
if out != out2 {
t.Errorf("streaming and final should contain the same text.\nstreaming:%s\nfinal:%s", out, out2)
}
const want = "Golden"
if !strings.Contains(out, want) {
t.Errorf("got %q, expecting it to contain %q", out, want)
}
if parts == 1 {
// Check if streaming actually occurred.
t.Errorf("expecting more than one part")
}
})
t.Run("tool", func(t *testing.T) {
req := &ai.GenerateRequest{
Candidates: 1,
Expand Down

0 comments on commit 3309d31

Please sign in to comment.