diff --git a/README.md b/README.md index e2b80b44..a5b72a61 100644 --- a/README.md +++ b/README.md @@ -368,7 +368,7 @@ modules](https://go.dev/ref/mod) for dependency resolution. You can run unit tests as follows: ```shell -go test -tags=internal_testkit -short ./... +go test -tags internal_testkit,internal_time_mock -short ./... ``` ### Integration and Benchmark Testing diff --git a/hooks/pre-commit b/hooks/pre-commit index 6ed154cb..e9b6326d 100755 --- a/hooks/pre-commit +++ b/hooks/pre-commit @@ -13,11 +13,11 @@ fi echo "# pre-commit hook" printf '%-15s' "## staticcheck " cd "$(mktemp -d)" && go install honnef.co/go/tools/cmd/staticcheck@"${staticcheck_version}" && cd - > /dev/null -"${GOBIN:-$(go env GOPATH)/bin}"/staticcheck -tags internal_testkit ./... +"${GOBIN:-$(go env GOPATH)/bin}"/staticcheck -tags internal_testkit,internal_time_mock ./... echo "✅" printf '%-15s' "## go vet " -go vet -tags internal_testkit ./... +go vet -tags internal_testkit,internal_time_mock ./... echo "✅" printf '%-15s' "## go test " diff --git a/neo4j/auth/auth.go b/neo4j/auth/auth.go index aff2bd12..4c0b1aed 100644 --- a/neo4j/auth/auth.go +++ b/neo4j/auth/auth.go @@ -19,12 +19,14 @@ package auth import ( "context" + "reflect" + "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" - "reflect" - "time" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" ) // TokenManager is an interface for components that can provide auth tokens. @@ -63,7 +65,6 @@ type neo4jAuthTokenManager struct { token *auth.Token expiration *time.Time mutex racing.Mutex - now *func() time.Time handledSecurityCodes collections.Set[string] } @@ -73,7 +74,7 @@ func (m *neo4jAuthTokenManager) GetAuthToken(ctx context.Context) (auth.Token, e "could not acquire lock in time when getting token in neo4jAuthTokenManager") } defer m.mutex.Unlock() - if m.token == nil || m.expiration != nil && (*m.now)().After(*m.expiration) { + if m.token == nil || m.expiration != nil && itime.Now().After(*m.expiration) { token, expiration, err := m.provider(ctx) if err != nil { return auth.Token{}, err @@ -111,11 +112,9 @@ func (m *neo4jAuthTokenManager) HandleSecurityException(ctx context.Context, tok // The provider function must only ever return auth information belonging to the same identity. // Switching identities is undefined behavior. func BasicTokenManager(provider authTokenProvider) TokenManager { - now := time.Now return &neo4jAuthTokenManager{ provider: wrapWithNilExpiration(provider), mutex: racing.NewMutex(), - now: &now, handledSecurityCodes: collections.NewSet([]string{ "Neo.ClientError.Security.Unauthorized", }), @@ -135,11 +134,9 @@ func BasicTokenManager(provider authTokenProvider) TokenManager { // The provider function must only ever return auth information belonging to the same identity. // Switching identities is undefined behavior. func BearerTokenManager(provider authTokenWithExpirationProvider) TokenManager { - now := time.Now return &neo4jAuthTokenManager{ provider: provider, mutex: racing.NewMutex(), - now: &now, handledSecurityCodes: collections.NewSet([]string{ "Neo.ClientError.Security.TokenExpired", "Neo.ClientError.Security.Unauthorized", diff --git a/neo4j/config.go b/neo4j/config.go index 51b3f0c1..f438d0f5 100644 --- a/neo4j/config.go +++ b/neo4j/config.go @@ -19,6 +19,7 @@ package neo4j import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/notifications" "math" "net/url" @@ -41,6 +42,7 @@ func defaultConfig() *Config { MaxConnectionPoolSize: 100, MaxConnectionLifetime: 1 * time.Hour, ConnectionAcquisitionTimeout: 1 * time.Minute, + ConnectionLivenessCheckTimeout: pool.DefaultConnectionLivenessCheckTimeout, SocketConnectTimeout: 5 * time.Second, SocketKeepalive: true, RootCAs: nil, @@ -77,6 +79,11 @@ func validateAndNormaliseConfig(config *Config) error { config.ConnectionAcquisitionTimeout = -1 } + // Connection Liveness Check Timeout + if config.ConnectionLivenessCheckTimeout < 0 { + return &UsageError{Message: "Connection liveness check timeout cannot be smaller than 0"} + } + // Socket Connect Timeout if config.SocketConnectTimeout < 0 { config.SocketConnectTimeout = 0 diff --git a/neo4j/config/driver.go b/neo4j/config/driver.go index 3ca58512..4cdf4705 100644 --- a/neo4j/config/driver.go +++ b/neo4j/config/driver.go @@ -103,6 +103,21 @@ type Config struct { // // default: 1 * time.Minute ConnectionAcquisitionTimeout time.Duration + // ConnectionLivenessCheckTimeout sets the timeout duration for idle connections in the pool. + // Connections idle longer than this timeout will be tested for liveliness before reuse. A low timeout value + // can increase network requests when acquiring a connection, impacting performance. Conversely, a high + // timeout may result in using connections that are no longer active, causing exceptions in your application. + // These exceptions typically resolve with a retry or using a driver API with automatic + // retries, assuming the database is operational. + // + // The parameter balances the likelihood of encountering connection issues against performance. + // Typically, adjustment of this parameter is not necessary. + // + // By default, no liveliness check is performed. A value of 0 ensures connections are always tested for + // validity, and negative values are not permitted. + // + // default: pool.DefaultConnectionLivenessCheckTimeout + ConnectionLivenessCheckTimeout time.Duration // Connect timeout that will be set on underlying sockets. Values less than // or equal to 0 results in no timeout being applied. // diff --git a/neo4j/config_test.go b/neo4j/config_test.go index dfcacccb..0911aca1 100644 --- a/neo4j/config_test.go +++ b/neo4j/config_test.go @@ -99,6 +99,16 @@ func TestValidateAndNormaliseConfig(rt *testing.T) { } }) + rt.Run("ConnectionLivenessCheckTimeout less than zero", func(t *testing.T) { + config := defaultConfig() + + config.ConnectionLivenessCheckTimeout = -1 * time.Second + err := validateAndNormaliseConfig(config) + if err == nil { + t.Errorf("ConnectionLivenessCheckTimeout is less than 0 but never returned an error") + } + }) + rt.Run("SocketConnectTimeout less than zero", func(t *testing.T) { config := defaultConfig() diff --git a/neo4j/driver_with_context.go b/neo4j/driver_with_context.go index a5cf2233..f97fdd03 100644 --- a/neo4j/driver_with_context.go +++ b/neo4j/driver_with_context.go @@ -1,3 +1,5 @@ +// Package neo4j provides required functionality to connect and execute statements against a Neo4j Database. + /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] @@ -15,25 +17,22 @@ * limitations under the License. */ -// Package neo4j provides required functionality to connect and execute statements against a Neo4j Database. package neo4j import ( "context" "fmt" "github.com/neo4j/neo4j-go-driver/v5/neo4j/auth" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/connector" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/router" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" "net/url" "strings" "sync" - "time" - - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/connector" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/router" ) // AccessMode defines modes that routing driver decides to which cluster member @@ -145,7 +144,7 @@ func NewDriverWithContext(target string, auth auth.TokenManager, configurers ... return nil, err } - d := driverWithContext{target: parsed, mut: racing.NewMutex(), now: time.Now, auth: auth} + d := driverWithContext{target: parsed, mut: racing.NewMutex(), auth: auth} routing := true d.connector.Network = "tcp" @@ -220,10 +219,9 @@ func NewDriverWithContext(target string, auth auth.TokenManager, configurers ... d.connector.Log = d.log d.connector.RoutingContext = routingContext d.connector.Config = d.config - d.connector.Now = &d.now // Let the pool use the same log ID as the driver to simplify log reading. - d.pool = pool.New(d.config, d.connector.Connect, d.log, d.logId, &d.now) + d.pool = pool.New(d.config, d.connector.Connect, d.log, d.logId) if !routing { d.router = &directRouter{address: address} @@ -241,7 +239,15 @@ func NewDriverWithContext(target string, auth auth.TokenManager, configurers ... } } // Let the router use the same log ID as the driver to simplify log reading. - d.router = router.New(address, routersResolver, routingContext, d.pool, d.log, d.logId, &d.now) + d.router = router.New( + address, + routersResolver, + routingContext, + d.pool, + d.config.ConnectionLivenessCheckTimeout, + d.log, + d.logId, + ) } d.pool.SetRouter(d.router) @@ -324,7 +330,6 @@ type driverWithContext struct { // this is *not* used by default by user-created session (see NewSession) executeQueryBookmarkManager BookmarkManager auth auth.TokenManager - now func() time.Time } func (d *driverWithContext) Target() url.URL { @@ -360,7 +365,7 @@ func (d *driverWithContext) NewSession(ctx context.Context, config SessionConfig return &erroredSessionWithContext{ err: &UsageError{Message: "Trying to create session on closed driver"}} } - return newSessionWithContext(d.config, config, d.router, d.pool, d.log, reAuthToken, &d.now) + return newSessionWithContext(d.config, config, d.router, d.pool, d.log, reAuthToken) } func (d *driverWithContext) VerifyConnectivity(ctx context.Context) error { diff --git a/neo4j/driver_with_context_testkit.go b/neo4j/driver_with_context_testkit.go index 298ff177..5e645bce 100644 --- a/neo4j/driver_with_context_testkit.go +++ b/neo4j/driver_with_context_testkit.go @@ -1,4 +1,4 @@ -//go:build internal_testkit +//go:build internal_testkit && internal_time_mock /* * Copyright (c) "Neo4j" @@ -25,22 +25,12 @@ import ( idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/router" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" - "time" ) type RoutingTable = idb.RoutingTable -func SetTimer(d DriverWithContext, timer func() time.Time) { - driver := d.(*driverWithContext) - driver.now = timer -} - -func ResetTime(d DriverWithContext) { - driver := d.(*driverWithContext) - driver.now = time.Now -} - func ForceRoutingTableUpdate(d DriverWithContext, database string, bookmarks []string, logger log.BoltLogger) error { driver := d.(*driverWithContext) ctx := context.Background() @@ -70,3 +60,8 @@ func GetRoutingTable(d DriverWithContext, database string) (*RoutingTable, error table := router.GetTable(database) return table, nil } + +var Now = itime.Now +var FreezeTime = itime.FreezeTime +var TickTime = itime.TickTime +var UnfreezeTime = itime.UnfreezeTime diff --git a/neo4j/internal/bolt/bolt3.go b/neo4j/internal/bolt/bolt3.go index a83951e0..0ff8e772 100644 --- a/neo4j/internal/bolt/bolt3.go +++ b/neo4j/internal/bolt/bolt3.go @@ -26,6 +26,7 @@ import ( idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/telemetry" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "net" "reflect" "time" @@ -95,18 +96,16 @@ type bolt3 struct { authManager auth.TokenManager resetAuth bool errorListener ConnectionErrorListener - now *func() time.Time } func NewBolt3( serverName string, conn net.Conn, errorListener ConnectionErrorListener, - timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, ) *bolt3 { - now := (*timer)() + now := itime.Now() b := &bolt3{ state: bolt3_unauthorized, conn: conn, @@ -123,7 +122,6 @@ func NewBolt3( idleDate: now, log: logger, errorListener: errorListener, - now: timer, } b.out = &outgoing{ chunker: newChunker(), @@ -166,7 +164,7 @@ func (b *bolt3) receiveMsg(ctx context.Context) any { b.state = bolt3_dead return nil } - b.idleDate = (*b.now)() + b.idleDate = itime.Now() return msg } diff --git a/neo4j/internal/bolt/bolt3_test.go b/neo4j/internal/bolt/bolt3_test.go index 1ed38f8e..f16a272e 100644 --- a/neo4j/internal/bolt/bolt3_test.go +++ b/neo4j/internal/bolt/bolt3_test.go @@ -104,7 +104,6 @@ func TestBolt3(outer *testing.T) { tcpConn, srv, cleanup := setupBolt3Pipe(t) go serverJob(srv) - timer := time.Now c, err := Connect( context.Background(), "serverName", @@ -116,7 +115,6 @@ func TestBolt3(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) if err != nil { t.Fatal(err) @@ -158,7 +156,6 @@ func TestBolt3(outer *testing.T) { srv.waitForHello() srv.rejectHelloUnauthorized() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -170,7 +167,6 @@ func TestBolt3(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNil(t, bolt) AssertError(t, err) diff --git a/neo4j/internal/bolt/bolt4.go b/neo4j/internal/bolt/bolt4.go index 48b19491..f5a77bd7 100644 --- a/neo4j/internal/bolt/bolt4.go +++ b/neo4j/internal/bolt/bolt4.go @@ -21,18 +21,19 @@ import ( "context" "errors" "fmt" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/auth" - iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections" - idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/telemetry" "net" "reflect" "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/auth" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" + iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/packstream" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/telemetry" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -111,18 +112,16 @@ type bolt4 struct { authManager auth.TokenManager resetAuth bool errorListener ConnectionErrorListener - now *func() time.Time } func NewBolt4( serverName string, conn net.Conn, errorListener ConnectionErrorListener, - timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, ) *bolt4 { - now := (*timer)() + now := itime.Now() b := &bolt4{ state: bolt4_unauthorized, conn: conn, @@ -133,7 +132,6 @@ func NewBolt4( streams: openstreams{}, lastQid: -1, errorListener: errorListener, - now: timer, } b.queue = newMessageQueue( conn, @@ -1135,7 +1133,7 @@ func (b *bolt4) expectedSuccessHandler(onSuccess func(*success)) responseHandler } func (b *bolt4) onNextMessage() { - b.idleDate = (*b.now)() + b.idleDate = itime.Now() } func (b *bolt4) onFailure(ctx context.Context, failure *db.Neo4jError) { diff --git a/neo4j/internal/bolt/bolt4_test.go b/neo4j/internal/bolt/bolt4_test.go index ae6ea1c7..1816aca0 100644 --- a/neo4j/internal/bolt/bolt4_test.go +++ b/neo4j/internal/bolt/bolt4_test.go @@ -108,7 +108,6 @@ func TestBolt4(outer *testing.T) { tcpConn, srv, cleanup := setupBolt4Pipe(t) go serverJob(srv) - timer := time.Now c, err := Connect(context.Background(), "serverName", tcpConn, @@ -119,7 +118,6 @@ func TestBolt4(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) if err != nil { t.Fatal(err) @@ -221,7 +219,6 @@ func TestBolt4(outer *testing.T) { } srv.acceptHello() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -233,7 +230,6 @@ func TestBolt4(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNoError(t, err) bolt.Close(context.Background()) @@ -252,7 +248,6 @@ func TestBolt4(outer *testing.T) { } srv.acceptHello() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -264,7 +259,6 @@ func TestBolt4(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNoError(t, err) bolt.Close(context.Background()) @@ -284,7 +278,6 @@ func TestBolt4(outer *testing.T) { } srv.acceptHello() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -296,7 +289,6 @@ func TestBolt4(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNoError(t, err) bolt.Close(context.Background()) @@ -312,7 +304,6 @@ func TestBolt4(outer *testing.T) { srv.waitForHello() srv.rejectHelloUnauthorized() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -324,7 +315,6 @@ func TestBolt4(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNil(t, bolt) AssertError(t, err) diff --git a/neo4j/internal/bolt/bolt5.go b/neo4j/internal/bolt/bolt5.go index 67b170a2..db054a00 100644 --- a/neo4j/internal/bolt/bolt5.go +++ b/neo4j/internal/bolt/bolt5.go @@ -21,18 +21,19 @@ import ( "context" "errors" "fmt" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/auth" - iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/boltagent" - idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/telemetry" "net" "reflect" "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/auth" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" + iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/boltagent" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/packstream" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/telemetry" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -113,7 +114,6 @@ type bolt5 struct { authManager auth.TokenManager resetAuth bool errorListener ConnectionErrorListener - now *func() time.Time telemetryEnabled bool } @@ -121,11 +121,10 @@ func NewBolt5( serverName string, conn net.Conn, errorListener ConnectionErrorListener, - timer *func() time.Time, logger log.Logger, boltLog log.BoltLogger, ) *bolt5 { - now := (*timer)() + now := itime.Now() b := &bolt5{ state: bolt5Unauthorized, conn: conn, @@ -136,7 +135,6 @@ func NewBolt5( streams: openstreams{}, lastQid: -1, errorListener: errorListener, - now: timer, } b.queue = newMessageQueue( conn, @@ -1132,7 +1130,7 @@ func (b *bolt5) onCommitSuccess(commitSuccess *success) { } func (b *bolt5) onNextMessage() { - b.idleDate = (*b.now)() + b.idleDate = itime.Now() } func (b *bolt5) onFailure(ctx context.Context, failure *db.Neo4jError) { diff --git a/neo4j/internal/bolt/bolt5_test.go b/neo4j/internal/bolt/bolt5_test.go index caee6442..bc1c9e3f 100644 --- a/neo4j/internal/bolt/bolt5_test.go +++ b/neo4j/internal/bolt/bolt5_test.go @@ -108,7 +108,6 @@ func TestBolt5(outer *testing.T) { tcpConn, srv, cleanup := setupBolt5Pipe(t) go serverJob(srv) - timer := time.Now c, err := Connect( context.Background(), "serverName", @@ -120,7 +119,6 @@ func TestBolt5(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) if err != nil { t.Fatal(err) @@ -278,7 +276,6 @@ func TestBolt5(outer *testing.T) { } srv.acceptHello() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -290,7 +287,6 @@ func TestBolt5(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNoError(t, err) bolt.Close(context.Background()) @@ -312,7 +308,6 @@ func TestBolt5(outer *testing.T) { srv.waitForLogon() srv.acceptLogon() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -324,7 +319,6 @@ func TestBolt5(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNoError(t, err) bolt.Close(context.Background()) @@ -343,7 +337,6 @@ func TestBolt5(outer *testing.T) { } srv.acceptHello() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -355,7 +348,6 @@ func TestBolt5(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNoError(t, err) bolt.Close(context.Background()) @@ -376,7 +368,6 @@ func TestBolt5(outer *testing.T) { srv.waitForLogon() srv.acceptLogon() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -388,7 +379,6 @@ func TestBolt5(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNoError(t, err) bolt.Close(context.Background()) @@ -404,7 +394,6 @@ func TestBolt5(outer *testing.T) { srv.waitForHello() srv.rejectHelloUnauthorized() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -416,7 +405,6 @@ func TestBolt5(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNil(t, bolt) AssertError(t, err) @@ -441,7 +429,6 @@ func TestBolt5(outer *testing.T) { srv.waitForLogon() srv.rejectLogonWithoutAuthToken() }() - timer := time.Now bolt, err := Connect( context.Background(), "serverName", @@ -453,7 +440,6 @@ func TestBolt5(outer *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertNil(t, bolt) AssertError(t, err) diff --git a/neo4j/internal/bolt/connect.go b/neo4j/internal/bolt/connect.go index d48bb1a2..9036d118 100644 --- a/neo4j/internal/bolt/connect.go +++ b/neo4j/internal/bolt/connect.go @@ -21,12 +21,11 @@ package bolt import ( "context" "fmt" + "net" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" - "net" - "time" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -56,7 +55,7 @@ func Connect(ctx context.Context, logger log.Logger, boltLogger log.BoltLogger, notificationConfig db.NotificationConfig, - timer *func() time.Time) (db.Connection, error) { +) (db.Connection, error) { // Perform Bolt handshake to negotiate version // Send handshake to server handshake := []byte{ @@ -93,11 +92,11 @@ func Connect(ctx context.Context, var boltConn db.Connection switch major { case 3: - boltConn = NewBolt3(serverName, conn, errorListener, timer, logger, boltLogger) + boltConn = NewBolt3(serverName, conn, errorListener, logger, boltLogger) case 4: - boltConn = NewBolt4(serverName, conn, errorListener, timer, logger, boltLogger) + boltConn = NewBolt4(serverName, conn, errorListener, logger, boltLogger) case 5: - boltConn = NewBolt5(serverName, conn, errorListener, timer, logger, boltLogger) + boltConn = NewBolt5(serverName, conn, errorListener, logger, boltLogger) case 0: return nil, fmt.Errorf("server did not accept any of the requested Bolt versions (%#v)", versions) default: diff --git a/neo4j/internal/bolt/connect_test.go b/neo4j/internal/bolt/connect_test.go index 20fc8d1e..c934303a 100644 --- a/neo4j/internal/bolt/connect_test.go +++ b/neo4j/internal/bolt/connect_test.go @@ -19,11 +19,10 @@ package bolt import ( "context" - iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" - idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "testing" - "time" + iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" . "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -54,7 +53,6 @@ func TestConnect(ot *testing.T) { srv.closeConnection() }() - timer := time.Now _, err := Connect( context.Background(), "servername", @@ -66,7 +64,6 @@ func TestConnect(ot *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertError(t, err) }) @@ -82,7 +79,6 @@ func TestConnect(ot *testing.T) { srv.acceptVersion(1, 0) }() - timer := time.Now boltconn, err := Connect( context.Background(), "servername", @@ -94,7 +90,6 @@ func TestConnect(ot *testing.T) { logger, nil, idb.NotificationConfig{}, - &timer, ) AssertError(t, err) if boltconn != nil { diff --git a/neo4j/internal/connector/connector.go b/neo4j/internal/connector/connector.go index 8e86d875..2477adce 100644 --- a/neo4j/internal/connector/connector.go +++ b/neo4j/internal/connector/connector.go @@ -22,14 +22,14 @@ import ( "context" "crypto/tls" "errors" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "io" "net" "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -41,7 +41,6 @@ type Connector struct { Network string Config *config.Config SupplyConnection func(context.Context, string) (net.Conn, error) - Now *func() time.Time } func (c Connector) Connect( @@ -87,7 +86,6 @@ func (c Connector) Connect( c.Log, boltLogger, notificationConfig, - c.Now, ) if err != nil { return nil, err @@ -122,7 +120,6 @@ func (c Connector) Connect( c.Log, boltLogger, notificationConfig, - c.Now, ) if err != nil { return nil, err diff --git a/neo4j/internal/connector/connector_test.go b/neo4j/internal/connector/connector_test.go index 188de200..01c945ef 100644 --- a/neo4j/internal/connector/connector_test.go +++ b/neo4j/internal/connector/connector_test.go @@ -19,15 +19,16 @@ package connector_test import ( "context" + "io" + "net" + "testing" + "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/connector" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" . "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" - "io" - "net" - "testing" - "time" ) type noopErrorListener struct{} @@ -51,12 +52,10 @@ func TestConnect(outer *testing.T) { server.acceptVersion(1, 0) }() connectionDelegate := &ConnDelegate{Delegate: clientConnection} - timer := time.Now connector := &connector.Connector{ SupplyConnection: supplyThis(connectionDelegate), SkipEncryption: true, Config: &config.Config{}, - Now: &timer, } connection, err := connector.Connect(ctx, "irrelevant", nil, noopErrorListener{}, nil) @@ -72,12 +71,10 @@ func TestConnect(outer *testing.T) { server.failAcceptingVersion() }() connectionDelegate := &ConnDelegate{Delegate: clientConnection} - timer := time.Now connector := &connector.Connector{ SupplyConnection: supplyThis(connectionDelegate), SkipEncryption: true, Config: &config.Config{}, - Now: &timer, } connection, err := connector.Connect(ctx, "irrelevant", nil, noopErrorListener{}, nil) diff --git a/neo4j/internal/pool/pool.go b/neo4j/internal/pool/pool.go index fe9f9473..f7f8757d 100644 --- a/neo4j/internal/pool/pool.go +++ b/neo4j/internal/pool/pool.go @@ -23,22 +23,23 @@ package pool import ( "container/list" "context" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" - idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "math" "sort" "sync" "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) -// DefaultLivenessCheckThreshold disables the liveness check of connections -// Liveness checks are performed before a connection is deemed idle enough to be reset -const DefaultLivenessCheckThreshold = math.MaxInt64 +// DefaultConnectionLivenessCheckTimeout disables the liveness check of connections. +// Liveness checks are performed before a connection is deemed idle enough to be reset. +const DefaultConnectionLivenessCheckTimeout = math.MaxInt64 type Connect func(context.Context, string, *idb.ReAuthToken, bolt.ConnectionErrorListener, log.BoltLogger) (idb.Connection, error) @@ -60,7 +61,6 @@ type Pool struct { serversMut sync.Mutex queueMut sync.Mutex queue list.List - now *func() time.Time closed bool log log.Logger logId string @@ -71,7 +71,7 @@ type serverPenalty struct { penalty uint32 } -func New(config *config.Config, connect Connect, logger log.Logger, logId string, now *func() time.Time) *Pool { +func New(config *config.Config, connect Connect, logger log.Logger, logId string) *Pool { // Means infinite life, simplifies checking later on p := &Pool{ @@ -81,7 +81,6 @@ func New(config *config.Config, connect Connect, logger log.Logger, logId string servers: make(map[string]*server), serversMut: sync.Mutex{}, queueMut: sync.Mutex{}, - now: now, logId: logId, log: logger, } @@ -137,7 +136,7 @@ func (p *Pool) getServers() map[string]*server { func (p *Pool) CleanUp(ctx context.Context) { p.serversMut.Lock() defer p.serversMut.Unlock() - now := (*p.now)() + now := itime.Now() for n, s := range p.servers { s.removeIdleOlderThan(ctx, now, p.config.MaxConnectionLifetime) if s.size() == 0 && !s.hasFailedConnect(now) { @@ -146,17 +145,13 @@ func (p *Pool) CleanUp(ctx context.Context) { } } -func (p *Pool) Now() time.Time { - return (*p.now)() -} - func (p *Pool) getPenaltiesForServers(ctx context.Context, serverNames []string) []serverPenalty { p.serversMut.Lock() defer p.serversMut.Unlock() // Retrieve penalty for each server penalties := make([]serverPenalty, len(serverNames)) - now := (*p.now)() + now := itime.Now() for i, n := range serverNames { s := p.servers[n] penalties[i].name = n @@ -171,7 +166,7 @@ func (p *Pool) getPenaltiesForServers(ctx context.Context, serverNames []string) return penalties } -func (p *Pool) tryAnyIdle(ctx context.Context, serverNames []string, idlenessThreshold time.Duration, auth *idb.ReAuthToken, logger log.BoltLogger) (idb.Connection, error) { +func (p *Pool) tryAnyIdle(ctx context.Context, serverNames []string, idlenessTimeout time.Duration, auth *idb.ReAuthToken, logger log.BoltLogger) (idb.Connection, error) { p.serversMut.Lock() var unlock = new(sync.Once) defer unlock.Do(p.serversMut.Unlock) @@ -185,11 +180,11 @@ serverLoop: continue serverLoop } unlock.Do(p.serversMut.Unlock) - healthy, err := srv.healthCheck(ctx, conn, idlenessThreshold, auth, logger) + healthy, err := srv.healthCheck(ctx, conn, idlenessTimeout, auth, logger) if healthy { return conn, nil } - p.unreg(ctx, serverName, conn, p.Now()) + p.unreg(ctx, serverName, conn, itime.Now()) if err != nil { p.log.Debugf(log.Pool, p.logId, "Health check failed for %s: %s", serverName, err) return nil, err @@ -202,7 +197,7 @@ serverLoop: return nil, nil } -func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { +func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, idlenessTimeout time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { for { if p.closed { return nil, &errorutil.PoolClosed{} @@ -223,7 +218,7 @@ func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait var conn idb.Connection for _, s := range penalties { - conn, err = p.tryBorrow(ctx, s.name, boltLogger, idlenessThreshold, auth) + conn, err = p.tryBorrow(ctx, s.name, boltLogger, idlenessTimeout, auth) if conn != nil { return conn, nil } @@ -252,7 +247,7 @@ func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait // Ok, now that we own the queue we can add the item there but between getting the lock // and above check for an existing connection another thread might have returned a connection // so check again to avoid potentially starving this thread. - conn, err = p.tryAnyIdle(ctx, serverNames, idlenessThreshold, auth, boltLogger) + conn, err = p.tryAnyIdle(ctx, serverNames, idlenessTimeout, auth, boltLogger) if err != nil { p.queueMut.Unlock() return nil, err @@ -284,7 +279,7 @@ func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait } } -func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { +func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log.BoltLogger, idlenessTimeout time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { // For now, lock complete servers map to avoid over connecting but with the downside // that long connect times will block connects to other servers as well. To fix this // we would need to add a pending connect to the server and lock per server. @@ -304,11 +299,11 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. break } unlock.Do(p.serversMut.Unlock) - healthy, err := srv.healthCheck(ctx, connection, idlenessThreshold, auth, boltLogger) + healthy, err := srv.healthCheck(ctx, connection, idlenessTimeout, auth, boltLogger) if healthy { return connection, nil } - p.unreg(ctx, serverName, connection, p.Now()) + p.unreg(ctx, serverName, connection, itime.Now()) if err != nil { p.log.Debugf(log.Pool, p.logId, "Health check failed for %s: %s", serverName, err) return nil, err @@ -337,7 +332,7 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log. p.log.Warnf(log.Pool, p.logId, "Failed to connect to %s: %s", serverName, err) // FeatureNotSupportedError is not the server fault, don't penalize it if _, ok := err.(*db.FeatureNotSupportedError); !ok { - srv.notifyFailedConnect((*p.now)()) + srv.notifyFailedConnect(itime.Now()) } return nil, err } @@ -397,7 +392,7 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) { // If the connection is dead, remove all other idle connections on the same server that older // or of the same age as the dead connection, otherwise perform normal cleanup of old connections maxAge := p.config.MaxConnectionLifetime - now := (*p.now)() + now := itime.Now() age := now.Sub(c.Birthdate()) if !isAlive { // Since this connection has died all other connections that connected before this one diff --git a/neo4j/internal/pool/pool_test.go b/neo4j/internal/pool/pool_test.go index 7b998424..39101f75 100644 --- a/neo4j/internal/pool/pool_test.go +++ b/neo4j/internal/pool/pool_test.go @@ -1,3 +1,5 @@ +//go:build internal_time_mock + /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] @@ -20,18 +22,19 @@ package pool import ( "context" "errors" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" - db "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" - iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" - idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "math/rand" "sync" "testing" "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" + iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" . "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -53,14 +56,15 @@ func TestPoolBorrowReturn(outer *testing.T) { } outer.Run("Single thread borrow+return", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() serverNames := []string{"srv1"} - conn, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + conn, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, conn, err) p.Return(ctx, conn) @@ -75,9 +79,10 @@ func TestPoolBorrowReturn(outer *testing.T) { }) outer.Run("First thread borrows, second thread blocks on borrow", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() @@ -86,14 +91,14 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Add(1) // First thread borrows - c1, err1 := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c1, err1 := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c1, err1) // Second thread tries to borrow the only allowed connection on the same server go func() { // Will block here until first thread detects me in the queue and returns the // connection which will unblock here. - c2, err2 := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c2, err2 := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c2, err2) wg.Done() }() @@ -106,20 +111,21 @@ func TestPoolBorrowReturn(outer *testing.T) { }) outer.Run("First thread borrows, second thread should not block on borrow without wait", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() serverNames := []string{"srv1"} // First thread borrows - c1, err1 := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c1, err1 := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c1, err1) // Actually don't need a thread here since we shouldn't block - c2, err2 := p.Borrow(ctx, getServers(serverNames), false, nil, DefaultLivenessCheckThreshold, reAuthToken) + c2, err2 := p.Borrow(ctx, getServers(serverNames), false, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertNoConnection(t, c2, err2) // Error should be pool full _ = err2.(*errorutil.PoolFull) @@ -127,9 +133,10 @@ func TestPoolBorrowReturn(outer *testing.T) { outer.Run("Multiple threads borrows and returns randomly", func(t *testing.T) { maxConnections := 2 - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: maxConnections} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") serverNames := []string{"srv1"} numWorkers := 5 wg := sync.WaitGroup{} @@ -137,7 +144,7 @@ func TestPoolBorrowReturn(outer *testing.T) { worker := func() { for i := 0; i < 5; i++ { - c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c, err) time.Sleep(time.Duration(rand.Int()%7) * time.Millisecond) p.Return(ctx, c) @@ -160,12 +167,13 @@ func TestPoolBorrowReturn(outer *testing.T) { }) outer.Run("Failing connect", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} - p := New(&conf, failingConnect, logger, "pool id", &timer) + p := New(&conf, failingConnect, logger, "pool id") p.SetRouter(&RouterFake{}) serverNames := []string{"srv1"} - c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertNoConnection(t, c, err) // Should get the connect error back if err != failingError { @@ -174,16 +182,17 @@ func TestPoolBorrowReturn(outer *testing.T) { }) outer.Run("Cancel Borrow", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) - c1, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + p := New(&conf, succeedingConnect, logger, "pool id") + c1, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) cancelableCtx, cancel := context.WithCancel(ctx) wg := sync.WaitGroup{} var err error wg.Add(1) go func() { - _, err = p.Borrow(cancelableCtx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + _, err = p.Borrow(cancelableCtx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) wg.Done() }() @@ -206,9 +215,8 @@ func TestPoolBorrowReturn(outer *testing.T) { whatATimeToBeAlive := &ConnFake{Alive: true, Idle: idleness, Name: "whatATimeToBeAlive", ForceResetHook: func() { t.Errorf("y u call me?") }} - timer := time.Now conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - pool := New(&conf, nil, logger, "pool id", &timer) + pool := New(&conf, nil, logger, "pool id") setIdleConnections(pool, map[string][]idb.Connection{"a server": { deadAfterReset, stayingAlive, @@ -230,9 +238,8 @@ func TestPoolBorrowReturn(outer *testing.T) { healthyConnection := &ConnFake{Name: "healthy", ForceResetHook: func() { t.Errorf("force reset should not be called on new connections") }} - timer := time.Now conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - pool := New(&conf, connectTo(healthyConnection), logger, "pool id", &timer) + pool := New(&conf, connectTo(healthyConnection), logger, "pool id") setIdleConnections(pool, map[string][]idb.Connection{serverName: {deadAfterReset1, deadAfterReset2}}) result, err := pool.tryBorrow(ctx, serverName, nil, idlenessThreshold, reAuthToken) @@ -244,16 +251,17 @@ func TestPoolBorrowReturn(outer *testing.T) { }) outer.Run("Waiting borrow does not receive returned broken connection", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) - c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + p := New(&conf, succeedingConnect, logger, "pool id") + c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c1, err) ctx = context.Background() wg := sync.WaitGroup{} wg.Add(1) go func() { - c2, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c2, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c2, err) AssertNotDeepEquals(t, c1, c2) wg.Done() @@ -272,16 +280,17 @@ func TestPoolBorrowReturn(outer *testing.T) { // sanity check AssertNotDeepEquals(t, reAuthToken.Manager, token2) reAuthToken2 := &idb.ReAuthToken{FromSession: false, Manager: token2} - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) - c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + p := New(&conf, succeedingConnect, logger, "pool id") + c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c1, err) ctx = context.Background() wg := sync.WaitGroup{} wg.Add(1) go func() { - c2, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken2) + c2, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken2) assertConnection(t, c2, err) AssertDeepEquals(t, c1, c2) wg.Done() @@ -311,28 +320,30 @@ func TestPoolResourceUsage(ot *testing.T) { } ot.Run("Use order of named servers as priority when creating new servers", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() serverNames := []string{"srvA", "srvB", "srvC", "srvD"} - c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) if c.ServerName() != serverNames[0] { t.Errorf("Should have created server for first server but created for %s", c.ServerName()) } }) ot.Run("Do not put dead connection back to server", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() serverNames := []string{"srvA"} - c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) c.(*ConnFake).Alive = false p.Return(ctx, c) servers := p.getServers() @@ -342,14 +353,16 @@ func TestPoolResourceUsage(ot *testing.T) { }) ot.Run("Do not put too old connection back to server", func(t *testing.T) { - timer := func() time.Time { return birthdate.Add(maxAge * 2) } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() serverNames := []string{"srvA"} - c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) + itime.ForceTickTime(2 * maxAge) p.Return(ctx, c) servers := p.getServers() if len(servers) > 0 && servers[serverNames[0]].size() > 0 { @@ -358,15 +371,16 @@ func TestPoolResourceUsage(ot *testing.T) { }) ot.Run("Returning dead connection to server should remove older idle connections", func(t *testing.T) { - timer := time.Now + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: 1<<63 - 1, MaxConnectionPoolSize: 3} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") // Trigger creation of three connections on the same server - c1, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) - c2, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) - c3, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c1, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) + c2, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) + c3, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) // Manipulate birthdate on the connections - nowTime := timer() + nowTime := itime.Now() c1.(*ConnFake).Birth = nowTime.Add(-1 * time.Second) c1.(*ConnFake).Id = 1 c2.(*ConnFake).Birth = nowTime @@ -385,43 +399,37 @@ func TestPoolResourceUsage(ot *testing.T) { }) ot.Run("Do not borrow too old connections", func(t *testing.T) { - nowMut := sync.Mutex{} - now := birthdate - timer := func() time.Time { - nowMut.Lock() - defer nowMut.Unlock() - return now - } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() serverNames := []string{"srvA"} - c1, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c1, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) c1.(*ConnFake).Id = 123 // It's alive when returning it p.Return(ctx, c1) - nowMut.Lock() - now = now.Add(2 * maxAge) - nowMut.Unlock() + itime.ForceTickTime(2 * maxAge) // Shouldn't get the same one back! - c2, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c2, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) if c2.(*ConnFake).Id == 123 { t.Errorf("Got the old connection back!") } }) ot.Run("Add servers when existing servers are full", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() - c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c1, err) - c2, err := p.Borrow(ctx, getServers([]string{"B"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c2, err := p.Borrow(ctx, getServers([]string{"B"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c2, err) assertNumberOfServers(t, p, 2) }) @@ -436,17 +444,18 @@ func TestPoolCleanup(ot *testing.T) { // Borrows a connection in server A and another in server B borrowConnections := func(t *testing.T, p *Pool) (idb.Connection, idb.Connection) { - c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c1, err) - c2, err := p.Borrow(ctx, getServers([]string{"B"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c2, err := p.Borrow(ctx, getServers([]string{"B"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c2, err) return c1, c2 } ot.Run("Should remove servers with only idle too old connections", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() @@ -458,15 +467,16 @@ func TestPoolCleanup(ot *testing.T) { assertNumberOfIdle(t, p, "B", 1) // Now go into the future and cleanup, should remove both servers and close the connections - timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } + itime.ForceTickTime(maxLife + 1*time.Second) p.CleanUp(ctx) assertNumberOfServers(t, p, 0) }) ot.Run("Should not remove servers with busy connections", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") defer func() { p.Close(ctx) }() @@ -477,55 +487,55 @@ func TestPoolCleanup(ot *testing.T) { assertNumberOfIdle(t, p, "B", 1) // Now go into the future and cleanup, should only remove B - timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } + itime.ForceTickTime(maxLife + 1*time.Second) p.CleanUp(ctx) assertNumberOfServers(t, p, 1) }) ot.Run("Should not remove servers with only idle connections but with recent connect failures ", func(t *testing.T) { + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() failingConnect := func(_ context.Context, s string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) { return nil, errors.New("an error") } - timer := time.Now conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} - p := New(&conf, failingConnect, logger, "pool id", &timer) + p := New(&conf, failingConnect, logger, "pool id") p.SetRouter(&RouterFake{}) defer func() { p.Close(ctx) }() - c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultLivenessCheckThreshold, reAuthToken) + c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertNoConnection(t, c1, err) assertNumberOfServers(t, p, 1) assertNumberOfIdle(t, p, "A", 0) // Now go into the future and cleanup, should not remove server A even if it has no connections since // we should remember the failure a bit longer - timer = func() time.Time { return birthdate.Add(maxLife).Add(1 * time.Second) } + itime.ForceTickTime(maxLife + 1*time.Second) p.CleanUp(ctx) assertNumberOfServers(t, p, 1) // Further in the future, the failure should have been forgotten - timer = func() time.Time { - return birthdate.Add(maxLife).Add(rememberFailedConnectDuration).Add(1 * time.Second) - } + itime.ForceTickTime(rememberFailedConnectDuration) p.CleanUp(ctx) assertNumberOfServers(t, p, 0) }) ot.Run("wakes up borrowers when closing", func(t *testing.T) { - timer := func() time.Time { return birthdate } + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() conf := config.Config{ ConnectionAcquisitionTimeout: 10 * time.Second, MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 1, } - p := New(&conf, succeedingConnect, logger, "pool id", &timer) + p := New(&conf, succeedingConnect, logger, "pool id") servers := getServers([]string{"example.com"}) - conn, err := p.Borrow(ctx, servers, false, nil, DefaultLivenessCheckThreshold, reAuthToken) + conn, err := p.Borrow(ctx, servers, false, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, conn, err) borrowErrChan := make(chan error) go func() { - _, err := p.Borrow(ctx, servers, true, nil, DefaultLivenessCheckThreshold, reAuthToken) + _, err := p.Borrow(ctx, servers, true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) borrowErrChan <- err }() waitForBorrowers(p, 1) @@ -620,12 +630,11 @@ func TestPoolErrorHanding(ot *testing.T) { return &connection, nil } - now := time.Now router := RouterFake{} - p := New(&config.Config{}, succeedingConnect, logger, "pool id", &now) + p := New(&config.Config{}, succeedingConnect, logger, "pool id") p.SetRouter(&router) defer p.Close(ctx) - conn, err := p.Borrow(ctx, getServers([]string{ServerName}), false, nil, DefaultLivenessCheckThreshold, reAuthToken) + conn, err := p.Borrow(ctx, getServers([]string{ServerName}), false, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, conn, err) AssertLen(t, errorListeners, 1) AssertLen(t, connections, 1) diff --git a/neo4j/internal/pool/server.go b/neo4j/internal/pool/server.go index 5ae14021..070a442f 100644 --- a/neo4j/internal/pool/server.go +++ b/neo4j/internal/pool/server.go @@ -20,10 +20,12 @@ package pool import ( "container/list" "context" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" "sync/atomic" "time" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) // Represents a server with a number of connections that either is in use (borrowed) or @@ -66,12 +68,12 @@ func (s *server) getIdle() db.Connection { func (s *server) healthCheck( ctx context.Context, connection db.Connection, - idlenessThreshold time.Duration, + idlenessTimeout time.Duration, auth *db.ReAuthToken, boltLogger log.BoltLogger) (healthy bool, _ error) { connection.SetBoltLogger(boltLogger) - if time.Since(connection.IdleDate()) > idlenessThreshold { + if itime.Since(connection.IdleDate()) > idlenessTimeout { connection.ForceReset(ctx) if !connection.IsAlive() { return false, nil diff --git a/neo4j/internal/pool/server_test.go b/neo4j/internal/pool/server_test.go index ce59c769..67878b5d 100644 --- a/neo4j/internal/pool/server_test.go +++ b/neo4j/internal/pool/server_test.go @@ -158,7 +158,7 @@ func TestServerPenalty(t *testing.T) { // Get the connection from srv1 and return it, now srv1 should have higher penalty. ctx := context.Background() idle := srv1.getIdle() - _, _ = srv1.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) + _, _ = srv1.healthCheck(ctx, idle, DefaultConnectionLivenessCheckTimeout, nil, nil) testutil.AssertDeepEquals(t, idle, c11) srv1.returnBusy(context.Background(), c11) assertPenaltiesGreaterThan(srv1, srv2, now) @@ -173,18 +173,18 @@ func TestServerPenalty(t *testing.T) { assertPenaltiesGreaterThan(srv2, srv1, now) // Get both idle connections from srv1 idle = srv1.getIdle() - _, _ = srv1.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) + _, _ = srv1.healthCheck(ctx, idle, DefaultConnectionLivenessCheckTimeout, nil, nil) idle = srv1.getIdle() - _, _ = srv1.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) + _, _ = srv1.healthCheck(ctx, idle, DefaultConnectionLivenessCheckTimeout, nil, nil) // Get one idle connection from srv2 idle = srv2.getIdle() - _, _ = srv2.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) + _, _ = srv2.healthCheck(ctx, idle, DefaultConnectionLivenessCheckTimeout, nil, nil) // Since more connections are in use on srv1, it should have higher penalty even though // srv2 was last used assertPenaltiesGreaterThan(srv1, srv2, now) // Return the connections idle = srv2.getIdle() - _, _ = srv2.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) + _, _ = srv2.healthCheck(ctx, idle, DefaultConnectionLivenessCheckTimeout, nil, nil) srv2.returnBusy(context.Background(), c21) srv2.returnBusy(context.Background(), c22) srv1.returnBusy(context.Background(), c11) @@ -200,9 +200,9 @@ func TestServerPenalty(t *testing.T) { testutil.AssertFalse(t, srv2.hasFailedConnect(now)) // Use srv2 to the max idle = srv2.getIdle() - _, _ = srv2.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) + _, _ = srv2.healthCheck(ctx, idle, DefaultConnectionLivenessCheckTimeout, nil, nil) idle = srv2.getIdle() - _, _ = srv2.healthCheck(ctx, idle, DefaultLivenessCheckThreshold, nil, nil) + _, _ = srv2.healthCheck(ctx, idle, DefaultConnectionLivenessCheckTimeout, nil, nil) // Even at this point we should prefer srv2 assertPenaltiesGreaterThan(srv1, srv2, now) diff --git a/neo4j/internal/retry/state.go b/neo4j/internal/retry/state.go index 3f6e0914..1fd5c451 100644 --- a/neo4j/internal/retry/state.go +++ b/neo4j/internal/retry/state.go @@ -22,11 +22,12 @@ import ( "context" "errors" "fmt" - idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -36,7 +37,6 @@ type State struct { Log log.Logger LogName string LogId string - Now *func() time.Time Sleep func(time.Duration) Throttle Throttler MaxDeadConnections int @@ -68,7 +68,7 @@ func (s *State) OnFailure(_ context.Context, err error, conn idb.Connection, isC func (s *State) Continue() bool { if s.start.IsZero() { - s.start = (*s.Now)() + s.start = itime.Now() } if len(s.Errs) == 0 { @@ -80,7 +80,7 @@ func (s *State) Continue() bool { return false } - if (*s.Now)().Sub(s.start) > s.MaxTransactionRetryTime { + if itime.Since(s.start) > s.MaxTransactionRetryTime { s.Errs = []error{&errorutil.TransactionExecutionLimit{ Cause: fmt.Sprintf("timeout (exceeded max retry time: %s)", s.MaxTransactionRetryTime.String()), Errors: s.Errs, diff --git a/neo4j/internal/retry/state_test.go b/neo4j/internal/retry/state_test.go index bc57ea8f..a42758c2 100644 --- a/neo4j/internal/retry/state_test.go +++ b/neo4j/internal/retry/state_test.go @@ -1,3 +1,5 @@ +//go:build internal_time_mock + /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] @@ -20,15 +22,16 @@ package retry import ( "context" "errors" - idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "io" "reflect" "testing" "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -36,7 +39,8 @@ type TStateInvocation struct { conn idb.Connection err error isCommitting bool - now time.Time + freezeTime bool + tick time.Duration expectContinued bool expectLastErrWasRetryable bool expectLastErrType error @@ -44,10 +48,8 @@ type TStateInvocation struct { func TestState(outer *testing.T) { var ( - baseTime = time.Now() maxRetryTime = time.Second * 10 - overTime = baseTime.Add(maxRetryTime).Add(1 * time.Second) - halfTime = baseTime.Add(maxRetryTime / 2) + halfTime = maxRetryTime / 2 maxDead = 2 dbName = "thedb" // a single server can be reused here since the router is a fake impl @@ -63,11 +65,11 @@ func TestState(outer *testing.T) { expectLastErrWasRetryable: true, expectLastErrType: &errorutil.PoolTimeout{}}, }, "Retry connect timeout": { - {conn: nil, err: dbTransientErr, expectContinued: true, now: baseTime, + {conn: nil, err: dbTransientErr, expectContinued: true, freezeTime: true, expectLastErrWasRetryable: true}, - {conn: nil, err: dbTransientErr, expectContinued: true, now: halfTime, + {conn: nil, err: dbTransientErr, expectContinued: true, tick: halfTime, expectLastErrWasRetryable: true}, - {conn: nil, err: dbTransientErr, expectContinued: false, now: overTime, + {conn: nil, err: dbTransientErr, expectContinued: false, tick: halfTime + 1*time.Second, expectLastErrWasRetryable: true}, }, "Retry dead connection": { @@ -80,9 +82,9 @@ func TestState(outer *testing.T) { }, "Retry dead connection timeout": { {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: dbTransientErr, - expectContinued: true, now: baseTime, expectLastErrWasRetryable: true}, + expectContinued: true, freezeTime: true, expectLastErrWasRetryable: true}, {conn: &testutil.ConnFake{Name: serverName, Alive: false}, err: errors.New("some error 2"), - expectContinued: false, now: overTime, + expectContinued: false, tick: 2*halfTime + 1*time.Second, expectLastErrWasRetryable: false}, }, "Retry dead connection max": { @@ -103,8 +105,8 @@ func TestState(outer *testing.T) { }, "Database transient error timeout": { {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: true, - expectLastErrWasRetryable: true}, - {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: false, now: overTime, + expectLastErrWasRetryable: true, freezeTime: true}, + {conn: &testutil.ConnFake{Alive: true}, err: dbTransientErr, expectContinued: false, tick: 2*halfTime + 1*time.Second, expectLastErrWasRetryable: true}, }, "User defined error": { @@ -141,10 +143,7 @@ func TestState(outer *testing.T) { ctx := context.Background() for i, testCase := range testCases { outer.Run(i, func(t *testing.T) { - now := baseTime - timer := func() time.Time { return now } state := State{ - Now: &timer, Log: &log.Void{}, LogName: "TEST", LogId: "State", @@ -154,9 +153,12 @@ func TestState(outer *testing.T) { DatabaseName: dbName, } for _, invocation := range testCase { - // Update now if a value has been provided - if !invocation.now.IsZero() { - now = invocation.now + if invocation.freezeTime { + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() + } + if invocation.tick > 0 { + itime.ForceTickTime(invocation.tick) } state.OnFailure(ctx, invocation.err, invocation.conn, invocation.isCommitting) diff --git a/neo4j/internal/router/readtable.go b/neo4j/internal/router/readtable.go index bf6785c1..8efa1c1a 100644 --- a/neo4j/internal/router/readtable.go +++ b/neo4j/internal/router/readtable.go @@ -19,9 +19,10 @@ package router import ( "context" + "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -32,6 +33,7 @@ func readTable( connectionPool Pool, routers []string, routerContext map[string]string, + idlenessTimeout time.Duration, bookmarks []string, database, impersonatedUser string, @@ -46,7 +48,7 @@ func readTable( // another db. for _, router := range routers { var conn db.Connection - if conn, err = connectionPool.Borrow(ctx, getStaticServer(router), true, boltLogger, pool.DefaultLivenessCheckThreshold, auth); err != nil { + if conn, err = connectionPool.Borrow(ctx, getStaticServer(router), true, boltLogger, idlenessTimeout, auth); err != nil { // Check if failed due to context timing out if ctx.Err() != nil { return nil, wrapError(router, ctx.Err()) diff --git a/neo4j/internal/router/readtable_test.go b/neo4j/internal/router/readtable_test.go index 42715e04..05524576 100644 --- a/neo4j/internal/router/readtable_test.go +++ b/neo4j/internal/router/readtable_test.go @@ -23,6 +23,7 @@ import ( iauth "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" "testing" @@ -180,6 +181,7 @@ func TestReadTableTable(ot *testing.T) { c.pool, c.routers, nil, + pool.DefaultConnectionLivenessCheckTimeout, nil, "dbname", "", diff --git a/neo4j/internal/router/router.go b/neo4j/internal/router/router.go index fcddedab..a257221f 100644 --- a/neo4j/internal/router/router.go +++ b/neo4j/internal/router/router.go @@ -27,6 +27,7 @@ import ( "sync" "time" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -40,41 +41,41 @@ type databaseRouter struct { // Router is thread safe type Router struct { - routerContext map[string]string - pool Pool - dbRouters map[string]*databaseRouter - updating map[string][]chan struct{} - dbRoutersMut sync.Mutex - now *func() time.Time - sleep func(time.Duration) - rootRouter string - getRouters func() []string - log log.Logger - logId string + routerContext map[string]string + pool Pool + idlenessTimeout time.Duration + dbRouters map[string]*databaseRouter + updating map[string][]chan struct{} + dbRoutersMut sync.Mutex + sleep func(time.Duration) + rootRouter string + getRouters func() []string + log log.Logger + logId string } type Pool interface { // Borrow acquires a connection from the provided list of servers // If all connections are busy and the pool is full, calls to Borrow may wait for a connection to become idle - // If a connection has been idle for longer than idlenessThreshold, it will be reset + // If a connection has been idle for longer than idlenessTimeout, it will be reset // to check if it's still alive. - Borrow(ctx context.Context, getServers func() []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) + Borrow(ctx context.Context, getServers func() []string, wait bool, boltLogger log.BoltLogger, idlenessTimeout time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) Return(ctx context.Context, c idb.Connection) } -func New(rootRouter string, getRouters func() []string, routerContext map[string]string, pool Pool, logger log.Logger, logId string, timer *func() time.Time) *Router { +func New(rootRouter string, getRouters func() []string, routerContext map[string]string, pool Pool, idlenessTimeout time.Duration, logger log.Logger, logId string) *Router { r := &Router{ - rootRouter: rootRouter, - getRouters: getRouters, - routerContext: routerContext, - pool: pool, - dbRouters: make(map[string]*databaseRouter), - updating: make(map[string][]chan struct{}), - dbRoutersMut: sync.Mutex{}, - now: timer, - sleep: time.Sleep, - log: logger, - logId: logId, + rootRouter: rootRouter, + getRouters: getRouters, + routerContext: routerContext, + pool: pool, + idlenessTimeout: idlenessTimeout, + dbRouters: make(map[string]*databaseRouter), + updating: make(map[string][]chan struct{}), + dbRoutersMut: sync.Mutex{}, + sleep: time.Sleep, + log: logger, + logId: logId, } r.log.Infof(log.Router, r.logId, "Created {context: %v}", routerContext) return r @@ -98,7 +99,7 @@ func (r *Router) readTable( if dbRouter != nil && len(dbRouter.table.Routers) > 0 { routers := dbRouter.table.Routers r.log.Infof(log.Router, r.logId, "Reading routing table for '%s' from previously known routers: %v", database, routers) - table, err = readTable(ctx, r.pool, routers, r.routerContext, bookmarks, database, impersonatedUser, auth, boltLogger) + table, err = readTable(ctx, r.pool, routers, r.routerContext, r.idlenessTimeout, bookmarks, database, impersonatedUser, auth, boltLogger) } if errorutil.IsFatalDuringDiscovery(err) { r.log.Error(log.Router, r.logId, err) @@ -108,7 +109,7 @@ func (r *Router) readTable( // Try initial router if no routers or failed if table == nil { r.log.Infof(log.Router, r.logId, "Reading routing table from initial router: %s", r.rootRouter) - table, err = readTable(ctx, r.pool, []string{r.rootRouter}, r.routerContext, bookmarks, database, impersonatedUser, auth, boltLogger) + table, err = readTable(ctx, r.pool, []string{r.rootRouter}, r.routerContext, r.idlenessTimeout, bookmarks, database, impersonatedUser, auth, boltLogger) } if errorutil.IsFatalDuringDiscovery(err) { r.log.Error(log.Router, r.logId, err) @@ -119,7 +120,7 @@ func (r *Router) readTable( if table == nil && r.getRouters != nil { routers := r.getRouters() r.log.Infof(log.Router, r.logId, "Reading routing table for '%s' from custom routers: %v", routers) - table, err = readTable(ctx, r.pool, routers, r.routerContext, bookmarks, database, impersonatedUser, auth, boltLogger) + table, err = readTable(ctx, r.pool, routers, r.routerContext, r.idlenessTimeout, bookmarks, database, impersonatedUser, auth, boltLogger) } if errorutil.IsFatalDuringDiscovery(err) { r.log.Error(log.Router, r.logId, err) @@ -189,7 +190,7 @@ func (r *Router) getOrUpdateTable(ctx context.Context, bookmarksFn func(context. } func (r *Router) getTableLocked(dbRouter *databaseRouter) *idb.RoutingTable { - now := (*r.now)() + now := itime.Now() if dbRouter != nil && now.Unix() < dbRouter.dueUnix { return dbRouter.table } @@ -206,7 +207,7 @@ func (r *Router) updateTable(ctx context.Context, bookmarksFn func(context.Conte return nil, err } - err = r.storeRoutingTable(ctx, database, table, (*r.now)()) + err = r.storeRoutingTable(ctx, database, table, itime.Now()) if err != nil { return nil, err } @@ -293,7 +294,7 @@ func (r *Router) GetNameOfDefaultDatabase(ctx context.Context, bookmarks []strin return "", err } // Store the fresh routing table as well to avoid another roundtrip to receive servers from session. - now := (*r.now)() + now := itime.Now() err = r.storeRoutingTable(ctx, table.DatabaseName, table, now) if err != nil { return "", err @@ -360,7 +361,7 @@ func removeServerFromList(list []string, server string) []string { func (r *Router) CleanUp() { r.log.Debugf(log.Router, r.logId, "Cleaning up") - now := (*r.now)().Unix() + now := itime.Now().Unix() r.dbRoutersMut.Lock() defer r.dbRoutersMut.Unlock() diff --git a/neo4j/internal/router/router_test.go b/neo4j/internal/router/router_test.go index 8e4c5e95..b50e74da 100644 --- a/neo4j/internal/router/router_test.go +++ b/neo4j/internal/router/router_test.go @@ -1,3 +1,5 @@ +//go:build internal_time_mock + /* * Copyright (c) "Neo4j" * Neo4j Sweden AB [https://neo4j.com] @@ -21,13 +23,15 @@ package router import ( "context" "errors" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" "reflect" "sync" "testing" "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + pool2 "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" + itime "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) @@ -36,7 +40,6 @@ var logger = &log.Void{} // Verifies that concurrent access works as expected relying on the race detector to // report suspicious behavior. func TestMultithreading(t *testing.T) { - // Set up a router that needs to read the routing table essentially on every access to // stress threading a bit more. num := 0 @@ -47,16 +50,14 @@ func TestMultithreading(t *testing.T) { return &testutil.ConnFake{Table: table}, nil }, } - n := time.Now() - mut := sync.Mutex{} - timer := func() time.Time { + + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() + itime.ForceAddTicker(func(now *time.Time) { // Need to lock here to make race detector happy - mut.Lock() - defer mut.Unlock() - n = n.Add(time.Duration(table.TimeToLive) * time.Second * 2) - return n - } - router := New("router", func() []string { return []string{} }, nil, pool, logger, "routerid", &timer) + *now = now.Add(time.Duration(table.TimeToLive) * time.Second * 2) + }) + router := New("router", func() []string { return []string{} }, nil, pool, pool2.DefaultConnectionLivenessCheckTimeout, logger, "routerid") dbName := "dbname" wg := sync.WaitGroup{} @@ -103,12 +104,9 @@ func TestRespectsTimeToLiveAndInvalidate(t *testing.T) { return &testutil.ConnFake{Table: table}, nil }, } - nzero := time.Now() - n := nzero - timer := func() time.Time { - return n - } - router := New("router", func() []string { return []string{} }, nil, pool, logger, "routerid", &timer) + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() + router := New("router", func() []string { return []string{} }, nil, pool, pool2.DefaultConnectionLivenessCheckTimeout, logger, "routerid") dbName := "dbname" // First access should trigger initial table read @@ -125,7 +123,7 @@ func TestRespectsTimeToLiveAndInvalidate(t *testing.T) { assertNum(t, numfetch, 1, "Should not have have fetched") // Third access with time passed table due should trigger fetch - n = n.Add(2 * time.Second) + itime.ForceTickTime(2 * time.Second) if _, err := router.GetOrUpdateReaders(ctx, nilBookmarks, dbName, nil, nil); err != nil { testutil.AssertNoError(t, err) } @@ -160,12 +158,9 @@ func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) { return conn, err }, } - nzero := time.Now() - n := nzero - timer := func() time.Time { - return n - } - router := New("rootRouter", func() []string { return []string{} }, nil, pool, logger, "routerid", &timer) + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() + router := New("rootRouter", func() []string { return []string{} }, nil, pool, pool2.DefaultConnectionLivenessCheckTimeout, logger, "routerid") dbName := "dbname" // First access should trigger initial table read from root router @@ -176,7 +171,7 @@ func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) { t.Errorf("Should have connected to root upon first router request") } // Next access should go to otherRouter - n = n.Add(2 * time.Second) + itime.ForceTickTime(2 * time.Second) if _, err := router.GetOrUpdateReaders(context.Background(), nilBookmarks, dbName, nil, nil); err != nil { testutil.AssertNoError(t, err) } @@ -203,7 +198,7 @@ func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) { requestedRoot = true return &testutil.ConnFake{Table: &db.RoutingTable{TimeToLive: 1, Readers: []string{"aReader"}}}, nil } - n = n.Add(2 * time.Second) + itime.ForceTickTime(2 * time.Second) readers, err := router.GetOrUpdateReaders(context.Background(), nilBookmarks, dbName, nil, nil) if err != nil { t.Error(err) @@ -228,8 +223,7 @@ func TestUseGetRoutersHookWhenInitialRouterFails(t *testing.T) { } rootRouter := "rootRouter" backupRouters := []string{"bup1", "bup2"} - timer := time.Now - router := New(rootRouter, func() []string { return backupRouters }, nil, pool, logger, "routerid", &timer) + router := New(rootRouter, func() []string { return backupRouters }, nil, pool, pool2.DefaultConnectionLivenessCheckTimeout, logger, "routerid") dbName := "dbname" // Trigger read of routing table @@ -255,8 +249,7 @@ func TestWritersFailAfterNRetries(t *testing.T) { }, } numsleep := 0 - timer := time.Now - router := New("router", func() []string { return []string{} }, nil, pool, logger, "routerid", &timer) + router := New("router", func() []string { return []string{} }, nil, pool, pool2.DefaultConnectionLivenessCheckTimeout, logger, "routerid") router.sleep = func(time.Duration) { numsleep++ } @@ -293,8 +286,7 @@ func TestWritersRetriesWhenNoWriters(t *testing.T) { }, } numsleep := 0 - timer := time.Now - router := New("router", func() []string { return []string{} }, nil, pool, logger, "routerid", &timer) + router := New("router", func() []string { return []string{} }, nil, pool, pool2.DefaultConnectionLivenessCheckTimeout, logger, "routerid") router.sleep = func(time.Duration) { numsleep++ } @@ -332,8 +324,7 @@ func TestReadersRetriesWhenNoReaders(t *testing.T) { }, } numsleep := 0 - timer := time.Now - router := New("router", func() []string { return []string{} }, nil, pool, logger, "routerid", &timer) + router := New("router", func() []string { return []string{} }, nil, pool, pool2.DefaultConnectionLivenessCheckTimeout, logger, "routerid") router.sleep = func(time.Duration) { numsleep++ } @@ -365,9 +356,9 @@ func TestCleanUp(t *testing.T) { return &testutil.ConnFake{Table: table}, nil }, } - now := time.Now() - timer := func() time.Time { return now } - router := New("router", func() []string { return []string{} }, nil, pool, logger, "routerid", &timer) + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() + router := New("router", func() []string { return []string{} }, nil, pool, pool2.DefaultConnectionLivenessCheckTimeout, logger, "routerid") ctx := context.Background() if _, err := router.GetOrUpdateReaders(ctx, nilBookmarks, "db1", nil, nil); err != nil { @@ -388,7 +379,7 @@ func TestCleanUp(t *testing.T) { t.Fatal("Should not have removed routing tables") } - timer = func() time.Time { return now.Add(1 * time.Minute) } + itime.ForceTickTime(1 * time.Minute) router.CleanUp() if len(router.dbRouters) != 0 { t.Fatal("Should have cleaned up") diff --git a/neo4j/auth/auth_testkit.go b/neo4j/internal/time/time.go similarity index 69% rename from neo4j/auth/auth_testkit.go rename to neo4j/internal/time/time.go index 017227ff..12222763 100644 --- a/neo4j/auth/auth_testkit.go +++ b/neo4j/internal/time/time.go @@ -1,4 +1,4 @@ -//go:build internal_testkit +//go:build !internal_time_mock /* * Copyright (c) "Neo4j" @@ -17,19 +17,9 @@ * limitations under the License. */ -package auth +package time import "time" -func SetTimer(t TokenManager, timer func() time.Time) { - if t, ok := t.(*neo4jAuthTokenManager); ok { - t.now = &timer - } -} - -func ResetTime(t TokenManager) { - if t, ok := t.(*neo4jAuthTokenManager); ok { - now := time.Now - t.now = &now - } -} +var Now = time.Now +var Since = time.Since diff --git a/neo4j/internal/time/time_mockable.go b/neo4j/internal/time/time_mockable.go new file mode 100644 index 00000000..cec855ad --- /dev/null +++ b/neo4j/internal/time/time_mockable.go @@ -0,0 +1,130 @@ +//go:build internal_time_mock + +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package time + +import ( + "errors" + "sync" + "time" +) + +type Ticker = func(*time.Time) + +var frozenNow *time.Time = nil +var tickerRegistry []Ticker = nil +var frozenNowMutex = &sync.Mutex{} + +func now() time.Time { + if frozenNow == nil { + return time.Now() + } + for _, ticker := range tickerRegistry { + ticker(frozenNow) + } + return *frozenNow +} + +func Now() time.Time { + frozenNowMutex.Lock() + defer frozenNowMutex.Unlock() + + return now() +} + +func Since(t time.Time) time.Duration { + frozenNowMutex.Lock() + defer frozenNowMutex.Unlock() + + if frozenNow == nil { + return time.Since(t) + } + return frozenNow.Sub(t) +} + +func FreezeTime() error { + frozenNowMutex.Lock() + defer frozenNowMutex.Unlock() + + if frozenNow != nil { + return errors.New("time already frozen") + } + now := now() + frozenNow = &now + return nil +} + +func ForceFreezeTime() { + if err := FreezeTime(); err != nil { + panic(err) + } +} + +func TickTime(d time.Duration) error { + frozenNowMutex.Lock() + defer frozenNowMutex.Unlock() + + if frozenNow == nil { + return errors.New("time not frozen") + } + newNow := frozenNow.Add(d) + frozenNow = &newNow + return nil +} + +func ForceTickTime(d time.Duration) { + if err := TickTime(d); err != nil { + panic(err) + } +} + +func AddTicker(ticker Ticker) error { + frozenNowMutex.Lock() + defer frozenNowMutex.Unlock() + + if frozenNow == nil { + return errors.New("time not frozen") + } + tickerRegistry = append(tickerRegistry, ticker) + return nil +} + +func ForceAddTicker(ticker Ticker) { + if err := AddTicker(ticker); err != nil { + panic(err) + } +} + +func UnfreezeTime() error { + frozenNowMutex.Lock() + defer frozenNowMutex.Unlock() + + if frozenNow == nil { + return errors.New("time not frozen") + } + frozenNow = nil + tickerRegistry = nil + return nil +} + +func ForceUnfreezeTime() { + if err := UnfreezeTime(); err != nil { + panic(err) + } +} diff --git a/neo4j/session_with_context.go b/neo4j/session_with_context.go index 1e2533a4..7467db1f 100644 --- a/neo4j/session_with_context.go +++ b/neo4j/session_with_context.go @@ -20,17 +20,16 @@ package neo4j import ( "context" "fmt" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections" - idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/pool" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/telemetry" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/notifications" "math" "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections" + idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/retry" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/telemetry" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/notifications" ) // TransactionWork represents a unit of work that will be executed against the provided @@ -186,10 +185,9 @@ const FetchDefault = 0 // Connection pool as seen by the session. type sessionPool interface { - Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, livenessCheckThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) + Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, livenessCheckTimeout time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) Return(ctx context.Context, c idb.Connection) CleanUp(ctx context.Context) - Now() time.Time } type sessionWithContext struct { @@ -202,7 +200,6 @@ type sessionWithContext struct { explicitTx *explicitTransaction autocommitTx *autocommitTransaction sleep func(d time.Duration) - now *func() time.Time logId string log log.Logger throttleTime time.Duration @@ -218,7 +215,6 @@ func newSessionWithContext( pool sessionPool, logger log.Logger, token *idb.ReAuthToken, - now *func() time.Time, ) *sessionWithContext { logId := log.NewId() logger.Debugf(log.Session, logId, "Created") @@ -237,7 +233,6 @@ func newSessionWithContext( config: sessConfig, resolveHomeDb: sessConfig.DatabaseName == "", sleep: time.Sleep, - now: now, log: logger, logId: logId, throttleTime: time.Second * 1, @@ -302,7 +297,7 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . } // Get a connection from the pool. This could fail in clustered environment. - conn, err := s.getConnection(ctx, s.defaultMode, pool.DefaultLivenessCheckThreshold) + conn, err := s.getConnection(ctx, s.defaultMode, s.driverConfig.ConnectionLivenessCheckTimeout) if err != nil { return nil, errorutil.WrapError(err) } @@ -416,7 +411,6 @@ func (s *sessionWithContext) runRetriable( Log: s.log, LogName: log.Session, LogId: s.logId, - Now: s.now, Sleep: s.sleep, Throttle: retry.Throttler(s.throttleTime), MaxDeadConnections: s.driverConfig.MaxConnectionPoolSize, @@ -442,7 +436,7 @@ func (s *sessionWithContext) executeTransactionFunction( blockingTxBegin bool, api telemetry.API) (bool, any) { - conn, err := s.getConnection(ctx, mode, pool.DefaultLivenessCheckThreshold) + conn, err := s.getConnection(ctx, mode, s.driverConfig.ConnectionLivenessCheckTimeout) if err != nil { state.OnFailure(ctx, err, conn, false) return false, nil @@ -525,7 +519,7 @@ func (s *sessionWithContext) getServers(mode idb.AccessMode) func() []string { } } -func (s *sessionWithContext) getConnection(ctx context.Context, mode idb.AccessMode, livenessCheckThreshold time.Duration) (idb.Connection, error) { +func (s *sessionWithContext) getConnection(ctx context.Context, mode idb.AccessMode, livenessCheckTimeout time.Duration) (idb.Connection, error) { timeout := s.driverConfig.ConnectionAcquisitionTimeout if timeout > 0 { var cancel context.CancelFunc @@ -552,7 +546,7 @@ func (s *sessionWithContext) getConnection(ctx context.Context, mode idb.AccessM s.getServers(mode), timeout != 0, s.config.BoltLogger, - livenessCheckThreshold, + livenessCheckTimeout, s.auth) if err != nil { return nil, errorutil.WrapError(err) @@ -606,7 +600,7 @@ func (s *sessionWithContext) Run(ctx context.Context, return nil, err } - conn, err := s.getConnection(ctx, s.defaultMode, pool.DefaultLivenessCheckThreshold) + conn, err := s.getConnection(ctx, s.defaultMode, s.driverConfig.ConnectionLivenessCheckTimeout) if err != nil { return nil, errorutil.WrapError(err) } diff --git a/neo4j/session_with_context_test.go b/neo4j/session_with_context_test.go index 5a788ee1..d9b35870 100644 --- a/neo4j/session_with_context_test.go +++ b/neo4j/session_with_context_test.go @@ -47,13 +47,12 @@ func TestSession(outer *testing.T) { } } - now := time.Now createSession := func() (*RouterFake, *PoolFake, *sessionWithContext) { conf := Config{MaxTransactionRetryTime: 3 * time.Millisecond, MaxConnectionPoolSize: 100} router := RouterFake{} pool := PoolFake{} sessConfig := SessionConfig{AccessMode: AccessModeRead, BoltLogger: boltLogger} - sess := newSessionWithContext(&conf, sessConfig, &router, &pool, logger, nil, &now) + sess := newSessionWithContext(&conf, sessConfig, &router, &pool, logger, nil) sess.throttleTime = time.Millisecond * 1 return &router, &pool, sess } @@ -62,7 +61,7 @@ func TestSession(outer *testing.T) { conf := Config{MaxTransactionRetryTime: 3 * time.Millisecond} router := RouterFake{} pool := PoolFake{} - sess := newSessionWithContext(&conf, sessConfig, &router, &pool, logger, nil, &now) + sess := newSessionWithContext(&conf, sessConfig, &router, &pool, logger, nil) sess.throttleTime = time.Millisecond * 1 return &router, &pool, sess } diff --git a/neo4j/test-integration/dbconn_test.go b/neo4j/test-integration/dbconn_test.go index c77aabf7..d2a17d7b 100644 --- a/neo4j/test-integration/dbconn_test.go +++ b/neo4j/test-integration/dbconn_test.go @@ -78,7 +78,6 @@ func makeRawConnection(ctx context.Context, logger log.Logger, boltLogger log.Bo }, } - timer := time.Now boltConn, err := bolt.Connect( context.Background(), parsedUri.Host, @@ -90,7 +89,6 @@ func makeRawConnection(ctx context.Context, logger log.Logger, boltLogger log.Bo logger, boltLogger, idb.NotificationConfig{}, - &timer, ) if err != nil { panic(err) diff --git a/testkit-backend/backend.go b/testkit-backend/backend.go index 4e41ee1f..42c64e0f 100644 --- a/testkit-backend/backend.go +++ b/testkit-backend/backend.go @@ -60,19 +60,6 @@ type backend struct { suppliedBookmarks map[string]neo4j.Bookmarks consumedBookmarks map[string]struct{} bookmarkManagers map[string]neo4j.BookmarkManager - timer *Timer -} - -type Timer struct { - now time.Time -} - -func (t *Timer) Now() time.Time { - return t.now -} - -func (t *Timer) Tick(duration time.Duration) { - t.now = t.now.Add(duration) } // To implement transactional functions a bit of extra state is needed on the @@ -498,6 +485,9 @@ func (b *backend) handleRequest(req map[string]any) { if data["connectionAcquisitionTimeoutMs"] != nil { c.ConnectionAcquisitionTimeout = time.Millisecond * time.Duration(asInt64(data["connectionAcquisitionTimeoutMs"].(json.Number))) } + if data["livenessCheckTimeoutMs"] != nil { + c.ConnectionLivenessCheckTimeout = time.Millisecond * time.Duration(asInt64(data["livenessCheckTimeoutMs"].(json.Number))) + } if data["maxConnectionPoolSize"] != nil { c.MaxConnectionPoolSize = asInt(data["maxConnectionPoolSize"].(json.Number)) } @@ -535,9 +525,6 @@ func (b *backend) handleRequest(req map[string]any) { b.writeError(err) return } - if b.timer != nil { - neo4j.SetTimer(driver, b.timer.Now) - } idKey := b.nextId() b.drivers[idKey] = driver b.writeResponse("Driver", map[string]any{"id": idKey}) @@ -964,30 +951,25 @@ func (b *backend) handleRequest(req map[string]any) { b.writeResponse("Driver", map[string]any{"id": driverId}) case "FakeTimeInstall": - b.timer = &Timer{ - now: time.Unix(0, 0), - } - for _, driver := range b.drivers { - neo4j.SetTimer(driver, b.timer.Now) - } - for _, manager := range b.authTokenManagers { - auth.SetTimer(manager, b.timer.Now) + if err := neo4j.FreezeTime(); err != nil { + b.writeError(err) + return } b.writeResponse("FakeTimeAck", nil) case "FakeTimeUninstall": - b.timer = nil - for _, driver := range b.drivers { - neo4j.ResetTime(driver) - } - for _, manager := range b.authTokenManagers { - auth.ResetTime(manager) + if err := neo4j.UnfreezeTime(); err != nil { + b.writeError(err) + return } b.writeResponse("FakeTimeAck", nil) case "FakeTimeTick": milliseconds := asInt64(data["incrementMs"].(json.Number)) - b.timer.Tick(time.Duration(milliseconds) * time.Millisecond) + if err := neo4j.TickTime(time.Duration(milliseconds) * time.Millisecond); err != nil { + b.writeError(err) + return + } b.writeResponse("FakeTimeAck", nil) case "VerifyAuthentication": @@ -1084,9 +1066,6 @@ func (b *backend) handleRequest(req map[string]any) { } } }) - if b.timer != nil { - auth.SetTimer(manager, b.timer.Now) - } b.authTokenManagers[managerId] = manager b.writeResponse("BasicAuthTokenManager", map[string]any{"id": managerId}) case "BasicAuthTokenProviderCompleted": @@ -1113,9 +1092,6 @@ func (b *backend) handleRequest(req map[string]any) { } } }) - if b.timer != nil { - auth.SetTimer(manager, b.timer.Now) - } b.authTokenManagers[managerId] = manager b.writeResponse("BearerAuthTokenManager", map[string]any{"id": managerId}) case "BearerAuthTokenProviderCompleted": @@ -1127,16 +1103,10 @@ func (b *backend) handleRequest(req map[string]any) { return } var expiration *time.Time - var now func() time.Time - if b.timer != nil { - now = b.timer.Now - } else { - now = time.Now - } expiresInRaw := bearerToken["expiresInMs"] if expiresInRaw != nil { expiresIn := time.Millisecond * time.Duration(asInt64(bearerToken["expiresInMs"].(json.Number))) - expirationTime := now().Add(expiresIn) + expirationTime := neo4j.Now().Add(expiresIn) expiration = &expirationTime } b.resolvedBearerTokens[id] = AuthTokenAndExpiration{token, expiration} @@ -1158,8 +1128,7 @@ func (b *backend) handleRequest(req map[string]any) { "Feature:API:Driver.VerifyAuthentication", "Feature:API:Driver.VerifyConnectivity", //"Feature:API:Driver.SupportsSessionAuth", - // Go driver does not support LivenessCheckTimeout yet - //"Feature:API:Liveness.Check", + "Feature:API:Liveness.Check", "Feature:API:Result.List", "Feature:API:Result.Peek", //"Feature:API:Result.Single", @@ -1549,6 +1518,7 @@ func testSkips() map[string]string { "stub.routing.test_routing_v*.RoutingV*.test_should_accept_routing_table_without_writers_and_then_rediscover": "Driver retries to fetch a routing table up to 100 times if it's empty", "stub.routing.test_routing_v*.RoutingV*.test_should_fail_on_routing_table_with_no_reader": "Driver retries to fetch a routing table up to 100 times if it's empty", "stub.routing.test_routing_v*.RoutingV*.test_should_fail_discovery_when_router_fails_with_unknown_code": "Unify: other drivers have a list of fast failing errors during discover: on anything else, the driver will try the next router", + "stub.routing.test_routing_v*.RoutingV*.test_should_drop_connections_failing_liveness_check": "Liveness check error handling is not (yet) unified: https://github.com/neo-technology/drivers-adr/pull/83", "stub.*.test_0_timeout": "Fixme: driver omits 0 as tx timeout value", "stub.summary.test_summary.TestSummary.test_server_info": "pending unification: should the server address be pre or post DNS resolution?", } diff --git a/testkit/backend.py b/testkit/backend.py index 3386b56d..d0269f11 100644 --- a/testkit/backend.py +++ b/testkit/backend.py @@ -11,5 +11,10 @@ if __name__ == "__main__": backend_path = os.path.join(".", "testkit-backend") - subprocess.check_call(["go", "run", "-tags", "internal_testkit", "-buildvcs=false", backend_path], - stdout=sys.stdout, stderr=sys.stderr) + subprocess.check_call( + [ + "go", "run", "-tags", "internal_testkit,internal_time_mock", + "-buildvcs=false", backend_path + ], + stdout=sys.stdout, stderr=sys.stderr + ) diff --git a/testkit/build.py b/testkit/build.py index 2c7166ed..2b6a72cf 100644 --- a/testkit/build.py +++ b/testkit/build.py @@ -21,7 +21,13 @@ def run(args, env=None): defaultEnv["GOFLAGS"] = "-buildvcs=false" print("Building for current target", flush=True) - run(["go", "build", "-tags", "internal_testkit", "-v", "./..."], env=defaultEnv) + run( + [ + "go", "build", "-tags", "internal_testkit,internal_time_mock", + "-v", "./..." + ], + env=defaultEnv + ) # Compile for 32 bits ARM to make sure it builds print("Building for 32 bits", flush=True) @@ -32,13 +38,27 @@ def run(args, env=None): run(["go", "build", "./neo4j/..."], env=arm32Env) print("Vet sources", flush=True) - run(["go", "vet", "-tags", "internal_testkit", "./..."], env=defaultEnv) + run( + [ + "go", "vet", "-tags", "internal_testkit,internal_time_mock", + "./..." + ], + env=defaultEnv + ) print("Install staticcheck", flush=True) - run(["go", "install", "honnef.co/go/tools/cmd/staticcheck@v0.3.3"], env=defaultEnv) + run(["go", "install", "honnef.co/go/tools/cmd/staticcheck@v0.3.3"], + env=defaultEnv) print("Run staticcheck", flush=True) gopath = Path( subprocess.check_output(["go", "env", "GOPATH"]).decode("utf-8").strip() ) - run([str(gopath / "bin" / "staticcheck"), "-tags", "internal_testkit", "./..."], env=defaultEnv) + run( + [ + str(gopath / "bin" / "staticcheck"), + "-tags", "internal_testkit,internal_time_mock", + "./..." + ], + env=defaultEnv + ) diff --git a/testkit/unittests.py b/testkit/unittests.py index 14504c12..088ffb5d 100644 --- a/testkit/unittests.py +++ b/testkit/unittests.py @@ -19,12 +19,16 @@ def run(args): if __name__ == "__main__": # Run explicit set of unit tests to avoid running integration tests # Specify -v -json to make TeamCity pickup the tests - cmd = ["go", "test"] - if os.environ.get("TEST_IN_TEAMCITY", False): - cmd = cmd + ["-v", "-json"] - path = os.path.join(".", "neo4j", "...") - run(cmd + ["-buildvcs=false", "-short", path]) + + for extra_args in ( + (), ("-tags", "internal_time_mock") + ): + cmd = ["go", "test", *extra_args] + if os.environ.get("TEST_IN_TEAMCITY", False): + cmd = cmd + ["-v", "-json"] + + run(cmd + ["-buildvcs=false", "-short", path]) # Repeat racing tests run(cmd + ["-buildvcs=false", "-race", "-count", "50",