diff --git a/eng/config.json b/eng/config.json index 18dbd410be92..d04f536773cd 100644 --- a/eng/config.json +++ b/eng/config.json @@ -36,6 +36,10 @@ "Name": "azfile", "CoverageGoal": 0.75 }, + { + "Name": "azopenai", + "CoverageGoal": 0.45 + }, { "Name": "aztemplate", "CoverageGoal": 0.50 diff --git a/sdk/cognitiveservices/azopenai/CHANGELOG.md b/sdk/cognitiveservices/azopenai/CHANGELOG.md new file mode 100644 index 000000000000..c00ce1cf1381 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/CHANGELOG.md @@ -0,0 +1,5 @@ +# Release History + +## 0.1.0 (unreleased) + +* Initial release of the `azopenai` library diff --git a/sdk/cognitiveservices/azopenai/LICENSE.txt b/sdk/cognitiveservices/azopenai/LICENSE.txt new file mode 100644 index 000000000000..ec703274aadd --- /dev/null +++ b/sdk/cognitiveservices/azopenai/LICENSE.txt @@ -0,0 +1,21 @@ + MIT License + +Copyright (c) Microsoft Corporation. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE \ No newline at end of file diff --git a/sdk/cognitiveservices/azopenai/README.md b/sdk/cognitiveservices/azopenai/README.md new file mode 100644 index 000000000000..8b23a1894444 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/README.md @@ -0,0 +1,107 @@ +# Azure OpenAI client module for Go + +Azure OpenAI is a managed service that allows developers to deploy, tune, and generate content from OpenAI models on Azure resources. + +The Azure OpenAI client library for GO is an adaptation of OpenAI's REST APIs that provides an idiomatic interface and rich integration with the rest of the Azure SDK ecosystem. + +[Source code][azopenai_repo] | [Package (pkg.go.dev)][azopenai_pkg_go] | [REST API documentation][openai_rest_docs] | [Product documentation][openai_docs] + +## Getting started + +### Prerequisites + +* Go, version 1.18 or higher - [Install Go](https://go.dev/doc/install) +* [Azure subscription][azure_sub] +* [Azure OpenAI access][azure_openai_access] + +### Install the packages + +Install the `azopenai` and `azidentity` modules with `go get`: + +```bash +go get github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai +go get github.com/Azure/azure-sdk-for-go/sdk/azidentity +``` + +The [azidentity][azure_identity] module is used for authentication during client construction. + +### Authentication + + + +#### Create a client + +Constructing the client requires your vault's URL, which you can get from the Azure CLI or the Azure Portal. + +```go +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" +) + +func main() { + endpoint := "https://" + apiKey := "" + + var err error + cred := azopenai.KeyCredential{APIKey: apiKey} + client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, &options) + if err != nil { + // TODO: handle error + } +} +``` + +## Key concepts + +See [Key concepts][openai_key_concepts] in the product documentation for more details about general concepts. + +## Troubleshooting + +### Error Handling + +All methods that send HTTP requests return `*azcore.ResponseError` when these requests fail. `ResponseError` has error details and the raw response from the service. + +### Logging + +This module uses the logging implementation in `azcore`. To turn on logging for all Azure SDK modules, set `AZURE_SDK_GO_LOGGING` to `all`. By default, the logger writes to stderr. Use the `azcore/log` package to control log output. For example, logging only HTTP request and response events, and printing them to stdout: + +```go +import azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + +// Print log events to stdout +azlog.SetListener(func(cls azlog.Event, msg string) { + fmt.Println(msg) +}) + +// Includes only requests and responses in credential logs +azlog.SetEvents(azlog.EventRequest, azlog.EventResponse) +``` + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require you to agree to a [Contributor License Agreement (CLA)][cla] declaring that you have the right to, and actually do, grant us the rights to use your contribution. + +When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate +the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to +do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct][coc]. For more information, see +the [Code of Conduct FAQ][coc_faq] or contact [opencode@microsoft.com][coc_contact] with any additional questions or +comments. + + + +[azure_openai_access]: https://learn.microsoft.com/azure/cognitive-services/openai/overview#how-do-i-get-access-to-azure-openai +[azopenai_repo]: https://github.com/Azure/azure-sdk-for-go/tree/main/sdk +[azopenai_pkg_go]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk +[azure_identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity +[azure_sub]: https://azure.microsoft.com/free/ +[openai_docs]: https://learn.microsoft.com/azure/cognitive-services/openai +[openai_key_concepts]: https://learn.microsoft.com/azure/cognitive-services/openai/overview#key-concepts +[openai_rest_docs]: https://learn.microsoft.com/azure/cognitive-services/openai/reference +[cla]: https://cla.microsoft.com +[coc]: https://opensource.microsoft.com/codeofconduct/ +[coc_faq]: https://opensource.microsoft.com/codeofconduct/faq/ +[coc_contact]: mailto:opencode@microsoft.com \ No newline at end of file diff --git a/sdk/cognitiveservices/azopenai/assets.json b/sdk/cognitiveservices/azopenai/assets.json new file mode 100644 index 000000000000..ba5284a4a597 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/assets.json @@ -0,0 +1,6 @@ +{ + "AssetsRepo": "Azure/azure-sdk-assets", + "AssetsRepoPrefixPath": "go", + "TagPrefix": "go/cognitiveservices/azopenai", + "Tag": "go/cognitiveservices/azopenai_0bc6dc4171" +} diff --git a/sdk/cognitiveservices/azopenai/autorest.md b/sdk/cognitiveservices/azopenai/autorest.md new file mode 100644 index 000000000000..fff153b3f8c2 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/autorest.md @@ -0,0 +1,150 @@ +# Go + +These settings apply only when `--go` is specified on the command line. + +``` yaml +input-file: +- https://github.com/mikekistler/azure-rest-api-specs/blob/baed660fd853b4a387ca9f0b9491fd1414b66e9e/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-03-15-preview/inference.json +output-folder: ../azopenai +clear-output-folder: false +module: github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai +license-header: MICROSOFT_MIT_NO_VERSION +openapi-type: data-plane +go: true +use: "@autorest/go@4.0.0-preview.50" +``` + +## Transformations + +``` yaml +directive: + # Add x-ms-parameter-location to parameters in x-ms-parameterized-host + - from: openapi-document + where: $.servers.0.variables.endpoint + debug: true + transform: $["x-ms-parameter-location"] = "client"; + + # Make deploymentId a client parameter + # This must be done in each operation as the parameter is not defined in the components section + - from: openapi-document + where: $.paths..parameters..[?(@.name=='deploymentId')] + transform: $["x-ms-parameter-location"] = "client"; + + # Update operationIds to combine all operations into a single client + - rename-operation: + from: getCompletions + to: OpenAI_GetCompletions + - rename-operation: + from: getEmbeddings + to: OpenAI_GetEmbeddings + - rename-operation: + from: getChatCompletions + to: OpenAI_GetChatCompletions + + # Mark request bodies as required (TypeSpec issue #1838) + - from: openapi-document + where: $.paths["/deployments/{deploymentId}/completions"].post.requestBody + transform: $["required"] = true; + - from: openapi-document + where: $.paths["/deployments/{deploymentId}/embeddings"].post.requestBody + transform: $["required"] = true; + + # Remove stream property from CompletionsOptions and ChatCompletionsOptions + - from: openapi-document + where: $.components.schemas["CompletionsOptions"] + transform: delete $.properties.stream; + - from: openapi-document + where: $.components.schemas["ChatCompletionsOptions"] + transform: delete $.properties.stream; + + # Replace anyOf schemas with an empty schema (no type) to get an "any" type generated + - from: openapi-document + where: '$.components.schemas["EmbeddingsOptions"].properties["input"]' + transform: delete $.anyOf; + + # Fix autorest bug + - from: openapi-document + where: $.components.schemas["ChatMessage"].properties.role + transform: > + delete $.allOf; + $["$ref"] = "#/components/schemas/ChatRole"; + + # Fix another autorest bug + - from: openapi-document + where: $.components.schemas["Choice"].properties.finish_reason + transform: > + delete $.oneOf; + $["$ref"] = "#/components/schemas/CompletionsFinishReason"; + - from: openapi-document + where: $.components.schemas["ChatChoice"].properties.finish_reason + transform: > + delete $.oneOf; + $["$ref"] = "#/components/schemas/CompletionsFinishReason"; + + # Fix "AutoGenerated" models + - from: openapi-document + where: $.components.schemas["ChatCompletions"].properties.usage + transform: > + delete $.allOf; + $["$ref"] = "#/components/schemas/CompletionsUsage"; + - from: openapi-document + where: $.components.schemas["Completions"].properties.usage + transform: > + delete $.allOf; + $["$ref"] = "#/components/schemas/CompletionsUsage"; + + # + # strip out the deploymentID validation code - we absorbed this into the endpoint. + # + # urlPath := "/deployments/{deploymentId}/embeddings" + # if client.deploymentID == "" { + # return nil, errors.New("parameter client.deploymentID cannot be empty") + # } + # urlPath = strings.ReplaceAll(urlPath, "{deploymentId}", url.PathEscape(client.deploymentID)) + - from: client.go + where: $ + transform: >- + return $.replace( + /(\s+)urlPath\s*:=\s*"\/deployments\/\{deploymentId\}\/([^"]+)".+?url\.PathEscape.+?\n/gs, + "$1urlPath := \"$2\"\n") + + # splice out the auto-generated `deploymentID` field from the client + - from: client.go + where: $ + transform: >- + return $.replace( + /(type Client struct.+?)deploymentID string([^}]+})/s, + "$1$2") + + # delete unused error models + - from: models.go + where: $ + transform: >- + return $.replace( + /\/\/ AzureCoreFoundations.*?type AzureCoreFoundations(Error|ErrorResponse|ErrorResponseError|InnerError|InnerErrorInnererror|ErrorInnererror) struct \{[^}]+\}/gs, + "") + - from: models_serde.go + where: $ + transform: >- + return $.replace( + /\/\/ (UnmarshalJSON|MarshalJSON) implements.*?AzureCoreFoundations.*?func.+?\n}/gs, + "") + - from: models_serde.go + where: $ + transform: return $.replace(/(?:\/\/.*\s)?func \(\w \*?(?:ErrorResponse|ErrorResponseError|InnerError|InnerErrorInnererror)\).*\{\s(?:.+\s)+\}\s/g, ""); + + - from: constants.go + where: $ + transform: >- + return $.replace( + /type ServiceAPIVersions string.+PossibleServiceAPIVersionsValues.+?\n}/gs, + "") + + # delete client name prefix from method options and response types + - from: + - client.go + - models.go + - response_types.go + where: $ + transform: return $.replace(/Client(\w+)((?:Options|Response))/g, "$1$2"); +``` diff --git a/sdk/cognitiveservices/azopenai/build.go b/sdk/cognitiveservices/azopenai/build.go new file mode 100644 index 000000000000..12924d856949 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/build.go @@ -0,0 +1,11 @@ +//go:build go1.18 +// +build go1.18 + +//go:generate autorest ./autorest.md +//go:generate go mod tidy +//go:generate goimports -w . + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai diff --git a/sdk/cognitiveservices/azopenai/ci.yml b/sdk/cognitiveservices/azopenai/ci.yml new file mode 100644 index 000000000000..90db9512357e --- /dev/null +++ b/sdk/cognitiveservices/azopenai/ci.yml @@ -0,0 +1,28 @@ +# NOTE: Please refer to https://aka.ms/azsdk/engsys/ci-yaml before editing this file. +trigger: + branches: + include: + - main + - feature/* + - hotfix/* + - release/* + paths: + include: + - sdk/cognitiveservices/azopenai + - eng/ + +pr: + branches: + include: + - main + - feature/* + - hotfix/* + - release/* + paths: + include: + - sdk/cognitiveservices/azopenai + +stages: + - template: /eng/pipelines/templates/jobs/archetype-sdk-client.yml + parameters: + ServiceDirectory: "cognitiveservices/azopenai" diff --git a/sdk/cognitiveservices/azopenai/client.go b/sdk/cognitiveservices/azopenai/client.go new file mode 100644 index 000000000000..cd7237df5032 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client.go @@ -0,0 +1,174 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. +// Code generated by Microsoft (R) AutoRest Code Generator. DO NOT EDIT. +// Changes may cause incorrect behavior and will be lost if the code is regenerated. + +package azopenai + +import ( + "context" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) + +// Client contains the methods for the OpenAI group. +// Don't use this type directly, use a constructor function instead. +type Client struct { + internal *azcore.Client + endpoint string +} + +// GetChatCompletions - Gets chat completions for the provided chat messages. Completions support a wide variety of tasks +// and generate text that continues from or "completes" provided prompt data. +// If the operation fails it returns an *azcore.ResponseError type. +// +// Generated from API version 2023-03-15-preview +// - options - GetChatCompletionsOptions contains the optional parameters for the Client.GetChatCompletions method. +func (client *Client) GetChatCompletions(ctx context.Context, body ChatCompletionsOptions, options *GetChatCompletionsOptions) (GetChatCompletionsResponse, error) { + var err error + req, err := client.getChatCompletionsCreateRequest(ctx, body, options) + if err != nil { + return GetChatCompletionsResponse{}, err + } + httpResp, err := client.internal.Pipeline().Do(req) + if err != nil { + return GetChatCompletionsResponse{}, err + } + if !runtime.HasStatusCode(httpResp, http.StatusOK) { + err = runtime.NewResponseError(httpResp) + return GetChatCompletionsResponse{}, err + } + resp, err := client.getChatCompletionsHandleResponse(httpResp) + return resp, err +} + +// getChatCompletionsCreateRequest creates the GetChatCompletions request. +func (client *Client) getChatCompletionsCreateRequest(ctx context.Context, body ChatCompletionsOptions, options *GetChatCompletionsOptions) (*policy.Request, error) { + urlPath := "chat/completions" + req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(client.endpoint, urlPath)) + if err != nil { + return nil, err + } + reqQP := req.Raw().URL.Query() + reqQP.Set("api-version", "2023-03-15-preview") + req.Raw().URL.RawQuery = reqQP.Encode() + req.Raw().Header["Accept"] = []string{"application/json"} + if err := runtime.MarshalAsJSON(req, body); err != nil { + return nil, err + } + return req, nil +} + +// getChatCompletionsHandleResponse handles the GetChatCompletions response. +func (client *Client) getChatCompletionsHandleResponse(resp *http.Response) (GetChatCompletionsResponse, error) { + result := GetChatCompletionsResponse{} + if err := runtime.UnmarshalAsJSON(resp, &result.ChatCompletions); err != nil { + return GetChatCompletionsResponse{}, err + } + return result, nil +} + +// GetCompletions - Gets completions for the provided input prompts. Completions support a wide variety of tasks and generate +// text that continues from or "completes" provided prompt data. +// If the operation fails it returns an *azcore.ResponseError type. +// +// Generated from API version 2023-03-15-preview +// - options - GetCompletionsOptions contains the optional parameters for the Client.GetCompletions method. +func (client *Client) GetCompletions(ctx context.Context, body CompletionsOptions, options *GetCompletionsOptions) (GetCompletionsResponse, error) { + var err error + req, err := client.getCompletionsCreateRequest(ctx, body, options) + if err != nil { + return GetCompletionsResponse{}, err + } + httpResp, err := client.internal.Pipeline().Do(req) + if err != nil { + return GetCompletionsResponse{}, err + } + if !runtime.HasStatusCode(httpResp, http.StatusOK) { + err = runtime.NewResponseError(httpResp) + return GetCompletionsResponse{}, err + } + resp, err := client.getCompletionsHandleResponse(httpResp) + return resp, err +} + +// getCompletionsCreateRequest creates the GetCompletions request. +func (client *Client) getCompletionsCreateRequest(ctx context.Context, body CompletionsOptions, options *GetCompletionsOptions) (*policy.Request, error) { + urlPath := "completions" + req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(client.endpoint, urlPath)) + if err != nil { + return nil, err + } + reqQP := req.Raw().URL.Query() + reqQP.Set("api-version", "2023-03-15-preview") + req.Raw().URL.RawQuery = reqQP.Encode() + req.Raw().Header["Accept"] = []string{"application/json"} + if err := runtime.MarshalAsJSON(req, body); err != nil { + return nil, err + } + return req, nil +} + +// getCompletionsHandleResponse handles the GetCompletions response. +func (client *Client) getCompletionsHandleResponse(resp *http.Response) (GetCompletionsResponse, error) { + result := GetCompletionsResponse{} + if err := runtime.UnmarshalAsJSON(resp, &result.Completions); err != nil { + return GetCompletionsResponse{}, err + } + return result, nil +} + +// GetEmbeddings - Return the embeddings for a given prompt. +// If the operation fails it returns an *azcore.ResponseError type. +// +// Generated from API version 2023-03-15-preview +// - options - GetEmbeddingsOptions contains the optional parameters for the Client.GetEmbeddings method. +func (client *Client) GetEmbeddings(ctx context.Context, body EmbeddingsOptions, options *GetEmbeddingsOptions) (GetEmbeddingsResponse, error) { + var err error + req, err := client.getEmbeddingsCreateRequest(ctx, body, options) + if err != nil { + return GetEmbeddingsResponse{}, err + } + httpResp, err := client.internal.Pipeline().Do(req) + if err != nil { + return GetEmbeddingsResponse{}, err + } + if !runtime.HasStatusCode(httpResp, http.StatusOK) { + err = runtime.NewResponseError(httpResp) + return GetEmbeddingsResponse{}, err + } + resp, err := client.getEmbeddingsHandleResponse(httpResp) + return resp, err +} + +// getEmbeddingsCreateRequest creates the GetEmbeddings request. +func (client *Client) getEmbeddingsCreateRequest(ctx context.Context, body EmbeddingsOptions, options *GetEmbeddingsOptions) (*policy.Request, error) { + urlPath := "embeddings" + req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(client.endpoint, urlPath)) + if err != nil { + return nil, err + } + reqQP := req.Raw().URL.Query() + reqQP.Set("api-version", "2023-03-15-preview") + req.Raw().URL.RawQuery = reqQP.Encode() + req.Raw().Header["Accept"] = []string{"application/json"} + if err := runtime.MarshalAsJSON(req, body); err != nil { + return nil, err + } + return req, nil +} + +// getEmbeddingsHandleResponse handles the GetEmbeddings response. +func (client *Client) getEmbeddingsHandleResponse(resp *http.Response) (GetEmbeddingsResponse, error) { + result := GetEmbeddingsResponse{} + if err := runtime.UnmarshalAsJSON(resp, &result.Embeddings); err != nil { + return GetEmbeddingsResponse{}, err + } + return result, nil +} diff --git a/sdk/cognitiveservices/azopenai/client_shared_test.go b/sdk/cognitiveservices/azopenai/client_shared_test.go new file mode 100644 index 000000000000..c3d643709801 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client_shared_test.go @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "fmt" + "os" + "regexp" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" + "github.com/joho/godotenv" + "github.com/stretchr/testify/require" +) + +var ( + endpoint string // env: AOAI_ENDPOINT + apiKey string // env: AOAI_API_KEY + streamingModelDeployment string // env: AOAI_STREAMING_MODEL_DEPLOYMENT + + openAIKey string // env: OPENAI_API_KEY + openAIEndpoint string // env: OPENAI_ENDPOINT +) + +const fakeEndpoint = "https://recordedhost/" +const fakeAPIKey = "redacted" + +func init() { + if recording.GetRecordMode() == recording.PlaybackMode { + endpoint = fakeEndpoint + apiKey = fakeAPIKey + openAIKey = fakeAPIKey + openAIEndpoint = fakeEndpoint + streamingModelDeployment = "text-davinci-003" + } else { + if err := godotenv.Load(); err != nil { + fmt.Printf("Failed to load .env file: %s\n", err) + os.Exit(1) + } + + endpoint = os.Getenv("AOAI_ENDPOINT") + + if endpoint != "" && !strings.HasSuffix(endpoint, "/") { + // (this just makes recording replacement easier) + endpoint += "/" + } + + apiKey = os.Getenv("AOAI_API_KEY") + + // Ex: text-davinci-003 + streamingModelDeployment = os.Getenv("AOAI_STREAMING_MODEL_DEPLOYMENT") + + openAIKey = os.Getenv("OPENAI_API_KEY") + openAIEndpoint = os.Getenv("OPENAI_ENDPOINT") + + if openAIEndpoint != "" && !strings.HasSuffix(openAIEndpoint, "/") { + // (this just makes recording replacement easier) + openAIEndpoint += "/" + } + } +} + +func newRecordingTransporter(t *testing.T) policy.Transporter { + transport, err := recording.NewRecordingHTTPClient(t, nil) + require.NoError(t, err) + + err = recording.Start(t, "sdk/cognitiveservices/azopenai/testdata", nil) + require.NoError(t, err) + + if recording.GetRecordMode() != recording.PlaybackMode { + err = recording.AddHeaderRegexSanitizer("Api-Key", fakeAPIKey, "", nil) + require.NoError(t, err) + + // "RequestUri": "https://openai-shared.openai.azure.com/openai/deployments/text-davinci-003/completions?api-version=2023-03-15-preview", + err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(endpoint), nil) + require.NoError(t, err) + + if openAIEndpoint != "" { + err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(openAIEndpoint), nil) + require.NoError(t, err) + } + } + + t.Cleanup(func() { + err := recording.Stop(t, nil) + require.NoError(t, err) + }) + + return transport +} + +func newClientOptionsForTest(t *testing.T) *ClientOptions { + co := &ClientOptions{} + co.Transport = newRecordingTransporter(t) + return co +} diff --git a/sdk/cognitiveservices/azopenai/client_test.go b/sdk/cognitiveservices/azopenai/client_test.go new file mode 100644 index 000000000000..7c2335eaff1a --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client_test.go @@ -0,0 +1,337 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "context" + "log" + "net/http" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" +) + +func TestClient_GetChatCompletions(t *testing.T) { + deploymentID := "gpt-35-turbo" + + cred := KeyCredential{APIKey: apiKey} + chatClient, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t)) + require.NoError(t, err) + + testGetChatCompletions(t, chatClient, deploymentID) +} + +func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) { + if recording.GetRecordMode() == recording.PlaybackMode { + t.Skipf("Not running this test in playback (for now)") + } + + if os.Getenv("USE_TOKEN_CREDS") != "true" { + t.Skipf("USE_TOKEN_CREDS is not true, disabling token credential tests") + } + + deploymentID := "gpt-35-turbo" + + recordingTransporter := newRecordingTransporter(t) + + dac, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{ + ClientOptions: policy.ClientOptions{ + Transport: recordingTransporter, + }, + }) + require.NoError(t, err) + + chatClient, err := NewClient(endpoint, dac, deploymentID, &ClientOptions{ + ClientOptions: policy.ClientOptions{Transport: recordingTransporter}, + }) + require.NoError(t, err) + + testGetChatCompletions(t, chatClient, deploymentID) +} + +func TestClient_OpenAI_GetChatCompletions(t *testing.T) { + chatClient := newOpenAIClientForTest(t) + testGetChatCompletions(t, chatClient, "gpt-3.5-turbo") +} + +func TestClient_OpenAI_InvalidModel(t *testing.T) { + chatClient := newOpenAIClientForTest(t) + + _, err := chatClient.GetChatCompletions(context.Background(), ChatCompletionsOptions{ + Messages: []*ChatMessage{ + { + Role: to.Ptr(ChatRoleSystem), + Content: to.Ptr("hello"), + }, + }, + Model: to.Ptr("non-existent-model"), + }, nil) + + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, http.StatusNotFound, respErr.StatusCode) + require.Contains(t, respErr.Error(), "The model `non-existent-model` does not exist") +} + +func testGetChatCompletions(t *testing.T, chatClient *Client, modelOrDeployment string) { + type args struct { + ctx context.Context + deploymentID string + body ChatCompletionsOptions + options *GetChatCompletionsOptions + } + + tests := []struct { + name string + client *Client + args args + want GetChatCompletionsResponse + + wantErr bool + }{ + { + name: "ChatCompletions", + client: chatClient, + args: args{ + ctx: context.TODO(), + deploymentID: modelOrDeployment, + body: ChatCompletionsOptions{ + Messages: []*ChatMessage{ + { + Role: to.Ptr(ChatRole("user")), + Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."), + }, + }, + MaxTokens: to.Ptr(int32(1024)), + Temperature: to.Ptr(float32(0.0)), + Model: &modelOrDeployment, + }, + options: nil, + }, + want: GetChatCompletionsResponse{ + ChatCompletions: ChatCompletions{ + Choices: []*ChatChoice{ + { + Message: &ChatChoiceMessage{ + Role: to.Ptr(ChatRole("assistant")), + Content: to.Ptr("1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100."), + }, + Index: to.Ptr(int32(0)), + FinishReason: to.Ptr(CompletionsFinishReason("stop")), + }, + }, + Usage: &CompletionsUsage{ + CompletionTokens: to.Ptr(int32(299)), + PromptTokens: to.Ptr(int32(37)), + TotalTokens: to.Ptr(int32(336)), + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.client.GetChatCompletions(tt.args.ctx, tt.args.body, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetChatCompletions() error = %v, wantErr %v", err, tt.wantErr) + return + } + opts := cmpopts.IgnoreFields(ChatCompletions{}, "Created", "ID") + if diff := cmp.Diff(tt.want.ChatCompletions, got.ChatCompletions, opts); diff != "" { + t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) + } + }) + } +} + +func TestClient_GetChatCompletions_InvalidModel(t *testing.T) { + cred := KeyCredential{APIKey: apiKey} + chatClient, err := NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t)) + require.NoError(t, err) + + _, err = chatClient.GetChatCompletions(context.Background(), ChatCompletionsOptions{ + Messages: []*ChatMessage{ + { + Role: to.Ptr(ChatRole("user")), + Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."), + }, + }, + MaxTokens: to.Ptr(int32(1024)), + Temperature: to.Ptr(float32(0.0)), + }, nil) + + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, "DeploymentNotFound", respErr.ErrorCode) +} + +func TestClient_GetEmbeddings_InvalidModel(t *testing.T) { + cred := KeyCredential{APIKey: apiKey} + chatClient, err := NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t)) + require.NoError(t, err) + + _, err = chatClient.GetEmbeddings(context.Background(), EmbeddingsOptions{}, nil) + + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, "DeploymentNotFound", respErr.ErrorCode) +} + +func TestClient_GetCompletions(t *testing.T) { + type args struct { + ctx context.Context + deploymentID string + body CompletionsOptions + options *GetCompletionsOptions + } + cred := KeyCredential{APIKey: apiKey} + client, err := NewClientWithKeyCredential(endpoint, cred, streamingModelDeployment, newClientOptionsForTest(t)) + if err != nil { + log.Fatalf("%v", err) + } + tests := []struct { + name string + client *Client + args args + want GetCompletionsResponse + wantErr bool + }{ + { + name: "chatbot", + client: client, + args: args{ + ctx: context.TODO(), + deploymentID: streamingModelDeployment, + body: CompletionsOptions{ + Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + }, + options: nil, + }, + want: GetCompletionsResponse{ + Completions: Completions{ + Choices: []*Choice{ + { + Text: to.Ptr("\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."), + Index: to.Ptr(int32(0)), + FinishReason: to.Ptr(CompletionsFinishReason("stop")), + Logprobs: nil, + }, + }, + Usage: &CompletionsUsage{ + CompletionTokens: to.Ptr(int32(85)), + PromptTokens: to.Ptr(int32(6)), + TotalTokens: to.Ptr(int32(91)), + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.client.GetCompletions(tt.args.ctx, tt.args.body, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetCompletions() error = %v, wantErr %v", err, tt.wantErr) + return + } + opts := cmpopts.IgnoreFields(Completions{}, "Created", "ID") + if diff := cmp.Diff(tt.want.Completions, got.Completions, opts); diff != "" { + t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) + } + }) + } +} + +func TestClient_OpenAI_GetEmbeddings(t *testing.T) { + client := newOpenAIClientForTest(t) + modelID := "text-similarity-curie-001" + testGetEmbeddings(t, client, modelID) +} + +func TestClient_GetEmbeddings(t *testing.T) { + // model deployment points to `text-similarity-curie-001` + deploymentID := "embedding" + + cred := KeyCredential{APIKey: apiKey} + client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t)) + require.NoError(t, err) + + testGetEmbeddings(t, client, deploymentID) +} + +func testGetEmbeddings(t *testing.T, client *Client, modelOrDeploymentID string) { + type args struct { + ctx context.Context + deploymentID string + body EmbeddingsOptions + options *GetEmbeddingsOptions + } + + tests := []struct { + name string + client *Client + args args + want GetEmbeddingsResponse + wantErr bool + }{ + { + name: "Embeddings", + client: client, + args: args{ + ctx: context.TODO(), + deploymentID: modelOrDeploymentID, + body: EmbeddingsOptions{ + Input: []byte("\"Your text string goes here\""), + Model: &modelOrDeploymentID, + }, + options: nil, + }, + want: GetEmbeddingsResponse{ + Embeddings{ + Data: []*EmbeddingItem{}, + Usage: &EmbeddingsUsage{}, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.client.GetEmbeddings(tt.args.ctx, tt.args.body, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetEmbeddings() error = %v, wantErr %v", err, tt.wantErr) + return + } + if len(got.Embeddings.Data[0].Embedding) != 4096 { + t.Errorf("Client.GetEmbeddings() len(Data) want 4096, got %d", len(got.Embeddings.Data)) + return + } + }) + } +} + +func newOpenAIClientForTest(t *testing.T) *Client { + if openAIKey == "" { + t.Skipf("OPENAI_API_KEY not defined, skipping OpenAI public endpoint test") + } + + chatClient, err := NewClientForOpenAI(openAIEndpoint, KeyCredential{APIKey: openAIKey}, newClientOptionsForTest(t)) + require.NoError(t, err) + + return chatClient +} diff --git a/sdk/cognitiveservices/azopenai/constants.go b/sdk/cognitiveservices/azopenai/constants.go new file mode 100644 index 000000000000..b35f00d25a9a --- /dev/null +++ b/sdk/cognitiveservices/azopenai/constants.go @@ -0,0 +1,45 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. +// Code generated by Microsoft (R) AutoRest Code Generator. DO NOT EDIT. +// Changes may cause incorrect behavior and will be lost if the code is regenerated. + +package azopenai + +// ChatRole - A description of the intended purpose of a message within a chat completions interaction. +type ChatRole string + +const ( + ChatRoleAssistant ChatRole = "assistant" + ChatRoleSystem ChatRole = "system" + ChatRoleUser ChatRole = "user" +) + +// PossibleChatRoleValues returns the possible values for the ChatRole const type. +func PossibleChatRoleValues() []ChatRole { + return []ChatRole{ + ChatRoleAssistant, + ChatRoleSystem, + ChatRoleUser, + } +} + +// CompletionsFinishReason - Representation of the manner in which a completions response concluded. +type CompletionsFinishReason string + +const ( + CompletionsFinishReasonContentFilter CompletionsFinishReason = "content_filter" + CompletionsFinishReasonLength CompletionsFinishReason = "length" + CompletionsFinishReasonStop CompletionsFinishReason = "stop" +) + +// PossibleCompletionsFinishReasonValues returns the possible values for the CompletionsFinishReason const type. +func PossibleCompletionsFinishReasonValues() []CompletionsFinishReason { + return []CompletionsFinishReason{ + CompletionsFinishReasonContentFilter, + CompletionsFinishReasonLength, + CompletionsFinishReasonStop, + } +} diff --git a/sdk/cognitiveservices/azopenai/custom_client.go b/sdk/cognitiveservices/azopenai/custom_client.go new file mode 100644 index 000000000000..98780286a11a --- /dev/null +++ b/sdk/cognitiveservices/azopenai/custom_client.go @@ -0,0 +1,166 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +// this file contains handwritten additions to the generated code + +import ( + "context" + "encoding/json" + "net/http" + "net/url" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" +) + +const ( + clientName = "azopenai.Client" + apiVersion = "2023-03-15-preview" + tokenScope = "https://cognitiveservices.azure.com/.default" +) + +// Clients + +// ClientOptions contains optional settings for Client. +type ClientOptions struct { + azcore.ClientOptions +} + +// NewClient creates a new instance of Client that connects to an Azure OpenAI endpoint. +// - endpoint - Azure OpenAI service endpoint, for example: https://{your-resource-name}.openai.azure.com +// - credential - used to authorize requests. Usually a credential from [github.com/Azure/azure-sdk-for-go/sdk/azidentity]. +// - deploymentID - the deployment ID of the model to query +// - options - client options, pass nil to accept the default values. +func NewClient(endpoint string, credential azcore.TokenCredential, deploymentID string, options *ClientOptions) (*Client, error) { + if options == nil { + options = &ClientOptions{} + } + + authPolicy := runtime.NewBearerTokenPolicy(credential, []string{tokenScope}, nil) + azcoreClient, err := azcore.NewClient(clientName, version, runtime.PipelineOptions{PerRetry: []policy.Policy{authPolicy}}, &options.ClientOptions) + if err != nil { + return nil, err + } + + fullEndpoint := formatAzureOpenAIURL(endpoint, deploymentID) + return &Client{endpoint: fullEndpoint, internal: azcoreClient}, nil +} + +// NewClientWithKeyCredential creates a new instance of Client that connects to an Azure OpenAI endpoint. +// - endpoint - Azure OpenAI service endpoint, for example: https://{your-resource-name}.openai.azure.com +// - credential - used to authorize requests with an API Key credential +// - deploymentID - the deployment ID of the model to query +// - options - client options, pass nil to accept the default values. +func NewClientWithKeyCredential(endpoint string, credential KeyCredential, deploymentID string, options *ClientOptions) (*Client, error) { + if options == nil { + options = &ClientOptions{} + } + + authPolicy := newAPIKeyPolicy(credential, "api-key") + azcoreClient, err := azcore.NewClient(clientName, version, runtime.PipelineOptions{PerRetry: []policy.Policy{authPolicy}}, &options.ClientOptions) + if err != nil { + return nil, err + } + + fullEndpoint := formatAzureOpenAIURL(endpoint, deploymentID) + return &Client{endpoint: fullEndpoint, internal: azcoreClient}, nil +} + +// NewClientForOpenAI creates a new instance of Client which connects to the public OpenAI endpoint. +// - endpoint - OpenAI service endpoint, for example: https://api.openai.com/v1 +// - credential - used to authorize requests with an API Key credential +// - options - client options, pass nil to accept the default values. +func NewClientForOpenAI(endpoint string, credential KeyCredential, options *ClientOptions) (*Client, error) { + if options == nil { + options = &ClientOptions{} + } + openAIPolicy := newOpenAIPolicy(credential) + azcoreClient, err := azcore.NewClient(clientName, version, runtime.PipelineOptions{PerRetry: []policy.Policy{openAIPolicy}}, &options.ClientOptions) + if err != nil { + return nil, err + } + return &Client{endpoint: endpoint, internal: azcoreClient}, nil +} + +// openAIPolicy is an internal pipeline policy to remove the api-version query parameter +type openAIPolicy struct { + cred KeyCredential +} + +// newOpenAIPolicy creates a new instance of openAIPolicy. +// cred: a KeyCredential implementation. +func newOpenAIPolicy(cred KeyCredential) *openAIPolicy { + return &openAIPolicy{cred: cred} +} + +// Do returns a function which adapts a request to target OpenAI. +// Specifically, it removes the api-version query parameter. +func (b *openAIPolicy) Do(req *policy.Request) (*http.Response, error) { + q := req.Raw().URL.Query() + q.Del("api-version") + req.Raw().Header.Set("authorization", "Bearer "+b.cred.APIKey) + return req.Next() +} + +// Methods that return streaming response + +type streamCompletionsOptions struct { + CompletionsOptions + Stream bool `json:"stream"` +} + +func (o streamCompletionsOptions) MarshalJSON() ([]byte, error) { + bytes, err := o.CompletionsOptions.MarshalJSON() + if err != nil { + return nil, err + } + objectMap := make(map[string]any) + err = json.Unmarshal(bytes, &objectMap) + if err != nil { + return nil, err + } + objectMap["stream"] = o.Stream + return json.Marshal(objectMap) +} + +// GetCompletionsStream - Return the completions for a given prompt as a sequence of events. +// If the operation fails it returns an *azcore.ResponseError type. +// - options - GetCompletionsOptions contains the optional parameters for the Client.GetCompletions method. +func (client *Client) GetCompletionsStream(ctx context.Context, body CompletionsOptions, options *GetCompletionsStreamOptions) (GetCompletionsStreamResponse, error) { + req, err := client.getCompletionsCreateRequest(ctx, CompletionsOptions{}, &GetCompletionsOptions{}) + + if err != nil { + return GetCompletionsStreamResponse{}, err + } + + if err := runtime.MarshalAsJSON(req, streamCompletionsOptions{body, true}); err != nil { + return GetCompletionsStreamResponse{}, err + } + + runtime.SkipBodyDownload(req) + + resp, err := client.internal.Pipeline().Do(req) + + if err != nil { + return GetCompletionsStreamResponse{}, err + } + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return GetCompletionsStreamResponse{}, runtime.NewResponseError(resp) + } + + return GetCompletionsStreamResponse{ + CompletionsStream: newEventReader[Completions](resp.Body), + }, nil +} + +func formatAzureOpenAIURL(endpoint, deploymentID string) string { + escapedDeplID := url.PathEscape(deploymentID) + return runtime.JoinPaths(endpoint, "openai", "deployments", escapedDeplID) +} diff --git a/sdk/cognitiveservices/azopenai/custom_client_test.go b/sdk/cognitiveservices/azopenai/custom_client_test.go new file mode 100644 index 000000000000..412e04169310 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/custom_client_test.go @@ -0,0 +1,125 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "context" + "io" + "reflect" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" +) + +func TestNewClient(t *testing.T) { + type args struct { + endpoint string + credential azcore.TokenCredential + deploymentID string + options *ClientOptions + } + tests := []struct { + name string + args args + want *Client + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewClient(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewClient() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewClientWithKeyCredential(t *testing.T) { + type args struct { + endpoint string + credential KeyCredential + deploymentID string + options *ClientOptions + } + tests := []struct { + name string + args args + want *Client + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewClientWithKeyCredential(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("NewClientWithKeyCredential() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewClientWithKeyCredential() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_GetCompletionsStream(t *testing.T) { + body := CompletionsOptions{ + Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + } + cred := KeyCredential{APIKey: apiKey} + + client, err := NewClientWithKeyCredential(endpoint, cred, streamingModelDeployment, newClientOptionsForTest(t)) + if err != nil { + t.Errorf("NewClientWithKeyCredential() error = %v", err) + return + } + response, err := client.GetCompletionsStream(context.TODO(), body, nil) + if err != nil { + t.Errorf("Client.GetCompletionsStream() error = %v", err) + return + } + reader := response.CompletionsStream + defer reader.Close() + + var sb strings.Builder + var eventCount int + for { + event, err := reader.Read() + if err == io.EOF { + break + } + eventCount++ + if err != nil { + t.Errorf("reader.Read() error = %v", err) + return + } + sb.WriteString(*event.Choices[0].Text) + } + got := sb.String() + const want = "\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models." + if got != want { + i := 0 + for i < len(got) && i < len(want) && got[i] == want[i] { + i++ + } + t.Errorf("Client.GetCompletionsStream() text[%d] = %c, want %c", i, got[i], want[i]) + } + if eventCount != 86 { + t.Errorf("Client.GetCompletionsStream() got = %v, want %v", eventCount, 1) + } +} diff --git a/sdk/cognitiveservices/azopenai/custom_models.go b/sdk/cognitiveservices/azopenai/custom_models.go new file mode 100644 index 000000000000..fb320bebfb02 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/custom_models.go @@ -0,0 +1,20 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +// Models for methods that return streaming response + +// GetCompletionsStreamOptions contains the optional parameters for the Client.GetCompletions method. +type GetCompletionsStreamOptions struct { + // placeholder for future optional parameters +} + +// GetCompletionsStreamResponse is the response from [GetCompletionsStream]. +type GetCompletionsStreamResponse struct { + // CompletionsStream returns the stream of completions. Token limits and other settings may limit the number of completions returned by the service. + CompletionsStream *EventReader[Completions] +} diff --git a/sdk/cognitiveservices/azopenai/event_reader.go b/sdk/cognitiveservices/azopenai/event_reader.go new file mode 100644 index 000000000000..c98b74ddd6af --- /dev/null +++ b/sdk/cognitiveservices/azopenai/event_reader.go @@ -0,0 +1,64 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "bufio" + "encoding/json" + "errors" + "io" + "strings" +) + +// EventReader streams events dynamically from an OpenAI endpoint. +type EventReader[T any] struct { + reader io.Reader // Required for Closing + scanner *bufio.Scanner +} + +func newEventReader[T any](r io.Reader) *EventReader[T] { + return &EventReader[T]{reader: r, scanner: bufio.NewScanner(r)} +} + +// Read reads the next event from the stream. +// Returns io.EOF when there are no further events. +func (er *EventReader[T]) Read() (T, error) { + // https://html.spec.whatwg.org/multipage/server-sent-events.html + for er.scanner.Scan() { // Scan while no error + line := er.scanner.Text() // Get the line & interpret the event stream: + + if line == "" || line[0] == ':' { // If the line is blank or is a comment, skip it + continue + } + + if strings.Contains(line, ":") { // If the line contains a U+003A COLON character (:), process the field + tokens := strings.SplitN(line, ":", 2) + tokens[0], tokens[1] = strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1]) + var data T + switch tokens[0] { + case "data": // return the deserialized JSON object + if tokens[1] == "[DONE]" { // If data is [DONE], end of stream was reached + return data, io.EOF + } + err := json.Unmarshal([]byte(tokens[1]), &data) + return data, err + + default: // Any other event type is an unexpected + return data, errors.New("Unexpected event type: " + tokens[0]) + } + // Unreachable + } + } + return *new(T), er.scanner.Err() +} + +// Close closes the EventReader and any applicable inner stream state. +func (er *EventReader[T]) Close() { + if closer, ok := er.reader.(io.Closer); ok { + closer.Close() + } +} diff --git a/sdk/cognitiveservices/azopenai/examples_client_test.go b/sdk/cognitiveservices/azopenai/examples_client_test.go new file mode 100644 index 000000000000..ea9ebf1e7ef0 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/examples_client_test.go @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai_test + +import ( + "context" + "errors" + "fmt" + "io" + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" +) + +func ExampleNewClientForOpenAI() { + // NOTE: this constructor creates a client that connects to the public OpenAI endpoint. + // To connect to an Azure OpenAI endpoint, use azopenai.NewClient() or azopenai.NewClientWithyKeyCredential. + keyCredential := azopenai.KeyCredential{ + APIKey: "open-ai-apikey", + } + + client, err := azopenai.NewClientForOpenAI("https://api.openai.com/v1", keyCredential, nil) + + if err != nil { + panic(err) + } + + _ = client +} + +func ExampleNewClient() { + // NOTE: this constructor creates a client that connects to an Azure OpenAI endpoint. + // To connect to the public OpenAI endpoint, use azopenai.NewClientForOpenAI + dac, err := azidentity.NewDefaultAzureCredential(nil) + + if err != nil { + panic(err) + } + + modelDeploymentID := "model deployment ID" + client, err := azopenai.NewClient("https://.openai.azure.com", dac, modelDeploymentID, nil) + + if err != nil { + panic(err) + } + + _ = client +} + +func ExampleNewClientWithKeyCredential() { + // NOTE: this constructor creates a client that connects to an Azure OpenAI endpoint. + // To connect to the public OpenAI endpoint, use azopenai.NewClientForOpenAI + keyCredential := azopenai.KeyCredential{ + APIKey: "Azure OpenAI apikey", + } + + modelDeploymentID := "model deployment ID" + client, err := azopenai.NewClientWithKeyCredential("https://.openai.azure.com", keyCredential, modelDeploymentID, nil) + + if err != nil { + panic(err) + } + + _ = client +} + +func ExampleClient_GetCompletionsStream() { + azureOpenAIKey := os.Getenv("AOAI_API_KEY") + modelDeploymentID := os.Getenv("AOAI_STREAMING_MODEL_DEPLOYMENT") + + // Ex: "https://.openai.azure.com" + azureOpenAIEndpoint := os.Getenv("AOAI_ENDPOINT") + + if azureOpenAIKey == "" || modelDeploymentID == "" || azureOpenAIEndpoint == "" { + return + } + + keyCredential := azopenai.KeyCredential{ + APIKey: azureOpenAIKey, + } + + client, err := azopenai.NewClientWithKeyCredential(azureOpenAIEndpoint, keyCredential, modelDeploymentID, nil) + + if err != nil { + panic(err) + } + + resp, err := client.GetCompletionsStream(context.TODO(), azopenai.CompletionsOptions{ + Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + }, nil) + + if err != nil { + panic(err) + } + + for { + entry, err := resp.CompletionsStream.Read() + + if errors.Is(err, io.EOF) { + fmt.Printf("More more completions") + break + } + + if err != nil { + panic(err) + } + + for _, choice := range entry.Choices { + fmt.Printf("%s", *choice.Text) + } + } +} diff --git a/sdk/cognitiveservices/azopenai/go.mod b/sdk/cognitiveservices/azopenai/go.mod new file mode 100644 index 000000000000..710379f09aa8 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/go.mod @@ -0,0 +1,29 @@ +module github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai + +go 1.18 + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 + github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 + github.com/google/go-cmp v0.5.9 + github.com/joho/godotenv v1.3.0 + github.com/stretchr/testify v1.7.0 +) + +require ( + github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dnaeon/go-vcr v1.2.0 // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.7.0 // indirect + golang.org/x/net v0.8.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/text v0.8.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/sdk/cognitiveservices/azopenai/go.sum b/sdk/cognitiveservices/azopenai/go.sum new file mode 100644 index 000000000000..d3d94d1ae54a --- /dev/null +++ b/sdk/cognitiveservices/azopenai/go.sum @@ -0,0 +1,48 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1 h1:SEy2xmstIphdPwNBUi7uhvjyjhVKISfwjfOJmuy7kg4= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sdk/cognitiveservices/azopenai/models.go b/sdk/cognitiveservices/azopenai/models.go new file mode 100644 index 000000000000..42c0ddb088b3 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/models.go @@ -0,0 +1,353 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. +// Code generated by Microsoft (R) AutoRest Code Generator. DO NOT EDIT. +// Changes may cause incorrect behavior and will be lost if the code is regenerated. + +package azopenai + +// ChatChoice - The representation of a single prompt completion as part of an overall chat completions request. Generally, +// n choices are generated per provided prompt with a default value of 1. Token limits and +// other settings may limit the number of choices generated. +type ChatChoice struct { + // REQUIRED; The reason that this chat completions choice completed its generated. + FinishReason *CompletionsFinishReason + + // REQUIRED; The ordered index associated with this chat completions choice. + Index *int32 + + // The delta message content for a streaming response. + Delta *ChatChoiceDelta + + // The chat message for a given chat completions prompt. + Message *ChatChoiceMessage +} + +// ChatChoiceDelta - The delta message content for a streaming response. +type ChatChoiceDelta struct { + // REQUIRED; The role associated with this message payload. + Role *ChatRole + + // The text associated with this message payload. + Content *string +} + +// ChatChoiceMessage - The chat message for a given chat completions prompt. +type ChatChoiceMessage struct { + // REQUIRED; The role associated with this message payload. + Role *ChatRole + + // The text associated with this message payload. + Content *string +} + +// ChatCompletions - Representation of the response data from a chat completions request. Completions support a wide variety +// of tasks and generate text that continues from or "completes" provided prompt data. +type ChatCompletions struct { + // REQUIRED; The collection of completions choices associated with this completions response. Generally, n choices are generated + // per provided prompt with a default value of 1. Token limits and other settings may + // limit the number of choices generated. + Choices []*ChatChoice + + // REQUIRED; The first timestamp associated with generation activity for this completions response, represented as seconds + // since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. + Created *int32 + + // REQUIRED; A unique identifier associated with this chat completions response. + ID *string + + // REQUIRED; Usage information for tokens processed and generated as part of this completions operation. + Usage *CompletionsUsage +} + +// ChatCompletionsOptions - The configuration information for a chat completions request. Completions support a wide variety +// of tasks and generate text that continues from or "completes" provided prompt data. +type ChatCompletionsOptions struct { + // REQUIRED; The collection of context messages associated with this chat completions request. Typical usage begins with a + // chat message for the System role that provides instructions for the behavior of the + // assistant, followed by alternating messages between the User and Assistant roles. + Messages []*ChatMessage + + // A value that influences the probability of generated tokens appearing based on their cumulative frequency in generated + // text. Positive values will make tokens less likely to appear as their frequency + // increases and decrease the likelihood of the model repeating the same statements verbatim. + FrequencyPenalty *float32 + + // A map between GPT token IDs and bias scores that influences the probability of specific tokens appearing in a completions + // response. Token IDs are computed via external tokenizer tools, while bias + // scores reside in the range of -100 to 100 with minimum and maximum values corresponding to a full ban or exclusive selection + // of a token, respectively. The exact behavior of a given bias score varies + // by model. + LogitBias map[string]*int32 + + // The maximum number of tokens to generate. + MaxTokens *int32 + + // The model name to provide as part of this completions request. Not applicable to Azure OpenAI, where deployment information + // should be included in the Azure resource URI that's connected to. + Model *string + + // The number of chat completions choices that should be generated for a chat completions response. Because this setting can + // generate many completions, it may quickly consume your token quota. Use + // carefully and ensure reasonable settings for max_tokens and stop. + N *int32 + + // A value that influences the probability of generated tokens appearing based on their existing presence in generated text. + // Positive values will make tokens less likely to appear when they already exist + // and increase the model's likelihood to output new topics. + PresencePenalty *float32 + + // A collection of textual sequences that will end completions generation. + Stop []*string + + // The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make + // output more random while lower values will make results more focused and + // deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction + // of these two settings is difficult to predict. + Temperature *float32 + + // An alternative to sampling with temperature called nucleus sampling. This value causes the model to consider the results + // of tokens with the provided probability mass. As an example, a value of 0.15 + // will cause only the tokens comprising the top 15% of probability mass to be considered. It is not recommended to modify + // temperature and top_p for the same completions request as the interaction of + // these two settings is difficult to predict. + TopP *float32 + + // An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. + User *string +} + +// ChatMessage - A single, role-attributed message within a chat completion interaction. +type ChatMessage struct { + // REQUIRED; The role associated with this message payload. + Role *ChatRole + + // The text associated with this message payload. + Content *string +} + +// Choice - The representation of a single prompt completion as part of an overall completions request. Generally, n choices +// are generated per provided prompt with a default value of 1. Token limits and other +// settings may limit the number of choices generated. +type Choice struct { + // REQUIRED; Reason for finishing + FinishReason *CompletionsFinishReason + + // REQUIRED; The ordered index associated with this completions choice. + Index *int32 + + // REQUIRED; The log probabilities model for tokens associated with this completions choice. + Logprobs *ChoiceLogprobs + + // REQUIRED; The generated text for a given completions prompt. + Text *string +} + +// ChoiceLogprobs - The log probabilities model for tokens associated with this completions choice. +type ChoiceLogprobs struct { + // REQUIRED; The text offsets associated with tokens in this completions data. + TextOffset []*int32 + + // REQUIRED; A collection of log probability values for the tokens in this completions data. + TokenLogprobs []*float32 + + // REQUIRED; The textual forms of tokens evaluated in this probability model. + Tokens []*string + + // REQUIRED; A mapping of tokens to maximum log probability values in this completions data. + TopLogprobs []any +} + +// GetChatCompletionsOptions contains the optional parameters for the Client.GetChatCompletions method. +type GetChatCompletionsOptions struct { + // placeholder for future optional parameters +} + +// GetCompletionsOptions contains the optional parameters for the Client.GetCompletions method. +type GetCompletionsOptions struct { + // placeholder for future optional parameters +} + +// GetEmbeddingsOptions contains the optional parameters for the Client.GetEmbeddings method. +type GetEmbeddingsOptions struct { + // placeholder for future optional parameters +} + +// Completions - Representation of the response data from a completions request. Completions support a wide variety of tasks +// and generate text that continues from or "completes" provided prompt data. +type Completions struct { + // REQUIRED; The collection of completions choices associated with this completions response. Generally, n choices are generated + // per provided prompt with a default value of 1. Token limits and other settings may + // limit the number of choices generated. + Choices []*Choice + + // REQUIRED; The first timestamp associated with generation activity for this completions response, represented as seconds + // since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. + Created *int32 + + // REQUIRED; A unique identifier associated with this completions response. + ID *string + + // REQUIRED; Usage information for tokens processed and generated as part of this completions operation. + Usage *CompletionsUsage +} + +// CompletionsLogProbabilityModel - Representation of a log probabilities model for a completions generation. +type CompletionsLogProbabilityModel struct { + // REQUIRED; The text offsets associated with tokens in this completions data. + TextOffset []*int32 + + // REQUIRED; A collection of log probability values for the tokens in this completions data. + TokenLogprobs []*float32 + + // REQUIRED; The textual forms of tokens evaluated in this probability model. + Tokens []*string + + // REQUIRED; A mapping of tokens to maximum log probability values in this completions data. + TopLogprobs []any +} + +// CompletionsOptions - The configuration information for a completions request. Completions support a wide variety of tasks +// and generate text that continues from or "completes" provided prompt data. +type CompletionsOptions struct { + // REQUIRED; The prompts to generate completions from. + Prompt []*string + + // A value that controls how many completions will be internally generated prior to response formulation. When used together + // with n, bestof controls the number of candidate completions and must be + // greater than n. Because this setting can generate many completions, it may quickly consume your token quota. Use carefully + // and ensure reasonable settings for maxtokens and stop. + BestOf *int32 + + // A value specifying whether completions responses should include input prompts as prefixes to their generated output. + Echo *bool + + // A value that influences the probability of generated tokens appearing based on their cumulative frequency in generated + // text. Positive values will make tokens less likely to appear as their frequency + // increases and decrease the likelihood of the model repeating the same statements verbatim. + FrequencyPenalty *float32 + + // A map between GPT token IDs and bias scores that influences the probability of specific tokens appearing in a completions + // response. Token IDs are computed via external tokenizer tools, while bias + // scores reside in the range of -100 to 100 with minimum and maximum values corresponding to a full ban or exclusive selection + // of a token, respectively. The exact behavior of a given bias score varies + // by model. + LogitBias map[string]*int32 + + // A value that controls the emission of log probabilities for the provided number of most likely tokens within a completions + // response. + Logprobs *int32 + + // The maximum number of tokens to generate. + MaxTokens *int32 + + // The model name to provide as part of this completions request. Not applicable to Azure OpenAI, where deployment information + // should be included in the Azure resource URI that's connected to. + Model *string + + // The number of completions choices that should be generated per provided prompt as part of an overall completions response. + // Because this setting can generate many completions, it may quickly consume + // your token quota. Use carefully and ensure reasonable settings for max_tokens and stop. + N *int32 + + // A value that influences the probability of generated tokens appearing based on their existing presence in generated text. + // Positive values will make tokens less likely to appear when they already exist + // and increase the model's likelihood to output new topics. + PresencePenalty *float32 + + // A collection of textual sequences that will end completions generation. + Stop []*string + + // The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make + // output more random while lower values will make results more focused and + // deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction + // of these two settings is difficult to predict. + Temperature *float32 + + // An alternative to sampling with temperature called nucleus sampling. This value causes the model to consider the results + // of tokens with the provided probability mass. As an example, a value of 0.15 + // will cause only the tokens comprising the top 15% of probability mass to be considered. It is not recommended to modify + // temperature and top_p for the same completions request as the interaction of + // these two settings is difficult to predict. + TopP *float32 + + // An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. + User *string +} + +// CompletionsUsage - Representation of the token counts processed for a completions request. Counts consider all tokens across +// prompts, choices, choice alternates, best_of generations, and other consumers. +type CompletionsUsage struct { + // REQUIRED; The number of tokens generated across all completions emissions. + CompletionTokens *int32 + + // REQUIRED; The number of tokens in the provided prompts for the completions request. + PromptTokens *int32 + + // REQUIRED; The total number of tokens processed for the completions request and response. + TotalTokens *int32 +} + +// Deployment - A specific deployment +type Deployment struct { + // READ-ONLY; deployment id of the deployed model + DeploymentID *string +} + +// EmbeddingItem - Representation of a single embeddings relatedness comparison. +type EmbeddingItem struct { + // REQUIRED; List of embeddings value for the input prompt. These represent a measurement of the vector-based relatedness + // of the provided input. + Embedding []*float32 + + // REQUIRED; Index of the prompt to which the EmbeddingItem corresponds. + Index *int32 +} + +// Embeddings - Representation of the response data from an embeddings request. Embeddings measure the relatedness of text +// strings and are commonly used for search, clustering, recommendations, and other similar +// scenarios. +type Embeddings struct { + // REQUIRED; Embedding values for the prompts submitted in the request. + Data []*EmbeddingItem + + // REQUIRED; Usage counts for tokens input using the embeddings API. + Usage *EmbeddingsUsage +} + +// EmbeddingsOptions - The configuration information for an embeddings request. Embeddings measure the relatedness of text +// strings and are commonly used for search, clustering, recommendations, and other similar scenarios. +type EmbeddingsOptions struct { + // REQUIRED; Input text to get embeddings for, encoded as a string. To get embeddings for multiple inputs in a single request, + // pass an array of strings. Each input must not exceed 2048 tokens in length. + // Unless you are embedding code, we suggest replacing newlines (\n) in your input with a single space, as we have observed + // inferior results when newlines are present. + Input any + + // The model name to provide as part of this embeddings request. Not applicable to Azure OpenAI, where deployment information + // should be included in the Azure resource URI that's connected to. + Model *string + + // An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. + User *string +} + +// EmbeddingsUsage - Usage counts for tokens input using the embeddings API. +type EmbeddingsUsage struct { + // REQUIRED; Number of tokens sent in the original request. + PromptTokens *int32 + + // REQUIRED; Total number of tokens transacted in this request/response. + TotalTokens *int32 +} + +// EmbeddingsUsageAutoGenerated - Measurement of the amount of tokens used in this request and response. +type EmbeddingsUsageAutoGenerated struct { + // REQUIRED; Number of tokens sent in the original request. + PromptTokens *int32 + + // REQUIRED; Total number of tokens transacted in this request/response. + TotalTokens *int32 +} diff --git a/sdk/cognitiveservices/azopenai/models_serde.go b/sdk/cognitiveservices/azopenai/models_serde.go new file mode 100644 index 000000000000..a757d9b7f13c --- /dev/null +++ b/sdk/cognitiveservices/azopenai/models_serde.go @@ -0,0 +1,741 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. +// Code generated by Microsoft (R) AutoRest Code Generator. DO NOT EDIT. +// Changes may cause incorrect behavior and will be lost if the code is regenerated. + +package azopenai + +import ( + "encoding/json" + "fmt" + "reflect" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// MarshalJSON implements the json.Marshaller interface for type ChatChoice. +func (c ChatChoice) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "delta", c.Delta) + populate(objectMap, "finish_reason", c.FinishReason) + populate(objectMap, "index", c.Index) + populate(objectMap, "message", c.Message) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatChoice. +func (c *ChatChoice) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "delta": + err = unpopulate(val, "Delta", &c.Delta) + delete(rawMsg, key) + case "finish_reason": + err = unpopulate(val, "FinishReason", &c.FinishReason) + delete(rawMsg, key) + case "index": + err = unpopulate(val, "Index", &c.Index) + delete(rawMsg, key) + case "message": + err = unpopulate(val, "Message", &c.Message) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ChatChoiceDelta. +func (c ChatChoiceDelta) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "content", c.Content) + populate(objectMap, "role", c.Role) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatChoiceDelta. +func (c *ChatChoiceDelta) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "content": + err = unpopulate(val, "Content", &c.Content) + delete(rawMsg, key) + case "role": + err = unpopulate(val, "Role", &c.Role) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ChatChoiceMessage. +func (c ChatChoiceMessage) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "content", c.Content) + populate(objectMap, "role", c.Role) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatChoiceMessage. +func (c *ChatChoiceMessage) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "content": + err = unpopulate(val, "Content", &c.Content) + delete(rawMsg, key) + case "role": + err = unpopulate(val, "Role", &c.Role) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ChatCompletions. +func (c ChatCompletions) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "choices", c.Choices) + populate(objectMap, "created", c.Created) + populate(objectMap, "id", c.ID) + populate(objectMap, "usage", c.Usage) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatCompletions. +func (c *ChatCompletions) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "choices": + err = unpopulate(val, "Choices", &c.Choices) + delete(rawMsg, key) + case "created": + err = unpopulate(val, "Created", &c.Created) + delete(rawMsg, key) + case "id": + err = unpopulate(val, "ID", &c.ID) + delete(rawMsg, key) + case "usage": + err = unpopulate(val, "Usage", &c.Usage) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsOptions. +func (c ChatCompletionsOptions) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "frequency_penalty", c.FrequencyPenalty) + populate(objectMap, "logit_bias", c.LogitBias) + populate(objectMap, "max_tokens", c.MaxTokens) + populate(objectMap, "messages", c.Messages) + populate(objectMap, "model", c.Model) + populate(objectMap, "n", c.N) + populate(objectMap, "presence_penalty", c.PresencePenalty) + populate(objectMap, "stop", c.Stop) + populate(objectMap, "temperature", c.Temperature) + populate(objectMap, "top_p", c.TopP) + populate(objectMap, "user", c.User) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatCompletionsOptions. +func (c *ChatCompletionsOptions) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "frequency_penalty": + err = unpopulate(val, "FrequencyPenalty", &c.FrequencyPenalty) + delete(rawMsg, key) + case "logit_bias": + err = unpopulate(val, "LogitBias", &c.LogitBias) + delete(rawMsg, key) + case "max_tokens": + err = unpopulate(val, "MaxTokens", &c.MaxTokens) + delete(rawMsg, key) + case "messages": + err = unpopulate(val, "Messages", &c.Messages) + delete(rawMsg, key) + case "model": + err = unpopulate(val, "Model", &c.Model) + delete(rawMsg, key) + case "n": + err = unpopulate(val, "N", &c.N) + delete(rawMsg, key) + case "presence_penalty": + err = unpopulate(val, "PresencePenalty", &c.PresencePenalty) + delete(rawMsg, key) + case "stop": + err = unpopulate(val, "Stop", &c.Stop) + delete(rawMsg, key) + case "temperature": + err = unpopulate(val, "Temperature", &c.Temperature) + delete(rawMsg, key) + case "top_p": + err = unpopulate(val, "TopP", &c.TopP) + delete(rawMsg, key) + case "user": + err = unpopulate(val, "User", &c.User) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ChatMessage. +func (c ChatMessage) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "content", c.Content) + populate(objectMap, "role", c.Role) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChatMessage. +func (c *ChatMessage) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "content": + err = unpopulate(val, "Content", &c.Content) + delete(rawMsg, key) + case "role": + err = unpopulate(val, "Role", &c.Role) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type Choice. +func (c Choice) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "finish_reason", c.FinishReason) + populate(objectMap, "index", c.Index) + populate(objectMap, "logprobs", c.Logprobs) + populate(objectMap, "text", c.Text) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type Choice. +func (c *Choice) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "finish_reason": + err = unpopulate(val, "FinishReason", &c.FinishReason) + delete(rawMsg, key) + case "index": + err = unpopulate(val, "Index", &c.Index) + delete(rawMsg, key) + case "logprobs": + err = unpopulate(val, "Logprobs", &c.Logprobs) + delete(rawMsg, key) + case "text": + err = unpopulate(val, "Text", &c.Text) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type ChoiceLogprobs. +func (c ChoiceLogprobs) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "text_offset", c.TextOffset) + populate(objectMap, "token_logprobs", c.TokenLogprobs) + populate(objectMap, "tokens", c.Tokens) + populate(objectMap, "top_logprobs", c.TopLogprobs) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type ChoiceLogprobs. +func (c *ChoiceLogprobs) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "text_offset": + err = unpopulate(val, "TextOffset", &c.TextOffset) + delete(rawMsg, key) + case "token_logprobs": + err = unpopulate(val, "TokenLogprobs", &c.TokenLogprobs) + delete(rawMsg, key) + case "tokens": + err = unpopulate(val, "Tokens", &c.Tokens) + delete(rawMsg, key) + case "top_logprobs": + err = unpopulate(val, "TopLogprobs", &c.TopLogprobs) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type Completions. +func (c Completions) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "choices", c.Choices) + populate(objectMap, "created", c.Created) + populate(objectMap, "id", c.ID) + populate(objectMap, "usage", c.Usage) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type Completions. +func (c *Completions) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "choices": + err = unpopulate(val, "Choices", &c.Choices) + delete(rawMsg, key) + case "created": + err = unpopulate(val, "Created", &c.Created) + delete(rawMsg, key) + case "id": + err = unpopulate(val, "ID", &c.ID) + delete(rawMsg, key) + case "usage": + err = unpopulate(val, "Usage", &c.Usage) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type CompletionsLogProbabilityModel. +func (c CompletionsLogProbabilityModel) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "text_offset", c.TextOffset) + populate(objectMap, "token_logprobs", c.TokenLogprobs) + populate(objectMap, "tokens", c.Tokens) + populate(objectMap, "top_logprobs", c.TopLogprobs) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type CompletionsLogProbabilityModel. +func (c *CompletionsLogProbabilityModel) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "text_offset": + err = unpopulate(val, "TextOffset", &c.TextOffset) + delete(rawMsg, key) + case "token_logprobs": + err = unpopulate(val, "TokenLogprobs", &c.TokenLogprobs) + delete(rawMsg, key) + case "tokens": + err = unpopulate(val, "Tokens", &c.Tokens) + delete(rawMsg, key) + case "top_logprobs": + err = unpopulate(val, "TopLogprobs", &c.TopLogprobs) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type CompletionsOptions. +func (c CompletionsOptions) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "best_of", c.BestOf) + populate(objectMap, "echo", c.Echo) + populate(objectMap, "frequency_penalty", c.FrequencyPenalty) + populate(objectMap, "logit_bias", c.LogitBias) + populate(objectMap, "logprobs", c.Logprobs) + populate(objectMap, "max_tokens", c.MaxTokens) + populate(objectMap, "model", c.Model) + populate(objectMap, "n", c.N) + populate(objectMap, "presence_penalty", c.PresencePenalty) + populate(objectMap, "prompt", c.Prompt) + populate(objectMap, "stop", c.Stop) + populate(objectMap, "temperature", c.Temperature) + populate(objectMap, "top_p", c.TopP) + populate(objectMap, "user", c.User) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type CompletionsOptions. +func (c *CompletionsOptions) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "best_of": + err = unpopulate(val, "BestOf", &c.BestOf) + delete(rawMsg, key) + case "echo": + err = unpopulate(val, "Echo", &c.Echo) + delete(rawMsg, key) + case "frequency_penalty": + err = unpopulate(val, "FrequencyPenalty", &c.FrequencyPenalty) + delete(rawMsg, key) + case "logit_bias": + err = unpopulate(val, "LogitBias", &c.LogitBias) + delete(rawMsg, key) + case "logprobs": + err = unpopulate(val, "Logprobs", &c.Logprobs) + delete(rawMsg, key) + case "max_tokens": + err = unpopulate(val, "MaxTokens", &c.MaxTokens) + delete(rawMsg, key) + case "model": + err = unpopulate(val, "Model", &c.Model) + delete(rawMsg, key) + case "n": + err = unpopulate(val, "N", &c.N) + delete(rawMsg, key) + case "presence_penalty": + err = unpopulate(val, "PresencePenalty", &c.PresencePenalty) + delete(rawMsg, key) + case "prompt": + err = unpopulate(val, "Prompt", &c.Prompt) + delete(rawMsg, key) + case "stop": + err = unpopulate(val, "Stop", &c.Stop) + delete(rawMsg, key) + case "temperature": + err = unpopulate(val, "Temperature", &c.Temperature) + delete(rawMsg, key) + case "top_p": + err = unpopulate(val, "TopP", &c.TopP) + delete(rawMsg, key) + case "user": + err = unpopulate(val, "User", &c.User) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type CompletionsUsage. +func (c CompletionsUsage) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "completion_tokens", c.CompletionTokens) + populate(objectMap, "prompt_tokens", c.PromptTokens) + populate(objectMap, "total_tokens", c.TotalTokens) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type CompletionsUsage. +func (c *CompletionsUsage) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "completion_tokens": + err = unpopulate(val, "CompletionTokens", &c.CompletionTokens) + delete(rawMsg, key) + case "prompt_tokens": + err = unpopulate(val, "PromptTokens", &c.PromptTokens) + delete(rawMsg, key) + case "total_tokens": + err = unpopulate(val, "TotalTokens", &c.TotalTokens) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", c, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type Deployment. +func (d Deployment) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "deploymentId", d.DeploymentID) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type Deployment. +func (d *Deployment) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", d, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "deploymentId": + err = unpopulate(val, "DeploymentID", &d.DeploymentID) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", d, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type EmbeddingItem. +func (e EmbeddingItem) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "embedding", e.Embedding) + populate(objectMap, "index", e.Index) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type EmbeddingItem. +func (e *EmbeddingItem) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "embedding": + err = unpopulate(val, "Embedding", &e.Embedding) + delete(rawMsg, key) + case "index": + err = unpopulate(val, "Index", &e.Index) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type Embeddings. +func (e Embeddings) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "data", e.Data) + populate(objectMap, "usage", e.Usage) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type Embeddings. +func (e *Embeddings) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "data": + err = unpopulate(val, "Data", &e.Data) + delete(rawMsg, key) + case "usage": + err = unpopulate(val, "Usage", &e.Usage) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type EmbeddingsOptions. +func (e EmbeddingsOptions) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populateAny(objectMap, "input", e.Input) + populate(objectMap, "model", e.Model) + populate(objectMap, "user", e.User) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type EmbeddingsOptions. +func (e *EmbeddingsOptions) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "input": + err = unpopulate(val, "Input", &e.Input) + delete(rawMsg, key) + case "model": + err = unpopulate(val, "Model", &e.Model) + delete(rawMsg, key) + case "user": + err = unpopulate(val, "User", &e.User) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type EmbeddingsUsage. +func (e EmbeddingsUsage) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "prompt_tokens", e.PromptTokens) + populate(objectMap, "total_tokens", e.TotalTokens) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type EmbeddingsUsage. +func (e *EmbeddingsUsage) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "prompt_tokens": + err = unpopulate(val, "PromptTokens", &e.PromptTokens) + delete(rawMsg, key) + case "total_tokens": + err = unpopulate(val, "TotalTokens", &e.TotalTokens) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for type EmbeddingsUsageAutoGenerated. +func (e EmbeddingsUsageAutoGenerated) MarshalJSON() ([]byte, error) { + objectMap := make(map[string]any) + populate(objectMap, "prompt_tokens", e.PromptTokens) + populate(objectMap, "total_tokens", e.TotalTokens) + return json.Marshal(objectMap) +} + +// UnmarshalJSON implements the json.Unmarshaller interface for type EmbeddingsUsageAutoGenerated. +func (e *EmbeddingsUsageAutoGenerated) UnmarshalJSON(data []byte) error { + var rawMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + for key, val := range rawMsg { + var err error + switch key { + case "prompt_tokens": + err = unpopulate(val, "PromptTokens", &e.PromptTokens) + delete(rawMsg, key) + case "total_tokens": + err = unpopulate(val, "TotalTokens", &e.TotalTokens) + delete(rawMsg, key) + } + if err != nil { + return fmt.Errorf("unmarshalling type %T: %v", e, err) + } + } + return nil +} + +func populate(m map[string]any, k string, v any) { + if v == nil { + return + } else if azcore.IsNullValue(v) { + m[k] = nil + } else if !reflect.ValueOf(v).IsNil() { + m[k] = v + } +} + +func populateAny(m map[string]any, k string, v any) { + if v == nil { + return + } else if azcore.IsNullValue(v) { + m[k] = nil + } else { + m[k] = v + } +} + +func unpopulate(data json.RawMessage, fn string, v any) error { + if data == nil { + return nil + } + if err := json.Unmarshal(data, v); err != nil { + return fmt.Errorf("struct field %s: %v", fn, err) + } + return nil +} diff --git a/sdk/cognitiveservices/azopenai/policy_apikey.go b/sdk/cognitiveservices/azopenai/policy_apikey.go new file mode 100644 index 000000000000..ad353946801b --- /dev/null +++ b/sdk/cognitiveservices/azopenai/policy_apikey.go @@ -0,0 +1,40 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// KeyCredential is used when doing APIKey-based authentication. +type KeyCredential struct { + // APIKey is the api key for the client. + APIKey string +} + +// apiKeyPolicy authorizes requests with an API key acquired from a KeyCredential. +type apiKeyPolicy struct { + header string + cred KeyCredential +} + +// newAPIKeyPolicy creates a policy object that authorizes requests with an API Key. +// cred: a KeyCredential implementation. +func newAPIKeyPolicy(cred KeyCredential, header string) *apiKeyPolicy { + return &apiKeyPolicy{ + header: header, + cred: cred, + } +} + +// Do returns a function which authorizes req with a token from the policy's credential +func (b *apiKeyPolicy) Do(req *policy.Request) (*http.Response, error) { + req.Raw().Header.Set(b.header, b.cred.APIKey) + return req.Next() +} diff --git a/sdk/cognitiveservices/azopenai/policy_apikey_test.go b/sdk/cognitiveservices/azopenai/policy_apikey_test.go new file mode 100644 index 000000000000..b5945449c503 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/policy_apikey_test.go @@ -0,0 +1,80 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "context" + "net/http" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +func TestNewAPIKeyPolicy(t *testing.T) { + type args struct { + header string + cred KeyCredential + } + simpleCred := KeyCredential{APIKey: "apiKey"} + simpleHeader := "headerName" + tests := []struct { + name string + args args + want *apiKeyPolicy + }{ + { + name: "simple", + args: args{ + cred: simpleCred, + header: simpleHeader, + }, + want: &apiKeyPolicy{ + header: simpleHeader, + cred: simpleCred, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newAPIKeyPolicy(tt.args.cred, tt.args.header); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewAPIKeyPolicy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAPIKeyPolicy_Success(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + cred := KeyCredential{ + APIKey: "secret", + } + authPolicy := newAPIKeyPolicy(cred, "api-key") + pipeline := runtime.NewPipeline( + "testmodule", + "v0.1.0", + runtime.PipelineOptions{PerRetry: []policy.Policy{authPolicy}}, + &policy.ClientOptions{ + Transport: srv, + }) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := pipeline.Do(req) + if err != nil { + t.Fatalf("Expected nil error but received one") + } + if hdrValue := resp.Request.Header.Get("api-key"); hdrValue != "secret" { + t.Fatalf("expected api-key '%s', got '%s'", "secret", hdrValue) + } +} diff --git a/sdk/cognitiveservices/azopenai/response_types.go b/sdk/cognitiveservices/azopenai/response_types.go new file mode 100644 index 000000000000..8847d0af3da3 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/response_types.go @@ -0,0 +1,24 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. +// Code generated by Microsoft (R) AutoRest Code Generator. DO NOT EDIT. +// Changes may cause incorrect behavior and will be lost if the code is regenerated. + +package azopenai + +// GetChatCompletionsResponse contains the response from method Client.GetChatCompletions. +type GetChatCompletionsResponse struct { + ChatCompletions +} + +// GetCompletionsResponse contains the response from method Client.GetCompletions. +type GetCompletionsResponse struct { + Completions +} + +// GetEmbeddingsResponse contains the response from method Client.GetEmbeddings. +type GetEmbeddingsResponse struct { + Embeddings +} diff --git a/sdk/cognitiveservices/azopenai/version.go b/sdk/cognitiveservices/azopenai/version.go new file mode 100644 index 000000000000..cecc80db67a5 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/version.go @@ -0,0 +1,11 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +const ( + version = "v0.1.0" +)