From 171cdfe1bcf24d8c8fb745496ec9cb774bd29856 Mon Sep 17 00:00:00 2001 From: Rich Coggins Date: Mon, 16 Oct 2023 18:56:12 -0400 Subject: [PATCH] Sync with upstream --- CONTRIBUTING.md | 88 +++++++++++++++++++++ README.md | 122 ++++++++++++++++++++++++----- audio.go | 21 ++++- chat.go | 11 ++- chat_stream_test.go | 89 ++++++++++++++++++++- chat_test.go | 114 +++++++++++++++++++++++++++ client.go | 43 +++++++--- client_test.go | 24 ++++-- completion.go | 4 +- edits.go | 4 +- embeddings.go | 140 ++++++++++++++++++++++++++++++--- embeddings_test.go | 169 ++++++++++++++++++++++++++++++++++++++-- engines.go | 8 +- error.go | 25 ++++-- error_test.go | 78 +++++++++++++++++++ files.go | 27 ++++--- files_api_test.go | 1 - fine_tunes.go | 60 ++++++++++++-- fine_tuning_job.go | 157 +++++++++++++++++++++++++++++++++++++ fine_tuning_job_test.go | 105 +++++++++++++++++++++++++ image.go | 10 ++- models.go | 30 ++++++- models_test.go | 15 ++++ moderation.go | 4 +- ratelimit.go | 43 ++++++++++ stream_reader.go | 2 + 26 files changed, 1304 insertions(+), 90 deletions(-) create mode 100644 CONTRIBUTING.md create mode 100644 fine_tuning_job.go create mode 100644 fine_tuning_job_test.go create mode 100644 ratelimit.go diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..4dd184042 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,88 @@ +# Contributing Guidelines + +## Overview +Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests. + +## Reporting Bugs +If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it. + +## Suggesting Features +If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion. + +## Reporting Vulnerabilities +If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published. + +## Questions for Users +If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions). + +## Contributing Code +There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one. + +### Requirements for Merging a Pull Request + +The requirements to accept a pull request are as follows: + +- Features not provided by the OpenAI API will not be accepted. +- The functionality of the feature must match that of the official OpenAI API. +- All pull requests should be written in Go according to common conventions, formatted with `goimports`, and free of warnings from tools like `golangci-lint`. +- Include tests and ensure all tests pass. +- Maintain test coverage without any reduction. +- All pull requests require approval from at least one Go OpenAI maintainer. + +**Note:** +The merging method for pull requests in this repository is squash merge. + +### Creating a Pull Request +- Fork the repository. +- Create a new branch and commit your changes. +- Push that branch to GitHub. +- Start a new Pull Request on GitHub. (Please use the pull request template to provide detailed information.) + +**Note:** +If your changes introduce breaking changes, please prefix your pull request title with "[BREAKING_CHANGES]". + +### Code Style +In this project, we adhere to the standard coding style of Go. Your code should maintain consistency with the rest of the codebase. To achieve this, please format your code using tools like `goimports` and resolve any syntax or style issues with `golangci-lint`. + +**Run goimports:** +``` +go install golang.org/x/tools/cmd/goimports@latest +``` + +``` +goimports -w . +``` + +**Run golangci-lint:** +``` +go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest +``` + +``` +golangci-lint run --out-format=github-actions +``` + +### Unit Test +Please create or update tests relevant to your changes. Ensure all tests run successfully to verify that your modifications do not adversely affect other functionalities. + +**Run test:** +``` +go test -v ./... +``` + +### Integration Test +Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. + +**Notes:** +These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. + +**Run integration test:** +``` +OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go +``` + +If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. + +--- + +We wholeheartedly welcome your active participation. Let's build an amazing project together! diff --git a/README.md b/README.md index 33d8214a1..17bcd5a5c 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,16 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * DALL·E 2 * Whisper -### Installation: +## Installation + ``` go get github.com/coggsfl/go-openai ``` Currently, go-openai requires Go version 1.18 or greater. + +## Usage + ### ChatGPT example usage: ```go @@ -479,6 +483,62 @@ func main() { ``` +
+Embedding Semantic Similarity + +```go +package main + +import ( + "context" + "log" + openai "github.com/sashabaranov/go-openai" + +) + +func main() { + client := openai.NewClient("your-token") + + // Create an EmbeddingRequest for the user query + queryReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the user query + queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq) + if err != nil { + log.Fatal("Error creating query embedding:", err) + } + + // Create an EmbeddingRequest for the target text + targetReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the target text + targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq) + if err != nil { + log.Fatal("Error creating target embedding:", err) + } + + // Now that we have the embeddings for the user query and the target text, we + // can calculate their similarity. + queryEmbedding := queryResponse.Data[0] + targetEmbedding := targetResponse.Data[0] + + similarity, err := queryEmbedding.DotProduct(&targetEmbedding) + if err != nil { + log.Fatal("Error calculating dot product:", err) + } + + log.Printf("The similarity score between the query and the target is %f", similarity) +} + +``` +
+
Azure OpenAI Embeddings @@ -666,11 +726,16 @@ func main() { client := openai.NewClient("your token") ctx := context.Background() - // create a .jsonl file with your training data + // create a .jsonl file with your training data for conversational model // {"prompt": "", "completion": ""} // {"prompt": "", "completion": ""} // {"prompt": "", "completion": ""} + // chat models are trained using the following file format: + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} + // you can use openai cli tool to validate the data // For more info - https://platform.openai.com/docs/guides/fine-tuning @@ -683,29 +748,29 @@ func main() { return } - // create a fine tune job + // create a fine tuning job // Streams events until the job is done (this often takes minutes, but can take hours if there are many jobs in the queue or your dataset is large) // use below get method to know the status of your model - tune, err := client.CreateFineTune(ctx, openai.FineTuneRequest{ + fineTuningJob, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{ TrainingFile: file.ID, - Model: "ada", // babbage, curie, davinci, or a fine-tuned model created after 2022-04-21. + Model: "davinci-002", // gpt-3.5-turbo-0613, babbage-002. }) if err != nil { fmt.Printf("Creating new fine tune model error: %v\n", err) return } - getTune, err := client.GetFineTune(ctx, tune.ID) + fineTuningJob, err = client.RetrieveFineTuningJob(ctx, fineTuningJob.ID) if err != nil { fmt.Printf("Getting fine tune model error: %v\n", err) return } - fmt.Println(getTune.FineTunedModel) + fmt.Println(fineTuningJob.FineTunedModel) - // once the status of getTune is `succeeded`, you can use your fine tune model in Completion Request + // once the status of fineTuningJob is `succeeded`, you can use your fine tune model in Completion Request or Chat Completion Request // resp, err := client.CreateCompletion(ctx, openai.CompletionRequest{ - // Model: getTune.FineTunedModel, + // Model: fineTuningJob.FineTunedModel, // Prompt: "your prompt", // }) // if err != nil { @@ -719,19 +784,40 @@ func main() {
See the `examples/` folder for more. -### Integration tests: +## Frequently Asked Questions -Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. +### Why don't we get the same answer when specifying a temperature field of 0 and asking the same question? -**Notes:** -These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. +Even when specifying a temperature field of 0, it doesn't guarantee that you'll always get the same response. Several factors come into play. -**Run tests using:** -``` -OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go -``` +1. Go OpenAI Behavior: When you specify a temperature field of 0 in Go OpenAI, the omitempty tag causes that field to be removed from the request. Consequently, the OpenAI API applies the default value of 1. +2. Token Count for Input/Output: If there's a large number of tokens in the input and output, setting the temperature to 0 can still result in non-deterministic behavior. In particular, when using around 32k tokens, the likelihood of non-deterministic behavior becomes highest even with a temperature of 0. + +Due to the factors mentioned above, different answers may be returned even for the same question. + +**Workarounds:** +1. Using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. +2. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. + +By adopting these strategies, you can expect more consistent results. + +**Related Issues:** +[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9) + +### Does Go OpenAI provide a method to count tokens? + +No, Go OpenAI does not offer a feature to count tokens, and there are no plans to provide such a feature in the future. However, if there's a way to implement a token counting feature with zero dependencies, it might be possible to merge that feature into Go OpenAI. Otherwise, it would be more appropriate to implement it in a dedicated library or repository. + +For counting tokens, you might find the following links helpful: +- [Counting Tokens For Chat API Calls](https://github.com/pkoukk/tiktoken-go#counting-tokens-for-chat-api-calls) +- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) + +**Related Issues:** +[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62) + +## Contributing -If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. +By following [Contributing Guidelines](https://github.com/sashabaranov/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently. ## Thank you diff --git a/audio.go b/audio.go index 34595e9c2..766d30ed0 100644 --- a/audio.go +++ b/audio.go @@ -63,6 +63,21 @@ type AudioResponse struct { Transient bool `json:"transient"` } `json:"segments"` Text string `json:"text"` + + httpHeader +} + +type audioTextResponse struct { + Text string `json:"text"` + + httpHeader +} + +func (r *audioTextResponse) ToAudioResponse() AudioResponse { + return AudioResponse{ + Text: r.Text, + httpHeader: r.httpHeader, + } } // CreateTranscription — API call to create a transcription. Returns transcribed text. @@ -102,9 +117,11 @@ func (c *Client) callAudioAPI( } if request.HasJSONResponse() { - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) } else { - err = c.sendRequest(ctx, req, &response.Text) + var textResponse audioTextResponse + err = c.sendRequest(req, &textResponse) + response = textResponse.ToAudioResponse() } if err != nil { return AudioResponse{}, err diff --git a/chat.go b/chat.go index 0c73bed56..df0e5f970 100644 --- a/chat.go +++ b/chat.go @@ -114,6 +114,13 @@ const ( FinishReasonNull FinishReason = "null" ) +func (r FinishReason) MarshalJSON() ([]byte, error) { + if r == FinishReasonNull || r == "" { + return []byte("null"), nil + } + return []byte(`"` + string(r) + `"`), nil // best effort to not break future API changes +} + type ChatCompletionChoice struct { Index int `json:"index"` Message ChatCompletionMessage `json:"message"` @@ -135,6 +142,8 @@ type ChatCompletionResponse struct { Model string `json:"model"` Choices []ChatCompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateChatCompletion — API call to Create a completion for the chat message. @@ -158,6 +167,6 @@ func (c *Client) CreateChatCompletion( return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } diff --git a/chat_stream_test.go b/chat_stream_test.go index e1c5fb30c..9ce4f26b6 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,15 +1,17 @@ package openai_test import ( - . "github.com/coggsfl/go-openai" - "github.com/coggsfl/go-openai/internal/test/checks" - "context" "encoding/json" "errors" + "fmt" "io" "net/http" + "strconv" "testing" + + . "github.com/coggsfl/go-openai" + "github.com/coggsfl/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { @@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set(xCustomHeader, xCustomHeaderValue) + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + value := stream.Header().Get(xCustomHeader) + if value != xCustomHeaderValue { + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) + } +} + +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + headers := stream.GetRateLimitHeaders() + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index e2f13ce9e..2c7bfaff7 100644 --- a/chat_test.go +++ b/chat_test.go @@ -16,6 +16,22 @@ import ( "github.com/coggsfl/go-openai/jsonschema" ) +const ( + xCustomHeader = "X-CUSTOM-HEADER" + xCustomHeaderValue = "test" +) + +var ( + rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", + } +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -68,6 +84,64 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + a := resp.Header().Get(xCustomHeader) + _ = a + if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { + t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) + } +} + +// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + headers := resp.GetRateLimitHeaders() + resetRequests := headers.ResetRequests.String() + if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] { + t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"]) + } + resetRequestsTime := headers.ResetRequests.Time() + if resetRequestsTime.Before(time.Now()) { + t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + } + + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -281,6 +355,15 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } fmt.Fprintln(w, string(resBytes)) } @@ -298,3 +381,34 @@ func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { } return completion, nil } + +func TestFinishReason(t *testing.T) { + c := &ChatCompletionChoice{ + FinishReason: FinishReasonNull, + } + resBytes, _ := json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + c.FinishReason = "" + + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + otherReasons := []FinishReason{ + FinishReasonStop, + FinishReasonLength, + FinishReasonFunctionCall, + FinishReasonContentFilter, + } + for _, r := range otherReasons { + c.FinishReason = r + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), fmt.Sprintf(`"finish_reason":"%s"`, r)) { + t.Errorf("%s should be quoted", r) + } + } +} diff --git a/client.go b/client.go index 2fb437645..91869ed18 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "strings" "time" @@ -33,9 +32,11 @@ type Client struct { type CBData []struct { URL string `json:"url"` } + type CBResult struct { Data CBData `json:"data"` } + type CallBackResponse struct { Created int64 `json:"created"` Expires int64 `json:"expires"` @@ -44,6 +45,24 @@ type CallBackResponse struct { Status string `json:"status"` } +type Response interface { + SetHeader(http.Header) +} + +type httpHeader http.Header + +func (h *httpHeader) SetHeader(header http.Header) { + *h = httpHeader(header) +} + +func (h *httpHeader) Header() http.Header { + return http.Header(*h) +} + +func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { + return newRateLimitHeaders(h.Header()) +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -106,7 +125,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... return req, nil } -func (c *Client) sendRequest(ctx context.Context, req *http.Request, v any) error { +func (c *Client) sendRequest(req *http.Request, v Response) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Check whether Content-Type is already set, Upload Files API requires @@ -130,13 +149,16 @@ func (c *Client) sendRequest(ctx context.Context, req *http.Request, v any) erro if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { // Special handling for initial call to Azure DALL-E API. if strings.Contains(req.URL.Path, "openai/images/generations") { - return c.requestImage(ctx, res, v) + return c.requestImage(res, v) } // Special handling for callBack to Azure DALL-E API. if strings.Contains(req.URL.Path, "openai/operations/images") { - return c.imageRequestCallback(ctx, req, v, res) + return c.imageRequestCallback(req, v, res) } } + if v != nil { + v.SetHeader(res.Header) + } return decodeResponse(res.Body, v) } @@ -173,6 +195,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream response: resp, errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, + httpHeader: httpHeader(resp.Header), }, nil } @@ -194,8 +217,8 @@ func isFailureStatusCode(resp *http.Response) bool { return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest } -func (c *Client) requestImage(ctx context.Context, res *http.Response, v any) error { - _, err := io.Copy(ioutil.Discard, res.Body) +func (c *Client) requestImage(res *http.Response, v Response) error { + _, err := io.Copy(io.Discard, res.Body) if err != nil { return err } @@ -203,16 +226,16 @@ func (c *Client) requestImage(ctx context.Context, res *http.Response, v any) er if callBackURL == "" { return ErrClientEmptyCallbackURL } - newReq, err := c.newRequest(ctx, http.MethodGet, callBackURL) + newReq, err := c.newRequest(context.Background(), http.MethodGet, callBackURL) if err != nil { return err } - return c.sendRequest(ctx, newReq, v) + return c.sendRequest(newReq, v) } // Handle image callback response from Azure DALL-E API. -func (c *Client) imageRequestCallback(ctx context.Context, req *http.Request, v any, res *http.Response) error { +func (c *Client) imageRequestCallback(req *http.Request, v Response, res *http.Response) error { // Retry Sleep seconds for Azure DALL-E 2 callback URL. var callBackWaitTime = 3 @@ -228,7 +251,7 @@ func (c *Client) imageRequestCallback(ctx context.Context, req *http.Request, v if result.Status != "succeeded" { time.Sleep(time.Duration(callBackWaitTime) * time.Second) req.Header.Add("Retry", "true") - return c.sendRequest(ctx, req, v) + return c.sendRequest(req, v) } // Convert the callBack response to the OpenAI ImageResponse diff --git a/client_test.go b/client_test.go index 81d22acd5..9e4f942c0 100644 --- a/client_test.go +++ b/client_test.go @@ -228,6 +228,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListFineTuneEvents", func() (any, error) { return client.ListFineTuneEvents(ctx, "") }}, + {"CreateFineTuningJob", func() (any, error) { + return client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + }}, + {"CancelFineTuningJob", func() (any, error) { + return client.CancelFineTuningJob(ctx, "") + }}, + {"RetrieveFineTuningJob", func() (any, error) { + return client.RetrieveFineTuningJob(ctx, "") + }}, + {"ListFineTuningJobEvents", func() (any, error) { + return client.ListFineTuningJobEvents(ctx, "") + }}, {"Moderations", func() (any, error) { return client.Moderations(ctx, ModerationRequest{}) }}, @@ -264,6 +276,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"GetModel", func() (any, error) { return client.GetModel(ctx, "text-davinci-003") }}, + {"DeleteFineTuneModel", func() (any, error) { + return client.DeleteFineTuneModel(ctx, "") + }}, } for _, testCase := range testCases { @@ -305,8 +320,7 @@ func TestRequestImageErrors(t *testing.T) { Body: ioutil.NopCloser(bytes.NewBufferString("")), } v := &ImageRequest{} - ctx := context.Background() - err = client.requestImage(ctx, res, v) + err = client.requestImage(res, v) if !errors.Is(err, ErrClientEmptyCallbackURL) { t.Fatalf("%s did not return error. requestImage failed: %v", testCase, err) @@ -317,9 +331,9 @@ func TestRequestImageErrors(t *testing.T) { res = &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Operation-Location": []string{"hxxp://localhost:8080/openai/operations/images/request-id"}}, - Body: ioutil.NopCloser(bytes.NewBufferString("")), + Body: io.NopCloser(bytes.NewBufferString("")), } - err = client.requestImage(ctx, res, v) + err = client.requestImage(res, v) if err == nil { t.Fatalf("%s did not return error. requestImage failed: %v", testCase, err) } @@ -361,7 +375,7 @@ func TestImageRequestCallbackErrors(t *testing.T) { Body: ioutil.NopCloser(bytes.NewBufferString(cbResponseBytes.String())), } v := &ImageRequest{} - err = client.imageRequestCallback(ctx, req, v, res) + err = client.imageRequestCallback(req, v, res) if !errors.Is(err, ErrClientRetievingCallbackResponse) { t.Fatalf("%s did not return error. imageRequestCallback failed: %v", testCase, err) diff --git a/completion.go b/completion.go index 8879680d5..c7ff94afc 100644 --- a/completion.go +++ b/completion.go @@ -154,6 +154,8 @@ type CompletionResponse struct { Model string `json:"model"` Choices []CompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well @@ -186,6 +188,6 @@ func (c *Client) CreateCompletion( return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } diff --git a/edits.go b/edits.go index bd93fa8ca..97d026029 100644 --- a/edits.go +++ b/edits.go @@ -28,6 +28,8 @@ type EditsResponse struct { Created int64 `json:"created"` Usage Usage `json:"usage"` Choices []EditsChoice `json:"choices"` + + httpHeader } // Edits Perform an API call to the Edits endpoint. @@ -41,6 +43,6 @@ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response Edit return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } diff --git a/embeddings.go b/embeddings.go index b96329076..7e2aa7eb0 100644 --- a/embeddings.go +++ b/embeddings.go @@ -2,9 +2,15 @@ package openai import ( "context" + "encoding/base64" + "encoding/binary" + "errors" + "math" "net/http" ) +var ErrVectorLengthMismatch = errors.New("vector length mismatch") + // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. type EmbeddingModel int @@ -121,12 +127,90 @@ type Embedding struct { Index int `json:"index"` } +// DotProduct calculates the dot product of the embedding vector with another +// embedding vector. Both vectors must have the same length; otherwise, an +// ErrVectorLengthMismatch is returned. The method returns the calculated dot +// product as a float32 value. +func (e *Embedding) DotProduct(other *Embedding) (float32, error) { + if len(e.Embedding) != len(other.Embedding) { + return 0, ErrVectorLengthMismatch + } + + var dotProduct float32 + for i := range e.Embedding { + dotProduct += e.Embedding[i] * other.Embedding[i] + } + + return dotProduct, nil +} + // EmbeddingResponse is the response from a Create embeddings request. type EmbeddingResponse struct { Object string `json:"object"` Data []Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader +} + +type base64String string + +func (b base64String) Decode() ([]float32, error) { + decodedData, err := base64.StdEncoding.DecodeString(string(b)) + if err != nil { + return nil, err + } + + const sizeOfFloat32 = 4 + floats := make([]float32, len(decodedData)/sizeOfFloat32) + for i := 0; i < len(floats); i++ { + floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4])) + } + + return floats, nil +} + +// Base64Embedding is a container for base64 encoded embeddings. +type Base64Embedding struct { + Object string `json:"object"` + Embedding base64String `json:"embedding"` + Index int `json:"index"` +} + +// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format. +type EmbeddingResponseBase64 struct { + Object string `json:"object"` + Data []Base64Embedding `json:"data"` + Model EmbeddingModel `json:"model"` + Usage Usage `json:"usage"` + + httpHeader +} + +// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. +func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) { + data := make([]Embedding, len(r.Data)) + + for i, base64Embedding := range r.Data { + embedding, err := base64Embedding.Embedding.Decode() + if err != nil { + return EmbeddingResponse{}, err + } + + data[i] = Embedding{ + Object: base64Embedding.Object, + Embedding: embedding, + Index: base64Embedding.Index, + } + } + + return EmbeddingResponse{ + Object: r.Object, + Model: r.Model, + Data: data, + Usage: r.Usage, + }, nil } type EmbeddingRequestConverter interface { @@ -134,10 +218,21 @@ type EmbeddingRequestConverter interface { Convert() EmbeddingRequest } +// EmbeddingEncodingFormat is the format of the embeddings data. +// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. +// If not specified OpenAI will use "float". +type EmbeddingEncodingFormat string + +const ( + EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float" + EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64" +) + type EmbeddingRequest struct { - Input any `json:"input"` - Model EmbeddingModel `json:"model"` - User string `json:"user"` + Input any `json:"input"` + Model EmbeddingModel `json:"model"` + User string `json:"user"` + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -158,13 +253,18 @@ type EmbeddingRequestStrings struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { return EmbeddingRequest{ - Input: r.Input, - Model: r.Model, - User: r.User, + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, } } @@ -181,13 +281,18 @@ type EmbeddingRequestTokens struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { return EmbeddingRequest{ - Input: r.Input, - Model: r.Model, - User: r.User, + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, } } @@ -196,14 +301,27 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { // // Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens // for embedding groups of text already converted to tokens. -func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll +func (c *Client) CreateEmbeddings( + ctx context.Context, + conv EmbeddingRequestConverter, +) (res EmbeddingResponse, err error) { baseReq := conv.Convert() req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) if err != nil { return } - err = c.sendRequest(ctx, req, &res) + if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 { + err = c.sendRequest(req, &res) + return + } + + base64Response := &EmbeddingResponseBase64{} + err = c.sendRequest(req, base64Response) + if err != nil { + return + } + res, err = base64Response.ToEmbeddingResponse() return } diff --git a/embeddings_test.go b/embeddings_test.go index 7179116d5..56f496442 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -1,15 +1,18 @@ package openai_test import ( - . "github.com/coggsfl/go-openai" - "github.com/coggsfl/go-openai/internal/test/checks" - "bytes" "context" "encoding/json" + "errors" "fmt" + "math" "net/http" + "reflect" "testing" + + . "github.com/coggsfl/go-openai" + "github.com/coggsfl/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { @@ -97,22 +100,174 @@ func TestEmbeddingModel(t *testing.T) { func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() + + sampleEmbeddings := []Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + } + + sampleBase64Embeddings := []Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + } + server.RegisterHandler( "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EmbeddingResponse{}) + var req struct { + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` + } + _ = json.NewDecoder(r.Body).Decode(&req) + + var resBytes []byte + switch { + case req.User == "invalid": + w.WriteHeader(http.StatusBadRequest) + return + case req.EncodingFormat == EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + default: + resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) + } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) - _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (simple embedding request) + res, err = client.CreateEmbeddings( + context.Background(), + EmbeddingRequest{ + EncodingFormat: EmbeddingEncodingFormatBase64, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } // test create embeddings with strings - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } // test create embeddings with tokens - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test failed sendRequest + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ + User: "invalid", + EncodingFormat: EmbeddingEncodingFormatBase64, + }) + checks.HasError(t, err, "CreateEmbeddings error") +} + +func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { + type fields struct { + Object string + Data []Base64Embedding + Model EmbeddingModel + Usage Usage + } + tests := []struct { + name string + fields fields + want EmbeddingResponse + wantErr bool + }{ + { + name: "test embedding response base64 to embedding response", + fields: fields{ + Data: []Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + }, + }, + want: EmbeddingResponse{ + Data: []Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + }, + }, + wantErr: false, + }, + { + name: "Invalid embedding", + fields: fields{ + Data: []Base64Embedding{ + { + Embedding: "----", + }, + }, + }, + want: EmbeddingResponse{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &EmbeddingResponseBase64{ + Object: tt.fields.Object, + Data: tt.fields.Data, + Model: tt.fields.Model, + Usage: tt.fields.Usage, + } + got, err := r.ToEmbeddingResponse() + if (err != nil) != tt.wantErr { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDotProduct(t *testing.T) { + v1 := &Embedding{Embedding: []float32{1, 2, 3}} + v2 := &Embedding{Embedding: []float32{2, 4, 6}} + expected := float32(28.0) + + result, err := v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1, 0}} + expected = float32(0.0) + + result, err = v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + // Test for VectorLengthMismatchError + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1}} + _, err = v1.DotProduct(v2) + if !errors.Is(err, ErrVectorLengthMismatch) { + t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) + } } diff --git a/engines.go b/engines.go index 45c502bb8..5a0dba858 100644 --- a/engines.go +++ b/engines.go @@ -12,11 +12,15 @@ type Engine struct { Object string `json:"object"` Owner string `json:"owner"` Ready bool `json:"ready"` + + httpHeader } // EnginesList is a list of engines. type EnginesList struct { Engines []Engine `json:"data"` + + httpHeader } // ListEngines Lists the currently available engines, and provides basic @@ -27,7 +31,7 @@ func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err erro return } - err = c.sendRequest(ctx, req, &engines) + err = c.sendRequest(req, &engines) return } @@ -43,6 +47,6 @@ func (c *Client) GetEngine( return } - err = c.sendRequest(ctx, req, &engine) + err = c.sendRequest(req, &engine) return } diff --git a/error.go b/error.go index 523510afe..eacdd1aa7 100644 --- a/error.go +++ b/error.go @@ -7,12 +7,20 @@ import ( ) // APIError provides error information returned by the OpenAI API. +// InnerError struct is only valid for Azure OpenAI Service. type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - HTTPStatusCode int `json:"-"` + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` + HTTPStatusCode int `json:"-"` + InnerError *InnerError `json:"innererror,omitempty"` +} + +// InnerError Azure Content filtering. Only valid for Azure OpenAI Service. +type InnerError struct { + Code string `json:"code,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } // RequestError provides informations about generic request errors. @@ -61,6 +69,13 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } } + if _, ok := rawMap["innererror"]; ok { + err = json.Unmarshal(rawMap["innererror"], &e.InnerError) + if err != nil { + return + } + } + // optional fields if _, ok := rawMap["param"]; ok { err = json.Unmarshal(rawMap["param"], &e.Param) diff --git a/error_test.go b/error_test.go index 25b488c92..fe862701b 100644 --- a/error_test.go +++ b/error_test.go @@ -3,6 +3,7 @@ package openai_test import ( "errors" "net/http" + "reflect" "testing" . "github.com/coggsfl/go-openai" @@ -57,6 +58,77 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { assertAPIErrorMessage(t, apiErr, "") }, }, + { + name: "parse succeeds when the innerError is not exists (Azure Openai)", + response: `{ + "message": "test message", + "type": null, + "param": "prompt", + "code": "content_filter", + "status": 400, + "innererror": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_result": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": true, + "severity": "medium" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + } + } + }`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{ + Code: "ResponsibleAIPolicyViolation", + ContentFilterResults: ContentFilterResults{ + Hate: Hate{ + Filtered: false, + Severity: "safe", + }, + SelfHarm: SelfHarm{ + Filtered: false, + Severity: "safe", + }, + Sexual: Sexual{ + Filtered: true, + Severity: "medium", + }, + Violence: Violence{ + Filtered: false, + Severity: "safe", + }, + }, + }) + }, + }, + { + name: "parse succeeds when the innerError is empty (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{}) + }, + }, + { + name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{}) + }, + }, { name: "parse failed when the message is object", response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, @@ -152,6 +224,12 @@ func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { } } +func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { + if !reflect.DeepEqual(apiErr.InnerError, expected) { + t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) + } +} + func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { switch v := apiErr.Code.(type) { case int: diff --git a/files.go b/files.go index 571e503f5..9e521fbbe 100644 --- a/files.go +++ b/files.go @@ -17,18 +17,23 @@ type FileRequest struct { // File struct represents an OpenAPI file. type File struct { - Bytes int `json:"bytes"` - CreatedAt int64 `json:"created_at"` - ID string `json:"id"` - FileName string `json:"filename"` - Object string `json:"object"` - Owner string `json:"owner"` - Purpose string `json:"purpose"` + Bytes int `json:"bytes"` + CreatedAt int64 `json:"created_at"` + ID string `json:"id"` + FileName string `json:"filename"` + Object string `json:"object"` + Status string `json:"status"` + Purpose string `json:"purpose"` + StatusDetails string `json:"status_details"` + + httpHeader } // FilesList is a list of files that belong to the user or organization. type FilesList struct { Files []File `json:"data"` + + httpHeader } // CreateFile uploads a jsonl file to GPT3 @@ -63,7 +68,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - err = c.sendRequest(ctx, req, &file) + err = c.sendRequest(req, &file) return } @@ -74,7 +79,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { return } - err = c.sendRequest(ctx, req, nil) + err = c.sendRequest(req, nil) return } @@ -86,7 +91,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { return } - err = c.sendRequest(ctx, req, &files) + err = c.sendRequest(req, &files) return } @@ -99,7 +104,7 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err return } - err = c.sendRequest(ctx, req, &file) + err = c.sendRequest(req, &file) return } diff --git a/files_api_test.go b/files_api_test.go index d85bf0668..be5b0d021 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -64,7 +64,6 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { Purpose: purpose, CreatedAt: time.Now().Unix(), Object: "test-objecct", - Owner: "test-owner", } resBytes, _ = json.Marshal(fileReq) diff --git a/fine_tunes.go b/fine_tunes.go index a0f1e0143..ca840781c 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -6,6 +6,9 @@ import ( "net/http" ) +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneRequest struct { TrainingFile string `json:"training_file"` ValidationFile string `json:"validation_file,omitempty"` @@ -21,6 +24,9 @@ type FineTuneRequest struct { Suffix string `json:"suffix,omitempty"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTune struct { ID string `json:"id"` Object string `json:"object"` @@ -35,8 +41,13 @@ type FineTune struct { ValidationFiles []File `json:"validation_files"` TrainingFiles []File `json:"training_files"` UpdatedAt int64 `json:"updated_at"` + + httpHeader } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEvent struct { Object string `json:"object"` CreatedAt int64 `json:"created_at"` @@ -44,6 +55,9 @@ type FineTuneEvent struct { Message string `json:"message"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneHyperParams struct { BatchSize int `json:"batch_size"` LearningRateMultiplier float64 `json:"learning_rate_multiplier"` @@ -51,21 +65,40 @@ type FineTuneHyperParams struct { PromptLossWeight float64 `json:"prompt_loss_weight"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` + + httpHeader } + +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` + + httpHeader } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) @@ -73,31 +106,40 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } // CancelFineTune cancel a fine-tune job. +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) if err != nil { return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) @@ -105,26 +147,32 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) if err != nil { return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) if err != nil { return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } diff --git a/fine_tuning_job.go b/fine_tuning_job.go new file mode 100644 index 000000000..9dcb49de1 --- /dev/null +++ b/fine_tuning_job.go @@ -0,0 +1,157 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type FineTuningJob struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at"` + Model string `json:"model"` + FineTunedModel string `json:"fine_tuned_model,omitempty"` + OrganizationID string `json:"organization_id"` + Status string `json:"status"` + Hyperparameters Hyperparameters `json:"hyperparameters"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + ResultFiles []string `json:"result_files"` + TrainedTokens int `json:"trained_tokens"` + + httpHeader +} + +type Hyperparameters struct { + Epochs any `json:"n_epochs,omitempty"` +} + +type FineTuningJobRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTuningJobEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` + HasMore bool `json:"has_more"` + + httpHeader +} + +type FineTuningJobEvent struct { + Object string `json:"object"` + ID string `json:"id"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + Data any `json:"data"` + Type string `json:"type"` +} + +// CreateFineTuningJob create a fine tuning job. +func (c *Client) CreateFineTuningJob( + ctx context.Context, + request FineTuningJobRequest, +) (response FineTuningJob, err error) { + urlSuffix := "/fine_tuning/jobs" + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelFineTuningJob cancel a fine tuning job. +func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel")) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveFineTuningJob retrieve a fine tuning job. +func (c *Client) RetrieveFineTuningJob( + ctx context.Context, + fineTuningJobID string, +) (response FineTuningJob, err error) { + urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type listFineTuningJobEventsParameters struct { + after *string + limit *int +} + +type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) + +func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.after = &after + } +} + +func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.limit = &limit + } +} + +// ListFineTuningJobs list fine tuning jobs events. +func (c *Client) ListFineTuningJobEvents( + ctx context.Context, + fineTuningJobID string, + setters ...ListFineTuningJobEventsParameter, +) (response FineTuningJobEventList, err error) { + parameters := &listFineTuningJobEventsParameters{ + after: nil, + limit: nil, + } + + for _, setter := range setters { + setter(parameters) + } + + urlValues := url.Values{} + if parameters.after != nil { + urlValues.Add("after", *parameters.after) + } + if parameters.limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit)) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go new file mode 100644 index 000000000..2299df2f0 --- /dev/null +++ b/fine_tuning_job_test.go @@ -0,0 +1,105 @@ +package openai_test + +import ( + "context" + + . "github.com/coggsfl/go-openai" + "github.com/coggsfl/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +const testFineTuninigJobID = "fine-tuning-job-id" + +// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. +func TestFineTuningJob(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler( + "/v1/fine_tuning/jobs", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJob{ + Object: "fine_tuning.job", + ID: testFineTuninigJobID, + Model: "davinci-002", + CreatedAt: 1692661014, + FinishedAt: 1692661190, + FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + OrganizationID: "org-123", + ResultFiles: []string{"file-abc123"}, + Status: "succeeded", + ValidationFile: "", + TrainingFile: "file-abc123", + Hyperparameters: Hyperparameters{ + Epochs: "auto", + }, + TrainedTokens: 5768, + }) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID, + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJobEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + ctx := context.Background() + + _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + checks.NoError(t, err, "CreateFineTuningJob error") + + _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "CancelFineTuningJob error") + + _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "RetrieveFineTuningJob error") + + _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") +} diff --git a/image.go b/image.go index 8e059d9fc..ee2abf1e5 100644 --- a/image.go +++ b/image.go @@ -27,12 +27,16 @@ type ImageRequest struct { Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` User string `json:"user,omitempty"` + + httpHeader } // ImageResponse represents a response structure for image API. type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + + httpHeader } // ImageResponseDataInner represents a response data structure for image API. @@ -49,7 +53,7 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } @@ -113,7 +117,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } @@ -163,6 +167,6 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } diff --git a/models.go b/models.go index 0eca4b436..d94f98836 100644 --- a/models.go +++ b/models.go @@ -15,6 +15,8 @@ type Model struct { Permission []Permission `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` + + httpHeader } // Permission struct represents an OpenAPI permission. @@ -33,9 +35,20 @@ type Permission struct { IsBlocking bool `json:"is_blocking"` } +// FineTuneModelDeleteResponse represents the deletion status of a fine-tuned model. +type FineTuneModelDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` + + httpHeader } // ListModels Lists the currently available models, @@ -46,7 +59,7 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) return } - err = c.sendRequest(ctx, req, &models) + err = c.sendRequest(req, &models) return } @@ -59,6 +72,19 @@ func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err return } - err = c.sendRequest(ctx, req, &model) + err = c.sendRequest(req, &model) + return +} + +// DeleteFineTuneModel Deletes a fine-tune model. You must have the Owner +// role in your organization to delete a model. +func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( + response FineTuneModelDeleteResponse, err error) { + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/models/"+modelID)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) return } diff --git a/models_test.go b/models_test.go index e9db5b67f..d2e5e1994 100644 --- a/models_test.go +++ b/models_test.go @@ -14,6 +14,8 @@ import ( "testing" ) +const testFineTuneModelID = "fine-tune-model-id" + // TestListModels Tests the list models endpoint of the API using the mocked server. func TestListModels(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -78,3 +80,16 @@ func TestGetModelReturnTimeoutError(t *testing.T) { t.Fatal("Did not return timeout error") } } + +func TestDeleteFineTuneModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/"+testFineTuneModelID, handleDeleteFineTuneModelEndpoint) + _, err := client.DeleteFineTuneModel(context.Background(), testFineTuneModelID) + checks.NoError(t, err, "DeleteFineTuneModel error") +} + +func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) + fmt.Fprintln(w, string(resBytes)) +} diff --git a/moderation.go b/moderation.go index 4847d6486..f8d20ee51 100644 --- a/moderation.go +++ b/moderation.go @@ -69,6 +69,8 @@ type ModerationResponse struct { ID string `json:"id"` Model string `json:"model"` Results []Result `json:"results"` + + httpHeader } // Moderations — perform a moderation api call over a string. @@ -83,6 +85,6 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re return } - err = c.sendRequest(ctx, req, &response) + err = c.sendRequest(req, &response) return } diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 000000000..e8953f716 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,43 @@ +package openai + +import ( + "net/http" + "strconv" + "time" +) + +// RateLimitHeaders struct represents Openai rate limits headers. +type RateLimitHeaders struct { + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` +} + +type ResetTime string + +func (r ResetTime) String() string { + return string(r) +} + +func (r ResetTime) Time() time.Time { + d, _ := time.ParseDuration(string(r)) + return time.Now().Add(d) +} + +func newRateLimitHeaders(h http.Header) RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return RateLimitHeaders{ + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), + } +} diff --git a/stream_reader.go b/stream_reader.go index d62f7afd6..36accae2d 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -27,6 +27,8 @@ type streamReader[T streamable] struct { response *http.Response errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler + + httpHeader } func (stream *streamReader[T]) Recv() (response T, err error) {