From 70f698921f159bacdb012e982610cd6e040c19cc Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Fri, 19 Jul 2024 14:30:20 -0600 Subject: [PATCH] genai: add helper function NewUserContent (#188) Turns out we already had an internal function like this. The exported one is slightly more ergonomic due to `...` I tried finding all places where we could benefit form using it --- genai/caching_test.go | 4 +- genai/chat.go | 4 +- genai/client.go | 10 ++-- genai/client_test.go | 2 +- genai/content.go | 10 ++++ genai/embed.go | 2 +- genai/example_test.go | 54 ++++++-------------- genai/internal/samples/docs-snippets_test.go | 45 +++++----------- 8 files changed, 48 insertions(+), 83 deletions(-) diff --git a/genai/caching_test.go b/genai/caching_test.go index 9f2d0f0..27724a1 100644 --- a/genai/caching_test.go +++ b/genai/caching_test.go @@ -111,7 +111,7 @@ func testCaching(t *testing.T, client *Client) { argcc := &CachedContent{ Model: model, Expiration: ExpireTimeOrTTL{TTL: ttl}, - Contents: []*Content{{Role: "user", Parts: parts}}, + Contents: []*Content{NewUserContent(parts...)}, } cc := must(client.CreateCachedContent(ctx, argcc)) compare(cc, wantExpireTime) @@ -158,7 +158,7 @@ func testCaching(t *testing.T, client *Client) { txt := strings.Repeat("George Washington was the first president of the United States. ", 3000) argcc := &CachedContent{ Model: model, - Contents: []*Content{{Role: "user", Parts: []Part{Text(txt)}}}, + Contents: []*Content{NewUserContent(Text(txt))}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { diff --git a/genai/chat.go b/genai/chat.go index f4b95dc..334df1d 100644 --- a/genai/chat.go +++ b/genai/chat.go @@ -32,7 +32,7 @@ func (m *GenerativeModel) StartChat() *ChatSession { // SendMessage sends a request to the model as part of a chat session. func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) { // Call the underlying client with the entire history plus the argument Content. - cs.History = append(cs.History, newUserContent(parts)) + cs.History = append(cs.History, NewUserContent(parts...)) req, err := cs.m.newGenerateContentRequest(cs.History...) if err != nil { return nil, err @@ -48,7 +48,7 @@ func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*Generat // SendMessageStream is like SendMessage, but with a streaming request. func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator { - cs.History = append(cs.History, newUserContent(parts)) + cs.History = append(cs.History, NewUserContent(parts...)) req, err := cs.m.newGenerateContentRequest(cs.History...) if err != nil { return &GenerateContentResponseIterator{err: err} diff --git a/genai/client.go b/genai/client.go index 1d13112..7e81777 100644 --- a/genai/client.go +++ b/genai/client.go @@ -179,7 +179,7 @@ func fullModelName(name string) string { // GenerateContent produces a single request and response. func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) { - content := newUserContent(parts) + content := NewUserContent(parts...) req, err := m.newGenerateContentRequest(content) if err != nil { return nil, err @@ -194,7 +194,7 @@ func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (* // GenerateContentStream returns an iterator that enumerates responses. func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator { iter := &GenerateContentResponseIterator{} - req, err := m.newGenerateContentRequest(newUserContent(parts)) + req, err := m.newGenerateContentRequest(NewUserContent(parts...)) if err != nil { iter.err = err } else { @@ -241,10 +241,6 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.G }) } -func newUserContent(parts []Part) *Content { - return &Content{Role: roleUser, Parts: parts} -} - // GenerateContentResponseIterator is an iterator over GnerateContentResponse. type GenerateContentResponseIterator struct { sc pb.GenerativeService_StreamGenerateContentClient @@ -313,7 +309,7 @@ func (iter *GenerateContentResponseIterator) MergedResponse() *GenerateContentRe // CountTokens counts the number of tokens in the content. func (m *GenerativeModel) CountTokens(ctx context.Context, parts ...Part) (*CountTokensResponse, error) { - req, err := m.newCountTokensRequest(newUserContent(parts)) + req, err := m.newCountTokensRequest(NewUserContent(parts...)) if err != nil { return nil, err } diff --git a/genai/client_test.go b/genai/client_test.go index a129e6f..16eb929 100644 --- a/genai/client_test.go +++ b/genai/client_test.go @@ -814,7 +814,7 @@ func TestRecoverPanic(t *testing.T) { Response: map[string]any{"x": 1 + 2i}, // complex values are invalid } var m GenerativeModel - _, err := m.newGenerateContentRequest(newUserContent([]Part{fr})) + _, err := m.newGenerateContentRequest(NewUserContent(fr)) if err == nil { t.Fatal("got nil, want error") } diff --git a/genai/content.go b/genai/content.go index 59769d3..030d1ce 100644 --- a/genai/content.go +++ b/genai/content.go @@ -170,3 +170,13 @@ func (c *Candidate) FunctionCalls() []FunctionCall { } return fcs } + +// NewUserContent returns a *Content with a "user" role set and one or more +// parts. +func NewUserContent(parts ...Part) *Content { + content := &Content{Role: roleUser, Parts: []Part{}} + for _, part := range parts { + content.Parts = append(content.Parts, part) + } + return content +} diff --git a/genai/embed.go b/genai/embed.go index 95dbd45..cea90ec 100644 --- a/genai/embed.go +++ b/genai/embed.go @@ -65,7 +65,7 @@ func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string func newEmbedContentRequest(model string, tt TaskType, title string, parts []Part) *pb.EmbedContentRequest { req := &pb.EmbedContentRequest{ Model: model, - Content: newUserContent(parts).toProto(), + Content: NewUserContent(parts...).toProto(), } // A non-empty title overrides the task type. if title != "" { diff --git a/genai/example_test.go b/genai/example_test.go index cfff518..00b8d54 100644 --- a/genai/example_test.go +++ b/genai/example_test.go @@ -191,9 +191,7 @@ func ExampleGenerativeModel_GenerateContent_config() { model.SetTopP(0.5) model.SetTopK(20) model.SetMaxOutputTokens(100) - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")}, - } + model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars.")) model.ResponseMIMEType = "application/json" resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?")) if err != nil { @@ -212,9 +210,7 @@ func ExampleGenerativeModel_GenerateContent_systemInstruction() { defer client.Close() model := client.GenerativeModel("gemini-1.5-flash") - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")}, - } + model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko.")) resp, err := model.GenerateContent(ctx, genai.Text("Good morning! How are you?")) if err != nil { log.Fatal(err) @@ -303,7 +299,6 @@ func ExampleGenerativeModel_GenerateContentStream() { } defer client.Close() - // START [text_gen_text_only_prompt_streaming] model := client.GenerativeModel("gemini-1.5-flash") iter := model.GenerateContentStream(ctx, genai.Text("Write a story about a magic backpack.")) for { @@ -316,7 +311,7 @@ func ExampleGenerativeModel_GenerateContentStream() { } printResponse(resp) } - // END [text_gen_text_only_prompt_streaming] + } func ExampleGenerativeModel_GenerateContentStream_imagePrompt() { @@ -327,7 +322,6 @@ func ExampleGenerativeModel_GenerateContentStream_imagePrompt() { } defer client.Close() - // START [text_gen_multimodal_one_image_prompt_streaming] model := client.GenerativeModel("gemini-1.5-flash") imgData, err := os.ReadFile(filepath.Join(testDataDir, "organ.jpg")) @@ -347,7 +341,7 @@ func ExampleGenerativeModel_GenerateContentStream_imagePrompt() { } printResponse(resp) } - // END [text_gen_multimodal_one_image_prompt_streaming] + } func ExampleGenerativeModel_GenerateContentStream_videoPrompt() { @@ -358,7 +352,6 @@ func ExampleGenerativeModel_GenerateContentStream_videoPrompt() { } defer client.Close() - // START [text_gen_multimodal_video_prompt_streaming] model := client.GenerativeModel("gemini-1.5-flash") file, err := uploadFile(ctx, client, filepath.Join(testDataDir, "earth.mp4"), "") @@ -380,7 +373,7 @@ func ExampleGenerativeModel_GenerateContentStream_videoPrompt() { } printResponse(resp) } - // END [text_gen_multimodal_video_prompt_streaming] + } func ExampleGenerativeModel_CountTokens_contextWindow() { @@ -447,7 +440,7 @@ func ExampleGenerativeModel_CountTokens_cachedContent() { txt := strings.Repeat("George Washington was the first president of the United States. ", 3000) argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - Contents: []*genai.Content{{Role: "user", Parts: []genai.Part{genai.Text(txt)}}}, + Contents: []*genai.Content{genai.NewUserContent(genai.Text(txt))}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -657,9 +650,7 @@ func ExampleGenerativeModel_CountTokens_systemInstruction() { // ( total_tokens: 10 ) // Same prompt, this time with system instruction - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")}, - } + model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko.")) respWithInstruction, err := model.CountTokens(ctx, genai.Text(prompt)) if err != nil { log.Fatal(err) @@ -1201,8 +1192,8 @@ func ExampleCachedContent_create() { argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")), - Contents: []*genai.Content{userContent(fd)}, + SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")), + Contents: []*genai.Content{genai.NewUserContent(fd)}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -1238,7 +1229,7 @@ func ExampleCachedContent_createFromChat() { modelName := "gemini-1.5-flash-001" model := client.GenerativeModel(modelName) - model.SystemInstruction = userContent(genai.Text("You are an expert analyzing transcripts.")) + model.SystemInstruction = genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")) cs := model.StartChat() resp, err := cs.SendMessage(ctx, genai.Text("Hi, could you summarize this transcript?"), fd) @@ -1292,8 +1283,8 @@ func ExampleClient_GetCachedContent() { argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")), - Contents: []*genai.Content{userContent(fd)}, + SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")), + Contents: []*genai.Content{genai.NewUserContent(fd)}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -1337,8 +1328,8 @@ func ExampleClient_ListCachedContents() { argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")), - Contents: []*genai.Content{userContent(fd)}, + SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")), + Contents: []*genai.Content{genai.NewUserContent(fd)}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -1378,8 +1369,8 @@ func ExampleClient_UpdateCachedContent() { argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")), - Contents: []*genai.Content{userContent(fd)}, + SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")), + Contents: []*genai.Content{genai.NewUserContent(fd)}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -1455,19 +1446,6 @@ func ExampleClient_setProxy() { printResponse(resp) } -// userContent helps create a *genai.Content with a "user" role and one or -// more parts with less verbosity. -func userContent(parts ...genai.Part) *genai.Content { - content := &genai.Content{ - Role: "user", - Parts: []genai.Part{}, - } - for _, part := range parts { - content.Parts = append(content.Parts, part) - } - return content -} - func printResponse(resp *genai.GenerateContentResponse) { for _, cand := range resp.Candidates { if cand.Content != nil { diff --git a/genai/internal/samples/docs-snippets_test.go b/genai/internal/samples/docs-snippets_test.go index 323c772..2254514 100644 --- a/genai/internal/samples/docs-snippets_test.go +++ b/genai/internal/samples/docs-snippets_test.go @@ -196,9 +196,7 @@ func ExampleGenerativeModel_GenerateContent_config() { model.SetTopP(0.5) model.SetTopK(20) model.SetMaxOutputTokens(100) - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")}, - } + model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars.")) model.ResponseMIMEType = "application/json" resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?")) if err != nil { @@ -218,9 +216,7 @@ func ExampleGenerativeModel_GenerateContent_systemInstruction() { // [START system_instruction] model := client.GenerativeModel("gemini-1.5-flash") - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")}, - } + model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko.")) resp, err := model.GenerateContent(ctx, genai.Text("Good morning! How are you?")) if err != nil { log.Fatal(err) @@ -458,7 +454,7 @@ func ExampleGenerativeModel_CountTokens_cachedContent() { txt := strings.Repeat("George Washington was the first president of the United States. ", 3000) argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - Contents: []*genai.Content{{Role: "user", Parts: []genai.Part{genai.Text(txt)}}}, + Contents: []*genai.Content{genai.NewUserContent(genai.Text(txt))}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -674,9 +670,7 @@ func ExampleGenerativeModel_CountTokens_systemInstruction() { // ( total_tokens: 10 ) // Same prompt, this time with system instruction - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")}, - } + model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko.")) respWithInstruction, err := model.CountTokens(ctx, genai.Text(prompt)) if err != nil { log.Fatal(err) @@ -1235,8 +1229,8 @@ func ExampleCachedContent_create() { argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")), - Contents: []*genai.Content{userContent(fd)}, + SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")), + Contents: []*genai.Content{genai.NewUserContent(fd)}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -1273,7 +1267,7 @@ func ExampleCachedContent_createFromChat() { modelName := "gemini-1.5-flash-001" model := client.GenerativeModel(modelName) - model.SystemInstruction = userContent(genai.Text("You are an expert analyzing transcripts.")) + model.SystemInstruction = genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")) cs := model.StartChat() resp, err := cs.SendMessage(ctx, genai.Text("Hi, could you summarize this transcript?"), fd) @@ -1330,8 +1324,8 @@ func ExampleClient_GetCachedContent() { argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")), - Contents: []*genai.Content{userContent(fd)}, + SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")), + Contents: []*genai.Content{genai.NewUserContent(fd)}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -1377,8 +1371,8 @@ func ExampleClient_ListCachedContents() { argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")), - Contents: []*genai.Content{userContent(fd)}, + SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")), + Contents: []*genai.Content{genai.NewUserContent(fd)}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -1419,8 +1413,8 @@ func ExampleClient_UpdateCachedContent() { argcc := &genai.CachedContent{ Model: "gemini-1.5-flash-001", - SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")), - Contents: []*genai.Content{userContent(fd)}, + SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")), + Contents: []*genai.Content{genai.NewUserContent(fd)}, } cc, err := client.CreateCachedContent(ctx, argcc) if err != nil { @@ -1496,19 +1490,6 @@ func ExampleClient_setProxy() { printResponse(resp) } -// userContent helps create a *genai.Content with a "user" role and one or -// more parts with less verbosity. -func userContent(parts ...genai.Part) *genai.Content { - content := &genai.Content{ - Role: "user", - Parts: []genai.Part{}, - } - for _, part := range parts { - content.Parts = append(content.Parts, part) - } - return content -} - func printResponse(resp *genai.GenerateContentResponse) { for _, cand := range resp.Candidates { if cand.Content != nil {