Skip to content

Commit

Permalink
add support for custom dial function with timeouts (#1669)
Browse files Browse the repository at this point in the history
* add support for custom dial function with timeouts

* fix linting

---------

Co-authored-by: Aviv Carmi <aviv@perimeterx.com>
  • Loading branch information
avivcarmis and avivpxi authored Nov 27, 2023
1 parent f196617 commit 8ca7a9c
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 34 deletions.
95 changes: 65 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,15 @@ type Client struct {

// Callback for establishing new connections to hosts.
//
// Default Dial is used if not set.
// Default DialTimeout is used if not set.
DialTimeout DialFuncWithTimeout

// Callback for establishing new connections to hosts.
//
// Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
// If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
//
// If not set, DialTimeout is used.
Dial DialFunc

// Attempt to connect to both ipv4 and ipv6 addresses if set to true.
Expand Down Expand Up @@ -505,6 +513,7 @@ func (c *Client) Do(req *Request, resp *Response) error {
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
Dial: c.Dial,
DialTimeout: c.DialTimeout,
DialDualStack: c.DialDualStack,
IsTLS: isTLS,
TLSConfig: c.TLSConfig,
Expand Down Expand Up @@ -624,6 +633,21 @@ const DefaultMaxIdemponentCallAttempts = 5
// - foobar.com:8080
type DialFunc func(addr string) (net.Conn, error)

// DialFuncWithTimeout must establish connection to addr.
// Unlike DialFunc, it also accepts a timeout.
//
// There is no need in establishing TLS (SSL) connection for https.
// The client automatically converts connection to TLS
// if HostClient.IsTLS is set.
//
// TCP address passed to DialFuncWithTimeout always contains host and port.
// Example TCP addr values:
//
// - foobar.com:80
// - foobar.com:443
// - foobar.com:8080
type DialFuncWithTimeout func(addr string, timeout time.Duration) (net.Conn, error)

// RetryIfFunc signature of retry if function
//
// Request argument passed to RetryIfFunc, if there are any request errors.
Expand Down Expand Up @@ -656,7 +680,7 @@ type HostClient struct {
noCopy noCopy

// Comma-separated list of upstream HTTP server host addresses,
// which are passed to Dial in a round-robin manner.
// which are passed to Dial or DialTimeout in a round-robin manner.
//
// Each address may contain port if default dialer is used.
// For example,
Expand All @@ -673,16 +697,24 @@ type HostClient struct {
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool

// Callback for establishing new connection to the host.
// Callback for establishing new connections to hosts.
//
// Default Dial is used if not set.
// Default DialTimeout is used if not set.
DialTimeout DialFuncWithTimeout

// Callback for establishing new connections to hosts.
//
// Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
// If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
//
// If not set, DialTimeout is used.
Dial DialFunc

// Attempt to connect to both ipv4 and ipv6 host addresses
// if set to true.
//
// This option is used only if default TCP dialer is used,
// i.e. if Dial is blank.
// i.e. if Dial and DialTimeout are blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
Expand Down Expand Up @@ -1827,7 +1859,8 @@ func (c *HostClient) nextAddr() string {
}

func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err error) {
// use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or dial has been set.
// use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or if
// c.DialTimeout has not been set and c.Dial has been set.
// attempt to dial all the available hosts before giving up.

c.addrsLock.Lock()
Expand All @@ -1839,16 +1872,6 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err
n = 1
}

dial := c.Dial
if dialTimeout != 0 && dial == nil {
dial = func(addr string) (net.Conn, error) {
if c.DialDualStack {
return DialDualStackTimeout(addr, dialTimeout)
}
return DialTimeout(addr, dialTimeout)
}
}

timeout := c.ReadTimeout + c.WriteTimeout
if timeout <= 0 {
timeout = DefaultDialTimeout
Expand All @@ -1857,7 +1880,7 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err
for n > 0 {
addr := c.nextAddr()
tlsConfig := c.cachedTLSConfig(addr)
conn, err = dialAddr(addr, dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
conn, err = dialAddr(addr, c.Dial, c.DialTimeout, c.DialDualStack, c.IsTLS, tlsConfig, dialTimeout, c.WriteTimeout)
if err == nil {
return conn, nil
}
Expand Down Expand Up @@ -1916,17 +1939,9 @@ func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.T
return conn, nil
}

func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
deadline := time.Now().Add(timeout)
if dial == nil {
if dialDualStack {
dial = DialDualStack
} else {
dial = Dial
}
addr = AddMissingPort(addr, isTLS)
}
conn, err := dial(addr)
func dialAddr(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, tlsConfig *tls.Config, dialTimeout, writeTimeout time.Duration) (net.Conn, error) {
deadline := time.Now().Add(writeTimeout)
conn, err := callDialFunc(addr, dial, dialWithTimeout, dialDualStack, isTLS, dialTimeout)
if err != nil {
return nil, err
}
Expand All @@ -1939,14 +1954,34 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
_, isTLSAlready := conn.(interface{ Handshake() error })

if isTLS && !isTLSAlready {
if timeout == 0 {
if writeTimeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, deadline)
}
return conn, nil
}

func callDialFunc(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, timeout time.Duration) (net.Conn, error) {
if dialWithTimeout != nil {
return dialWithTimeout(addr, timeout)
}
if dial != nil {
return dial(addr)
}
addr = AddMissingPort(addr, isTLS)
if timeout > 0 {
if dialDualStack {
return DialDualStackTimeout(addr, timeout)
}
return DialTimeout(addr, timeout)
}
if dialDualStack {
return DialDualStack(addr)
}
return Dial(addr)
}

// AddMissingPort adds a port to a host if it is missing.
// A literal IPv6 address in hostport must be enclosed in square
// brackets, as in "[::1]:80", "[::1%lo0]:80".
Expand Down Expand Up @@ -2591,7 +2626,7 @@ func (c *pipelineConnClient) init() {

func (c *pipelineConnClient) worker() error {
tlsConfig := c.cachedTLSConfig()
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout)
if err != nil {
return err
}
Expand Down
92 changes: 92 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -3392,3 +3393,94 @@ func Test_getRedirectURL(t *testing.T) {
})
}
}

type clientDoTimeOuter interface {
DoTimeout(req *Request, resp *Response, timeout time.Duration) error
}

func TestDialTimeout(t *testing.T) {
t.Parallel()

tests := []struct {
name string
client clientDoTimeOuter
requestTimeout time.Duration
shouldFailFast bool
}{
{
name: "Client should fail after a millisecond due to request timeout",
client: &Client{
// should be ignored due to DialTimeout
Dial: func(addr string) (net.Conn, error) {
time.Sleep(time.Second)
return nil, errors.New("timeout")
},
// should be used
DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
time.Sleep(timeout)
return nil, errors.New("timeout")
},
},
requestTimeout: time.Millisecond,
shouldFailFast: true,
},
{
name: "Client should fail after a second due to no DialTimeout set",
client: &Client{
Dial: func(addr string) (net.Conn, error) {
time.Sleep(time.Second)
return nil, errors.New("timeout")
},
},
requestTimeout: time.Millisecond,
shouldFailFast: false,
},
{
name: "HostClient should fail after a millisecond due to request timeout",
client: &HostClient{
// should be ignored due to DialTimeout
Dial: func(addr string) (net.Conn, error) {
time.Sleep(time.Second)
return nil, errors.New("timeout")
},
// should be used
DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
time.Sleep(timeout)
return nil, errors.New("timeout")
},
},
requestTimeout: time.Millisecond,
shouldFailFast: true,
},
{
name: "HostClient should fail after a second due to no DialTimeout set",
client: &HostClient{
Dial: func(addr string) (net.Conn, error) {
time.Sleep(time.Second)
return nil, errors.New("timeout")
},
},
requestTimeout: time.Millisecond,
shouldFailFast: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
start := time.Now()
err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout)
if err == nil {
t.Fatal("expected error (timeout)")
}
if tt.shouldFailFast {
if time.Since(start) > time.Second {
t.Fatal("expected timeout after a millisecond")
}
} else {
if time.Since(start) < time.Second {
t.Fatal("expected timeout after a second")
}
}
})
}
}
8 changes: 4 additions & 4 deletions tcpdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Dial(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
Expand Down Expand Up @@ -102,7 +102,7 @@ func DialDualStack(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
Expand Down Expand Up @@ -199,7 +199,7 @@ func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
Expand Down Expand Up @@ -253,7 +253,7 @@ func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
Expand Down

0 comments on commit 8ca7a9c

Please sign in to comment.