Skip to content

Commit

Permalink
Consolidate on functional options (#6)
Browse files Browse the repository at this point in the history
* Consolidate on functional options

This changes the signature of NewClient to only accept variadic Options,
and to return an error if one occurs while setting up the client.

This leaves us with the internal implementation details of the client
and its options fully encapsulated by the package, which should give us
substantial freedom to extend the API in future while minimising
breakage for users.

* Consolidate options into client

* Rename Option to ClientOption

* Add tests for initializing client from env

Extract env variable name into internal constant

---------

Co-authored-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
nickstenning and mattt authored Aug 6, 2023
1 parent fa2c77a commit 7c58cd1
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 44 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import (
"github.com/replicate/replicate-go"
)

client := replicate.NewClient(os.Getenv("REPLICATE_API_TOKEN"))
// You can also provide a token directly with `replicate.NewClient(replicate.WithToken("r8_..."))`
client := replicate.NewClient(replicate.WithTokenFromEnv())

// https://replicate.com/stability-ai/stable-diffusion
version := "db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
Expand Down
111 changes: 82 additions & 29 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,114 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
)

var (
envAuthToken = "REPLICATE_API_TOKEN"

defaultUserAgent = "replicate/go" // TODO: embed version information
defaultBaseURL = "https://api.replicate.com/v1"

ErrNoAuth = errors.New(`no auth token or token source provided -- perhaps you forgot to pass replicate.WithToken("...")`)
)

// Client is a client for the Replicate API.
type Client struct {
Auth string
UserAgent *string
BaseURL string
HTTPClient *http.Client
options *options
c *http.Client
}

// ClientOption is a function that modifies a Client.
type ClientOption func(*Client)
type options struct {
auth string
baseURL string
httpClient *http.Client
userAgent *string
}

// NewClient creates a new Replicate API client.
func NewClient(auth string, options ...ClientOption) *Client {
defaultUserAgent := "replicate-go"
defaultBaseURL := "https://api.replicate.com/v1"
defaultClient := http.DefaultClient
// ClientOption is a function that modifies an options struct.
type ClientOption func(*options) error

// NewClient creates a new Replicate API client.
func NewClient(opts ...ClientOption) (*Client, error) {
c := &Client{
Auth: auth,
UserAgent: &defaultUserAgent,
BaseURL: defaultBaseURL,
HTTPClient: defaultClient,
options: &options{
userAgent: &defaultUserAgent,
baseURL: defaultBaseURL,
httpClient: http.DefaultClient,
},
}

for _, option := range options {
option(c)
var errs []error
for _, option := range opts {
err := option(c.options)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return nil, errors.Join(errs...)
}

return c
if c.options.auth == "" {
return nil, ErrNoAuth
}

c.c = c.options.httpClient

return c, nil
}

// WithToken sets the auth token used by the client.
func WithToken(token string) ClientOption {
return func(o *options) error {
o.auth = token
return nil
}
}

// WithTokenFromEnv configures the client to use the auth token provided in the
// REPLICATE_API_TOKEN environment variable.
func WithTokenFromEnv() ClientOption {
return func(o *options) error {
token, ok := os.LookupEnv(envAuthToken)
if !ok {
return fmt.Errorf("%s environment variable not set", envAuthToken)
}
if token == "" {
return fmt.Errorf("%s environment variable is empty", envAuthToken)
}
o.auth = token
return nil
}
}

// WithUserAgent sets the User-Agent header on requests made by the client.
func WithUserAgent(userAgent string) ClientOption {
return func(c *Client) {
c.UserAgent = &userAgent
return func(o *options) error {
o.userAgent = &userAgent
return nil
}
}

// WithBaseURL sets the base URL for the client.
func WithBaseURL(baseURL string) ClientOption {
return func(c *Client) {
c.BaseURL = baseURL
return func(o *options) error {
o.baseURL = baseURL
return nil
}
}

// WithHTTPClient sets the HTTP client used by the client.
func WithHTTPClient(httpClient *http.Client) ClientOption {
return func(c *Client) {
c.HTTPClient = httpClient
return func(o *options) error {
o.httpClient = httpClient
return nil
}
}

Expand All @@ -73,19 +126,19 @@ func (r *Client) request(ctx context.Context, method, path string, body interfac
bodyBuffer = bytes.NewBuffer(bodyBytes)
}

url := constructURL(r.BaseURL, path)
url := constructURL(r.options.baseURL, path)
request, err := http.NewRequestWithContext(ctx, method, url, bodyBuffer)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", fmt.Sprintf("Token %s", r.Auth))
if r.UserAgent != nil {
request.Header.Set("User-Agent", *r.UserAgent)
request.Header.Set("Authorization", fmt.Sprintf("Token %s", r.options.auth))
if r.options.userAgent != nil {
request.Header.Set("User-Agent", *r.options.userAgent)
}

response, err := r.HTTPClient.Do(request)
response, err := r.c.Do(request)
if err != nil {
return fmt.Errorf("failed to make request: %w", err)
}
Expand Down
73 changes: 59 additions & 14 deletions replicate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,27 @@ import (

"github.com/replicate/replicate-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewClientNoAuth(t *testing.T) {
_, err := replicate.NewClient()

assert.ErrorIs(t, err, replicate.ErrNoAuth)
}

func TestNewClientBlankAuthTokenFromEnv(t *testing.T) {
t.Setenv("REPLICATE_API_TOKEN", "")
_, err := replicate.NewClient(replicate.WithTokenFromEnv())
require.ErrorContains(t, err, "REPLICATE_API_TOKEN")
}

func TestNewClientAuthTokenFromEnv(t *testing.T) {
t.Setenv("REPLICATE_API_TOKEN", "test-token")
_, err := replicate.NewClient(replicate.WithTokenFromEnv())
require.NoError(t, err)
}

func TestListCollections(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/collections", r.URL.Path)
Expand Down Expand Up @@ -53,9 +72,11 @@ func TestListCollections(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

initialPage, err := client.ListCollections(context.Background())
if err != nil {
Expand Down Expand Up @@ -105,9 +126,11 @@ func TestGetCollection(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down Expand Up @@ -145,9 +168,11 @@ func TestGetModel(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

model, err := client.GetModel(context.Background(), "replicate", "hello-world")
assert.NoError(t, err)
Expand All @@ -174,9 +199,11 @@ func TestListModelVersions(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

versionsPage, err := client.ListModelVersions(context.Background(), "replicate", "hello-world")
assert.NoError(t, err)
Expand All @@ -203,9 +230,11 @@ func TestGetModelVersion(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

version, err := client.GetModelVersion(context.Background(), "replicate", "hello-world", "version1")
assert.NoError(t, err)
Expand Down Expand Up @@ -261,9 +290,11 @@ func TestCreatePrediction(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

input := replicate.PredictionInput{"text": "Alice"}
webhook := replicate.Webhook{
Expand Down Expand Up @@ -325,9 +356,11 @@ func TestListPredictions(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

initialPage, err := client.ListPredictions(context.Background())
if err != nil {
Expand Down Expand Up @@ -379,9 +412,11 @@ func TestGetPrediction(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down Expand Up @@ -426,9 +461,11 @@ func TestWait(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

prediction := &replicate.Prediction{
ID: "ufawqhfynnddngldkgtslldrkq",
Expand All @@ -441,7 +478,7 @@ func TestWait(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

prediction, err := client.Wait(ctx, *prediction, time.Millisecond*100, 5)
prediction, err = client.Wait(ctx, *prediction, time.Millisecond*100, 5)
assert.NoError(t, err)
assert.Equal(t, "ufawqhfynnddngldkgtslldrkq", prediction.ID)
assert.Equal(t, replicate.Succeeded, prediction.Status)
Expand Down Expand Up @@ -472,9 +509,11 @@ func TestCreateTraining(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down Expand Up @@ -514,9 +553,11 @@ func TestGetTraining(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down Expand Up @@ -547,9 +588,11 @@ func TestCancelTraining(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down Expand Up @@ -582,9 +625,11 @@ func TestListTrainings(t *testing.T) {
}))
defer mockServer.Close()

client := replicate.NewClient("test-token",
client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down

0 comments on commit 7c58cd1

Please sign in to comment.