Skip to content

Commit

Permalink
uds: implement a connect timeout option
Browse files Browse the repository at this point in the history
  • Loading branch information
iksaif committed Jan 17, 2024
1 parent 2f549e6 commit e33f9c9
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 61 deletions.
13 changes: 13 additions & 0 deletions statsd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (
defaultWorkerCount = 32
defaultSenderQueueSize = 0
defaultWriteTimeout = 100 * time.Millisecond
defaultConnectTimeout = 1000 * time.Millisecond
defaultTelemetry = true
defaultReceivingMode = mutexMode
defaultChannelModeBufferSize = 4096
Expand All @@ -40,6 +41,7 @@ type Options struct {
workersCount int
senderQueueSize int
writeTimeout time.Duration
connectTimeout time.Duration
telemetry bool
receiveMode receivingMode
channelModeBufferSize int
Expand All @@ -65,6 +67,7 @@ func resolveOptions(options []Option) (*Options, error) {
workersCount: defaultWorkerCount,
senderQueueSize: defaultSenderQueueSize,
writeTimeout: defaultWriteTimeout,
connectTimeout: defaultConnectTimeout,
telemetry: defaultTelemetry,
receiveMode: defaultReceivingMode,
channelModeBufferSize: defaultChannelModeBufferSize,
Expand Down Expand Up @@ -206,6 +209,16 @@ func WithWriteTimeout(writeTimeout time.Duration) Option {
}
}

// WithConnectTimeout sets the timeout for network connection with the Agent, after this interval the connection
// attempt is aborted. This is only used for UDS connection. This will also reset the connection if nothing can be
// written to it for this duration.
func WithConnectTimeout(connectTimeout time.Duration) Option {
return func(o *Options) error {
o.connectTimeout = connectTimeout
return nil
}
}

// WithChannelMode make the client use channels to receive metrics
//
// This determines how the client receive metrics from the app (for example when calling the `Gauge()` method).
Expand Down
12 changes: 6 additions & 6 deletions statsd/statsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ func parseAgentURL(agentURL string) string {
return ""
}

func createWriter(addr string, writeTimeout time.Duration) (Transport, string, error) {
func createWriter(addr string, writeTimeout time.Duration, connectTimeout time.Duration) (Transport, string, error) {
addr = resolveAddr(addr)
if addr == "" {
return nil, "", errors.New("No address passed and autodetection from environment failed")
Expand All @@ -379,13 +379,13 @@ func createWriter(addr string, writeTimeout time.Duration) (Transport, string, e
w, err := newWindowsPipeWriter(addr, writeTimeout)
return w, writerWindowsPipe, err
case strings.HasPrefix(addr, UnixAddressPrefix):
w, err := newUDSWriter(addr[len(UnixAddressPrefix):], writeTimeout, "")
w, err := newUDSWriter(addr[len(UnixAddressPrefix):], writeTimeout, connectTimeout, "")
return w, writerNameUDS, err
case strings.HasPrefix(addr, UnixAddressDatagramPrefix):
w, err := newUDSWriter(addr[len(UnixAddressDatagramPrefix):], writeTimeout, "unixgram")
w, err := newUDSWriter(addr[len(UnixAddressDatagramPrefix):], writeTimeout, connectTimeout, "unixgram")
return w, writerNameUDS, err
case strings.HasPrefix(addr, UnixAddressStreamPrefix):
w, err := newUDSWriter(addr[len(UnixAddressStreamPrefix):], writeTimeout, "unix")
w, err := newUDSWriter(addr[len(UnixAddressStreamPrefix):], writeTimeout, connectTimeout, "unix")
return w, writerNameUDS, err
default:
w, err := newUDPWriter(addr, writeTimeout)
Expand All @@ -401,7 +401,7 @@ func New(addr string, options ...Option) (*Client, error) {
return nil, err
}

w, writerType, err := createWriter(addr, o.writeTimeout)
w, writerType, err := createWriter(addr, o.writeTimeout, o.connectTimeout)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -542,7 +542,7 @@ func newWithWriter(w Transport, o *Options, writerName string) (*Client, error)
c.telemetryClient = newTelemetryClient(&c, c.agg != nil)
} else {
var err error
c.telemetryClient, err = newTelemetryClientWithCustomAddr(&c, o.telemetryAddr, c.agg != nil, bufferPool, o.writeTimeout)
c.telemetryClient, err = newTelemetryClientWithCustomAddr(&c, o.telemetryAddr, c.agg != nil, bufferPool, o.writeTimeout, o.connectTimeout)
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions statsd/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ func newTelemetryClient(c *Client, aggregationEnabled bool) *telemetryClient {
return t
}

func newTelemetryClientWithCustomAddr(c *Client, telemetryAddr string, aggregationEnabled bool, pool *bufferPool, writeTimeout time.Duration) (*telemetryClient, error) {
telemetryWriter, _, err := createWriter(telemetryAddr, writeTimeout)
func newTelemetryClientWithCustomAddr(c *Client, telemetryAddr string, aggregationEnabled bool, pool *bufferPool,
writeTimeout time.Duration, connectTimeout time.Duration,
) (*telemetryClient, error) {
telemetryWriter, _, err := createWriter(telemetryAddr, writeTimeout, connectTimeout)
if err != nil {
return nil, fmt.Errorf("Could not resolve telemetry address: %v", err)
}
Expand Down
71 changes: 26 additions & 45 deletions statsd/uds.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ type udsWriter struct {
conn net.Conn
// write timeout
writeTimeout time.Duration
sync.RWMutex // used to lock conn / writer can replace it
// connect timeout
connectTimeout time.Duration
sync.RWMutex // used to lock conn / writer can replace it
}

// newUDSWriter returns a pointer to a new udsWriter given a socket file path as addr.
func newUDSWriter(addr string, writeTimeout time.Duration, transport string) (*udsWriter, error) {
func newUDSWriter(addr string, writeTimeout time.Duration, connectTimeout time.Duration, transport string) (*udsWriter, error) {
// Defer connection to first Write
writer := &udsWriter{addr: addr, transport: transport, conn: nil, writeTimeout: writeTimeout}
writer := &udsWriter{addr: addr, transport: transport, conn: nil, writeTimeout: writeTimeout, connectTimeout: connectTimeout}
return writer, nil
}

Expand All @@ -43,56 +45,23 @@ func (w *udsWriter) GetTransportName() string {
}
}

// retryOnWriteErr returns true if we should retry writing after a write error
func (w *udsWriter) retryOnWriteErr(err error, stream bool) bool {
// Never retry when using unixgram (to preserve the historical behavior)
if !stream {
return false
}
// Otherwise we retry on timeout because we might have written a partial packet
if networkError, ok := err.(net.Error); ok && networkError.Timeout() {
func (w *udsWriter) shouldCloseConnection(err error, partialWrite bool) bool {
if err != nil && partialWrite {
// We can't recover from a partial write
return true
}
return false
}

func (w *udsWriter) shouldCloseConnection(err error) bool {
if err, isNetworkErr := err.(net.Error); err != nil && (!isNetworkErr || !err.Timeout()) {
// Statsd server disconnected, retry connecting at next packet
return true
}
return false
}

// writeFull writes the whole data to the UDS connection
func (w *udsWriter) writeFull(data []byte, stopIfNoneWritten bool, stream bool) (int, error) {
written := 0
for written < len(data) {
n, e := w.conn.Write(data[written:])
written += n

// If we haven't written anything, and we're supposed to stop if we can't write anything, return the error
if written == 0 && stopIfNoneWritten {
return written, e
}

// If there's an error, check if it is retryable
if e != nil && !w.retryOnWriteErr(e, stream) {
return written, e
}

// When using "unix" we need to be able to finish to write partially written packets once we have started.
if stream {
w.conn.SetWriteDeadline(time.Time{})
}
}
return written, nil
}

// Write data to the UDS connection with write timeout and minimal error handling:
// create the connection if nil, and destroy it if the statsd server has disconnected
func (w *udsWriter) Write(data []byte) (int, error) {
var n int
partialWrite := false
conn, err := w.ensureConnection()
if err != nil {
return 0, err
Expand All @@ -107,15 +76,26 @@ func (w *udsWriter) Write(data []byte) (int, error) {
if stream {
bs := []byte{0, 0, 0, 0}
binary.LittleEndian.PutUint32(bs, uint32(len(data)))
_, err = w.writeFull(bs, true, true)
_, err = w.conn.Write(bs)

partialWrite = true

// W need to be able to finish to write partially written packets once we have started.
// But we will reset the connection if we can't write anything at all for a long time.
w.conn.SetWriteDeadline(time.Now().Add(w.connectTimeout))

// Continue writing only if we've written the length of the packet
if err == nil {
n, err = w.writeFull(data, false, true)
n, err = w.conn.Write(data)
if err == nil {
partialWrite = false
}
}
} else {
n, err = w.writeFull(data, true, false)
n, err = w.conn.Write(data)
}

if w.shouldCloseConnection(err) {
if w.shouldCloseConnection(err, partialWrite) {
w.unsetConnection()
}
return n, err
Expand All @@ -133,7 +113,7 @@ func (w *udsWriter) tryToDial(network string) (net.Conn, error) {
if err != nil {
return nil, err
}
newConn, err := net.Dial(udsAddr.Network(), udsAddr.String())
newConn, err := net.DialTimeout(udsAddr.Network(), udsAddr.String(), w.connectTimeout)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -182,5 +162,6 @@ func (w *udsWriter) ensureConnection() (net.Conn, error) {
func (w *udsWriter) unsetConnection() {
w.Lock()
defer w.Unlock()
_ = w.conn.Close()
w.conn = nil
}
50 changes: 42 additions & 8 deletions statsd/uds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ package statsd

import (
"encoding/binary"
"golang.org/x/net/nettest"
"math/rand"
"net"
"os"
"testing"
"time"

"golang.org/x/net/nettest"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -21,13 +22,13 @@ func init() {
}

func TestNewUDSWriter(t *testing.T) {
w, err := newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "")
w, err := newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "")
assert.NotNil(t, w)
assert.NoError(t, err)
w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "unix")
w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "unix")
assert.NotNil(t, w)
assert.NoError(t, err)
w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "unixgram")
w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "unixgram")
assert.NotNil(t, w)
assert.NoError(t, err)
}
Expand All @@ -44,7 +45,7 @@ func TestUDSDatagramWrite(t *testing.T) {
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, "")
w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

Expand Down Expand Up @@ -74,7 +75,7 @@ func TestUDSDatagramWriteUnsetConnection(t *testing.T) {
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, "")
w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

Expand Down Expand Up @@ -107,7 +108,7 @@ func TestUDSStreamWrite(t *testing.T) {
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, "")
w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

Expand All @@ -120,6 +121,7 @@ func TestUDSStreamWrite(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, len(msg), n)

// This works because the kernel accepts sockets before the accept call
if conn == nil {
conn, err = listener.Accept()
require.NoError(t, err)
Expand Down Expand Up @@ -148,7 +150,7 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) {
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, "")
w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

Expand All @@ -161,6 +163,7 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, len(msg), n)

// This works because the kernel accepts sockets before the accept call
if conn == nil {
conn, err = listener.Accept()
require.NoError(t, err)
Expand All @@ -180,3 +183,34 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) {
conn = nil
}
}

func TestUDSStreamPartialWrite(t *testing.T) {
socketPath, err := nettest.LocalPath()
require.NoError(t, err)
defer os.Remove(socketPath)

address, err := net.ResolveUnixAddr("unix", socketPath)
require.NoError(t, err)
listener, err := net.ListenUnix("unix", address)
defer listener.Close()
require.NoError(t, err)
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

// Force a connection
w.ensureConnection()
// Set a very low buffer size to force a partial write, but still enough to write the header
w.conn.(*net.UnixConn).SetWriteBuffer(8)

msg := []byte("some data")
n, err := w.Write(msg)
require.Error(t, err)
assert.Lessf(t, n, len(msg), "n: %d, len(msg): %d", n, len(msg))

// The connection should be dropped
assert.Nil(t, w.conn)
}

0 comments on commit e33f9c9

Please sign in to comment.