Skip to content

Commit

Permalink
genai: add content caching
Browse files Browse the repository at this point in the history
Add support for CachedContent.

See https://ai.google.dev/gemini-api/docs/caching.
  • Loading branch information
jba committed Jun 25, 2024
1 parent 5592db3 commit f8dd22c
Show file tree
Hide file tree
Showing 9 changed files with 4,272 additions and 3,742 deletions.
189 changes: 189 additions & 0 deletions genai/caching.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package genai

import (
"context"
"errors"
"fmt"
"time"

gl "cloud.google.com/go/ai/generativelanguage/apiv1beta"
pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb"
"google.golang.org/api/iterator"
durationpb "google.golang.org/protobuf/types/known/durationpb"
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
)

type cacheClient = gl.CacheClient

var (
newCacheClient = gl.NewCacheClient
newCacheRESTClient = gl.NewCacheRESTClient
)

// GenerativeModelFromCachedContent returns a [GenerativeModel] that uses the given [CachedContent].
// The argument should come from a call to [Client.CreateCachedContent] or [Client.GetCachedContent].
func (c *Client) GenerativeModelFromCachedContent(cc *CachedContent) *GenerativeModel {
return &GenerativeModel{
c: c,
fullName: cc.Model,
CachedContentName: cc.Name,
}
}

// CreateCachedContent creates a new CachedContent.
// The argument should contain a model name and some data to be cached, which can include
// contents, a system instruction, tools and/or tool configuration. It can also
// include an expiration time or TTL. But it should not include a name; the system
// will generate one.
//
// The return value will contain the name, which should be used to refer to the CachedContent
// in other API calls. It will also hold various metadata like expiration and creation time.
// It will not contain any of the actual content provided as input.
//
// You can use the return value to create a model with [Client.GenerativeModelFromCachedContent].
// Or you can set [GenerativeModel.CachedContentName] to the name of the CachedContent, in which
// case you must ensure that the model provided in this call matches the name in the [GenerativeModel].
func (c *Client) CreateCachedContent(ctx context.Context, cc *CachedContent) (*CachedContent, error) {
if cc.Name != "" {
return nil, errors.New("genai.CreateCachedContent: do not provide a name; one will be generated")
}
pcc := cc.toProto()
pcc.Model = Ptr(fullModelName(cc.Model))
return c.cachedContentFromProto(c.cc.CreateCachedContent(ctx, &pb.CreateCachedContentRequest{
CachedContent: pcc,
}))
}

// GetCachedContent retrieves the CachedContent with the given name.
func (c *Client) GetCachedContent(ctx context.Context, name string) (*CachedContent, error) {
return c.cachedContentFromProto(c.cc.GetCachedContent(ctx, &pb.GetCachedContentRequest{Name: name}))
}

// DeleteCachedContent deletes the CachedContent with the given name.
func (c *Client) DeleteCachedContent(ctx context.Context, name string) error {
return c.cc.DeleteCachedContent(ctx, &pb.DeleteCachedContentRequest{Name: name})
}

// CachedContentToUpdate specifies which fields of a CachedContent to modify in a call to
// [Client.UpdateCachedContent].
type CachedContentToUpdate struct {
// If non-nil, update the expire time or TTL.
Expiration *ExpireTimeOrTTL
}

// UpdateCachedContent modifies the [CachedContent] according to the values
// of the [CachedContentToUpdate] struct.
// It returns the modified CachedContent.
//
// The argument CachedContent must have its Name field populated.
// If its UpdateTime field is non-zero, it will be compared with the update time
// of the stored CachedContent and the call will fail if they differ.
// This avoids a race condition when two updates are attempted concurrently.
// All other fields of the argument CachedContent are ignored.
func (c *Client) UpdateCachedContent(ctx context.Context, cc *CachedContent, ccu *CachedContentToUpdate) (*CachedContent, error) {
if ccu == nil || ccu.Expiration == nil {
return nil, errors.New("cloud.google.com/go/vertexai/genai.UpdateCachedContent: no update specified")
}
cc2 := &CachedContent{
Name: cc.Name,
UpdateTime: cc.UpdateTime,
Expiration: *ccu.Expiration,
}
mask := "expire_time"
if ccu.Expiration.ExpireTime.IsZero() {
mask = "ttl"
}
return c.cachedContentFromProto(c.cc.UpdateCachedContent(ctx, &pb.UpdateCachedContentRequest{
CachedContent: cc2.toProto(),
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{mask}},
}))
}

// ListCachedContents lists all the CachedContents associated with the project and location.
func (c *Client) ListCachedContents(ctx context.Context) *CachedContentIterator {
return &CachedContentIterator{
it: c.cc.ListCachedContents(ctx, &pb.ListCachedContentsRequest{}),
}
}

// A CachedContentIterator iterates over CachedContents.
type CachedContentIterator struct {
it *gl.CachedContentIterator
}

// Next returns the next result. Its second return value is iterator.Done if there are no more
// results. Once Next returns Done, all subsequent calls will return Done.
func (it *CachedContentIterator) Next() (*CachedContent, error) {
m, err := it.it.Next()
if err != nil {
return nil, err
}
return (CachedContent{}).fromProto(m), nil
}

// PageInfo supports pagination. See the google.golang.org/api/iterator package for details.
func (it *CachedContentIterator) PageInfo() *iterator.PageInfo {
return it.it.PageInfo()
}

func (c *Client) cachedContentFromProto(pcc *pb.CachedContent, err error) (*CachedContent, error) {
if err != nil {
return nil, err
}
cc := (CachedContent{}).fromProto(pcc)
return cc, nil
}

// ExpireTimeOrTTL describes the time when a resource expires.
// If ExpireTime is non-zero, it is the expiration time.
// Otherwise, the expiration time is the value of TTL ("time to live") added
// to the current time.
type ExpireTimeOrTTL struct {
ExpireTime time.Time
TTL time.Duration
}

// populateCachedContentTo populates some fields of p from v.
func populateCachedContentTo(p *pb.CachedContent, v *CachedContent) {
exp := v.Expiration
if !exp.ExpireTime.IsZero() {
p.Expiration = &pb.CachedContent_ExpireTime{
ExpireTime: timestamppb.New(exp.ExpireTime),
}
} else if exp.TTL != 0 {
p.Expiration = &pb.CachedContent_Ttl{
Ttl: durationpb.New(exp.TTL),
}
}
// If both fields of v.Expiration are zero, leave p.Expiration unset.
}

// populateCachedContentFrom populates some fields of v from p.
func populateCachedContentFrom(v *CachedContent, p *pb.CachedContent) {
if p.Expiration == nil {
return
}
switch e := p.Expiration.(type) {
case *pb.CachedContent_ExpireTime:
v.Expiration.ExpireTime = pvTimeFromProto(e.ExpireTime)
case *pb.CachedContent_Ttl:
v.Expiration.TTL = e.Ttl.AsDuration()
default:
panic(fmt.Sprintf("unknown type of CachedContent.Expiration: %T", p.Expiration))
}
}
195 changes: 195 additions & 0 deletions genai/caching_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package genai

import (
"context"
"path/filepath"
"strings"
"testing"
"time"

pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/api/iterator"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
)

func TestPopulateCachedContent(t *testing.T) {
tm := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC)
cmpOpt := cmpopts.IgnoreUnexported(
timestamppb.Timestamp{},
durationpb.Duration{},
)
for _, test := range []struct {
proto *pb.CachedContent
veneer *CachedContent
}{
{&pb.CachedContent{}, &CachedContent{}},
{
&pb.CachedContent{Expiration: &pb.CachedContent_ExpireTime{ExpireTime: timestamppb.New(tm)}},
&CachedContent{Expiration: ExpireTimeOrTTL{ExpireTime: tm}},
},
{
&pb.CachedContent{Expiration: &pb.CachedContent_Ttl{Ttl: durationpb.New(time.Hour)}},
&CachedContent{Expiration: ExpireTimeOrTTL{TTL: time.Hour}},
},
} {
var gotp pb.CachedContent
populateCachedContentTo(&gotp, test.veneer)
if g, w := gotp.Expiration, test.proto.Expiration; !cmp.Equal(g, w, cmpOpt) {
t.Errorf("from %v to proto: got %v, want %v", test.veneer.Expiration, g, w)
}

var gotv CachedContent
populateCachedContentFrom(&gotv, test.proto)
if g, w := gotv.Expiration, test.veneer.Expiration; !cmp.Equal(g, w) {
t.Errorf("from %v to veneer: got %v, want %v", test.proto.Expiration, g, w)
}
}
}

func testCaching(t *testing.T, client *Client) {
ctx := context.Background()
const model = "gemini-1.5-flash-001"

file := uploadFile(t, ctx, client, filepath.Join("testdata", "earth.mp4"))

t.Run("CRUD", func(t *testing.T) {
must := func(cc *CachedContent, err error) *CachedContent {
t.Helper()
if err != nil {
t.Fatal(err)
}
return cc
}

want := &CachedContent{
Model: "models/" + model,
UsageMetadata: &CachedContentUsageMetadata{TotalTokenCount: 36876},
}

compare := func(got *CachedContent, expireTime time.Time) {
t.Helper()
want.Expiration.ExpireTime = expireTime
if got.CreateTime.IsZero() {
t.Error("missing CreateTime")
}
if got.UpdateTime.IsZero() {
t.Error("missing UpdateTime")

}
if diff := cmp.Diff(want, got,
cmpopts.EquateApproxTime(10*time.Second),
cmpopts.IgnoreFields(CachedContent{}, "Name", "CreateTime", "UpdateTime")); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}

ttl := 30 * time.Minute
wantExpireTime := time.Now().Add(ttl)
// Replicate the file content multiple times to reach the minimum token threshold
// for cached content.
fd := FileData{MIMEType: "text/plain", URI: file.URI}
parts := make([]Part, 25)
for i := range parts {
parts[i] = fd
}
argcc := &CachedContent{
Model: model,
Expiration: ExpireTimeOrTTL{TTL: ttl},
Contents: []*Content{{Role: "user", Parts: parts}},
}
cc := must(client.CreateCachedContent(ctx, argcc))
compare(cc, wantExpireTime)
name := cc.Name
cc2 := must(client.GetCachedContent(ctx, name))
compare(cc2, wantExpireTime)
gotList := listAll(t, client.ListCachedContents(ctx))
var cc3 *CachedContent
for _, cc := range gotList {
if cc.Name == name {
cc3 = cc
break
}
}
if cc3 == nil {
t.Fatal("did not find created in list")
}
compare(cc3, wantExpireTime)

// Update using expire time.
newExpireTime := cc3.Expiration.ExpireTime.Add(15 * time.Minute)
cc4 := must(client.UpdateCachedContent(ctx, cc3, &CachedContentToUpdate{
Expiration: &ExpireTimeOrTTL{ExpireTime: newExpireTime},
}))
compare(cc4, newExpireTime)

t.Run("update-ttl", func(t *testing.T) {
// Update using TTL.
cc5 := must(client.UpdateCachedContent(ctx, cc4, &CachedContentToUpdate{
Expiration: &ExpireTimeOrTTL{TTL: ttl},
}))
compare(cc5, time.Now().Add(ttl))
})

if err := client.DeleteCachedContent(ctx, name); err != nil {
t.Fatal(err)
}

if err := client.DeleteCachedContent(ctx, "bad name"); err == nil {
t.Fatal("want error, got nil")
}
})
t.Run("generation", func(t *testing.T) {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &CachedContent{
Model: model,
Contents: []*Content{{Role: "user", Parts: []Part{Text(txt)}}},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
t.Fatal(err)
}
defer client.DeleteCachedContent(ctx, cc.Name)
m := client.GenerativeModelFromCachedContent(cc)
res, err := m.GenerateContent(ctx, Text("Who was the first US president?"))
if err != nil {
t.Fatal(err)
}
got := responseString(res)
const want = "Washington"
if !strings.Contains(got, want) {
t.Errorf("got %q, want string containing %q", got, want)
}
})
}

func listAll(t *testing.T, iter *CachedContentIterator) []*CachedContent {
var ccs []*CachedContent
for {
cc, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatal(err)
}
ccs = append(ccs, cc)
}
return ccs
}
2 changes: 1 addition & 1 deletion genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *Ge
return &GenerateContentResponseIterator{err: err}
}
req.GenerationConfig.CandidateCount = Ptr[int32](1)
streamClient, err := cs.m.c.c.StreamGenerateContent(ctx, req)
streamClient, err := cs.m.c.gc.StreamGenerateContent(ctx, req)
return &GenerateContentResponseIterator{
sc: streamClient,
err: err,
Expand Down
Loading

0 comments on commit f8dd22c

Please sign in to comment.