Skip to content

Commit

Permalink
implement vertexai streaming (#384)
Browse files Browse the repository at this point in the history
* implement vertexai streaming

Fixes #344

* round of review

* use newly-added MergedResponse api

* small cleanups
  • Loading branch information
randall77 authored Jun 13, 2024
1 parent ba71c98 commit f594e1f
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 50 deletions.
28 changes: 14 additions & 14 deletions go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ module github.com/firebase/genkit/go
go 1.22.0

require (
cloud.google.com/go/aiplatform v1.66.0
cloud.google.com/go/logging v1.9.0
cloud.google.com/go/vertexai v0.7.1
cloud.google.com/go/aiplatform v1.68.0
cloud.google.com/go/logging v1.10.0
cloud.google.com/go/vertexai v0.12.0
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0
github.com/aymerick/raymond v2.0.2+incompatible
Expand All @@ -22,21 +22,21 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.26.0
go.opentelemetry.io/otel/trace v1.26.0
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81
google.golang.org/api v0.178.0
google.golang.org/api v0.183.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1
)

require (
cloud.google.com/go v0.113.0 // indirect
cloud.google.com/go v0.114.0 // indirect
cloud.google.com/go/ai v0.5.0 // indirect
cloud.google.com/go/auth v0.4.0 // indirect
cloud.google.com/go/auth v0.5.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
cloud.google.com/go/iam v1.1.7 // indirect
cloud.google.com/go/iam v1.1.8 // indirect
cloud.google.com/go/longrunning v0.5.7 // indirect
cloud.google.com/go/monitoring v1.18.1 // indirect
cloud.google.com/go/trace v1.10.6 // indirect
cloud.google.com/go/monitoring v1.19.0 // indirect
cloud.google.com/go/trace v1.10.7 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.46.0 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
Expand All @@ -57,14 +57,14 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/oauth2 v0.20.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240506185236-b8a5c65736ae // indirect
google.golang.org/grpc v1.63.2 // indirect
google.golang.org/genproto v0.0.0-20240528184218-531527333157 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
google.golang.org/grpc v1.64.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
56 changes: 28 additions & 28 deletions go/go.sum
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.113.0 h1:g3C70mn3lWfckKBiCVsAshabrDg01pQ0pnX1MNtnMkA=
cloud.google.com/go v0.113.0/go.mod h1:glEqlogERKYeePz6ZdkcLJ28Q2I6aERgDDErBg9GzO8=
cloud.google.com/go v0.114.0 h1:OIPFAdfrFDFO2ve2U7r/H5SwSbBzEdrBdE7xkgwc+kY=
cloud.google.com/go v0.114.0/go.mod h1:ZV9La5YYxctro1HTPug5lXH/GefROyW8PPD4T8n9J8E=
cloud.google.com/go/ai v0.5.0 h1:x8s4rDn5t9OVZvBCgtr5bZTH5X0O7JdE6zYo+O+MpRw=
cloud.google.com/go/ai v0.5.0/go.mod h1:96VBphk70e0zdXZrbtgPuKYRZsQ3UktSUXhuojwiKA8=
cloud.google.com/go/aiplatform v1.66.0 h1:bbFYY4JInclG10czRFUYj2rjD+obhh3Gi9zVlyoMgEc=
cloud.google.com/go/aiplatform v1.66.0/go.mod h1:bPQS0UjaXaTAq57UgP3XWDCtYFOIbXXpkMsl6uP4JAc=
cloud.google.com/go/auth v0.4.0 h1:vcJWEguhY8KuiHoSs/udg1JtIRYm3YAWPBE1moF1m3U=
cloud.google.com/go/auth v0.4.0/go.mod h1:tO/chJN3obc5AbRYFQDsuFbL4wW5y8LfbPtDCfgwOVE=
cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rBPcG8U=
cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME=
cloud.google.com/go/auth v0.5.1 h1:0QNO7VThG54LUzKiQxv8C6x1YX7lUrzlAa1nVLF8CIw=
cloud.google.com/go/auth v0.5.1/go.mod h1:vbZT8GjzDf3AVqCcQmqeeM32U9HBFc32vVVAbwDsa6s=
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/iam v1.1.7 h1:z4VHOhwKLF/+UYXAJDFwGtNF0b6gjsW1Pk9Ml0U/IoM=
cloud.google.com/go/iam v1.1.7/go.mod h1:J4PMPg8TtyurAUvSmPj8FF3EDgY1SPRZxcUGrn7WXGA=
cloud.google.com/go/logging v1.9.0 h1:iEIOXFO9EmSiTjDmfpbRjOxECO7R8C7b8IXUGOj7xZw=
cloud.google.com/go/logging v1.9.0/go.mod h1:1Io0vnZv4onoUnsVUQY3HZ3Igb1nBchky0A0y7BBBhE=
cloud.google.com/go/iam v1.1.8 h1:r7umDwhj+BQyz0ScZMp4QrGXjSTI3ZINnpgU2nlB/K0=
cloud.google.com/go/iam v1.1.8/go.mod h1:GvE6lyMmfxXauzNq8NbgJbeVQNspG+tcdL/W8QO1+zE=
cloud.google.com/go/logging v1.10.0 h1:f+ZXMqyrSJ5vZ5pE/zr0xC8y/M9BLNzQeLBwfeZ+wY4=
cloud.google.com/go/logging v1.10.0/go.mod h1:EHOwcxlltJrYGqMGfghSet736KR3hX1MAj614mrMk9I=
cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU=
cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng=
cloud.google.com/go/monitoring v1.18.1 h1:0yvFXK+xQd95VKo6thndjwnJMno7c7Xw1CwMByg0B+8=
cloud.google.com/go/monitoring v1.18.1/go.mod h1:52hTzJ5XOUMRm7jYi7928aEdVxBEmGwA0EjNJXIBvt8=
cloud.google.com/go/trace v1.10.6 h1:XF0Ejdw0NpRfAvuZUeQe3ClAG4R/9w5JYICo7l2weaw=
cloud.google.com/go/trace v1.10.6/go.mod h1:EABXagUjxGuKcZMy4pXyz0fJpE5Ghog3jzTxcEsVJS4=
cloud.google.com/go/vertexai v0.7.1 h1:CSdqsEwjklLIlI1e5SrsnkwG/I+CeJekkBbMTzeYhVg=
cloud.google.com/go/vertexai v0.7.1/go.mod h1:HfnfYR9aPS+qF2436S6Hzuw0Fp+PORjzK3ggqymdzSU=
cloud.google.com/go/monitoring v1.19.0 h1:NCXf8hfQi+Kmr56QJezXRZ6GPb80ZI7El1XztyUuLQI=
cloud.google.com/go/monitoring v1.19.0/go.mod h1:25IeMR5cQ5BoZ8j1eogHE5VPJLlReQ7zFp5OiLgiGZw=
cloud.google.com/go/trace v1.10.7 h1:gK8z2BIJQ3KIYGddw9RJLne5Fx0FEXkrEQzPaeEYVvk=
cloud.google.com/go/trace v1.10.7/go.mod h1:qk3eiKmZX0ar2dzIJN/3QhY2PIFh1eqcIdaN5uEjQPM=
cloud.google.com/go/vertexai v0.12.0 h1:zTadEo/CtsoyRXNx3uGCncoWAP1H2HakGqwznt+iMo8=
cloud.google.com/go/vertexai v0.12.0/go.mod h1:8u+d0TsvBfAAd2x5R6GMgbYhsLgo3J7lmP4bR8g2ig8=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 h1:n3T26hyfDl9RdgcOjWvOFMh1lCBNuZ0JQ/3DM5pou8Y=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0/go.mod h1:3S7qK2nHOO2cLID3xk6H8f55D38XswhVFzKEk0nqIbY=
Expand Down Expand Up @@ -159,8 +159,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo=
golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand All @@ -184,26 +184,26 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.178.0 h1:yoW/QMI4bRVCHF+NWOTa4cL8MoWL3Jnuc7FlcFF91Ok=
google.golang.org/api v0.178.0/go.mod h1:84/k2v8DFpDRebpGcooklv/lais3MEfqpaBLA12gl2U=
google.golang.org/api v0.183.0 h1:PNMeRDwo1pJdgNcFQ9GstuLe/noWKIc89pRWRLMvLwE=
google.golang.org/api v0.183.0/go.mod h1:q43adC5/pHoSZTx5h2mSmdF7NcyfW9JuDyIOJAgS9ZQ=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda h1:wu/KJm9KJwpfHWhkkZGohVC6KRrc1oJNr4jwtQMOQXw=
google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda/go.mod h1:g2LLCvCeCSir/JJSWosk19BR4NVxGqHUC6rxIRsd7Aw=
google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae h1:AH34z6WAGVNkllnKs5raNq3yRq93VnjBG6rpfub/jYk=
google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240506185236-b8a5c65736ae h1:c55+MER4zkBS14uJhSZMGGmya0yJx5iHV4x/fpOSNRk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240506185236-b8a5c65736ae/go.mod h1:I7Y+G38R2bu5j1aLzfFmQfTcU/WnFuqDwLZAbvKTKpM=
google.golang.org/genproto v0.0.0-20240528184218-531527333157 h1:u7WMYrIrVvs0TF5yaKwKNbcJyySYf+HAIFXxWltJOXE=
google.golang.org/genproto v0.0.0-20240528184218-531527333157/go.mod h1:ubQlAQnzejB8uZzszhrTCU2Fyp6Vi7ZE5nn0c3W8+qQ=
google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 h1:+rdxYoE3E5htTEWIe15GlN6IfvbURM//Jt0mmkmm6ZU=
google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117/go.mod h1:OimBR/bc1wPO9iV4NC2bpyjy3VnAwZh5EBPQdtaE5oo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 h1:1GBuWVLM/KMVUv1t1En5Gs+gFZCNd360GGb4sSxtrhU=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY=
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
Expand Down
51 changes: 43 additions & 8 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"cloud.google.com/go/vertexai/genai"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/plugins/internal/uri"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)

Expand Down Expand Up @@ -125,9 +126,6 @@ type generator struct {
}

func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
if cb != nil {
panic("streaming not supported yet") // TODO: streaming
}
gm := g.client.GenerativeModel(g.model)

// Translate from a ai.GenerateRequest to a genai request.
Expand All @@ -143,7 +141,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
gm.SetTemperature(float32(c.Temperature))
}
if c.TopK != 0 {
gm.SetTopK(float32(c.TopK))
gm.SetTopK(int32(c.TopK))
}
if c.TopP != 0 {
gm.SetTopP(float32(c.TopP))
Expand Down Expand Up @@ -213,12 +211,49 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
// TODO: gm.ToolConfig?

// Send out the actual request.
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
if cb == nil {
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
}

r := translateResponse(resp)
r.Request = input
return r, nil
}

r := translateResponse(resp)
// Streaming version.
iter := cs.SendMessageStream(ctx, parts...)
var r *ai.GenerateResponse
for {
chunk, err := iter.Next()
if err != nil {
if err == iterator.Done {
r = translateResponse(iter.MergedResponse())
break
}
return nil, err
}

// Process each candidate.
for _, c := range chunk.Candidates {
tc := translateCandidate(c)

// Call callback with the candidate info.
err := cb(ctx, &ai.GenerateResponseChunk{
Content: tc.Message.Content,
Index: tc.Index,
})
if err != nil {
return nil, err
}
}
}
if r == nil {
// No candidates were returned. Probably rare, but it might avoid a NPE
// to return an empty instead of nil result.
r = &ai.GenerateResponse{}
}
r.Request = input
return r, nil
}
Expand Down
40 changes: 40 additions & 0 deletions go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,46 @@ func TestLive(t *testing.T) {
t.Error("Request field not set properly")
}
})
t.Run("streaming", func(t *testing.T) {
req := &ai.GenerateRequest{
Candidates: 1,
Messages: []*ai.Message{
&ai.Message{
Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")},
Role: ai.RoleUser,
},
},
}

out := ""
parts := 0
model := vertexai.Model(modelName)
final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
parts++
for _, p := range c.Content {
out += p.Text
}
return nil
})
if err != nil {
t.Fatal(err)
}
out2 := ""
for _, p := range final.Candidates[0].Message.Content {
out2 += p.Text
}
if out != out2 {
t.Errorf("streaming and final should contain the same text.\nstreaming:%s\nfinal:%s", out, out2)
}
const want = "Golden"
if !strings.Contains(out, want) {
t.Errorf("got %q, expecting it to contain %q", out, want)
}
if parts == 1 {
// Check if streaming actually occurred.
t.Errorf("expecting more than one part")
}
})
t.Run("tool", func(t *testing.T) {
req := &ai.GenerateRequest{
Candidates: 1,
Expand Down

0 comments on commit f594e1f

Please sign in to comment.