Skip to content

Commit

Permalink
Working around a bug in Azure when doing completion/chatcompletion st…
Browse files Browse the repository at this point in the history
…reaming.
  • Loading branch information
Richard Park committed Jun 30, 2023
1 parent 1d3d3e1 commit 206bcca
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
12 changes: 9 additions & 3 deletions sdk/cognitiveservices/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestClient_GetChatCompletionsStream(t *testing.T) {
chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t))
require.NoError(t, err)

testGetChatCompletionsStream(t, chatClient)
testGetChatCompletionsStream(t, chatClient, true)
}

func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
Expand All @@ -64,7 +64,7 @@ func TestClient_OpenAI_GetChatCompletions(t *testing.T) {

func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testGetChatCompletionsStream(t, chatClient)
testGetChatCompletionsStream(t, chatClient, false)
}

func testGetChatCompletions(t *testing.T, client *azopenai.Client) {
Expand Down Expand Up @@ -100,10 +100,16 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client) {
require.Equal(t, expected, resp.ChatCompletions)
}

func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client) {
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure bool) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
require.NoError(t, err)

if isAzure {
// there's a bug right now where the first event comes back empty
_, err := streamResp.ChatCompletionsStream.Read()
require.NoError(t, err)
}

// the data comes back differently for streaming
// 1. the text comes back in the ChatCompletion.Delta field
// 2. the role is only sent on the first streamed ChatCompletion
Expand Down
3 changes: 2 additions & 1 deletion sdk/cognitiveservices/azopenai/client_embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -67,7 +68,7 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
ctx: context.TODO(),
deploymentID: modelOrDeploymentID,
body: azopenai.EmbeddingsOptions{
Input: []byte("\"Your text string goes here\""),
Input: []*string{to.Ptr("\"Your text string goes here\"")},
Model: &modelOrDeploymentID,
},
options: nil,
Expand Down
17 changes: 9 additions & 8 deletions sdk/cognitiveservices/azopenai/custom_client_image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

package azopenai
package azopenai_test

import (
"context"
Expand All @@ -13,14 +13,15 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"
"github.com/stretchr/testify/require"
)

func TestImageGeneration_AzureOpenAI(t *testing.T) {
cred, err := NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(apiKey)
require.NoError(t, err)

client, err := NewClientWithKeyCredential(endpoint, cred, streamingModelDeployment, newClientOptionsForTest(t))
client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, "", newClientOptionsForTest(t))
require.NoError(t, err)

testImageGeneration(t, client)
Expand All @@ -31,18 +32,18 @@ func TestImageGeneration_OpenAI(t *testing.T) {
testImageGeneration(t, client)
}

func testImageGeneration(t *testing.T, client *Client) {
func testImageGeneration(t *testing.T, client *azopenai.Client) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

resp, err := client.CreateImage(ctx, ImageGenerationOptions{
resp, err := client.CreateImage(ctx, azopenai.ImageGenerationOptions{
Prompt: to.Ptr("a cat"),
Size: to.Ptr(ImageSize256x256),
ResponseFormat: to.Ptr(ImageGenerationResponseFormatURL),
Size: to.Ptr(azopenai.ImageSize256x256),
ResponseFormat: to.Ptr(azopenai.ImageGenerationResponseFormatURL),
}, nil)
require.NoError(t, err)

headResp, err := http.DefaultClient.Head(*resp.Data[0].Result.(ImageLocation).URL)
headResp, err := http.DefaultClient.Head(*resp.Data[0].Result.(azopenai.ImageLocation).URL)
require.NoError(t, err)
require.Equal(t, http.StatusOK, headResp.StatusCode)
}

0 comments on commit 206bcca

Please sign in to comment.