Skip to content

Commit

Permalink
move error_accumulator into internal pkg (sashabaranov#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
vvatanabe committed May 31, 2023
1 parent 2c9d003 commit 15653fd
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 118 deletions.
2 changes: 1 addition & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (c *Client) CreateChatCompletionStream(
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
},
}
Expand Down
7 changes: 4 additions & 3 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field

import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -273,12 +274,12 @@ func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
checks.NoError(t, err)

stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
stream.errAccumulator = &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
}

_, err = stream.Recv()
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error())
}

// Helper funcs.
Expand Down
44 changes: 0 additions & 44 deletions error_accumulator.go

This file was deleted.

64 changes: 0 additions & 64 deletions error_accumulator_test.go

This file was deleted.

44 changes: 44 additions & 0 deletions internal/error_accumulator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package openai

import (
"bytes"
"fmt"
"io"
)

type ErrorAccumulator interface {
Write(p []byte) error
Bytes() []byte
}

type errorBuffer interface {
io.Writer
Len() int
Bytes() []byte
}

type DefaultErrorAccumulator struct {
Buffer errorBuffer
}

func NewErrorAccumulator() ErrorAccumulator {
return &DefaultErrorAccumulator{
Buffer: &bytes.Buffer{},
}
}

func (e *DefaultErrorAccumulator) Write(p []byte) error {
_, err := e.Buffer.Write(p)
if err != nil {
return fmt.Errorf("error accumulator write error, %w", err)
}
return nil
}

func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) {
if e.Buffer.Len() == 0 {
return
}
errBytes = e.Buffer.Bytes()
return
}
41 changes: 41 additions & 0 deletions internal/error_accumulator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package openai_test

import (
"bytes"
"errors"
"testing"

utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
)

func TestErrorAccumulatorBytes(t *testing.T) {
accumulator := &utils.DefaultErrorAccumulator{
Buffer: &bytes.Buffer{},
}

errBytes := accumulator.Bytes()
if len(errBytes) != 0 {
t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes))
}

err := accumulator.Write([]byte("{}"))
if err != nil {
t.Fatalf("%+v", err)
}

errBytes = accumulator.Bytes()
if len(errBytes) == 0 {
t.Fatalf("Did not return error bytes when has error: %s", string(errBytes))
}
}

func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
}
err := accumulator.Write([]byte("{"))
if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
}
21 changes: 21 additions & 0 deletions internal/test/failer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package test

import "errors"

var (
ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
)

type FailingErrorBuffer struct{}

func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) {
return 0, ErrTestErrorAccumulatorWriteFailed
}

func (b *FailingErrorBuffer) Len() int {
return 0
}

func (b *FailingErrorBuffer) Bytes() []byte {
return []byte{}
}
2 changes: 1 addition & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (c *Client) CreateCompletionStream(
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
},
}
Expand Down
6 changes: 3 additions & 3 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type streamReader[T streamable] struct {

reader *bufio.Reader
response *http.Response
errAccumulator errorAccumulator
errAccumulator utils.ErrorAccumulator
unmarshaler utils.Unmarshaler
}

Expand All @@ -45,7 +45,7 @@ waitForData:
var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
if writeErr := stream.errAccumulator.write(line); writeErr != nil {
if writeErr := stream.errAccumulator.Write(line); writeErr != nil {
err = writeErr
return
}
Expand All @@ -70,7 +70,7 @@ waitForData:
}

func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
errBytes := stream.errAccumulator.bytes()
errBytes := stream.errAccumulator.Bytes()
if len(errBytes) == 0 {
return
}
Expand Down
15 changes: 13 additions & 2 deletions stream_reader_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
package openai //nolint:testpackage // testing private field

import (
"errors"
"testing"

utils "github.com/sashabaranov/go-openai/internal"
)

var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")

type failingUnMarshaller struct{}

func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
return errTestUnmarshalerFailed
}

func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
errAccumulator: newErrorAccumulator(),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &failingUnMarshaller{},
}

Expand All @@ -15,7 +26,7 @@ func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) {
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
}

err := stream.errAccumulator.write([]byte("{"))
err := stream.errAccumulator.Write([]byte("{"))
if err != nil {
t.Fatalf("%+v", err)
}
Expand Down

0 comments on commit 15653fd

Please sign in to comment.