Skip to content

Commit

Permalink
fix(nats): drain connections upon error
Browse files Browse the repository at this point in the history
When failing to connect, drain connection rather than closing directly.
This is safer than calling close and will wait for connections before
closing, gracefully handling ongoing reconnects and avoid leaking
goroutines in test
  • Loading branch information
hspedro committed Feb 7, 2025
1 parent 289b204 commit 1c2f000
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 161 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ init-submodules:
@git submodule init

setup-ci:
@go get github.com/mattn/goveralls
@go install github.com/mattn/goveralls
@go get -u github.com/wadey/gocovmerge

setup-protobuf-macos:
Expand Down
1 change: 1 addition & 0 deletions cluster/nats_rpc_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ func TestNatsRPCClientCall(t *testing.T) {
t.Run(table.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
defer conn.Close()
assert.NoError(t, err)

sv2 := getServer()
Expand Down
48 changes: 45 additions & 3 deletions cluster/nats_rpc_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,46 @@ func getChannel(serverType, serverID string) string {
return fmt.Sprintf("pitaya/servers/%s/%s", serverType, serverID)
}

func drainAndClose(nc *nats.Conn) error {
if nc == nil {
return nil
}
// Drain connection (this will flush any pending messages and prevent new ones)
err := nc.Drain()
if err != nil {
logger.Log.Warnf("error draining nats connection: %v", err)
// Even if drain fails, try to close
nc.Close()
return err
}

// Wait for drain to complete with timeout
timeout := time.After(5 * time.Second)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()

for nc.IsDraining() {
select {
case <-ticker.C:
continue
case <-timeout:
logger.Log.Warn("drain timeout exceeded, forcing close")
nc.Close()
return fmt.Errorf("drain timeout exceeded")
}
}

// Close will happen automatically after drain completes
return nil
}

func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) {
connectedCh := make(chan bool)
initialConnectErrorCh := make(chan error)
natsOptions := append(
options,
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
logger.Log.Warnf("disconnected from nats! Reason: %q\n", err)
nats.DisconnectErrHandler(func(nc *nats.Conn, err error) {
logger.Log.Warnf("disconnected from nats (%s)! Reason: %q\n", nc.ConnectedAddr(), err)
}),
nats.ReconnectHandler(func(nc *nats.Conn) {
logger.Log.Warnf("reconnected to nats server %s with address %s in cluster %s!", nc.ConnectedServerName(), nc.ConnectedAddr(), nc.ConnectedClusterName())
Expand Down Expand Up @@ -78,7 +111,8 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O
logger.Log.Errorf(err.Error())
}
}),
nats.ConnectHandler(func(*nats.Conn) {
nats.ConnectHandler(func(nc *nats.Conn) {
logger.Log.Infof("connected to nats on %s", nc.ConnectedAddr())
connectedCh <- true
}),
)
Expand All @@ -104,8 +138,16 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O
case <-connectedCh:
return nc, nil
case err := <-initialConnectErrorCh:
drainErr := drainAndClose(nc)
if drainErr != nil {
logger.Log.Warnf("failed to drain and close: %s", drainErr)
}
return nil, err
case <-time.After(maxConnTimeout * 2):
drainErr := drainAndClose(nc)
if drainErr != nil {
logger.Log.Warnf("failed to drain and close: %s", drainErr)
}
return nil, fmt.Errorf("timeout setting up nats connection")
}
}
230 changes: 100 additions & 130 deletions cluster/nats_rpc_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"testing"
"time"

"github.com/nats-io/nats-server/v2/test"
nats "github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
"github.com/topfreegames/pitaya/v2/helpers"
Expand All @@ -47,167 +46,138 @@ func TestNatsRPCCommonGetChannel(t *testing.T) {

func TestNatsRPCCommonSetupNatsConn(t *testing.T) {
t.Parallel()
var conn *nats.Conn
s := helpers.GetTestNatsServer(t)
defer s.Shutdown()
defer func() {
drainAndClose(conn)
s.Shutdown()
s.WaitForShutdown()
}()
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
assert.NoError(t, err)
assert.NotNil(t, conn)
}

func TestNatsRPCCommonSetupNatsConnShouldError(t *testing.T) {
t.Parallel()
conn, err := setupNatsConn("nats://localhost:1234", nil)
conn, err := setupNatsConn("nats://invalid:1234", nil)
assert.Error(t, err)
assert.Nil(t, conn)
}

func TestNatsRPCCommonCloseHandler(t *testing.T) {
t.Parallel()
var conn *nats.Conn
s := helpers.GetTestNatsServer(t)
defer func() {
drainAndClose(conn)
s.Shutdown()
s.WaitForShutdown()
}()

dieChan := make(chan bool)

go func() {
value, ok := <-dieChan
assert.True(t, ok)
assert.True(t, value)
}()

conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), dieChan, nats.MaxReconnects(1),
nats.ReconnectWait(1*time.Millisecond))
assert.NoError(t, err)
assert.NotNil(t, conn)
}

s.Shutdown()

value, ok := <-dieChan
assert.True(t, ok)
assert.True(t, value)
func TestNatsRPCCommonWaitReconnections(t *testing.T) {
var conn *nats.Conn
ts := helpers.GetTestNatsServer(t)
defer func() {
drainAndClose(conn)
ts.Shutdown()
ts.WaitForShutdown()
}()

invalidAddr := "nats://invalid:4222"
validAddr := ts.ClientURL()

urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr)

// Setup connection with retry enabled
appDieCh := make(chan bool)
conn, err := setupNatsConn(
urls,
appDieCh,
nats.ReconnectWait(10*time.Millisecond),
nats.MaxReconnects(5),
nats.RetryOnFailedConnect(true),
)
assert.NoError(t, err)
assert.NotNil(t, conn)
assert.True(t, conn.IsConnected())
}

func TestSetupNatsConnReconnection(t *testing.T) {
t.Run("waits for reconnection on initial failure", func(t *testing.T) {
// Use an invalid address first to force initial connection failure
invalidAddr := "nats://invalid:4222"
validAddr := "nats://localhost:4222"
func TestNatsRPCCommonDoNotBlockOnConnectionFail(t *testing.T) {
invalidAddr := "nats://invalid:4222"

urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr)
appDieCh := make(chan bool)
done := make(chan any)

go func() {
time.Sleep(50 * time.Millisecond)
ts := test.RunDefaultServer()
defer ts.Shutdown()
<-time.After(200 * time.Millisecond)
}()
var conn *nats.Conn
ts := helpers.GetTestNatsServer(t)
defer func() {
drainAndClose(conn)
ts.Shutdown()
ts.WaitForShutdown()
}()

// Setup connection with retry enabled
appDieCh := make(chan bool)
go func() {
conn, err := setupNatsConn(
urls,
invalidAddr,
appDieCh,
nats.ReconnectWait(10*time.Millisecond),
nats.MaxReconnects(5),
nats.MaxReconnects(2),
nats.RetryOnFailedConnect(true),
)
assert.Error(t, err)
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-appDieCh:
case <-done:
case <-time.After(250 * time.Millisecond):
t.Fail()
}
}

assert.NoError(t, err)
assert.NotNil(t, conn)
assert.True(t, conn.IsConnected())

conn.Close()
})

t.Run("does not block indefinitely if all connect attempts fail", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(
invalidAddr,
appDieCh,
nats.ReconnectWait(10*time.Millisecond),
nats.MaxReconnects(2),
nats.RetryOnFailedConnect(true),
)
assert.Error(t, err)
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-appDieCh:
case <-done:
case <-time.After(250 * time.Millisecond):
t.Fail()
}
})

t.Run("if it fails to connect, exit with error even if appDieChan is not ready to listen", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(invalidAddr, appDieCh)
assert.Error(t, err)
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-done:
case <-time.After(50 * time.Millisecond):
t.Fail()
}
})

t.Run("if connection takes too long, exit with error after waiting maxReconnTimeout", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

initialConnectionTimeout := time.Nanosecond
maxReconnectionAtetmpts := 1
reconnectWait := time.Nanosecond
reconnectJitter := time.Nanosecond
maxReconnectionTimeout := reconnectWait + reconnectJitter + initialConnectionTimeout
maxReconnTimeout := initialConnectionTimeout + (time.Duration(maxReconnectionAtetmpts) * maxReconnectionTimeout)

maxTestTimeout := 100 * time.Millisecond

// Assert that if it fails because of connection timeout the test will capture
assert.Greater(t, maxTestTimeout, maxReconnTimeout)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(
invalidAddr,
appDieCh,
nats.Timeout(initialConnectionTimeout),
nats.ReconnectWait(reconnectWait),
nats.MaxReconnects(maxReconnectionAtetmpts),
nats.ReconnectJitter(reconnectJitter, reconnectJitter),
nats.RetryOnFailedConnect(true),
)
assert.Error(t, err)
assert.ErrorContains(t, err, "timeout setting up nats connection")
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-done:
case <-time.After(maxTestTimeout):
t.Fail()
}
})
func TestNatsRPCCommonFailWithoutAppDieChan(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

var conn *nats.Conn
ts := helpers.GetTestNatsServer(t)
defer func() {
drainAndClose(conn)
ts.Shutdown()
ts.WaitForShutdown()
}()

go func() {
conn, err := setupNatsConn(invalidAddr, appDieCh)
assert.Error(t, err)
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fail()
}
}
Loading

0 comments on commit 1c2f000

Please sign in to comment.