From f750d00d8f1a6f22aada1f7fbeae4509a3a74216 Mon Sep 17 00:00:00 2001 From: vvatanabe Date: Sun, 28 May 2023 13:00:00 +0900 Subject: [PATCH 1/3] move request_builder into internal pkg (#304) --- chat.go | 2 +- client.go | 6 +- client_test.go | 149 +++++++++++++++ completion.go | 2 +- edits.go | 2 +- embeddings.go | 2 +- engines.go | 4 +- files.go | 6 +- fine_tunes.go | 12 +- image.go | 2 +- .../request_builder.go | 18 +- internal/request_builder_test.go | 26 +++ models.go | 2 +- moderation.go | 2 +- request_builder_test.go | 177 ------------------ 15 files changed, 204 insertions(+), 208 deletions(-) rename request_builder.go => internal/request_builder.go (52%) create mode 100644 internal/request_builder_test.go delete mode 100644 request_builder_test.go diff --git a/chat.go b/chat.go index 312ef8e20..a7ce5486a 100644 --- a/chat.go +++ b/chat.go @@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) if err != nil { return } diff --git a/client.go b/client.go index c55166aa6..2486e36b6 100644 --- a/client.go +++ b/client.go @@ -15,7 +15,7 @@ import ( type Client struct { config ClientConfig - requestBuilder requestBuilder + requestBuilder utils.RequestBuilder createFormBuilder func(io.Writer) utils.FormBuilder } @@ -29,7 +29,7 @@ func NewClient(authToken string) *Client { func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, - requestBuilder: newRequestBuilder(), + requestBuilder: utils.NewRequestBuilder(), createFormBuilder: func(body io.Writer) utils.FormBuilder { return utils.NewFormBuilder(body) }, @@ -135,7 +135,7 @@ func (c *Client) newStreamRequest( urlSuffix string, body any, model string) (*http.Request, error) { - req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body) + req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body) if err != nil { return nil, err } diff --git a/client_test.go b/client_test.go index e30fa399b..862cbe856 100644 --- a/client_test.go +++ b/client_test.go @@ -2,13 +2,24 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" "errors" "fmt" "io" "net/http" "testing" + + "github.com/sashabaranov/go-openai/internal/test" ) +var errTestRequestBuilderFailed = errors.New("test request builder failed") + +type failingRequestBuilder struct{} + +func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) { + return nil, errTestRequestBuilderFailed +} + func TestClient(t *testing.T) { const mockToken = "mock token" client := NewClient(mockToken) @@ -145,3 +156,141 @@ func TestHandleErrorResp(t *testing.T) { }) } } + +func TestClientReturnsRequestBuilderErrors(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + ctx := context.Background() + + _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFineTunes(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CancelFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.DeleteFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFineTuneEvents(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.Moderations(ctx, ModerationRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.Edits(ctx, EditsRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateImage(ctx, ImageRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + err = client.DeleteFile(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetFile(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFiles(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListEngines(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetEngine(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListModels(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } +} + +func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + ctx := context.Background() + + _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) + if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) + if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } +} diff --git a/completion.go b/completion.go index e3d1b85eb..de1360fd9 100644 --- a/completion.go +++ b/completion.go @@ -155,7 +155,7 @@ func (c *Client) CreateCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) if err != nil { return } diff --git a/edits.go b/edits.go index c2c8db794..23b1a64f0 100644 --- a/edits.go +++ b/edits.go @@ -32,7 +32,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index 7fb432ead..942f3ea3a 100644 --- a/embeddings.go +++ b/embeddings.go @@ -132,7 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) if err != nil { return } diff --git a/engines.go b/engines.go index bb6a66ce4..ac01a00ed 100644 --- a/engines.go +++ b/engines.go @@ -22,7 +22,7 @@ type EnginesList struct { // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil) if err != nil { return } @@ -38,7 +38,7 @@ func (c *Client) GetEngine( engineID string, ) (engine Engine, err error) { urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } diff --git a/files.go b/files.go index 5667ec861..36c024365 100644 --- a/files.go +++ b/files.go @@ -70,7 +70,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) if err != nil { return } @@ -82,7 +82,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil) if err != nil { return } @@ -95,7 +95,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } diff --git a/fine_tunes.go b/fine_tunes.go index a1218670f..069ddccfd 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct { func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } @@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // CancelFineTune cancel a fine-tune job. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) if err != nil { return } @@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) if err != nil { return } @@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } @@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F } func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) if err != nil { return } @@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) if err != nil { return } diff --git a/image.go b/image.go index 87ffea25e..df7363865 100644 --- a/image.go +++ b/image.go @@ -44,7 +44,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } diff --git a/request_builder.go b/internal/request_builder.go similarity index 52% rename from request_builder.go rename to internal/request_builder.go index b4db07c2f..0a9eabfde 100644 --- a/request_builder.go +++ b/internal/request_builder.go @@ -4,25 +4,23 @@ import ( "bytes" "context" "net/http" - - utils "github.com/sashabaranov/go-openai/internal" ) -type requestBuilder interface { - build(ctx context.Context, method, url string, request any) (*http.Request, error) +type RequestBuilder interface { + Build(ctx context.Context, method, url string, request any) (*http.Request, error) } -type httpRequestBuilder struct { - marshaller utils.Marshaller +type HTTPRequestBuilder struct { + marshaller Marshaller } -func newRequestBuilder() *httpRequestBuilder { - return &httpRequestBuilder{ - marshaller: &utils.JSONMarshaller{}, +func NewRequestBuilder() *HTTPRequestBuilder { + return &HTTPRequestBuilder{ + marshaller: &JSONMarshaller{}, } } -func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) { +func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) { if request == nil { return http.NewRequestWithContext(ctx, method, url, nil) } diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go new file mode 100644 index 000000000..e981ca213 --- /dev/null +++ b/internal/request_builder_test.go @@ -0,0 +1,26 @@ +package openai //nolint:testpackage // testing private field + +import ( + "context" + "errors" + "testing" +) + +var errTestMarshallerFailed = errors.New("test marshaller failed") + +type failingMarshaller struct{} + +func (*failingMarshaller) Marshal(_ any) ([]byte, error) { + return []byte{}, errTestMarshallerFailed +} + +func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { + builder := HTTPRequestBuilder{ + marshaller: &failingMarshaller{}, + } + + _, err := builder.Build(context.Background(), "", "", struct{}{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } +} diff --git a/models.go b/models.go index 2be91aadb..485433b26 100644 --- a/models.go +++ b/models.go @@ -40,7 +40,7 @@ type ModelsList struct { // ListModels Lists the currently available models, // and provides basic information about each model such as the model id and parent. func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil) if err != nil { return } diff --git a/moderation.go b/moderation.go index ebd66afb9..bae788035 100644 --- a/moderation.go +++ b/moderation.go @@ -63,7 +63,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) if err != nil { return } diff --git a/request_builder_test.go b/request_builder_test.go deleted file mode 100644 index ed4b69113..000000000 --- a/request_builder_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "errors" - "net/http" - "testing" -) - -var ( - errTestMarshallerFailed = errors.New("test marshaller failed") - errTestRequestBuilderFailed = errors.New("test request builder failed") -) - -type ( - failingRequestBuilder struct{} - failingMarshaller struct{} -) - -func (*failingMarshaller) Marshal(_ any) ([]byte, error) { - return []byte{}, errTestMarshallerFailed -} - -func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) { - return nil, errTestRequestBuilderFailed -} - -func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { - builder := httpRequestBuilder{ - marshaller: &failingMarshaller{}, - } - - _, err := builder.build(context.Background(), "", "", struct{}{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } -} - -func TestClientReturnsRequestBuilderErrors(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - client.requestBuilder = &failingRequestBuilder{} - - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTunes(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CancelFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.DeleteFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTuneEvents(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Moderations(ctx, ModerationRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Edits(ctx, EditsRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateImage(ctx, ImageRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - err = client.DeleteFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFiles(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListEngines(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetEngine(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListModels(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } -} - -func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - client.requestBuilder = &failingRequestBuilder{} - - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) - if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) - if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } -} From 7eb94ba68d0136beed1138c568ed237f7abfc7ca Mon Sep 17 00:00:00 2001 From: vvatanabe Date: Tue, 30 May 2023 09:09:56 +0900 Subject: [PATCH 2/3] add some test for internal.RequestBuilder --- internal/request_builder_test.go | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e981ca213..e47d0f6ca 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -1,8 +1,11 @@ package openai //nolint:testpackage // testing private field import ( + "bytes" "context" "errors" + "net/http" + "reflect" "testing" ) @@ -24,3 +27,35 @@ func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { t.Fatalf("Did not return error when marshaller failed: %v", err) } } + +func TestRequestBuilderReturnsRequest(t *testing.T) { + b := NewRequestBuilder() + var ( + ctx = context.Background() + method = http.MethodPost + url = "/foo" + request = map[string]string{"foo": "bar"} + reqBytes, _ = b.marshaller.Marshal(request) + want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) + ) + got, _ := b.Build(ctx, method, url, request) + if !reflect.DeepEqual(got.Body, want.Body) || + !reflect.DeepEqual(got.URL, want.URL) || + !reflect.DeepEqual(got.Method, want.Method) { + t.Errorf("Build() got = %v, want %v", got, want) + } +} + +func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { + var ( + ctx = context.Background() + method = http.MethodGet + url = "/foo" + want, _ = http.NewRequestWithContext(ctx, method, url, nil) + ) + b := NewRequestBuilder() + got, _ := b.Build(ctx, method, url, nil) + if !reflect.DeepEqual(got, want) { + t.Errorf("Build() got = %v, want %v", got, want) + } +} From a6d1c8f06dfdc0f283cba816eaf18cf1bf20b813 Mon Sep 17 00:00:00 2001 From: vvatanabe Date: Tue, 30 May 2023 12:54:51 +0900 Subject: [PATCH 3/3] add a test for openai.GetEngine --- engines_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 engines_test.go diff --git a/engines_test.go b/engines_test.go new file mode 100644 index 000000000..dfa3187cf --- /dev/null +++ b/engines_test.go @@ -0,0 +1,34 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. +func TestGetEngine(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(Engine{}) + fmt.Fprintln(w, string(resBytes)) + }) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.GetEngine(ctx, "text-davinci-003") + checks.NoError(t, err, "GetEngine error") +}