Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[azopenai] Updating to the 2023-07-01 API surface #21169

Merged
merged 8 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/cognitiveservices/azopenai",
"Tag": "go/cognitiveservices/azopenai_2b6f93a94d"
"Tag": "go/cognitiveservices/azopenai_63852f374c"
}
60 changes: 52 additions & 8 deletions sdk/cognitiveservices/azopenai/autorest.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ go: true
use: "@autorest/go@4.0.0-preview.52"
title: "OpenAI"
slice-elements-byval: true
remove-non-reference-schema: true
# can't use this since it removes an innererror type that we want ()
# remove-non-reference-schema: true
```

## Transformations
Expand Down Expand Up @@ -81,7 +82,7 @@ directive:
where: $.components.schemas["ImageOperation"].properties.status
transform: $["$ref"] = $.anyOf[0]["$ref"];delete $.anyOf;
- from: openapi-document
where: $.components.schemas["ImageGenerationOptions"].properties
where: $.components.schemas.ImageGenerationOptions.properties
transform: |
$.size["$ref"] = "#/components/schemas/ImageSize"; delete $.allOf;
$.response_format["$ref"] = "#/components/schemas/ImageGenerationResponseFormat"; delete $.allOf;
Expand All @@ -93,11 +94,12 @@ directive:
- from: openapi-document
where: $.components.schemas["ImageOperationStatus"].properties.status
transform: $["$ref"] = "#/components/schemas/State"; delete $.allOf;
- from: openapi-document
where: $.components.schemas["ContentFilterResult"].properties.severity
transform: $["$ref"] = "#/components/schemas/ContentFilterSeverity"; delete $.allOf;
- from: openapi-document
where: $.components.schemas["ChatChoice"].properties.finish_reason
transform: >
delete $.oneOf;
$["$ref"] = "#/components/schemas/CompletionsFinishReason";
transform: $["$ref"] = "#/components/schemas/CompletionsFinishReason"; delete $.oneOf;
# Fix "AutoGenerated" models
- from: openapi-document
where: $.components.schemas["ChatCompletions"].properties.usage
Expand Down Expand Up @@ -163,7 +165,7 @@ directive:
- client.go
- models.go
- options.go
- response_types.go
- response_types.go
where: $
transform: return $.replace(/Client(\w+)((?:Options|Response))/g, "$1$2");

Expand All @@ -172,10 +174,19 @@ directive:
where: $
transform: return $.replace(/runtime\.JoinPaths\(client.endpoint, urlPath\)/g, "client.formatURL(urlPath)");

# Some ImageGenerations hackery to represent the ImageLocation/ImagePayload polymorphism.
# - Remove the auto-generated ImageGenerationsDataItem.
# - Replace the ImageGenerations.Data type with []ImageGenerationDataItem
# - from: models.go
# where: $
# transform: |
# return $.replace(/type ImageGenerationsDataItem struct {[^}]+}/, "// ImageGenerationsDataItem represents an image URL or payload\ntype ImageGenerationsDataItem struct{\nImageLocation\nImagePayload\n}")
# $.replace(/(type ImageGenerations struct.+?)Data any/g, "$1Data []ImageGenerationsDataItem")

- from: models.go
where: $
transform: |
return $.replace(/type ImageGenerationsDataItem struct {[^}]+}/, "// ImageGenerationsDataItem represents an image URL or payload\ntype ImageGenerationsDataItem struct{\nImageLocation\nImagePayload\n}");
return $.replace(/(type ImageGenerations struct.+?)Data any/sg, "$1Data []ImageGenerationsDataItem")

# delete the auto-generated ImageGenerationsDataItem, we handle that custom
- from: models.go
Expand Down Expand Up @@ -218,6 +229,17 @@ directive:
.replace(/BeginAzureBatchImageGenerationInternal/g, "beginAzureBatchImageGeneration")
.replace(/BatchImageGenerationOperationResponse/g, "batchImageGenerationOperationResponse");

# BUG: ChatCompletionsOptionsFunctionCall is another one of those "here's mutually exclusive values" options...
- from:
- models.go
- models_serde.go
where: $
transform: |
return $
.replace(/populateAny\(objectMap, "function_call", c.FunctionCall\)/, 'populate(objectMap, "function_call", c.FunctionCall)')
.replace(/\/\/ ChatCompletionsOptionsFunctionCall.+?\n}/, "")
.replace(/FunctionCall any/, "FunctionCall *ChatCompletionsOptionsFunctionCall");

# fix some casing
- from:
- client.go
Expand All @@ -228,8 +250,30 @@ directive:
where: $
transform: return $.replace(/Logprobs/g, "LogProbs")

# remove PossibleazureOpenAIOperationStateValues, since we don't expose the poller
# delete ContentFilterResult in favor of our custom representation.
- from:
- models.go
- models_serde.go
where: $
transform: |
return $.replace(/\/\/ ContentFilterResult.+?\n}/s, "")
.replace(/\/\/ MarshalJSON implements the json.Marshaller interface for type ContentFilterResult.+?\n}/s, "")
.replace(/\/\/ UnmarshalJSON implements the json.Unmarshaller interface for type ContentFilterResult.+?\n}/s, "");

- from: constants.go
where: $
transform: return $.replace(/\/\/ PossibleazureOpenAIOperationStateValues returns.+?\n}/s, "");

# fix incorrect property name for content filtering
# TODO: I imagine we should able to fix this in the tsp?
- from: models_serde.go
where: $
transform: |
return $
.replace(/ case "selfHarm":/g, ' case "self_harm":')
.replace(/populate\(objectMap, "selfHarm", c.SelfHarm\)/g, 'populate(objectMap, "self_harm", c.SelfHarm)');

- from: client.go
where: $
transform: return $.replace(/runtime\.NewResponseError/sg, "client.newError");
```
26 changes: 13 additions & 13 deletions sdk/cognitiveservices/azopenai/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 37 additions & 21 deletions sdk/cognitiveservices/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,33 @@ func TestClient_GetChatCompletions(t *testing.T) {
chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t))
require.NoError(t, err)

testGetChatCompletions(t, chatClient)
testGetChatCompletions(t, chatClient, true)
}

func TestClient_GetChatCompletionsStream(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
require.NoError(t, err)

chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t))
require.NoError(t, err)

testGetChatCompletionsStream(t, chatClient, true)
chatClient := newAzureOpenAIClientForTest(t, canaryChatCompletionsModelDeployment, true)
testGetChatCompletionsStream(t, chatClient)
}

func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}

chatClient := newOpenAIClientForTest(t)
testGetChatCompletions(t, chatClient)
testGetChatCompletions(t, chatClient, false)
}

func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}

chatClient := newOpenAIClientForTest(t)
testGetChatCompletionsStream(t, chatClient, false)
testGetChatCompletionsStream(t, chatClient)
}

func testGetChatCompletions(t *testing.T, client *azopenai.Client) {
func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
expected := azopenai.ChatCompletions{
Choices: []azopenai.ChatChoice{
{
Expand All @@ -91,6 +94,15 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client) {
resp, err := client.GetChatCompletions(context.Background(), chatCompletionsRequest, nil)
require.NoError(t, err)

if isAzure {
// Azure also provides content-filtering. This particular prompt and responses
// will be considered safe.
expected.PromptAnnotations = []azopenai.PromptFilterResult{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: (*azopenai.PromptFilterResultContentFilterResults)(safeContentFilter)},
}
expected.Choices[0].ContentFilterResults = safeContentFilter
}

require.NotEmpty(t, resp.ID)
require.NotEmpty(t, resp.Created)

Expand All @@ -100,17 +112,10 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client) {
require.Equal(t, expected, resp.ChatCompletions)
}

func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure bool) {
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
require.NoError(t, err)

if isAzure {
// there's a bug right now where the first event comes back empty
// Issue: https://github.com/Azure/azure-sdk-for-go/issues/21086
_, err := streamResp.ChatCompletionsStream.Read()
require.NoError(t, err)
}

// the data comes back differently for streaming
// 1. the text comes back in the ChatCompletion.Delta field
// 2. the role is only sent on the first streamed ChatCompletion
Expand All @@ -125,6 +130,18 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure
}

require.NoError(t, err)

if completion.PromptAnnotations != nil {
require.Equal(t, []azopenai.PromptFilterResult{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: (*azopenai.PromptFilterResultContentFilterResults)(safeContentFilter)},
}, completion.PromptAnnotations)
}

if len(completion.Choices) == 0 {
// you can get empty entries that contain just metadata (ie, prompt annotations)
continue
}

require.Equal(t, 1, len(completion.Choices))
choices = append(choices, completion.Choices[0])
}
Expand All @@ -140,7 +157,6 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure
}

require.Equal(t, expectedContent, message, "Ultimately, the same result as GetChatCompletions(), just sent across the .Delta field instead")

require.Equal(t, azopenai.ChatRoleAssistant, expectedRole)
}

Expand All @@ -167,7 +183,7 @@ func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) {
})
require.NoError(t, err)

testGetChatCompletions(t, chatClient)
testGetChatCompletions(t, chatClient, true)
}

func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
Expand Down
Loading