Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move error_accumulator into internal pkg (#304) #335

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (c *Client) CreateChatCompletionStream(
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
},
}
Expand Down
49 changes: 38 additions & 11 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package openai_test
package openai //nolint:testpackage // testing private field

import (
. "github.com/sashabaranov/go-openai"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -63,9 +63,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand Down Expand Up @@ -170,9 +170,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand Down Expand Up @@ -227,9 +227,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}

client := NewClientWithConfig(config)
Expand All @@ -255,6 +255,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}

func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
var err error
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", 200)
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)

ctx := context.Background()

stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
checks.NoError(t, err)

stream.errAccumulator = &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
}

_, err = stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error())
}

// Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
Expand Down
53 changes: 0 additions & 53 deletions error_accumulator.go

This file was deleted.

100 changes: 0 additions & 100 deletions error_accumulator_test.go

This file was deleted.

44 changes: 44 additions & 0 deletions internal/error_accumulator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package openai

import (
"bytes"
"fmt"
"io"
)

type ErrorAccumulator interface {
Write(p []byte) error
Bytes() []byte
}

type errorBuffer interface {
io.Writer
Len() int
Bytes() []byte
}

type DefaultErrorAccumulator struct {
Buffer errorBuffer
}

func NewErrorAccumulator() ErrorAccumulator {
return &DefaultErrorAccumulator{
Buffer: &bytes.Buffer{},
}
}

func (e *DefaultErrorAccumulator) Write(p []byte) error {
_, err := e.Buffer.Write(p)
if err != nil {
return fmt.Errorf("error accumulator write error, %w", err)
}
return nil
}

func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) {
if e.Buffer.Len() == 0 {
return
}
errBytes = e.Buffer.Bytes()
return
}
41 changes: 41 additions & 0 deletions internal/error_accumulator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package openai_test

import (
"bytes"
"errors"
"testing"

utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
)

func TestErrorAccumulatorBytes(t *testing.T) {
accumulator := &utils.DefaultErrorAccumulator{
Buffer: &bytes.Buffer{},
}

errBytes := accumulator.Bytes()
if len(errBytes) != 0 {
t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes))
}

err := accumulator.Write([]byte("{}"))
if err != nil {
t.Fatalf("%+v", err)
}

errBytes = accumulator.Bytes()
if len(errBytes) == 0 {
t.Fatalf("Did not return error bytes when has error: %s", string(errBytes))
}
}

func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
}
err := accumulator.Write([]byte("{"))
if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
}
21 changes: 21 additions & 0 deletions internal/test/failer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package test

import "errors"

var (
ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
)

type FailingErrorBuffer struct{}

func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) {
return 0, ErrTestErrorAccumulatorWriteFailed
}

func (b *FailingErrorBuffer) Len() int {
return 0
}

func (b *FailingErrorBuffer) Bytes() []byte {
return []byte{}
}
24 changes: 24 additions & 0 deletions internal/test/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package test
import (
"github.com/sashabaranov/go-openai/internal/test/checks"

"net/http"
"os"
"testing"
)
Expand All @@ -27,3 +28,26 @@ func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {

return path, func() { os.RemoveAll(path) }
}

// TokenRoundTripper is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each
// request include a valid API token in the headers for authentication and
// authorization.
type TokenRoundTripper struct {
Token string
Fallback http.RoundTripper
}

// RoundTrip takes an *http.Request as input and returns an
// *http.Response and an error.
//
// It is expected to use the provided request to create a connection to an HTTP
// server and return the response, or an error if one occurred. The returned
// Response should have its Body closed. If the RoundTrip method returns an
// error, the Client's Get, Head, Post, and PostForm methods return the same
// error.
func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.Token)
return t.Fallback.RoundTrip(req)
}
2 changes: 1 addition & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (c *Client) CreateCompletionStream(
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
},
}
Expand Down
Loading