From 2c9d003ac495cde222373227eecb0e60083f7ee5 Mon Sep 17 00:00:00 2001 From: vvatanabe Date: Thu, 1 Jun 2023 06:36:52 +0900 Subject: [PATCH] move error_accumulator into internal pkg (#304) --- chat_stream_test.go | 48 +++++++++++++++++++++++++-------- error_accumulator.go | 19 ++++--------- error_accumulator_test.go | 56 +++++++-------------------------------- internal/test/helpers.go | 24 +++++++++++++++++ stream_reader.go | 16 ++++++++++- stream_reader_test.go | 27 +++++++++++++++++++ stream_test.go | 41 +++++++--------------------- 7 files changed, 127 insertions(+), 104 deletions(-) create mode 100644 stream_reader_test.go diff --git a/chat_stream_test.go b/chat_stream_test.go index afcb86d5e..9c15f0c56 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,7 +1,6 @@ -package openai_test +package openai //nolint:testpackage // testing private field import ( - . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -63,9 +62,9 @@ func TestCreateChatCompletionStream(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -170,9 +169,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -227,9 +226,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -255,6 +254,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) { + var err error + server := test.NewTestServer() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "error", 200) + }) + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + + ctx := context.Background() + + stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) + checks.NoError(t, err) + + stream.errAccumulator = &defaultErrorAccumulator{ + buffer: &failingErrorBuffer{}, + } + + _, err = stream.Recv() + checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) +} + // Helper funcs. func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/error_accumulator.go b/error_accumulator.go index 568afdbcd..b54292090 100644 --- a/error_accumulator.go +++ b/error_accumulator.go @@ -4,13 +4,11 @@ import ( "bytes" "fmt" "io" - - utils "github.com/sashabaranov/go-openai/internal" ) type errorAccumulator interface { write(p []byte) error - unmarshalError() *ErrorResponse + bytes() []byte } type errorBuffer interface { @@ -20,14 +18,12 @@ type errorBuffer interface { } type defaultErrorAccumulator struct { - buffer errorBuffer - unmarshaler utils.Unmarshaler + buffer errorBuffer } func newErrorAccumulator() errorAccumulator { return &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, + buffer: &bytes.Buffer{}, } } @@ -39,15 +35,10 @@ func (e *defaultErrorAccumulator) write(p []byte) error { return nil } -func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) { +func (e *defaultErrorAccumulator) bytes() (errBytes []byte) { if e.buffer.Len() == 0 { return } - - err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp) - if err != nil { - errResp = nil - } - + errBytes = e.buffer.Bytes() return } diff --git a/error_accumulator_test.go b/error_accumulator_test.go index 821eb21b4..4c7999f1c 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -2,14 +2,8 @@ package openai //nolint:testpackage // testing private field import ( "bytes" - "context" "errors" - "net/http" "testing" - - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" ) var ( @@ -38,63 +32,33 @@ func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { return errTestUnmarshalerFailed } -func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) { +func TestErrorAccumulatorBytes(t *testing.T) { accumulator := &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - unmarshaler: &failingUnMarshaller{}, + buffer: &bytes.Buffer{}, } - respErr := accumulator.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil with empty buffer: %v", respErr) + errBytes := accumulator.bytes() + if len(errBytes) != 0 { + t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) } - err := accumulator.write([]byte("{")) + err := accumulator.write([]byte("{}")) if err != nil { t.Fatalf("%+v", err) } - respErr = accumulator.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) + errBytes = accumulator.bytes() + if len(errBytes) == 0 { + t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) } } func TestErrorByteWriteErrors(t *testing.T) { accumulator := &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, + buffer: &failingErrorBuffer{}, } err := accumulator.write([]byte("{")) if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { t.Fatalf("Did not return error when write failed: %v", err) } } - -func TestErrorAccumulatorWriteErrors(t *testing.T) { - var err error - server := test.NewTestServer() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "error", 200) - }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - ctx := context.Background() - - stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - checks.NoError(t, err) - - stream.errAccumulator = &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, - } - - _, err = stream.Recv() - checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) -} diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 8461e5374..0e63ae82f 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -3,6 +3,7 @@ package test import ( "github.com/sashabaranov/go-openai/internal/test/checks" + "net/http" "os" "testing" ) @@ -27,3 +28,26 @@ func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { return path, func() { os.RemoveAll(path) } } + +// TokenRoundTripper is a struct that implements the RoundTripper +// interface, specifically to handle the authentication token by adding a token +// to the request header. We need this because the API requires that each +// request include a valid API token in the headers for authentication and +// authorization. +type TokenRoundTripper struct { + Token string + Fallback http.RoundTripper +} + +// RoundTrip takes an *http.Request as input and returns an +// *http.Response and an error. +// +// It is expected to use the provided request to create a connection to an HTTP +// server and return the response, or an error if one occurred. The returned +// Response should have its Body closed. If the RoundTrip method returns an +// error, the Client's Get, Head, Post, and PostForm methods return the same +// error. +func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+t.Token) + return t.Fallback.RoundTrip(req) +} diff --git a/stream_reader.go b/stream_reader.go index 5eb6df7b8..a6256d803 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -35,7 +35,7 @@ func (stream *streamReader[T]) Recv() (response T, err error) { waitForData: line, err := stream.reader.ReadBytes('\n') if err != nil { - respErr := stream.errAccumulator.unmarshalError() + respErr := stream.unmarshalError() if respErr != nil { err = fmt.Errorf("error, %w", respErr.Error) } @@ -69,6 +69,20 @@ waitForData: return } +func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { + errBytes := stream.errAccumulator.bytes() + if len(errBytes) == 0 { + return + } + + err := stream.unmarshaler.Unmarshal(errBytes, &errResp) + if err != nil { + errResp = nil + } + + return +} + func (stream *streamReader[T]) Close() { stream.response.Body.Close() } diff --git a/stream_reader_test.go b/stream_reader_test.go new file mode 100644 index 000000000..5d9a203f5 --- /dev/null +++ b/stream_reader_test.go @@ -0,0 +1,27 @@ +package openai //nolint:testpackage // testing private field + +import ( + "testing" +) + +func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + errAccumulator: newErrorAccumulator(), + unmarshaler: &failingUnMarshaller{}, + } + + respErr := stream.unmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil with empty buffer: %v", respErr) + } + + err := stream.errAccumulator.write([]byte("{")) + if err != nil { + t.Fatalf("%+v", err) + } + + respErr = stream.unmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) + } +} diff --git a/stream_test.go b/stream_test.go index a5c591fde..589fc9e26 100644 --- a/stream_test.go +++ b/stream_test.go @@ -57,9 +57,9 @@ func TestCreateCompletionStream(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -142,9 +142,9 @@ func TestCreateCompletionStreamError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -194,9 +194,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -217,29 +217,6 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } -// A "tokenRoundTripper" is a struct that implements the RoundTripper -// interface, specifically to handle the authentication token by adding a token -// to the request header. We need this because the API requires that each -// request include a valid API token in the headers for authentication and -// authorization. -type tokenRoundTripper struct { - token string - fallback http.RoundTripper -} - -// RoundTrip takes an *http.Request as input and returns an -// *http.Response and an error. -// -// It is expected to use the provided request to create a connection to an HTTP -// server and return the response, or an error if one occurred. The returned -// Response should have its Body closed. If the RoundTrip method returns an -// error, the Client's Get, Head, Post, and PostForm methods return the same -// error. -func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("Authorization", "Bearer "+t.token) - return t.fallback.RoundTrip(req) -} - // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {