Skip to content

Commit

Permalink
move error_accumulator into internal pkg (sashabaranov#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
vvatanabe committed May 31, 2023
1 parent 61ba5f3 commit 2c9d003
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 104 deletions.
48 changes: 37 additions & 11 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package openai_test
package openai //nolint:testpackage // testing private field

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -63,9 +62,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand Down Expand Up @@ -170,9 +169,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand Down Expand Up @@ -227,9 +226,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand All @@ -255,6 +254,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}

func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
var err error
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", 200)
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)

ctx := context.Background()

stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
checks.NoError(t, err)

stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
}

_, err = stream.Recv()
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
}

// Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
Expand Down
19 changes: 5 additions & 14 deletions error_accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@ import (
"bytes"
"fmt"
"io"

utils "github.com/sashabaranov/go-openai/internal"
)

type errorAccumulator interface {
write(p []byte) error
unmarshalError() *ErrorResponse
bytes() []byte
}

type errorBuffer interface {
Expand All @@ -20,14 +18,12 @@ type errorBuffer interface {
}

type defaultErrorAccumulator struct {
buffer errorBuffer
unmarshaler utils.Unmarshaler
buffer errorBuffer
}

func newErrorAccumulator() errorAccumulator {
return &defaultErrorAccumulator{
buffer: &bytes.Buffer{},
unmarshaler: &utils.JSONUnmarshaler{},
buffer: &bytes.Buffer{},
}
}

Expand All @@ -39,15 +35,10 @@ func (e *defaultErrorAccumulator) write(p []byte) error {
return nil
}

func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
func (e *defaultErrorAccumulator) bytes() (errBytes []byte) {
if e.buffer.Len() == 0 {
return
}

err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp)
if err != nil {
errResp = nil
}

errBytes = e.buffer.Bytes()
return
}
56 changes: 10 additions & 46 deletions error_accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@ package openai //nolint:testpackage // testing private field

import (
"bytes"
"context"
"errors"
"net/http"
"testing"

utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

var (
Expand Down Expand Up @@ -38,63 +32,33 @@ func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
return errTestUnmarshalerFailed
}

func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
func TestErrorAccumulatorBytes(t *testing.T) {
accumulator := &defaultErrorAccumulator{
buffer: &bytes.Buffer{},
unmarshaler: &failingUnMarshaller{},
buffer: &bytes.Buffer{},
}

respErr := accumulator.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
errBytes := accumulator.bytes()
if len(errBytes) != 0 {
t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes))
}

err := accumulator.write([]byte("{"))
err := accumulator.write([]byte("{}"))
if err != nil {
t.Fatalf("%+v", err)
}

respErr = accumulator.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
errBytes = accumulator.bytes()
if len(errBytes) == 0 {
t.Fatalf("Did not return error bytes when has error: %s", string(errBytes))
}
}

func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &utils.JSONUnmarshaler{},
buffer: &failingErrorBuffer{},
}
err := accumulator.write([]byte("{"))
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
}

func TestErrorAccumulatorWriteErrors(t *testing.T) {
var err error
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", 200)
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)

ctx := context.Background()

stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
checks.NoError(t, err)

stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &utils.JSONUnmarshaler{},
}

_, err = stream.Recv()
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
}
24 changes: 24 additions & 0 deletions internal/test/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package test
import (
"github.com/sashabaranov/go-openai/internal/test/checks"

"net/http"
"os"
"testing"
)
Expand All @@ -27,3 +28,26 @@ func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {

return path, func() { os.RemoveAll(path) }
}

// TokenRoundTripper is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each
// request include a valid API token in the headers for authentication and
// authorization.
type TokenRoundTripper struct {
Token string
Fallback http.RoundTripper
}

// RoundTrip takes an *http.Request as input and returns an
// *http.Response and an error.
//
// It is expected to use the provided request to create a connection to an HTTP
// server and return the response, or an error if one occurred. The returned
// Response should have its Body closed. If the RoundTrip method returns an
// error, the Client's Get, Head, Post, and PostForm methods return the same
// error.
func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.Token)
return t.Fallback.RoundTrip(req)
}
16 changes: 15 additions & 1 deletion stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
respErr := stream.errAccumulator.unmarshalError()
respErr := stream.unmarshalError()
if respErr != nil {
err = fmt.Errorf("error, %w", respErr.Error)
}
Expand Down Expand Up @@ -69,6 +69,20 @@ waitForData:
return
}

func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
errBytes := stream.errAccumulator.bytes()
if len(errBytes) == 0 {
return
}

err := stream.unmarshaler.Unmarshal(errBytes, &errResp)
if err != nil {
errResp = nil
}

return
}

func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}
27 changes: 27 additions & 0 deletions stream_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package openai //nolint:testpackage // testing private field

import (
"testing"
)

func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
errAccumulator: newErrorAccumulator(),
unmarshaler: &failingUnMarshaller{},
}

respErr := stream.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
}

err := stream.errAccumulator.write([]byte("{"))
if err != nil {
t.Fatalf("%+v", err)
}

respErr = stream.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
}
}
41 changes: 9 additions & 32 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ func TestCreateCompletionStream(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand Down Expand Up @@ -142,9 +142,9 @@ func TestCreateCompletionStreamError(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand Down Expand Up @@ -194,9 +194,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand All @@ -217,29 +217,6 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}

// A "tokenRoundTripper" is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each
// request include a valid API token in the headers for authentication and
// authorization.
type tokenRoundTripper struct {
token string
fallback http.RoundTripper
}

// RoundTrip takes an *http.Request as input and returns an
// *http.Response and an error.
//
// It is expected to use the provided request to create a connection to an HTTP
// server and return the response, or an error if one occurred. The returned
// Response should have its Body closed. If the RoundTrip method returns an
// error, the Client's Get, Head, Post, and PostForm methods return the same
// error.
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.token)
return t.fallback.RoundTrip(req)
}

// Helper funcs.
func compareResponses(r1, r2 CompletionResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
Expand Down

0 comments on commit 2c9d003

Please sign in to comment.