Skip to content

Commit

Permalink
HTTP client factory for per-request clients
Browse files Browse the repository at this point in the history
  • Loading branch information
ash2k committed Sep 1, 2022
1 parent 493aa4c commit 56e7391
Showing 1 changed file with 66 additions and 21 deletions.
87 changes: 66 additions & 21 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,24 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t
// attempted. If overriding this, be sure to close the body if needed.
type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error)

type HTTPClient interface {
// Do performs an HTTP request and returns an HTTP response.
Do(*http.Request) (*http.Response, error)
// Done is called when the client is no longer needed.
Done()
}

type HTTPClientFactory interface {
// New returns an HTTP client to use for a request, including retries.
New() HTTPClient
}

// Client is used to make HTTP requests. It adds additional functionality
// like automatic retries to tolerate minor outages.
type Client struct {
HTTPClient *http.Client // Internal HTTP client.
Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger
HTTPClient *http.Client // Internal HTTP client. This field is used if set, otherwise HTTPClientFactory is used.
HTTPClientFactory HTTPClientFactory
Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger

RetryWaitMin time.Duration // Minimum time to wait
RetryWaitMax time.Duration // Maximum time to wait
Expand All @@ -397,19 +410,18 @@ type Client struct {
ErrorHandler ErrorHandler

loggerInit sync.Once
clientInit sync.Once
}

// NewClient creates a new Client with default settings.
func NewClient() *Client {
return &Client{
HTTPClient: cleanhttp.DefaultPooledClient(),
Logger: defaultLogger,
RetryWaitMin: defaultRetryWaitMin,
RetryWaitMax: defaultRetryWaitMax,
RetryMax: defaultRetryMax,
CheckRetry: DefaultRetryPolicy,
Backoff: DefaultBackoff,
HTTPClientFactory: &CleanPooledClientFactory{},
Logger: defaultLogger,
RetryWaitMin: defaultRetryWaitMin,
RetryWaitMax: defaultRetryWaitMax,
RetryMax: defaultRetryMax,
CheckRetry: DefaultRetryPolicy,
Backoff: DefaultBackoff,
}
}

Expand Down Expand Up @@ -573,12 +585,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo

// Do wraps calling an HTTP method with retries.
func (c *Client) Do(req *Request) (*http.Response, error) {
c.clientInit.Do(func() {
if c.HTTPClient == nil {
c.HTTPClient = cleanhttp.DefaultPooledClient()
}
})

logger := c.logger()

if logger != nil {
Expand All @@ -590,6 +596,9 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
}
}

httpClient := c.getHTTPClient()
defer httpClient.Done()

var resp *http.Response
var attempt int
var shouldRetry bool
Expand All @@ -603,7 +612,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
if req.body != nil {
body, err := req.body()
if err != nil {
c.HTTPClient.CloseIdleConnections()
return resp, err
}
if c, ok := body.(io.ReadCloser); ok {
Expand All @@ -625,7 +633,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
}

// Attempt the request
resp, doErr = c.HTTPClient.Do(req.Request)

resp, doErr = httpClient.Do(req.Request)

// Check if we should continue with retries.
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)
Expand Down Expand Up @@ -694,7 +703,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
select {
case <-req.Context().Done():
timer.Stop()
c.HTTPClient.CloseIdleConnections()
return nil, req.Context().Err()
case <-timer.C:
}
Expand All @@ -710,8 +718,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
return resp, nil
}

defer c.HTTPClient.CloseIdleConnections()

var err error
if checkErr != nil {
err = checkErr
Expand Down Expand Up @@ -758,6 +764,19 @@ func (c *Client) drainBody(body io.ReadCloser) {
}
}

func (c *Client) getHTTPClient() HTTPClient {
if c.HTTPClient != nil {
return &idleConnectionsClosingClient{
httpClient: c.HTTPClient,
}
}
clientFactory := c.HTTPClientFactory
if clientFactory == nil {
clientFactory = &CleanPooledClientFactory{}
}
return clientFactory.New()
}

// Get is a shortcut for doing a GET request without making a new client.
func Get(url string) (*http.Response, error) {
return defaultClient.Get(url)
Expand Down Expand Up @@ -820,3 +839,29 @@ func (c *Client) StandardClient() *http.Client {
Transport: &RoundTripper{Client: c},
}
}

var (
_ HTTPClientFactory = &CleanPooledClientFactory{}
_ HTTPClient = &idleConnectionsClosingClient{}
)

type CleanPooledClientFactory struct {
}

func (f *CleanPooledClientFactory) New() HTTPClient {
return &idleConnectionsClosingClient{
httpClient: cleanhttp.DefaultPooledClient(),
}
}

type idleConnectionsClosingClient struct {
httpClient *http.Client
}

func (c *idleConnectionsClosingClient) Do(req *http.Request) (*http.Response, error) {
return c.httpClient.Do(req)
}

func (c *idleConnectionsClosingClient) Done() {
c.httpClient.CloseIdleConnections()
}

0 comments on commit 56e7391

Please sign in to comment.