From 1fc333877acb90acf5944f8cc0d08785d500177c Mon Sep 17 00:00:00 2001 From: Richard Park Date: Mon, 12 Aug 2024 13:12:43 -0700 Subject: [PATCH] Adding in Azure support --- README.md | 37 +++++++ azure/azure.go | 237 ++++++++++++++++++++++++++++++++++++++++++ azure/azure_test.go | 130 +++++++++++++++++++++++ azure/example_test.go | 47 +++++++++ go.mod | 20 +++- go.sum | 35 ++++++- 6 files changed, 501 insertions(+), 5 deletions(-) create mode 100644 azure/azure.go create mode 100644 azure/azure_test.go create mode 100644 azure/example_test.go diff --git a/README.md b/README.md index 4e6d648..9c06d03 100644 --- a/README.md +++ b/README.md @@ -396,6 +396,43 @@ You may also replace the default `http.Client` with accepted (this overwrites any previous client) and receives requests after any middleware has been applied. +## Microsoft Azure OpenAI + +To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the option.RequestOption functions in the `azure` package. + +```go +package main + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/openai/openai-go" + "github.com/openai/openai-go/azure" +) + +func main() { + const azureOpenAIEndpoint = "https://.openai.azure.com" + + // The latest API versions, including previews, can be found here: + // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning + const azureOpenAIAPIVersion = "2024-06-01" + + tokenCredential, err := azidentity.NewDefaultAzureCredential(nil) + + if err != nil { + fmt.Printf("Failed to create the DefaultAzureCredential: %s", err) + os.Exit(1) + } + + client := openai.NewClient( + azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), + + // Choose between authenticating using a TokenCredential or an API Key + azure.WithTokenCredential(tokenCredential), + // or azure.WithAPIKey(azureOpenAIAPIKey), + ) +} +``` + ## Semantic versioning This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: diff --git a/azure/azure.go b/azure/azure.go new file mode 100644 index 0000000..5d3156f --- /dev/null +++ b/azure/azure.go @@ -0,0 +1,237 @@ +// Package azure provides configuration options so you can connect and use Azure OpenAI using the [openai.Client]. +// +// Typical usage of this package will look like this: +// +// client := openai.NewClient( +// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), +// azure.WithTokenCredential(azureIdentityTokenCredential), +// // or azure.WithAPIKey(azureOpenAIAPIKey), +// ) +// +// Or, if you want to construct a specific service: +// +// client := openai.NewChatCompletionService( +// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), +// azure.WithTokenCredential(azureIdentityTokenCredential), +// // or azure.WithAPIKey(azureOpenAIAPIKey), +// ) +package azure + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "mime" + "mime/multipart" + "net/http" + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" +) + +// WithEndpoint configures this client to connect to an Azure OpenAI endpoint. +// +// - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://.openai.azure.com +// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty. +// +// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this: +// +// client := openai.NewClient( +// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), +// azure.WithTokenCredential(azureIdentityTokenCredential), +// // or azure.WithAPIKey(azureOpenAIAPIKey), +// ) +// +// [Azure OpenAI apiversions]: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning +func WithEndpoint(endpoint string, apiVersion string) option.RequestOption { + if !strings.HasSuffix(endpoint, "/") { + endpoint += "/" + } + + endpoint += "openai/" + + withQueryAdd := option.WithQueryAdd("api-version", apiVersion) + withEndpoint := option.WithBaseURL(endpoint) + + withModelMiddleware := option.WithMiddleware(func(r *http.Request, mn option.MiddlewareNext) (*http.Response, error) { + replacementPath, err := getReplacementPathWithDeployment(r) + + if err != nil { + return nil, err + } + + r.URL.Path = replacementPath + return mn(r) + }) + + return func(rc *requestconfig.RequestConfig) error { + if apiVersion == "" { + return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.") + } + + if err := withQueryAdd(rc); err != nil { + return err + } + + if err := withEndpoint(rc); err != nil { + return err + } + + if err := withModelMiddleware(rc); err != nil { + return err + } + + return nil + } +} + +// WithTokenCredential configures this client to authenticate using an [Azure Identity] TokenCredential. +// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance. +// +// [Azure Identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity +func WithTokenCredential(tokenCredential azcore.TokenCredential) option.RequestOption { + bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, []string{"https://cognitiveservices.azure.com/.default"}, nil) + + // add in a middleware that uses the bearer token generated from the token credential + return option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{ + InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc.. + PerRetryPolicies: []policy.Policy{ + bearerTokenPolicy, + policyAdapter(next), + }, + }) + + req2, err := runtime.NewRequestFromRequest(req) + + if err != nil { + return nil, err + } + + return pipeline.Do(req2) + }) +} + +// WithAPIKey configures this client to authenticate using an API key. +// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance. +func WithAPIKey(apiKey string) option.RequestOption { + // NOTE: there is an option.WithApiKey(), but that adds the value into + // the Authorization header instead so we're doing this instead. + return option.WithHeader("Api-Key", apiKey) +} + +// jsonRoutes have JSON payloads - we'll deserialize looking for a .model field in there +// so we won't have to worry about individual types for completions vs embeddings, etc... +var jsonRoutes = map[string]bool{ + "/openai/completions": true, + "/openai/chat/completions": true, + "/openai/embeddings": true, + "/openai/audio/speech": true, + "/openai/images/generations": true, +} + +// audioMultipartRoutes have mime/multipart payloads. These are less generic - we're very much +// expecting a transcription or translation payload for these. +var audioMultipartRoutes = map[string]bool{ + "/openai/audio/transcriptions": true, + "/openai/audio/translations": true, +} + +// getReplacementPathWithDeployment parses the request body to extract out the Model parameter (or equivalent) +// (note, the req.Body is fully read as part of this, and is replaced with a bytes.Reader) +func getReplacementPathWithDeployment(req *http.Request) (string, error) { + if jsonRoutes[req.URL.Path] { + return getJSONRoute(req) + } + + if audioMultipartRoutes[req.URL.Path] { + return getAudioMultipartRoute(req) + } + + // No need to relocate the path. We've already tacked on /openai when we setup the endpoint. + return req.URL.Path, nil +} + +func getJSONRoute(req *http.Request) (string, error) { + // we need to deserialize the body, partly, in order to read out the model field. + jsonBytes, err := io.ReadAll(req.Body) + + if err != nil { + return "", err + } + + // make sure we restore the body so it can be used in later middlewares. + req.Body = io.NopCloser(bytes.NewReader(jsonBytes)) + + var v *struct { + Model string `json:"model"` + } + + if err := json.Unmarshal(jsonBytes, &v); err != nil { + return "", err + } + + escapedDeployment := url.PathEscape(v.Model) + return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil +} + +func getAudioMultipartRoute(req *http.Request) (string, error) { + // body is a multipart/mime body type instead. + mimeBytes, err := io.ReadAll(req.Body) + + if err != nil { + return "", err + } + + // make sure we restore the body so it can be used in later middlewares. + req.Body = io.NopCloser(bytes.NewReader(mimeBytes)) + + _, mimeParams, err := mime.ParseMediaType(req.Header.Get("Content-Type")) + + if err != nil { + return "", err + } + + mimeReader := multipart.NewReader( + io.NopCloser(bytes.NewReader(mimeBytes)), + mimeParams["boundary"]) + + for { + mp, err := mimeReader.NextPart() + + if err != nil { + if errors.Is(err, io.EOF) { + return "", errors.New("unable to find the model part in multipart body") + } + + return "", err + } + + defer mp.Close() + + if mp.FormName() == "model" { + modelBytes, err := io.ReadAll(mp) + + if err != nil { + return "", err + } + + escapedDeployment := url.PathEscape(string(modelBytes)) + return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil + } + } +} + +type policyAdapter option.MiddlewareNext + +func (mp policyAdapter) Do(req *policy.Request) (*http.Response, error) { + return (option.MiddlewareNext)(mp)(req.Raw()) +} + +const version = "v.0.1.0" diff --git a/azure/azure_test.go b/azure/azure_test.go new file mode 100644 index 0000000..ac694cc --- /dev/null +++ b/azure/azure_test.go @@ -0,0 +1,130 @@ +package azure + +import ( + "bytes" + "mime/multipart" + "net/http" + "testing" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/shared" +) + +func TestJSONRoute(t *testing.T) { + chatCompletionParams := openai.ChatCompletionNewParams{ + Model: openai.F(openai.ChatModel("arbitraryDeployment")), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.ChatCompletionAssistantMessageParam{ + Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), + Content: openai.F[openai.ChatCompletionAssistantMessageParamContentUnion](shared.UnionString("You are a helpful assistant")), + }, + openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Can you tell me another word for the universe?")), + }, + }), + } + + serializedBytes, err := apijson.MarshalRoot(chatCompletionParams) + + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(serializedBytes)) + + if err != nil { + t.Fatal(err) + } + + replacementPath, err := getReplacementPathWithDeployment(req) + + if err != nil { + t.Fatal(err) + } + + if replacementPath != "/openai/deployments/arbitraryDeployment/chat/completions" { + t.Fatalf("replacementpath didn't match: %s", replacementPath) + } +} + +func TestGetAudioMultipartRoute(t *testing.T) { + buff := &bytes.Buffer{} + mw := multipart.NewWriter(buff) + defer mw.Close() + + fw, err := mw.CreateFormFile("file", "test.mp3") + + if err != nil { + t.Fatal(err) + } + + if _, err = fw.Write([]byte("ignore me")); err != nil { + t.Fatal(err) + } + + if err := mw.WriteField("model", "arbitraryDeployment"); err != nil { + t.Fatal(err) + } + + if err := mw.Close(); err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("POST", "/openai/audio/transcriptions", bytes.NewReader(buff.Bytes())) + + if err != nil { + t.Fatal(err) + } + + req.Header.Set("Content-Type", mw.FormDataContentType()) + + replacementPath, err := getReplacementPathWithDeployment(req) + + if err != nil { + t.Fatal(err) + } + + if replacementPath != "/openai/deployments/arbitraryDeployment/audio/transcriptions" { + t.Fatalf("replacementpath didn't match: %s", replacementPath) + } +} + +func TestNoRouteChangeNeeded(t *testing.T) { + chatCompletionParams := openai.ChatCompletionNewParams{ + Model: openai.F(openai.ChatModel("arbitraryDeployment")), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.ChatCompletionAssistantMessageParam{ + Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), + Content: openai.F[openai.ChatCompletionAssistantMessageParamContentUnion](shared.UnionString("You are a helpful assistant")), + }, + openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Can you tell me another word for the universe?")), + }, + }), + } + + serializedBytes, err := apijson.MarshalRoot(chatCompletionParams) + + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("POST", "/openai/does/not/need/a/deployment", bytes.NewReader(serializedBytes)) + + if err != nil { + t.Fatal(err) + } + + replacementPath, err := getReplacementPathWithDeployment(req) + + if err != nil { + t.Fatal(err) + } + + if replacementPath != "/openai/does/not/need/a/deployment" { + t.Fatalf("replacementpath didn't match: %s", replacementPath) + } +} diff --git a/azure/example_test.go b/azure/example_test.go new file mode 100644 index 0000000..3a8ef21 --- /dev/null +++ b/azure/example_test.go @@ -0,0 +1,47 @@ +package azure_test + +import ( + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/openai/openai-go" + "github.com/openai/openai-go/azure" +) + +func Example_authentication() { + // There are two ways to authenticate - using a TokenCredential (via the azidentity + // package), or using an API Key. + const azureOpenAIEndpoint = "https://.openai.azure.com" + const azureOpenAIAPIVersion = "" + + // Using a TokenCredential + { + // For a full list of credential types look at the documentation for the Azure Identity + // package: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity + tokenCredential, err := azidentity.NewDefaultAzureCredential(nil) + + if err != nil { + fmt.Printf("Failed to create TokenCredential: %s\n", err) + return + } + + client := openai.NewClient( + azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), + azure.WithTokenCredential(tokenCredential), + ) + + _ = client + } + + // Using an API Key + { + const azureOpenAIAPIKey = "" + + client := openai.NewClient( + azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), + azure.WithAPIKey(azureOpenAIAPIKey), + ) + + _ = client + } +} diff --git a/go.mod b/go.mod index a242a50..f0a0b76 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,23 @@ module github.com/openai/openai-go go 1.21 require ( - github.com/google/uuid v1.3.0 // indirect - github.com/tidwall/gjson v1.14.4 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 + github.com/tidwall/gjson v1.14.4 + github.com/tidwall/sjson v1.2.5 +) + +require ( + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - github.com/tidwall/sjson v1.2.5 // indirect + golang.org/x/crypto v0.25.0 // indirect + golang.org/x/net v0.27.0 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/text v0.16.0 // indirect ) diff --git a/go.sum b/go.sum index 569e555..77c3f76 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,25 @@ -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -10,3 +30,14 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=