Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move request_builder into internal pkg (#304) #329

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion(
return
}

req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil {
return
}
Expand Down
6 changes: 3 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
type Client struct {
config ClientConfig

requestBuilder requestBuilder
requestBuilder utils.RequestBuilder
createFormBuilder func(io.Writer) utils.FormBuilder
}

Expand All @@ -29,7 +29,7 @@ func NewClient(authToken string) *Client {
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{
config: config,
requestBuilder: newRequestBuilder(),
requestBuilder: utils.NewRequestBuilder(),
createFormBuilder: func(body io.Writer) utils.FormBuilder {
return utils.NewFormBuilder(body)
},
Expand Down Expand Up @@ -135,7 +135,7 @@ func (c *Client) newStreamRequest(
urlSuffix string,
body any,
model string) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body)
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body)
if err != nil {
return nil, err
}
Expand Down
149 changes: 149 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@ package openai //nolint:testpackage // testing private field

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"testing"

"github.com/sashabaranov/go-openai/internal/test"
)

var errTestRequestBuilderFailed = errors.New("test request builder failed")

type failingRequestBuilder struct{}

func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}

func TestClient(t *testing.T) {
const mockToken = "mock token"
client := NewClient(mockToken)
Expand Down Expand Up @@ -145,3 +156,141 @@ func TestHandleErrorResp(t *testing.T) {
})
}
}

func TestClientReturnsRequestBuilderErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}

ctx := context.Background()

_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFineTunes(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CancelFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.DeleteFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFineTuneEvents(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

err = client.DeleteFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFiles(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListEngines(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetEngine(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListModels(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}

ctx := context.Background()

_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}
2 changes: 1 addition & 1 deletion completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (c *Client) CreateCompletion(
return
}

req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type EditsResponse struct {

// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
if err != nil {
return
}
Expand Down
4 changes: 2 additions & 2 deletions engines.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type EnginesList struct {
// ListEngines Lists the currently available engines, and provides basic
// information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
if err != nil {
return
}
Expand All @@ -38,7 +38,7 @@ func (c *Client) GetEngine(
engineID string,
) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}
Expand Down
34 changes: 34 additions & 0 deletions engines_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package openai_test

import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
func TestGetEngine(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(Engine{})
fmt.Fprintln(w, string(resBytes))
})
// create the test server
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

_, err := client.GetEngine(ctx, "text-davinci-003")
checks.NoError(t, err, "GetEngine error")
}
6 changes: 3 additions & 3 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File

// DeleteFile deletes an existing file.
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
if err != nil {
return
}
Expand All @@ -82,7 +82,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
// ListFiles Lists the currently available files,
// and provides basic information about each file such as the file name and purpose.
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil)
if err != nil {
return
}
Expand All @@ -95,7 +95,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
// such as the file name and purpose.
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
urlSuffix := fmt.Sprintf("/files/%s", fileID)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}
Expand Down
12 changes: 6 additions & 6 deletions fine_tunes.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct {

func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
urlSuffix := "/fine-tunes"
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}
Expand All @@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r

// CancelFineTune cancel a fine-tune job.
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
if err != nil {
return
}
Expand All @@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons
}

func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
if err != nil {
return
}
Expand All @@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err

func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}
Expand All @@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F
}

func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
if err != nil {
return
}
Expand All @@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons
}

func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion image.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type ImageResponseDataInner struct {
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
urlSuffix := "/images/generations"
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}
Expand Down
Loading