Skip to content

Commit

Permalink
genai: add helper function NewUserContent (#188)
Browse files Browse the repository at this point in the history
Turns out we already had an internal function like this. The exported
one is slightly more ergonomic due to `...`

I tried finding all places where we could benefit form using it
  • Loading branch information
eliben committed Jul 19, 2024
1 parent 7f8e8fe commit 70f6989
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 83 deletions.
4 changes: 2 additions & 2 deletions genai/caching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func testCaching(t *testing.T, client *Client) {
argcc := &CachedContent{
Model: model,
Expiration: ExpireTimeOrTTL{TTL: ttl},
Contents: []*Content{{Role: "user", Parts: parts}},
Contents: []*Content{NewUserContent(parts...)},
}
cc := must(client.CreateCachedContent(ctx, argcc))
compare(cc, wantExpireTime)
Expand Down Expand Up @@ -158,7 +158,7 @@ func testCaching(t *testing.T, client *Client) {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &CachedContent{
Model: model,
Contents: []*Content{{Role: "user", Parts: []Part{Text(txt)}}},
Contents: []*Content{NewUserContent(Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *GenerativeModel) StartChat() *ChatSession {
// SendMessage sends a request to the model as part of a chat session.
func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
// Call the underlying client with the entire history plus the argument Content.
cs.History = append(cs.History, newUserContent(parts))
cs.History = append(cs.History, NewUserContent(parts...))
req, err := cs.m.newGenerateContentRequest(cs.History...)
if err != nil {
return nil, err
Expand All @@ -48,7 +48,7 @@ func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*Generat

// SendMessageStream is like SendMessage, but with a streaming request.
func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
cs.History = append(cs.History, newUserContent(parts))
cs.History = append(cs.History, NewUserContent(parts...))
req, err := cs.m.newGenerateContentRequest(cs.History...)
if err != nil {
return &GenerateContentResponseIterator{err: err}
Expand Down
10 changes: 3 additions & 7 deletions genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func fullModelName(name string) string {

// GenerateContent produces a single request and response.
func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
content := newUserContent(parts)
content := NewUserContent(parts...)
req, err := m.newGenerateContentRequest(content)
if err != nil {
return nil, err
Expand All @@ -194,7 +194,7 @@ func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*
// GenerateContentStream returns an iterator that enumerates responses.
func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
iter := &GenerateContentResponseIterator{}
req, err := m.newGenerateContentRequest(newUserContent(parts))
req, err := m.newGenerateContentRequest(NewUserContent(parts...))
if err != nil {
iter.err = err
} else {
Expand Down Expand Up @@ -241,10 +241,6 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.G
})
}

func newUserContent(parts []Part) *Content {
return &Content{Role: roleUser, Parts: parts}
}

// GenerateContentResponseIterator is an iterator over GnerateContentResponse.
type GenerateContentResponseIterator struct {
sc pb.GenerativeService_StreamGenerateContentClient
Expand Down Expand Up @@ -313,7 +309,7 @@ func (iter *GenerateContentResponseIterator) MergedResponse() *GenerateContentRe

// CountTokens counts the number of tokens in the content.
func (m *GenerativeModel) CountTokens(ctx context.Context, parts ...Part) (*CountTokensResponse, error) {
req, err := m.newCountTokensRequest(newUserContent(parts))
req, err := m.newCountTokensRequest(NewUserContent(parts...))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ func TestRecoverPanic(t *testing.T) {
Response: map[string]any{"x": 1 + 2i}, // complex values are invalid
}
var m GenerativeModel
_, err := m.newGenerateContentRequest(newUserContent([]Part{fr}))
_, err := m.newGenerateContentRequest(NewUserContent(fr))
if err == nil {
t.Fatal("got nil, want error")
}
Expand Down
10 changes: 10 additions & 0 deletions genai/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,13 @@ func (c *Candidate) FunctionCalls() []FunctionCall {
}
return fcs
}

// NewUserContent returns a *Content with a "user" role set and one or more
// parts.
func NewUserContent(parts ...Part) *Content {
content := &Content{Role: roleUser, Parts: []Part{}}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}
2 changes: 1 addition & 1 deletion genai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string
func newEmbedContentRequest(model string, tt TaskType, title string, parts []Part) *pb.EmbedContentRequest {
req := &pb.EmbedContentRequest{
Model: model,
Content: newUserContent(parts).toProto(),
Content: NewUserContent(parts...).toProto(),
}
// A non-empty title overrides the task type.
if title != "" {
Expand Down
54 changes: 16 additions & 38 deletions genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ func ExampleGenerativeModel_GenerateContent_config() {
model.SetTopP(0.5)
model.SetTopK(20)
model.SetMaxOutputTokens(100)
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars."))
model.ResponseMIMEType = "application/json"
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
Expand All @@ -212,9 +210,7 @@ func ExampleGenerativeModel_GenerateContent_systemInstruction() {
defer client.Close()

model := client.GenerativeModel("gemini-1.5-flash")
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
resp, err := model.GenerateContent(ctx, genai.Text("Good morning! How are you?"))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -303,7 +299,6 @@ func ExampleGenerativeModel_GenerateContentStream() {
}
defer client.Close()

// START [text_gen_text_only_prompt_streaming]
model := client.GenerativeModel("gemini-1.5-flash")
iter := model.GenerateContentStream(ctx, genai.Text("Write a story about a magic backpack."))
for {
Expand All @@ -316,7 +311,7 @@ func ExampleGenerativeModel_GenerateContentStream() {
}
printResponse(resp)
}
// END [text_gen_text_only_prompt_streaming]

}

func ExampleGenerativeModel_GenerateContentStream_imagePrompt() {
Expand All @@ -327,7 +322,6 @@ func ExampleGenerativeModel_GenerateContentStream_imagePrompt() {
}
defer client.Close()

// START [text_gen_multimodal_one_image_prompt_streaming]
model := client.GenerativeModel("gemini-1.5-flash")

imgData, err := os.ReadFile(filepath.Join(testDataDir, "organ.jpg"))
Expand All @@ -347,7 +341,7 @@ func ExampleGenerativeModel_GenerateContentStream_imagePrompt() {
}
printResponse(resp)
}
// END [text_gen_multimodal_one_image_prompt_streaming]

}

func ExampleGenerativeModel_GenerateContentStream_videoPrompt() {
Expand All @@ -358,7 +352,6 @@ func ExampleGenerativeModel_GenerateContentStream_videoPrompt() {
}
defer client.Close()

// START [text_gen_multimodal_video_prompt_streaming]
model := client.GenerativeModel("gemini-1.5-flash")

file, err := uploadFile(ctx, client, filepath.Join(testDataDir, "earth.mp4"), "")
Expand All @@ -380,7 +373,7 @@ func ExampleGenerativeModel_GenerateContentStream_videoPrompt() {
}
printResponse(resp)
}
// END [text_gen_multimodal_video_prompt_streaming]

}

func ExampleGenerativeModel_CountTokens_contextWindow() {
Expand Down Expand Up @@ -447,7 +440,7 @@ func ExampleGenerativeModel_CountTokens_cachedContent() {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
Contents: []*genai.Content{{Role: "user", Parts: []genai.Part{genai.Text(txt)}}},
Contents: []*genai.Content{genai.NewUserContent(genai.Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -657,9 +650,7 @@ func ExampleGenerativeModel_CountTokens_systemInstruction() {
// ( total_tokens: 10 )

// Same prompt, this time with system instruction
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
respWithInstruction, err := model.CountTokens(ctx, genai.Text(prompt))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -1201,8 +1192,8 @@ func ExampleCachedContent_create() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1238,7 +1229,7 @@ func ExampleCachedContent_createFromChat() {

modelName := "gemini-1.5-flash-001"
model := client.GenerativeModel(modelName)
model.SystemInstruction = userContent(genai.Text("You are an expert analyzing transcripts."))
model.SystemInstruction = genai.NewUserContent(genai.Text("You are an expert analyzing transcripts."))

cs := model.StartChat()
resp, err := cs.SendMessage(ctx, genai.Text("Hi, could you summarize this transcript?"), fd)
Expand Down Expand Up @@ -1292,8 +1283,8 @@ func ExampleClient_GetCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1337,8 +1328,8 @@ func ExampleClient_ListCachedContents() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1378,8 +1369,8 @@ func ExampleClient_UpdateCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1455,19 +1446,6 @@ func ExampleClient_setProxy() {
printResponse(resp)
}

// userContent helps create a *genai.Content with a "user" role and one or
// more parts with less verbosity.
func userContent(parts ...genai.Part) *genai.Content {
content := &genai.Content{
Role: "user",
Parts: []genai.Part{},
}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}

func printResponse(resp *genai.GenerateContentResponse) {
for _, cand := range resp.Candidates {
if cand.Content != nil {
Expand Down
45 changes: 13 additions & 32 deletions genai/internal/samples/docs-snippets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ func ExampleGenerativeModel_GenerateContent_config() {
model.SetTopP(0.5)
model.SetTopK(20)
model.SetMaxOutputTokens(100)
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars."))
model.ResponseMIMEType = "application/json"
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
Expand All @@ -218,9 +216,7 @@ func ExampleGenerativeModel_GenerateContent_systemInstruction() {

// [START system_instruction]
model := client.GenerativeModel("gemini-1.5-flash")
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
resp, err := model.GenerateContent(ctx, genai.Text("Good morning! How are you?"))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -458,7 +454,7 @@ func ExampleGenerativeModel_CountTokens_cachedContent() {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
Contents: []*genai.Content{{Role: "user", Parts: []genai.Part{genai.Text(txt)}}},
Contents: []*genai.Content{genai.NewUserContent(genai.Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -674,9 +670,7 @@ func ExampleGenerativeModel_CountTokens_systemInstruction() {
// ( total_tokens: 10 )

// Same prompt, this time with system instruction
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
respWithInstruction, err := model.CountTokens(ctx, genai.Text(prompt))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -1235,8 +1229,8 @@ func ExampleCachedContent_create() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1273,7 +1267,7 @@ func ExampleCachedContent_createFromChat() {

modelName := "gemini-1.5-flash-001"
model := client.GenerativeModel(modelName)
model.SystemInstruction = userContent(genai.Text("You are an expert analyzing transcripts."))
model.SystemInstruction = genai.NewUserContent(genai.Text("You are an expert analyzing transcripts."))

cs := model.StartChat()
resp, err := cs.SendMessage(ctx, genai.Text("Hi, could you summarize this transcript?"), fd)
Expand Down Expand Up @@ -1330,8 +1324,8 @@ func ExampleClient_GetCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1377,8 +1371,8 @@ func ExampleClient_ListCachedContents() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1419,8 +1413,8 @@ func ExampleClient_UpdateCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1496,19 +1490,6 @@ func ExampleClient_setProxy() {
printResponse(resp)
}

// userContent helps create a *genai.Content with a "user" role and one or
// more parts with less verbosity.
func userContent(parts ...genai.Part) *genai.Content {
content := &genai.Content{
Role: "user",
Parts: []genai.Part{},
}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}

func printResponse(resp *genai.GenerateContentResponse) {
for _, cand := range resp.Candidates {
if cand.Content != nil {
Expand Down

0 comments on commit 70f6989

Please sign in to comment.