Skip to content

Commit

Permalink
ADR 019 password rotation (#523)
Browse files Browse the repository at this point in the history
* ADR 019 Password Rotation
  • Loading branch information
StephenCathcart committed Aug 30, 2023
1 parent 553515c commit a8a15a5
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 116 deletions.
100 changes: 75 additions & 25 deletions neo4j/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ package auth

import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing"
"reflect"
"time"
Expand Down Expand Up @@ -51,25 +53,29 @@ type TokenManager interface {
// The token returned must always belong to the same identity.
// Switching identities using the `TokenManager` is undefined behavior.
GetAuthToken(ctx context.Context) (auth.Token, error)
// OnTokenExpired is called by the driver when the provided token expires
// OnTokenExpired should invalidate the current token if it matches the provided one
OnTokenExpired(context.Context, auth.Token) error

// HandleSecurityException is called when the server returns any `Neo.ClientError.Security.*` error.
// It should return true if the error was handled, in which case the driver will mark the error as retryable.
HandleSecurityException(context.Context, auth.Token, *db.Neo4jError) (bool, error)
}

type authTokenProvider = func(context.Context) (auth.Token, error)

type authTokenWithExpirationProvider = func(context.Context) (auth.Token, *time.Time, error)

type expirationBasedTokenManager struct {
provider authTokenWithExpirationProvider
token *auth.Token
expiration *time.Time
mutex racing.Mutex
now *func() time.Time
type neo4jAuthTokenManager struct {
provider authTokenWithExpirationProvider
token *auth.Token
expiration *time.Time
mutex racing.Mutex
now *func() time.Time
handledSecurityCodes collections.Set[string]
}

func (m *expirationBasedTokenManager) GetAuthToken(ctx context.Context) (auth.Token, error) {
func (m *neo4jAuthTokenManager) GetAuthToken(ctx context.Context) (auth.Token, error) {
if !m.mutex.TryLock(ctx) {
return auth.Token{}, racing.LockTimeoutError(
"could not acquire lock in time when getting token in ExpirationBasedTokenManager")
"could not acquire lock in time when getting token in neo4jAuthTokenManager")
}
defer m.mutex.Unlock()
if m.token == nil || m.expiration != nil && (*m.now)().After(*m.expiration) {
Expand All @@ -83,34 +89,78 @@ func (m *expirationBasedTokenManager) GetAuthToken(ctx context.Context) (auth.To
return *m.token, nil
}

func (m *expirationBasedTokenManager) OnTokenExpired(ctx context.Context, token auth.Token) error {
func (m *neo4jAuthTokenManager) HandleSecurityException(ctx context.Context, token auth.Token, securityException *db.Neo4jError) (bool, error) {
if !m.handledSecurityCodes.Contains(securityException.Code) {
return false, nil
}
if !m.mutex.TryLock(ctx) {
return racing.LockTimeoutError(
"could not acquire lock in time when handling token expiration in ExpirationBasedTokenManager")
return false, racing.LockTimeoutError(
"could not acquire lock in time when handling security exception in neo4jAuthTokenManager")
}
defer m.mutex.Unlock()
if m.token != nil && reflect.DeepEqual(token.Tokens, m.token.Tokens) {
m.token = nil
}
return nil
return true, nil
}

// ExpirationBasedTokenManager creates a token manager for potentially expiring auth info.
// BasicTokenManager generates a TokenManager to manage basic auth password rotation.
// The provider is invoked solely when a new token instance is required, triggered by server
// rejection of the current token due to an authentication exception.
//
// WARNING:
//
// The first and only argument is a provider function that returns auth information and an optional expiration time.
// If the expiration time is nil, the auth info is assumed to never expire.
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
//
// The provider function must only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
//
// BasicTokenManager is part of the re-authentication preview feature
// (see README on what it means in terms of support and compatibility guarantees)
func BasicTokenManager(provider authTokenProvider) TokenManager {
now := time.Now
return &neo4jAuthTokenManager{
provider: wrapWithNilExpiration(provider),
mutex: racing.NewMutex(),
now: &now,
handledSecurityCodes: collections.NewSet([]string{
"Neo.ClientError.Security.Unauthorized",
}),
}
}

// BearerTokenManager generates a TokenManager to manage possibly expiring authentication details.
//
// The provider is invoked when a new token instance is required, triggered by server
// rejection of the current token due to authentication or token expiration exceptions.
//
// WARNING:
//
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
//
// The provider function only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
// The provider function must only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
//
// ExpirationBasedTokenManager is part of the re-authentication preview feature
// BearerTokenManager is part of the re-authentication preview feature
// (see README on what it means in terms of support and compatibility guarantees)
func ExpirationBasedTokenManager(provider authTokenWithExpirationProvider) TokenManager {
func BearerTokenManager(provider authTokenWithExpirationProvider) TokenManager {
now := time.Now
return &expirationBasedTokenManager{provider: provider, mutex: racing.NewMutex(), now: &now}
return &neo4jAuthTokenManager{
provider: provider,
mutex: racing.NewMutex(),
now: &now,
handledSecurityCodes: collections.NewSet([]string{
"Neo.ClientError.Security.TokenExpired",
"Neo.ClientError.Security.Unauthorized",
}),
}
}

func wrapWithNilExpiration(provider authTokenProvider) authTokenWithExpirationProvider {
return func(ctx context.Context) (auth.Token, *time.Time, error) {
token, err := provider(ctx)
return token, nil, err
}
}
28 changes: 23 additions & 5 deletions neo4j/auth/auth_examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,41 @@ import (
"time"
)

func ExampleExpirationBasedTokenManager() {
myProvider := func(ctx context.Context) (neo4j.AuthToken, *time.Time, error) {
// some way to getting a token
func ExampleBasicTokenManager() {
fetchBasicAuthToken := func(ctx context.Context) (neo4j.AuthToken, error) {
// some way of getting basic authentication information
username, password, realm, err := getBasicAuth()
if err != nil {
return neo4j.AuthToken{}, err
}
// create and return a basic authentication token with provided username, password and realm
return neo4j.BasicAuth(username, password, realm), nil
}
// create a new driver with a basic token manager which uses provider to handle basic auth password rotation.
_, _ = neo4j.NewDriverWithContext(getUrl(), auth.BasicTokenManager(fetchBasicAuthToken))
}

func ExampleBearerTokenManager() {
fetchAuthTokenFromMyProvider := func(ctx context.Context) (neo4j.AuthToken, *time.Time, error) {
// some way of getting a token
token, err := getSsoToken(ctx)
if err != nil {
return neo4j.AuthToken{}, nil, err
}
// assume we know our tokens expire every 60 seconds

expiresIn := time.Now().Add(60 * time.Second)
// Include a little buffer so that we fetch a new token *before* the old one expires
expiresIn = expiresIn.Add(-10 * time.Second)
// or return nil instead of `&expiresIn` if we don't expect it to expire
return token, &expiresIn, nil
}
// create a new driver with a bearer token manager which uses provider to handle possibly expiring auth tokens.
_, _ = neo4j.NewDriverWithContext(getUrl(), auth.BearerTokenManager(fetchAuthTokenFromMyProvider))
}

_, _ = neo4j.NewDriverWithContext(getUrl(), auth.ExpirationBasedTokenManager(myProvider))
func getBasicAuth() (username, password, realm string, error error) {
username, password, realm = "username", "password", "realm"
return
}

func getSsoToken(context.Context) (neo4j.AuthToken, error) {
Expand Down
8 changes: 4 additions & 4 deletions neo4j/auth/auth_testkit.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build internal_testkit

/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [https://neo4j.com]
Expand All @@ -17,20 +19,18 @@
* limitations under the License.
*/

//go:build internal_testkit

package auth

import "time"

func SetTimer(t TokenManager, timer func() time.Time) {
if t, ok := t.(*expirationBasedTokenManager); ok {
if t, ok := t.(*neo4jAuthTokenManager); ok {
t.now = &timer
}
}

func ResetTime(t TokenManager) {
if t, ok := t.(*expirationBasedTokenManager); ok {
if t, ok := t.(*neo4jAuthTokenManager); ok {
now := time.Now
t.now = &now
}
Expand Down
4 changes: 4 additions & 0 deletions neo4j/db/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ func (e *Neo4jError) reclassify() {
}
}

func (e *Neo4jError) HasSecurityCode() bool {
return strings.HasPrefix(e.Code, "Neo.ClientError.Security.")
}

func (e *Neo4jError) IsAuthenticationFailed() bool {
return e.Code == "Neo.ClientError.Security.Unauthorized"
}
Expand Down
9 changes: 7 additions & 2 deletions neo4j/internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package auth

import "context"
import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
)

type Token struct {
Tokens map[string]any
Expand All @@ -29,4 +32,6 @@ func (a Token) GetAuthToken(context.Context) (Token, error) {
return a, nil
}

func (a Token) OnTokenExpired(context.Context, Token) error { return nil }
func (a Token) HandleSecurityException(context.Context, Token, *db.Neo4jError) (bool, error) {
return false, nil
}
5 changes: 5 additions & 0 deletions neo4j/internal/collections/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ func (set Set[T]) Copy() Set[T] {
}
return result
}

func (set Set[T]) Contains(value T) bool {
_, ok := set[value]
return ok
}
22 changes: 22 additions & 0 deletions neo4j/internal/collections/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,28 @@ func TestSet(outer *testing.T) {
t.Error(err)
}
})

outer.Run("contains", func(t *testing.T) {
strings := collections.NewSet([]string{
"golang",
"neo4j",
})
expected := "golang"
if found := strings.Contains(expected); !found {
t.Errorf("Set should have contained %v", expected)
}
})

outer.Run("does not contain", func(t *testing.T) {
strings := collections.NewSet([]string{
"golang",
"neo4j",
})
expected := "foobar"
if found := strings.Contains(expected); found {
t.Errorf("Set should not have contain %v", expected)
}
})
}

func containsExactlyOnce[T comparable](values collections.Set[T], search T) bool {
Expand Down
32 changes: 16 additions & 16 deletions neo4j/internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/config"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt"
idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil"
Expand Down Expand Up @@ -453,35 +452,36 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) {
}

func (p *Pool) OnNeo4jError(ctx context.Context, connection idb.Connection, error *db.Neo4jError) error {
switch error.Code {
case "Neo.ClientError.Security.AuthorizationExpired":
if error.Code == "Neo.ClientError.Security.AuthorizationExpired" {
serverName := connection.ServerName()
p.serversMut.Lock()
defer p.serversMut.Unlock()
server := p.servers[serverName]
server.executeForAllConnections(func(c idb.Connection) {
c.ResetAuth()
})
case "Neo.ClientError.Security.TokenExpired":
}
if error.Code == "Neo.TransientError.General.DatabaseUnavailable" {
p.deactivate(ctx, connection.ServerName())
}
if error.IsRetriableCluster() {
var database string
if dbSelector, ok := connection.(idb.DatabaseSelector); ok {
database = dbSelector.Database()
}
p.deactivateWriter(connection.ServerName(), database)
}
if error.HasSecurityCode() {
manager, token := connection.GetCurrentAuth()
if manager != nil {
if err := manager.OnTokenExpired(ctx, token); err != nil {
handled, err := manager.HandleSecurityException(ctx, token, error)
if err != nil {
return err
}
if _, isStaticToken := manager.(auth.Token); !isStaticToken {
if handled {
error.MarkRetriable()
}
}
case "Neo.TransientError.General.DatabaseUnavailable":
p.deactivate(ctx, connection.ServerName())
default:
if error.IsRetriableCluster() {
var database string
if dbSelector, ok := connection.(idb.DatabaseSelector); ok {
database = dbSelector.Database()
}
p.deactivateWriter(connection.ServerName(), database)
}
}

return nil
Expand Down
Loading

0 comments on commit a8a15a5

Please sign in to comment.