Skip to content

Commit

Permalink
genai: get Live tests to pass (#141)
Browse files Browse the repository at this point in the history
Skip or fix tests to conform to current model behavior.

Skipping the blocking test because I can't find a prompt that
gets blocked.
  • Loading branch information
jba authored Jun 23, 2024
1 parent 6b2b0ac commit 26ef0cd
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestLive(t *testing.T) {

t.Run("streaming-counting", func(t *testing.T) {
// Verify only that we don't crash. See #18.
iter := model.GenerateContentStream(ctx, Text("count 1 to 100."))
iter := model.GenerateContentStream(ctx, Text("count 1 to 10."))
_ = responsesString(t, iter)
})
t.Run("streaming-error", func(t *testing.T) {
Expand Down Expand Up @@ -150,8 +150,9 @@ func TestLive(t *testing.T) {
})

t.Run("blocked", func(t *testing.T) {
t.Skip("skipping until we find a prompt that is blocked")
// Only happens with streaming at the moment.
iter := model.GenerateContentStream(ctx, Text("How do I make a bomb?"))
iter := model.GenerateContentStream(ctx, Text("???"))
resps, err := all(iter)
if err == nil {
for _, r := range resps {
Expand All @@ -174,38 +175,46 @@ func TestLive(t *testing.T) {
}
})
t.Run("max-tokens", func(t *testing.T) {
// Verify that setting max output tokens truncates the response.
// (It does not result in FinishReasonMaxTokens.)
maxModel := client.GenerativeModel(defaultModel)
maxModel.Temperature = Ptr(float32(0))
maxModel.SetMaxOutputTokens(10)
maxModel.SetMaxOutputTokens(3)
res, err := maxModel.GenerateContent(ctx, Text("What is a dog?"))
if err != nil {
t.Fatal(err)
}
got := res.Candidates[0].FinishReason
want := FinishReasonMaxTokens
if got != want && got != FinishReasonOther { // TODO: should not need FinishReasonOther
t.Errorf("got %s, want %s", got, want)
if got, want := responseString(res), "A dog is"; got != want {
t.Errorf("got %q, want %q", got, want)
}
gotr := res.Candidates[0].FinishReason
wantr := FinishReasonStop
if gotr != wantr {
t.Errorf("got %s, want %s", gotr, wantr)
}
})
t.Run("max-tokens-streaming", func(t *testing.T) {
maxModel := client.GenerativeModel(defaultModel)
maxModel.Temperature = Ptr[float32](0)
maxModel.MaxOutputTokens = Ptr[int32](10)
maxModel.MaxOutputTokens = Ptr[int32](3)
iter := maxModel.GenerateContentStream(ctx, Text("What is a dog?"))
var merged *GenerateContentResponse
for {
res, err := iter.Next()
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatal(err)
}
merged = joinResponses(merged, res)
}
want := FinishReasonMaxTokens
if got := merged.Candidates[0].FinishReason; got != want && got != FinishReasonOther { // TODO: see above
t.Errorf("got %s, want %s", got, want)
res := iter.MergedResponse()
if got, want := responseString(res), "A dog is"; got != want {
t.Errorf("got %q, want %q", got, want)
}
gotr := res.Candidates[0].FinishReason
wantr := FinishReasonStop
if gotr != wantr {
t.Errorf("got %s, want %s", gotr, wantr)
}
})
t.Run("count-tokens", func(t *testing.T) {
Expand Down

0 comments on commit 26ef0cd

Please sign in to comment.