From 15653fd13586f7c5ba07fa52a2f6c6f80e509f60 Mon Sep 17 00:00:00 2001 From: vvatanabe Date: Thu, 1 Jun 2023 08:54:52 +0900 Subject: [PATCH] move error_accumulator into internal pkg (#304) --- chat_stream.go | 2 +- chat_stream_test.go | 7 ++-- error_accumulator.go | 44 -------------------- error_accumulator_test.go | 64 ------------------------------ internal/error_accumulator.go | 44 ++++++++++++++++++++ internal/error_accumulator_test.go | 41 +++++++++++++++++++ internal/test/failer.go | 21 ++++++++++ stream.go | 2 +- stream_reader.go | 6 +-- stream_reader_test.go | 15 ++++++- 10 files changed, 128 insertions(+), 118 deletions(-) delete mode 100644 error_accumulator.go delete mode 100644 error_accumulator_test.go create mode 100644 internal/error_accumulator.go create mode 100644 internal/error_accumulator_test.go create mode 100644 internal/test/failer.go diff --git a/chat_stream.go b/chat_stream.go index 842835e15..9378c7124 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -66,7 +66,7 @@ func (c *Client) CreateChatCompletionStream( emptyMessagesLimit: c.config.EmptyMessagesLimit, reader: bufio.NewReader(resp.Body), response: resp, - errAccumulator: newErrorAccumulator(), + errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, }, } diff --git a/chat_stream_test.go b/chat_stream_test.go index 9c15f0c56..77d373c6a 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,6 +1,7 @@ package openai //nolint:testpackage // testing private field import ( + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -273,12 +274,12 @@ func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) { stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) checks.NoError(t, err) - stream.errAccumulator = &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, + stream.errAccumulator = &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, } _, err = stream.Recv() - checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error()) } // Helper funcs. diff --git a/error_accumulator.go b/error_accumulator.go deleted file mode 100644 index b54292090..000000000 --- a/error_accumulator.go +++ /dev/null @@ -1,44 +0,0 @@ -package openai - -import ( - "bytes" - "fmt" - "io" -) - -type errorAccumulator interface { - write(p []byte) error - bytes() []byte -} - -type errorBuffer interface { - io.Writer - Len() int - Bytes() []byte -} - -type defaultErrorAccumulator struct { - buffer errorBuffer -} - -func newErrorAccumulator() errorAccumulator { - return &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - } -} - -func (e *defaultErrorAccumulator) write(p []byte) error { - _, err := e.buffer.Write(p) - if err != nil { - return fmt.Errorf("error accumulator write error, %w", err) - } - return nil -} - -func (e *defaultErrorAccumulator) bytes() (errBytes []byte) { - if e.buffer.Len() == 0 { - return - } - errBytes = e.buffer.Bytes() - return -} diff --git a/error_accumulator_test.go b/error_accumulator_test.go deleted file mode 100644 index 4c7999f1c..000000000 --- a/error_accumulator_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "bytes" - "errors" - "testing" -) - -var ( - errTestUnmarshalerFailed = errors.New("test unmarshaler failed") - errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") -) - -type ( - failingUnMarshaller struct{} - failingErrorBuffer struct{} -) - -func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) { - return 0, errTestErrorAccumulatorWriteFailed -} - -func (b *failingErrorBuffer) Len() int { - return 0 -} - -func (b *failingErrorBuffer) Bytes() []byte { - return []byte{} -} - -func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { - return errTestUnmarshalerFailed -} - -func TestErrorAccumulatorBytes(t *testing.T) { - accumulator := &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - } - - errBytes := accumulator.bytes() - if len(errBytes) != 0 { - t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) - } - - err := accumulator.write([]byte("{}")) - if err != nil { - t.Fatalf("%+v", err) - } - - errBytes = accumulator.bytes() - if len(errBytes) == 0 { - t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) - } -} - -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - } - err := accumulator.write([]byte("{")) - if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) - } -} diff --git a/internal/error_accumulator.go b/internal/error_accumulator.go new file mode 100644 index 000000000..3d3e805fe --- /dev/null +++ b/internal/error_accumulator.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "fmt" + "io" +) + +type ErrorAccumulator interface { + Write(p []byte) error + Bytes() []byte +} + +type errorBuffer interface { + io.Writer + Len() int + Bytes() []byte +} + +type DefaultErrorAccumulator struct { + Buffer errorBuffer +} + +func NewErrorAccumulator() ErrorAccumulator { + return &DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } +} + +func (e *DefaultErrorAccumulator) Write(p []byte) error { + _, err := e.Buffer.Write(p) + if err != nil { + return fmt.Errorf("error accumulator write error, %w", err) + } + return nil +} + +func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { + if e.Buffer.Len() == 0 { + return + } + errBytes = e.Buffer.Bytes() + return +} diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go new file mode 100644 index 000000000..d48f28177 --- /dev/null +++ b/internal/error_accumulator_test.go @@ -0,0 +1,41 @@ +package openai_test + +import ( + "bytes" + "errors" + "testing" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" +) + +func TestErrorAccumulatorBytes(t *testing.T) { + accumulator := &utils.DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } + + errBytes := accumulator.Bytes() + if len(errBytes) != 0 { + t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) + } + + err := accumulator.Write([]byte("{}")) + if err != nil { + t.Fatalf("%+v", err) + } + + errBytes = accumulator.Bytes() + if len(errBytes) == 0 { + t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) + } +} + +func TestErrorByteWriteErrors(t *testing.T) { + accumulator := &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + } + err := accumulator.Write([]byte("{")) + if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("Did not return error when write failed: %v", err) + } +} diff --git a/internal/test/failer.go b/internal/test/failer.go new file mode 100644 index 000000000..10ca64e34 --- /dev/null +++ b/internal/test/failer.go @@ -0,0 +1,21 @@ +package test + +import "errors" + +var ( + ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") +) + +type FailingErrorBuffer struct{} + +func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) { + return 0, ErrTestErrorAccumulatorWriteFailed +} + +func (b *FailingErrorBuffer) Len() int { + return 0 +} + +func (b *FailingErrorBuffer) Bytes() []byte { + return []byte{} +} diff --git a/stream.go b/stream.go index b9e784acf..d4e352314 100644 --- a/stream.go +++ b/stream.go @@ -55,7 +55,7 @@ func (c *Client) CreateCompletionStream( emptyMessagesLimit: c.config.EmptyMessagesLimit, reader: bufio.NewReader(resp.Body), response: resp, - errAccumulator: newErrorAccumulator(), + errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, }, } diff --git a/stream_reader.go b/stream_reader.go index a6256d803..a9940b0ae 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -20,7 +20,7 @@ type streamReader[T streamable] struct { reader *bufio.Reader response *http.Response - errAccumulator errorAccumulator + errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler } @@ -45,7 +45,7 @@ waitForData: var headerData = []byte("data: ") line = bytes.TrimSpace(line) if !bytes.HasPrefix(line, headerData) { - if writeErr := stream.errAccumulator.write(line); writeErr != nil { + if writeErr := stream.errAccumulator.Write(line); writeErr != nil { err = writeErr return } @@ -70,7 +70,7 @@ waitForData: } func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { - errBytes := stream.errAccumulator.bytes() + errBytes := stream.errAccumulator.Bytes() if len(errBytes) == 0 { return } diff --git a/stream_reader_test.go b/stream_reader_test.go index 5d9a203f5..e6e8d0dbd 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -1,12 +1,23 @@ package openai //nolint:testpackage // testing private field import ( + "errors" "testing" + + utils "github.com/sashabaranov/go-openai/internal" ) +var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") + +type failingUnMarshaller struct{} + +func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { + return errTestUnmarshalerFailed +} + func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { stream := &streamReader[ChatCompletionStreamResponse]{ - errAccumulator: newErrorAccumulator(), + errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &failingUnMarshaller{}, } @@ -15,7 +26,7 @@ func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { t.Fatalf("Did not return nil with empty buffer: %v", respErr) } - err := stream.errAccumulator.write([]byte("{")) + err := stream.errAccumulator.Write([]byte("{")) if err != nil { t.Fatalf("%+v", err) }