From 061c97ef7ee75e3ad782797cb46925949e548746 Mon Sep 17 00:00:00 2001 From: Stephen Young Date: Fri, 14 Apr 2023 14:09:40 -0400 Subject: [PATCH] Implement Unmarshaller interface. Resolves #244 (#248) --- api_test.go | 102 ++++++++++++++++++++++++++++++++++++++++++++++++++-- error.go | 48 +++++++++++++++++++++++-- 2 files changed, 145 insertions(+), 5 deletions(-) diff --git a/api_test.go b/api_test.go index 478a274d4..d6ad78932 100644 --- a/api_test.go +++ b/api_test.go @@ -1,6 +1,8 @@ package openai_test import ( + "encoding/json" + . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -110,7 +112,7 @@ func TestAPIError(t *testing.T) { c := NewClient(apiToken + "_invalid") ctx := context.Background() _, err = c.ListEngines(ctx) - checks.NoError(t, err, "ListEngines did not fail") + checks.HasError(t, err, "ListEngines should fail with an invalid key") var apiErr *APIError if !errors.As(err, &apiErr) { @@ -120,14 +122,108 @@ func TestAPIError(t *testing.T) { if apiErr.StatusCode != 401 { t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) } - if *apiErr.Code != "invalid_api_key" { - t.Fatalf("Unexpected API error code: %s", *apiErr.Code) + + switch v := apiErr.Code.(type) { + case string: + if v != "invalid_api_key" { + t.Fatalf("Unexpected API error code: %s", v) + } + default: + t.Fatalf("Unexpected API error code type: %T", v) } + if apiErr.Error() == "" { t.Fatal("Empty error message occurred") } } +func TestAPIErrorUnmarshalJSONInteger(t *testing.T) { + var apiErr APIError + response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` + err := json.Unmarshal([]byte(response), &apiErr) + checks.NoError(t, err, "Unexpected Unmarshal API response error") + + switch v := apiErr.Code.(type) { + case int: + if v != 418 { + t.Fatalf("Unexpected API code integer: %d; expected 418", v) + } + default: + t.Fatalf("Unexpected API error code type: %T", v) + } +} + +func TestAPIErrorUnmarshalJSONString(t *testing.T) { + var apiErr APIError + response := `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}` + err := json.Unmarshal([]byte(response), &apiErr) + checks.NoError(t, err, "Unexpected Unmarshal API response error") + + switch v := apiErr.Code.(type) { + case string: + if v != "teapot" { + t.Fatalf("Unexpected API code string: %s; expected `teapot`", v) + } + default: + t.Fatalf("Unexpected API error code type: %T", v) + } +} + +func TestAPIErrorUnmarshalJSONNoCode(t *testing.T) { + // test integer code + response := `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` + var apiErr APIError + err := json.Unmarshal([]byte(response), &apiErr) + checks.NoError(t, err, "Unexpected Unmarshal API response error") + + switch v := apiErr.Code.(type) { + case nil: + default: + t.Fatalf("Unexpected API error code type: %T", v) + } +} + +func TestAPIErrorUnmarshalInvalidData(t *testing.T) { + apiErr := APIError{} + data := []byte(`--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`) + err := apiErr.UnmarshalJSON(data) + checks.HasError(t, err, "Expected error when unmarshaling invalid data") + + if apiErr.Code != nil { + t.Fatalf("Expected nil code, got %q", apiErr.Code) + } + if apiErr.Message != "" { + t.Fatalf("Expected empty message, got %q", apiErr.Message) + } + if apiErr.Param != nil { + t.Fatalf("Expected nil param, got %q", *apiErr.Param) + } + if apiErr.Type != "" { + t.Fatalf("Expected empty type, got %q", apiErr.Type) + } +} + +func TestAPIErrorUnmarshalJSONInvalidParam(t *testing.T) { + var apiErr APIError + response := `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}` + err := json.Unmarshal([]byte(response), &apiErr) + checks.HasError(t, err, "Param should be a string") +} + +func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) { + var apiErr APIError + response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}` + err := json.Unmarshal([]byte(response), &apiErr) + checks.HasError(t, err, "Type should be a string") +} + +func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) { + var apiErr APIError + response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}` + err := json.Unmarshal([]byte(response), &apiErr) + checks.HasError(t, err, "Message should be a string") +} + func TestRequestError(t *testing.T) { var err error diff --git a/error.go b/error.go index d041da23b..32ffa6cc8 100644 --- a/error.go +++ b/error.go @@ -1,10 +1,13 @@ package openai -import "fmt" +import ( + "encoding/json" + "fmt" +) // APIError provides error information returned by the OpenAI API. type APIError struct { - Code *string `json:"code,omitempty"` + Code any `json:"code,omitempty"` Message string `json:"message"` Param *string `json:"param,omitempty"` Type string `json:"type"` @@ -25,6 +28,47 @@ func (e *APIError) Error() string { return e.Message } +func (e *APIError) UnmarshalJSON(data []byte) (err error) { + var rawMap map[string]json.RawMessage + err = json.Unmarshal(data, &rawMap) + if err != nil { + return + } + + err = json.Unmarshal(rawMap["message"], &e.Message) + if err != nil { + return + } + + err = json.Unmarshal(rawMap["type"], &e.Type) + if err != nil { + return + } + + // optional fields + if _, ok := rawMap["param"]; ok { + err = json.Unmarshal(rawMap["param"], &e.Param) + if err != nil { + return + } + } + + if _, ok := rawMap["code"]; !ok { + return nil + } + + // if the api returned a number, we need to force an integer + // since the json package defaults to float64 + var intCode int + err = json.Unmarshal(rawMap["code"], &intCode) + if err == nil { + e.Code = intCode + return nil + } + + return json.Unmarshal(rawMap["code"], &e.Code) +} + func (e *RequestError) Error() string { if e.Err != nil { return e.Err.Error()