Skip to content

Commit

Permalink
Wake up connection borrowers when pool is closed (#489)
Browse files Browse the repository at this point in the history
This speeds up the pool responsiveness quite a bit.
Borrowers would only be woken up after the connection acquisition
timeout otherwise (if any).

This problem was noticed while trying to unsuccessfully
reproduce #451

Co-authored-by: Rouven Bauer <rouven.bauer@neo4j.com>
  • Loading branch information
fbiville and robsdedude authored May 24, 2023
1 parent 786866e commit 29a9396
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 73 deletions.
14 changes: 8 additions & 6 deletions neo4j/internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ func New(config *config.Config, connect Connect, logger log.Logger, logId string

func (p *Pool) Close(ctx context.Context) error {
p.closed = true
// Cancel everything in the queue by just emptying at and let all callers timeout
if !p.queueMut.TryLock(ctx) {
return racing.LockTimeoutError("could not acquire queue lock in time when closing pool")
}
p.queue.Init()
for e := p.queue.Front(); e != nil; e = e.Next() {
queuedRequest := e.Value.(*qitem)
p.queue.Remove(e)
queuedRequest.wakeup <- true
}
p.queueMut.Unlock()
// Go through each server and close all connections to it
if !p.serversMut.TryLock(ctx) {
Expand Down Expand Up @@ -213,11 +216,10 @@ serverLoop:
}

func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) {
if p.closed {
return nil, &errorutil.PoolClosed{}
}

for {
if p.closed {
return nil, &errorutil.PoolClosed{}
}
serverNames, err := getServerNames(ctx)
if err != nil {
return nil, err
Expand Down
144 changes: 77 additions & 67 deletions neo4j/internal/pool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
"testing"
"time"

"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil"
. "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/log"
)

Expand All @@ -45,7 +45,7 @@ func TestPoolBorrowReturn(outer *testing.T) {
birthdate := time.Now()

succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) {
return &testutil.ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
}

failingError := errors.New("whatever")
Expand Down Expand Up @@ -105,14 +105,7 @@ func TestPoolBorrowReturn(outer *testing.T) {
wg.Done()
}()

// Wait until entered queue
for {
if size, err := p.queueSize(ctx); err != nil {
t.Errorf("should not fail computing queue size, got: %v", err)
} else if size > 0 {
break
}
}
waitForBorrowers(t, p, 1)

// Give back the connection
if err := p.Return(ctx, c1); err != nil {
Expand Down Expand Up @@ -209,14 +202,7 @@ func TestPoolBorrowReturn(outer *testing.T) {
wg.Done()
}()

// Wait until entered queue
for {
if size, err := p.queueSize(cancelableCtx); err != nil {
t.Errorf("should not fail computing queue size, got: %v", err)
} else if size > 0 {
break
}
}
waitForBorrowers(t, p, 1)
cancel()
wg.Wait()
if err := p.Return(ctx, c1); err != nil {
Expand All @@ -233,8 +219,8 @@ func TestPoolBorrowReturn(outer *testing.T) {
idlenessThreshold := 1 * time.Hour
idleness := time.Now().Add(-2 * idlenessThreshold)
deadAfterReset := deadConnectionAfterForceReset("deadAfterReset", idleness)
stayingAlive := &testutil.ConnFake{Alive: true, Idle: idleness, Name: "stayingAlive", ForceResetHook: func() {}}
whatATimeToBeAlive := &testutil.ConnFake{Alive: true, Idle: idleness, Name: "whatATimeToBeAlive", ForceResetHook: func() {
stayingAlive := &ConnFake{Alive: true, Idle: idleness, Name: "stayingAlive", ForceResetHook: func() {}}
whatATimeToBeAlive := &ConnFake{Alive: true, Idle: idleness, Name: "whatATimeToBeAlive", ForceResetHook: func() {
t.Errorf("y u call me?")
}}
timer := time.Now
Expand All @@ -248,8 +234,8 @@ func TestPoolBorrowReturn(outer *testing.T) {

result, err := pool.tryBorrow(ctx, "a server", nil, idlenessThreshold, reAuthToken)

testutil.AssertNil(t, err)
testutil.AssertDeepEquals(t, result, stayingAlive)
AssertNil(t, err)
AssertDeepEquals(t, result, stayingAlive)
})

outer.Run("Borrows new connection if resets of all long-idle connections fail", func(t *testing.T) {
Expand All @@ -258,7 +244,7 @@ func TestPoolBorrowReturn(outer *testing.T) {
idleness := time.Now().Add(-2 * idlenessThreshold)
deadAfterReset1 := deadConnectionAfterForceReset("deadAfterReset1", idleness)
deadAfterReset2 := deadConnectionAfterForceReset("deadAfterReset2", idleness)
healthyConnection := &testutil.ConnFake{Name: "healthy", ForceResetHook: func() {
healthyConnection := &ConnFake{Name: "healthy", ForceResetHook: func() {
t.Errorf("force reset should not be called on new connections")
}}
timer := time.Now
Expand All @@ -268,10 +254,10 @@ func TestPoolBorrowReturn(outer *testing.T) {

result, err := pool.tryBorrow(ctx, serverName, nil, idlenessThreshold, reAuthToken)

testutil.AssertNil(t, err)
testutil.AssertDeepEquals(t, result, healthyConnection)
testutil.AssertIntEqual(t, pool.servers[serverName].numIdle(), 0)
testutil.AssertIntEqual(t, pool.servers[serverName].numBusy(), 1)
AssertNil(t, err)
AssertDeepEquals(t, result, healthyConnection)
AssertIntEqual(t, pool.servers[serverName].numIdle(), 0)
AssertIntEqual(t, pool.servers[serverName].numBusy(), 1)
})

outer.Run("Waiting borrow does not receive returned broken connection", func(t *testing.T) {
Expand All @@ -286,29 +272,22 @@ func TestPoolBorrowReturn(outer *testing.T) {
go func() {
c2, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken)
assertConnection(t, c2, err)
testutil.AssertNotDeepEquals(t, c1, c2)
AssertNotDeepEquals(t, c1, c2)
wg.Done()
}()

// Wait until entered queue
for {
if size, err := p.queueSize(ctx); err != nil {
t.Errorf("should not fail computing queue size, got: %v", err)
} else if size > 0 {
break
}
}
waitForBorrowers(t, p, 1)
// break the connection. then it shouldn't be picked up by the waiting borrow
c1.(*testutil.ConnFake).Alive = false
c1.(*ConnFake).Alive = false
err = p.Return(ctx, c1)
testutil.AssertNoError(t, err)
AssertNoError(t, err)
wg.Wait()
})

outer.Run("Waiting borrow does re-auth", func(t *testing.T) {
token2 := iauth.Token{Tokens: map[string]any{"scheme": "foobar"}}
// sanity check
testutil.AssertNotDeepEquals(t, reAuthToken.Manager, token2)
AssertNotDeepEquals(t, reAuthToken.Manager, token2)
reAuthToken2 := &db.ReAuthToken{FromSession: false, Manager: token2}
timer := func() time.Time { return birthdate }
conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1}
Expand All @@ -321,28 +300,21 @@ func TestPoolBorrowReturn(outer *testing.T) {
go func() {
c2, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken2)
assertConnection(t, c2, err)
testutil.AssertDeepEquals(t, c1, c2)
AssertDeepEquals(t, c1, c2)
wg.Done()
}()

// Wait until entered queue
for {
if size, err := p.queueSize(ctx); err != nil {
t.Errorf("should not fail computing queue size, got: %v", err)
} else if size > 0 {
break
}
}
waitForBorrowers(t, p, 1)
reAuthCalled := false
c1.(*testutil.ConnFake).ReAuthHook = func(_ context.Context, token *db.ReAuthToken) error {
testutil.AssertDeepEquals(t, token.Manager, token2)
c1.(*ConnFake).ReAuthHook = func(_ context.Context, token *db.ReAuthToken) error {
AssertDeepEquals(t, token.Manager, token2)
reAuthCalled = true
return nil
}
err = p.Return(ctx, c1)
testutil.AssertNoError(t, err)
AssertNoError(t, err)
wg.Wait()
testutil.AssertTrue(t, reAuthCalled)
AssertTrue(t, reAuthCalled)
})
}

Expand All @@ -352,7 +324,7 @@ func TestPoolResourceUsage(ot *testing.T) {
birthdate := time.Now()

succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) {
return &testutil.ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
}

ot.Run("Use order of named servers as priority when creating new servers", func(t *testing.T) {
Expand Down Expand Up @@ -382,7 +354,7 @@ func TestPoolResourceUsage(ot *testing.T) {
}()
serverNames := []string{"srvA"}
c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken)
c.(*testutil.ConnFake).Alive = false
c.(*ConnFake).Alive = false
if err := p.Return(ctx, c); err != nil {
t.Errorf("Should not fail returning connection to pool, but got: %v", err)
}
Expand Down Expand Up @@ -428,12 +400,12 @@ func TestPoolResourceUsage(ot *testing.T) {
c3, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken)
// Manipulate birthdate on the connections
nowTime := timer()
c1.(*testutil.ConnFake).Birth = nowTime.Add(-1 * time.Second)
c1.(*testutil.ConnFake).Id = 1
c2.(*testutil.ConnFake).Birth = nowTime
c2.(*testutil.ConnFake).Id = 2
c3.(*testutil.ConnFake).Birth = nowTime.Add(1 * time.Second)
c3.(*testutil.ConnFake).Id = 3
c1.(*ConnFake).Birth = nowTime.Add(-1 * time.Second)
c1.(*ConnFake).Id = 1
c2.(*ConnFake).Birth = nowTime
c2.(*ConnFake).Id = 2
c3.(*ConnFake).Birth = nowTime.Add(1 * time.Second)
c3.(*ConnFake).Id = 3
// Return the old and young connections to make them idle
if err := p.Return(ctx, c1); err != nil {
t.Errorf("Should not fail returning connection to pool, but got: %v", err)
Expand All @@ -444,7 +416,7 @@ func TestPoolResourceUsage(ot *testing.T) {
assertNumberOfServers(t, ctx, p, 1)
assertNumberOfIdle(t, ctx, p, "A", 2)
// Kill the middle-aged connection and return it
c2.(*testutil.ConnFake).Alive = false
c2.(*ConnFake).Alive = false
if err := p.Return(ctx, c2); err != nil {
t.Errorf("Should not fail returning connection to pool, but got: %v", err)
}
Expand All @@ -468,7 +440,7 @@ func TestPoolResourceUsage(ot *testing.T) {
}()
serverNames := []string{"srvA"}
c1, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken)
c1.(*testutil.ConnFake).Id = 123
c1.(*ConnFake).Id = 123
// It's alive when returning it
if err := p.Return(ctx, c1); err != nil {
t.Errorf("Should not fail returning connection to pool, but got: %v", err)
Expand All @@ -478,7 +450,7 @@ func TestPoolResourceUsage(ot *testing.T) {
nowMut.Unlock()
// Shouldn't get the same one back!
c2, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken)
if c2.(*testutil.ConnFake).Id == 123 {
if c2.(*ConnFake).Id == 123 {
t.Errorf("Got the old connection back!")
}
})
Expand All @@ -504,7 +476,7 @@ func TestPoolCleanup(ot *testing.T) {
birthdate := time.Now()
maxLife := 1 * time.Second
succeedingConnect := func(_ context.Context, s string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) {
return &testutil.ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
return &ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
}

// Borrows a connection in server A and another in server B
Expand Down Expand Up @@ -603,9 +575,37 @@ func TestPoolCleanup(ot *testing.T) {
}
assertNumberOfServers(t, ctx, p, 0)
})

ot.Run("wakes up borrowers when closing", func(t *testing.T) {
timer := func() time.Time { return birthdate }
conf := config.Config{
ConnectionAcquisitionTimeout: 10 * time.Second,
MaxConnectionLifetime: maxLife,
MaxConnectionPoolSize: 1,
}
p := New(&conf, succeedingConnect, logger, "pool id", &timer)
servers := getServers([]string{"example.com"})
conn, err := p.Borrow(ctx, servers, false, nil, DefaultLivenessCheckThreshold, reAuthToken)
assertConnection(t, conn, err)
borrowErrChan := make(chan error)
go func() {
_, err := p.Borrow(ctx, servers, true, nil, DefaultLivenessCheckThreshold, reAuthToken)
borrowErrChan <- err
}()
waitForBorrowers(t, p, 1)

AssertNoError(t, p.Close(ctx))

select {
case err := <-borrowErrChan:
AssertErrorMessageContains(t, err, "Pool closed")
case <-time.After(5 * time.Second):
t.Errorf("timed out waiting for borrow error")
}
})
}

func connectTo(singleConnection *testutil.ConnFake) func(ctx context.Context, name string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) {
func connectTo(singleConnection *ConnFake) func(ctx context.Context, name string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) {
return func(ctx context.Context, name string, _ *db.ReAuthToken, _ bolt.Neo4jErrorCallback, _ log.BoltLogger) (db.Connection, error) {
return singleConnection, nil
}
Expand All @@ -625,8 +625,8 @@ func setIdleConnections(pool *Pool, servers map[string][]db.Connection) {
pool.servers = poolServers
}

func deadConnectionAfterForceReset(name string, idleness time.Time) *testutil.ConnFake {
result := &testutil.ConnFake{Alive: true, Idle: idleness, Name: name}
func deadConnectionAfterForceReset(name string, idleness time.Time) *ConnFake {
result := &ConnFake{Alive: true, Idle: idleness, Name: name}
result.ForceResetHook = func() {
result.Alive = false
}
Expand All @@ -638,3 +638,13 @@ func getServers(servers []string) func(context.Context) ([]string, error) {
return servers, nil
}
}

func waitForBorrowers(t *testing.T, p *Pool, minBorrowers int) {
for {
if size, err := p.queueSize(ctx); err != nil {
t.Errorf("should not fail computing queue size, got: %v", err)
} else if size >= minBorrowers {
break
}
}
}

0 comments on commit 29a9396

Please sign in to comment.