Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix json marshaling error response of azure openai (#343) #345

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,67 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}

func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
wantCode := "429"
wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " +
"version 2023-03-15-preview have exceeded token rate limit of your current OpenAI S0 pricing tier. " +
"Please retry after 20 seconds. " +
"Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit."

server := test.NewTestServer()
server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions",
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
// Send test responses
dataBytes := []byte(`{"error": { "code": "` + wantCode + `", "message": "` + wantMessage + `"}}`)
_, err := w.Write(dataBytes)

checks.NoError(t, err, "Write error")
})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultAzureConfig(test.GetTestToken(), ts.URL)
client := NewClientWithConfig(config)
ctx := context.Background()

request := ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
}

apiErr := &APIError{}
_, err = client.CreateChatCompletionStream(ctx, request)
if !errors.As(err, &apiErr) {
t.Errorf("Did not return APIError: %+v\n", apiErr)
return
}
if apiErr.HTTPStatusCode != http.StatusTooManyRequests {
t.Errorf("Did not return HTTPStatusCode got = %d, want = %d\n", apiErr.HTTPStatusCode, http.StatusTooManyRequests)
return
}
code, ok := apiErr.Code.(string)
if !ok || code != wantCode {
t.Errorf("Did not return Code. got = %v, want = %s\n", apiErr.Code, wantCode)
return
}
if apiErr.Message != wantMessage {
t.Errorf("Did not return Message. got = %s, want = %s\n", apiErr.Message, wantMessage)
return
}
}

func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
var err error
server := test.NewTestServer()
Expand Down
9 changes: 9 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ func TestHandleErrorResp(t *testing.T) {
}`)),
expected: "error, status code: 503, message: That model...",
},
{
name: "503 no message (Unknown response)",
httpCode: http.StatusServiceUnavailable,
body: bytes.NewReader([]byte(`
{
"error":{}
}`)),
expected: "error, status code: 503, message: ",
},
}

for _, tc := range testCases {
Expand Down
10 changes: 7 additions & 3 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
return
}

err = json.Unmarshal(rawMap["type"], &e.Type)
if err != nil {
return
// optional fields for azure openai
// refs: https://github.com/sashabaranov/go-openai/issues/343
if _, ok := rawMap["type"]; ok {
err = json.Unmarshal(rawMap["type"], &e.Type)
if err != nil {
return
}
}

// optional fields
Expand Down