-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for CachedContent. See https://ai.google.dev/gemini-api/docs/caching.
- Loading branch information
Showing
9 changed files
with
4,272 additions
and
3,742 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package genai | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"time" | ||
|
||
gl "cloud.google.com/go/ai/generativelanguage/apiv1beta" | ||
pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" | ||
"google.golang.org/api/iterator" | ||
durationpb "google.golang.org/protobuf/types/known/durationpb" | ||
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb" | ||
timestamppb "google.golang.org/protobuf/types/known/timestamppb" | ||
) | ||
|
||
type cacheClient = gl.CacheClient | ||
|
||
var ( | ||
newCacheClient = gl.NewCacheClient | ||
newCacheRESTClient = gl.NewCacheRESTClient | ||
) | ||
|
||
// GenerativeModelFromCachedContent returns a [GenerativeModel] that uses the given [CachedContent]. | ||
// The argument should come from a call to [Client.CreateCachedContent] or [Client.GetCachedContent]. | ||
func (c *Client) GenerativeModelFromCachedContent(cc *CachedContent) *GenerativeModel { | ||
return &GenerativeModel{ | ||
c: c, | ||
fullName: cc.Model, | ||
CachedContentName: cc.Name, | ||
} | ||
} | ||
|
||
// CreateCachedContent creates a new CachedContent. | ||
// The argument should contain a model name and some data to be cached, which can include | ||
// contents, a system instruction, tools and/or tool configuration. It can also | ||
// include an expiration time or TTL. But it should not include a name; the system | ||
// will generate one. | ||
// | ||
// The return value will contain the name, which should be used to refer to the CachedContent | ||
// in other API calls. It will also hold various metadata like expiration and creation time. | ||
// It will not contain any of the actual content provided as input. | ||
// | ||
// You can use the return value to create a model with [Client.GenerativeModelFromCachedContent]. | ||
// Or you can set [GenerativeModel.CachedContentName] to the name of the CachedContent, in which | ||
// case you must ensure that the model provided in this call matches the name in the [GenerativeModel]. | ||
func (c *Client) CreateCachedContent(ctx context.Context, cc *CachedContent) (*CachedContent, error) { | ||
if cc.Name != "" { | ||
return nil, errors.New("genai.CreateCachedContent: do not provide a name; one will be generated") | ||
} | ||
pcc := cc.toProto() | ||
pcc.Model = Ptr(fullModelName(cc.Model)) | ||
return c.cachedContentFromProto(c.cc.CreateCachedContent(ctx, &pb.CreateCachedContentRequest{ | ||
CachedContent: pcc, | ||
})) | ||
} | ||
|
||
// GetCachedContent retrieves the CachedContent with the given name. | ||
func (c *Client) GetCachedContent(ctx context.Context, name string) (*CachedContent, error) { | ||
return c.cachedContentFromProto(c.cc.GetCachedContent(ctx, &pb.GetCachedContentRequest{Name: name})) | ||
} | ||
|
||
// DeleteCachedContent deletes the CachedContent with the given name. | ||
func (c *Client) DeleteCachedContent(ctx context.Context, name string) error { | ||
return c.cc.DeleteCachedContent(ctx, &pb.DeleteCachedContentRequest{Name: name}) | ||
} | ||
|
||
// CachedContentToUpdate specifies which fields of a CachedContent to modify in a call to | ||
// [Client.UpdateCachedContent]. | ||
type CachedContentToUpdate struct { | ||
// If non-nil, update the expire time or TTL. | ||
Expiration *ExpireTimeOrTTL | ||
} | ||
|
||
// UpdateCachedContent modifies the [CachedContent] according to the values | ||
// of the [CachedContentToUpdate] struct. | ||
// It returns the modified CachedContent. | ||
// | ||
// The argument CachedContent must have its Name field populated. | ||
// If its UpdateTime field is non-zero, it will be compared with the update time | ||
// of the stored CachedContent and the call will fail if they differ. | ||
// This avoids a race condition when two updates are attempted concurrently. | ||
// All other fields of the argument CachedContent are ignored. | ||
func (c *Client) UpdateCachedContent(ctx context.Context, cc *CachedContent, ccu *CachedContentToUpdate) (*CachedContent, error) { | ||
if ccu == nil || ccu.Expiration == nil { | ||
return nil, errors.New("cloud.google.com/go/vertexai/genai.UpdateCachedContent: no update specified") | ||
} | ||
cc2 := &CachedContent{ | ||
Name: cc.Name, | ||
UpdateTime: cc.UpdateTime, | ||
Expiration: *ccu.Expiration, | ||
} | ||
mask := "expire_time" | ||
if ccu.Expiration.ExpireTime.IsZero() { | ||
mask = "ttl" | ||
} | ||
return c.cachedContentFromProto(c.cc.UpdateCachedContent(ctx, &pb.UpdateCachedContentRequest{ | ||
CachedContent: cc2.toProto(), | ||
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{mask}}, | ||
})) | ||
} | ||
|
||
// ListCachedContents lists all the CachedContents associated with the project and location. | ||
func (c *Client) ListCachedContents(ctx context.Context) *CachedContentIterator { | ||
return &CachedContentIterator{ | ||
it: c.cc.ListCachedContents(ctx, &pb.ListCachedContentsRequest{}), | ||
} | ||
} | ||
|
||
// A CachedContentIterator iterates over CachedContents. | ||
type CachedContentIterator struct { | ||
it *gl.CachedContentIterator | ||
} | ||
|
||
// Next returns the next result. Its second return value is iterator.Done if there are no more | ||
// results. Once Next returns Done, all subsequent calls will return Done. | ||
func (it *CachedContentIterator) Next() (*CachedContent, error) { | ||
m, err := it.it.Next() | ||
if err != nil { | ||
return nil, err | ||
} | ||
return (CachedContent{}).fromProto(m), nil | ||
} | ||
|
||
// PageInfo supports pagination. See the google.golang.org/api/iterator package for details. | ||
func (it *CachedContentIterator) PageInfo() *iterator.PageInfo { | ||
return it.it.PageInfo() | ||
} | ||
|
||
func (c *Client) cachedContentFromProto(pcc *pb.CachedContent, err error) (*CachedContent, error) { | ||
if err != nil { | ||
return nil, err | ||
} | ||
cc := (CachedContent{}).fromProto(pcc) | ||
return cc, nil | ||
} | ||
|
||
// ExpireTimeOrTTL describes the time when a resource expires. | ||
// If ExpireTime is non-zero, it is the expiration time. | ||
// Otherwise, the expiration time is the value of TTL ("time to live") added | ||
// to the current time. | ||
type ExpireTimeOrTTL struct { | ||
ExpireTime time.Time | ||
TTL time.Duration | ||
} | ||
|
||
// populateCachedContentTo populates some fields of p from v. | ||
func populateCachedContentTo(p *pb.CachedContent, v *CachedContent) { | ||
exp := v.Expiration | ||
if !exp.ExpireTime.IsZero() { | ||
p.Expiration = &pb.CachedContent_ExpireTime{ | ||
ExpireTime: timestamppb.New(exp.ExpireTime), | ||
} | ||
} else if exp.TTL != 0 { | ||
p.Expiration = &pb.CachedContent_Ttl{ | ||
Ttl: durationpb.New(exp.TTL), | ||
} | ||
} | ||
// If both fields of v.Expiration are zero, leave p.Expiration unset. | ||
} | ||
|
||
// populateCachedContentFrom populates some fields of v from p. | ||
func populateCachedContentFrom(v *CachedContent, p *pb.CachedContent) { | ||
if p.Expiration == nil { | ||
return | ||
} | ||
switch e := p.Expiration.(type) { | ||
case *pb.CachedContent_ExpireTime: | ||
v.Expiration.ExpireTime = pvTimeFromProto(e.ExpireTime) | ||
case *pb.CachedContent_Ttl: | ||
v.Expiration.TTL = e.Ttl.AsDuration() | ||
default: | ||
panic(fmt.Sprintf("unknown type of CachedContent.Expiration: %T", p.Expiration)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package genai | ||
|
||
import ( | ||
"context" | ||
"path/filepath" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" | ||
"github.com/google/go-cmp/cmp" | ||
"github.com/google/go-cmp/cmp/cmpopts" | ||
"google.golang.org/api/iterator" | ||
durationpb "google.golang.org/protobuf/types/known/durationpb" | ||
timestamppb "google.golang.org/protobuf/types/known/timestamppb" | ||
) | ||
|
||
func TestPopulateCachedContent(t *testing.T) { | ||
tm := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC) | ||
cmpOpt := cmpopts.IgnoreUnexported( | ||
timestamppb.Timestamp{}, | ||
durationpb.Duration{}, | ||
) | ||
for _, test := range []struct { | ||
proto *pb.CachedContent | ||
veneer *CachedContent | ||
}{ | ||
{&pb.CachedContent{}, &CachedContent{}}, | ||
{ | ||
&pb.CachedContent{Expiration: &pb.CachedContent_ExpireTime{ExpireTime: timestamppb.New(tm)}}, | ||
&CachedContent{Expiration: ExpireTimeOrTTL{ExpireTime: tm}}, | ||
}, | ||
{ | ||
&pb.CachedContent{Expiration: &pb.CachedContent_Ttl{Ttl: durationpb.New(time.Hour)}}, | ||
&CachedContent{Expiration: ExpireTimeOrTTL{TTL: time.Hour}}, | ||
}, | ||
} { | ||
var gotp pb.CachedContent | ||
populateCachedContentTo(&gotp, test.veneer) | ||
if g, w := gotp.Expiration, test.proto.Expiration; !cmp.Equal(g, w, cmpOpt) { | ||
t.Errorf("from %v to proto: got %v, want %v", test.veneer.Expiration, g, w) | ||
} | ||
|
||
var gotv CachedContent | ||
populateCachedContentFrom(&gotv, test.proto) | ||
if g, w := gotv.Expiration, test.veneer.Expiration; !cmp.Equal(g, w) { | ||
t.Errorf("from %v to veneer: got %v, want %v", test.proto.Expiration, g, w) | ||
} | ||
} | ||
} | ||
|
||
func testCaching(t *testing.T, client *Client) { | ||
ctx := context.Background() | ||
const model = "gemini-1.5-flash-001" | ||
|
||
file := uploadFile(t, ctx, client, filepath.Join("testdata", "earth.mp4")) | ||
|
||
t.Run("CRUD", func(t *testing.T) { | ||
must := func(cc *CachedContent, err error) *CachedContent { | ||
t.Helper() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
return cc | ||
} | ||
|
||
want := &CachedContent{ | ||
Model: "models/" + model, | ||
UsageMetadata: &CachedContentUsageMetadata{TotalTokenCount: 36876}, | ||
} | ||
|
||
compare := func(got *CachedContent, expireTime time.Time) { | ||
t.Helper() | ||
want.Expiration.ExpireTime = expireTime | ||
if got.CreateTime.IsZero() { | ||
t.Error("missing CreateTime") | ||
} | ||
if got.UpdateTime.IsZero() { | ||
t.Error("missing UpdateTime") | ||
|
||
} | ||
if diff := cmp.Diff(want, got, | ||
cmpopts.EquateApproxTime(10*time.Second), | ||
cmpopts.IgnoreFields(CachedContent{}, "Name", "CreateTime", "UpdateTime")); diff != "" { | ||
t.Errorf("mismatch (-want, +got):\n%s", diff) | ||
} | ||
} | ||
|
||
ttl := 30 * time.Minute | ||
wantExpireTime := time.Now().Add(ttl) | ||
// Replicate the file content multiple times to reach the minimum token threshold | ||
// for cached content. | ||
fd := FileData{MIMEType: "text/plain", URI: file.URI} | ||
parts := make([]Part, 25) | ||
for i := range parts { | ||
parts[i] = fd | ||
} | ||
argcc := &CachedContent{ | ||
Model: model, | ||
Expiration: ExpireTimeOrTTL{TTL: ttl}, | ||
Contents: []*Content{{Role: "user", Parts: parts}}, | ||
} | ||
cc := must(client.CreateCachedContent(ctx, argcc)) | ||
compare(cc, wantExpireTime) | ||
name := cc.Name | ||
cc2 := must(client.GetCachedContent(ctx, name)) | ||
compare(cc2, wantExpireTime) | ||
gotList := listAll(t, client.ListCachedContents(ctx)) | ||
var cc3 *CachedContent | ||
for _, cc := range gotList { | ||
if cc.Name == name { | ||
cc3 = cc | ||
break | ||
} | ||
} | ||
if cc3 == nil { | ||
t.Fatal("did not find created in list") | ||
} | ||
compare(cc3, wantExpireTime) | ||
|
||
// Update using expire time. | ||
newExpireTime := cc3.Expiration.ExpireTime.Add(15 * time.Minute) | ||
cc4 := must(client.UpdateCachedContent(ctx, cc3, &CachedContentToUpdate{ | ||
Expiration: &ExpireTimeOrTTL{ExpireTime: newExpireTime}, | ||
})) | ||
compare(cc4, newExpireTime) | ||
|
||
t.Run("update-ttl", func(t *testing.T) { | ||
// Update using TTL. | ||
cc5 := must(client.UpdateCachedContent(ctx, cc4, &CachedContentToUpdate{ | ||
Expiration: &ExpireTimeOrTTL{TTL: ttl}, | ||
})) | ||
compare(cc5, time.Now().Add(ttl)) | ||
}) | ||
|
||
if err := client.DeleteCachedContent(ctx, name); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if err := client.DeleteCachedContent(ctx, "bad name"); err == nil { | ||
t.Fatal("want error, got nil") | ||
} | ||
}) | ||
t.Run("generation", func(t *testing.T) { | ||
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000) | ||
argcc := &CachedContent{ | ||
Model: model, | ||
Contents: []*Content{{Role: "user", Parts: []Part{Text(txt)}}}, | ||
} | ||
cc, err := client.CreateCachedContent(ctx, argcc) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
defer client.DeleteCachedContent(ctx, cc.Name) | ||
m := client.GenerativeModelFromCachedContent(cc) | ||
res, err := m.GenerateContent(ctx, Text("Who was the first US president?")) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
got := responseString(res) | ||
const want = "Washington" | ||
if !strings.Contains(got, want) { | ||
t.Errorf("got %q, want string containing %q", got, want) | ||
} | ||
}) | ||
} | ||
|
||
func listAll(t *testing.T, iter *CachedContentIterator) []*CachedContent { | ||
var ccs []*CachedContent | ||
for { | ||
cc, err := iter.Next() | ||
if err == iterator.Done { | ||
break | ||
} | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
ccs = append(ccs, cc) | ||
} | ||
return ccs | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.