Skip to content

Commit

Permalink
[azopenai] Errors weren't propagating properly in image generation fo…
Browse files Browse the repository at this point in the history
…r OpenAI (#21125)

Code that was handwritten needs to check and return ResponseError's by hand. Added in code to fix this for image generation, and to add in testing for all the areas that have hand-written code (ChatCompletions and Completions streaming and Dall-E integration with OpenAI).

Fixes #21120
  • Loading branch information
richardpark-msft authored Jul 11, 2023
1 parent c82eb8a commit 7de093f
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 17 deletions.
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/cognitiveservices/azopenai",
"Tag": "go/cognitiveservices/azopenai_25f5951837"
"Tag": "go/cognitiveservices/azopenai_2b6f93a94d"
}
25 changes: 24 additions & 1 deletion sdk/cognitiveservices/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var chatCompletionsRequest = azopenai.ChatCompletionsOptions{
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
Model: &openAIChatCompletionsModelDeployment,
Model: &openAIChatCompletionsModel,
}

var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
Expand Down Expand Up @@ -192,3 +192,26 @@ func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
require.ErrorAs(t, err, &respErr)
require.Equal(t, "DeploymentNotFound", respErr.ErrorCode)
}

func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}

doTest := func(t *testing.T, client *azopenai.Client) {
t.Helper()
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
}

t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t, chatCompletionsModelDeployment)
doTest(t, client)
})

t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
doTest(t, client)
})
}
52 changes: 44 additions & 8 deletions sdk/cognitiveservices/azopenai/client_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
Expand All @@ -25,10 +26,10 @@ var (
completionsModelDeployment string // env: AOAI_COMPLETIONS_MODEL_DEPLOYMENT
chatCompletionsModelDeployment string // env: AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT

openAIKey string // env: OPENAI_API_KEY
openAIEndpoint string // env: OPENAI_ENDPOINT
openAICompletionsModelDeployment string // env: OPENAI_CHAT_COMPLETIONS_MODEL
openAIChatCompletionsModelDeployment string // env: OPENAI_COMPLETIONS_MODEL
openAIKey string // env: OPENAI_API_KEY
openAIEndpoint string // env: OPENAI_ENDPOINT
openAICompletionsModel string // env: OPENAI_CHAT_COMPLETIONS_MODEL
openAIChatCompletionsModel string // env: OPENAI_COMPLETIONS_MODEL
)

const fakeEndpoint = "https://recordedhost/"
Expand All @@ -42,10 +43,10 @@ func init() {
openAIEndpoint = fakeEndpoint

completionsModelDeployment = "text-davinci-003"
openAICompletionsModelDeployment = "text-davinci-003"
openAICompletionsModel = "text-davinci-003"

chatCompletionsModelDeployment = "gpt-4"
openAIChatCompletionsModelDeployment = "gpt-4"
openAIChatCompletionsModel = "gpt-4"
} else {
if err := godotenv.Load(); err != nil {
fmt.Printf("Failed to load .env file: %s\n", err)
Expand All @@ -67,8 +68,8 @@ func init() {

openAIKey = os.Getenv("OPENAI_API_KEY")
openAIEndpoint = os.Getenv("OPENAI_ENDPOINT")
openAICompletionsModelDeployment = os.Getenv("OPENAI_COMPLETIONS_MODEL")
openAIChatCompletionsModelDeployment = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL")
openAICompletionsModel = os.Getenv("OPENAI_COMPLETIONS_MODEL")
openAIChatCompletionsModel = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL")

if openAIEndpoint != "" && !strings.HasSuffix(openAIEndpoint, "/") {
// (this just makes recording replacement easier)
Expand All @@ -88,6 +89,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
err = recording.AddHeaderRegexSanitizer("Api-Key", fakeAPIKey, "", nil)
require.NoError(t, err)

err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", ".*", nil)
require.NoError(t, err)

// "RequestUri": "https://openai-shared.openai.azure.com/openai/deployments/text-davinci-003/completions?api-version=2023-03-15-preview",
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(endpoint), nil)
require.NoError(t, err)
Expand Down Expand Up @@ -138,3 +142,35 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions {

return co
}

// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return
// a failure.
func newBogusAzureOpenAIClient(t *testing.T, modelDeploymentID string) *azopenai.Client {
cred, err := azopenai.NewKeyCredential("bogus-api-key")
require.NoError(t, err)

client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, modelDeploymentID, newClientOptionsForTest(t))
require.NoError(t, err)
return client
}

// newBogusOpenAIClient creates a client that uses an invalid key, which will cause OpenAI to return
// a failure.
func newBogusOpenAIClient(t *testing.T) *azopenai.Client {
cred, err := azopenai.NewKeyCredential("bogus-api-key")
require.NoError(t, err)

client, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)
return client
}

func assertResponseIsError(t *testing.T, err error) {
t.Helper()

var respErr *azcore.ResponseError
require.ErrorAs(t, err, &respErr)

// we sometimes get rate limited but (for this kind of test) it's actually okay
require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode)
}
4 changes: 4 additions & 0 deletions sdk/cognitiveservices/azopenai/custom_client_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ func generateImageWithOpenAI(ctx context.Context, client *Client, body ImageGene
return CreateImageResponse{}, err
}

if !runtime.HasStatusCode(resp, http.StatusOK) {
return CreateImageResponse{}, runtime.NewResponseError(resp)
}

var gens *ImageGenerations

if err := runtime.UnmarshalAsJSON(resp, &gens); err != nil {
Expand Down
32 changes: 32 additions & 0 deletions sdk/cognitiveservices/azopenai/custom_client_image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ func TestImageGeneration_OpenAI(t *testing.T) {
testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatURL)
}

func TestImageGeneration_AzureOpenAI_WithError(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}

client := newBogusAzureOpenAIClient(t, "")
testImageGenerationFailure(t, client)
}

func TestImageGeneration_OpenAI_WithError(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}

client := newBogusOpenAIClient(t)
testImageGenerationFailure(t, client)
}

func TestImageGeneration_OpenAI_Base64(t *testing.T) {
client := newOpenAIClientForTest(t)
testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatB64JSON)
Expand Down Expand Up @@ -76,3 +94,17 @@ func testImageGeneration(t *testing.T, client *azopenai.Client, responseFormat a
}
}
}

func testImageGenerationFailure(t *testing.T, bogusClient *azopenai.Client) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

resp, err := bogusClient.CreateImage(ctx, azopenai.ImageGenerationOptions{
Prompt: to.Ptr("a cat"),
Size: to.Ptr(azopenai.ImageSize256x256),
ResponseFormat: to.Ptr(azopenai.ImageGenerationResponseFormatURL),
}, nil)
require.Empty(t, resp)

assertResponseIsError(t, err)
}
37 changes: 30 additions & 7 deletions sdk/cognitiveservices/azopenai/custom_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -88,12 +89,7 @@ func TestGetCompletionsStream_AzureOpenAI(t *testing.T) {
}

func TestGetCompletionsStream_OpenAI(t *testing.T) {
cred, err := azopenai.NewKeyCredential(openAIKey)
require.NoError(t, err)

client, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

client := newOpenAIClientForTest(t)
testGetCompletionsStream(t, client, false)
}

Expand All @@ -102,7 +98,7 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048)),
Temperature: to.Ptr(float32(0.0)),
Model: to.Ptr(openAICompletionsModelDeployment),
Model: to.Ptr(openAICompletionsModel),
}

response, err := client.GetCompletionsStream(context.TODO(), body, nil)
Expand Down Expand Up @@ -142,3 +138,30 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo
require.Equal(t, want, got)
require.Equal(t, 86, eventCount)
}

func TestClient_GetCompletions_Error(t *testing.T) {
if recording.GetRecordMode() == recording.PlaybackMode {
t.Skip()
}

doTest := func(t *testing.T, client *azopenai.Client) {
streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
Model: &openAICompletionsModel,
}, nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
}

t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t, completionsModelDeployment)
doTest(t, client)
})

t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
doTest(t, client)
})
}

0 comments on commit 7de093f

Please sign in to comment.