Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADR 019 password rotation #523

Merged
merged 6 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 75 additions & 25 deletions neo4j/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ package auth

import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing"
"reflect"
"time"
Expand Down Expand Up @@ -51,25 +53,29 @@ type TokenManager interface {
// The token returned must always belong to the same identity.
// Switching identities using the `TokenManager` is undefined behavior.
GetAuthToken(ctx context.Context) (auth.Token, error)
// OnTokenExpired is called by the driver when the provided token expires
// OnTokenExpired should invalidate the current token if it matches the provided one
OnTokenExpired(context.Context, auth.Token) error

// HandleSecurityException is called when the server returns any `Neo.ClientError.Security.*` error.
// It should return true if the error was handled, in which case the driver will mark the error as retryable.
HandleSecurityException(context.Context, auth.Token, *db.Neo4jError) (bool, error)
}

type authTokenProvider = func(context.Context) (auth.Token, error)

type authTokenWithExpirationProvider = func(context.Context) (auth.Token, *time.Time, error)

type expirationBasedTokenManager struct {
provider authTokenWithExpirationProvider
token *auth.Token
expiration *time.Time
mutex racing.Mutex
now *func() time.Time
type neo4jAuthTokenManager struct {
provider authTokenWithExpirationProvider
token *auth.Token
expiration *time.Time
mutex racing.Mutex
now *func() time.Time
handledSecurityCodes collections.Set[string]
}

func (m *expirationBasedTokenManager) GetAuthToken(ctx context.Context) (auth.Token, error) {
func (m *neo4jAuthTokenManager) GetAuthToken(ctx context.Context) (auth.Token, error) {
if !m.mutex.TryLock(ctx) {
return auth.Token{}, racing.LockTimeoutError(
"could not acquire lock in time when getting token in ExpirationBasedTokenManager")
"could not acquire lock in time when getting token in neo4jAuthTokenManager")
}
defer m.mutex.Unlock()
if m.token == nil || m.expiration != nil && (*m.now)().After(*m.expiration) {
Expand All @@ -83,34 +89,78 @@ func (m *expirationBasedTokenManager) GetAuthToken(ctx context.Context) (auth.To
return *m.token, nil
}

func (m *expirationBasedTokenManager) OnTokenExpired(ctx context.Context, token auth.Token) error {
func (m *neo4jAuthTokenManager) HandleSecurityException(ctx context.Context, token auth.Token, securityException *db.Neo4jError) (bool, error) {
if !m.handledSecurityCodes.Contains(securityException.Code) {
return false, nil
}
if !m.mutex.TryLock(ctx) {
return racing.LockTimeoutError(
"could not acquire lock in time when handling token expiration in ExpirationBasedTokenManager")
return false, racing.LockTimeoutError(
"could not acquire lock in time when handling security exception in neo4jAuthTokenManager")
}
defer m.mutex.Unlock()
if m.token != nil && reflect.DeepEqual(token.Tokens, m.token.Tokens) {
m.token = nil
}
return nil
return true, nil
}

// ExpirationBasedTokenManager creates a token manager for potentially expiring auth info.
// BasicTokenManager generates a TokenManager to manage basic auth password rotation.
// The provider is invoked solely when a new token instance is required, triggered by server
// rejection of the current token due to an authentication exception.
//
// WARNING:
//
// The first and only argument is a provider function that returns auth information and an optional expiration time.
// If the expiration time is nil, the auth info is assumed to never expire.
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
//
// The provider function must only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
//
// BasicTokenManager is part of the re-authentication preview feature
// (see README on what it means in terms of support and compatibility guarantees)
func BasicTokenManager(provider authTokenProvider) TokenManager {
now := time.Now
return &neo4jAuthTokenManager{
provider: wrapWithNilExpiration(provider),
mutex: racing.NewMutex(),
now: &now,
handledSecurityCodes: collections.NewSet([]string{
"Neo.ClientError.Security.Unauthorized",
}),
}
}

// BearerTokenManager generates a TokenManager to manage possibly expiring authentication details.
//
// The provider is invoked when a new token instance is required, triggered by server
// rejection of the current token due to authentication or token expiration exceptions.
//
// WARNING:
//
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
//
// The provider function only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
// The provider function must only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
//
// ExpirationBasedTokenManager is part of the re-authentication preview feature
// BearerTokenManager is part of the re-authentication preview feature
// (see README on what it means in terms of support and compatibility guarantees)
func ExpirationBasedTokenManager(provider authTokenWithExpirationProvider) TokenManager {
func BearerTokenManager(provider authTokenWithExpirationProvider) TokenManager {
now := time.Now
return &expirationBasedTokenManager{provider: provider, mutex: racing.NewMutex(), now: &now}
return &neo4jAuthTokenManager{
provider: provider,
mutex: racing.NewMutex(),
now: &now,
handledSecurityCodes: collections.NewSet([]string{
"Neo.ClientError.Security.TokenExpired",
"Neo.ClientError.Security.Unauthorized",
}),
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}
}

func wrapWithNilExpiration(provider authTokenProvider) authTokenWithExpirationProvider {
return func(ctx context.Context) (auth.Token, *time.Time, error) {
token, err := provider(ctx)
return token, nil, err
}
}
28 changes: 23 additions & 5 deletions neo4j/auth/auth_examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,41 @@ import (
"time"
)

func ExampleExpirationBasedTokenManager() {
myProvider := func(ctx context.Context) (neo4j.AuthToken, *time.Time, error) {
// some way to getting a token
func ExampleBasicTokenManager() {
fetchBasicAuthToken := func(ctx context.Context) (neo4j.AuthToken, error) {
// some way of getting basic authentication information
username, password, realm, err := getBasicAuth()
if err != nil {
return neo4j.AuthToken{}, err
}
// create and return a basic authentication token with provided username, password and realm
return neo4j.BasicAuth(username, password, realm), nil
}
// create a new driver with a basic token manager which uses provider to handle basic auth password rotation.
_, _ = neo4j.NewDriverWithContext(getUrl(), auth.BasicTokenManager(fetchBasicAuthToken))
}

func ExampleBearerTokenManager() {
fetchAuthTokenFromMyProvider := func(ctx context.Context) (neo4j.AuthToken, *time.Time, error) {
// some way of getting a token
token, err := getSsoToken(ctx)
if err != nil {
return neo4j.AuthToken{}, nil, err
}
// assume we know our tokens expire every 60 seconds

expiresIn := time.Now().Add(60 * time.Second)
// Include a little buffer so that we fetch a new token *before* the old one expires
expiresIn = expiresIn.Add(-10 * time.Second)
// or return nil instead of `&expiresIn` if we don't expect it to expire
return token, &expiresIn, nil
}
// create a new driver with a bearer token manager which uses provider to handle possibly expiring auth tokens.
_, _ = neo4j.NewDriverWithContext(getUrl(), auth.BearerTokenManager(fetchAuthTokenFromMyProvider))
}

_, _ = neo4j.NewDriverWithContext(getUrl(), auth.ExpirationBasedTokenManager(myProvider))
func getBasicAuth() (username, password, realm string, error error) {
username, password, realm = "username", "password", "realm"
return
}

func getSsoToken(context.Context) (neo4j.AuthToken, error) {
Expand Down
8 changes: 4 additions & 4 deletions neo4j/auth/auth_testkit.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build internal_testkit

/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [https://neo4j.com]
Expand All @@ -17,20 +19,18 @@
* limitations under the License.
*/

//go:build internal_testkit

package auth

import "time"

func SetTimer(t TokenManager, timer func() time.Time) {
if t, ok := t.(*expirationBasedTokenManager); ok {
if t, ok := t.(*neo4jAuthTokenManager); ok {
t.now = &timer
}
}

func ResetTime(t TokenManager) {
if t, ok := t.(*expirationBasedTokenManager); ok {
if t, ok := t.(*neo4jAuthTokenManager); ok {
now := time.Now
t.now = &now
}
Expand Down
4 changes: 4 additions & 0 deletions neo4j/db/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ func (e *Neo4jError) reclassify() {
}
}

func (e *Neo4jError) HasSecurityCode() bool {
return strings.HasPrefix(e.Code, "Neo.ClientError.Security.")
}

func (e *Neo4jError) IsAuthenticationFailed() bool {
return e.Code == "Neo.ClientError.Security.Unauthorized"
}
Expand Down
9 changes: 7 additions & 2 deletions neo4j/internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package auth

import "context"
import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
)

type Token struct {
Tokens map[string]any
Expand All @@ -29,4 +32,6 @@ func (a Token) GetAuthToken(context.Context) (Token, error) {
return a, nil
}

func (a Token) OnTokenExpired(context.Context, Token) error { return nil }
func (a Token) HandleSecurityException(context.Context, Token, *db.Neo4jError) (bool, error) {
return false, nil
}
5 changes: 5 additions & 0 deletions neo4j/internal/collections/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ func (set Set[T]) Copy() Set[T] {
}
return result
}

func (set Set[T]) Contains(value T) bool {
_, ok := set[value]
return ok
}
22 changes: 22 additions & 0 deletions neo4j/internal/collections/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,28 @@ func TestSet(outer *testing.T) {
t.Error(err)
}
})

outer.Run("contains", func(t *testing.T) {
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
strings := collections.NewSet([]string{
"golang",
"neo4j",
})
expected := "golang"
if found := strings.Contains(expected); !found {
t.Errorf("Set should have contained %v", expected)
}
})

outer.Run("does not contain", func(t *testing.T) {
strings := collections.NewSet([]string{
"golang",
"neo4j",
})
expected := "foobar"
if found := strings.Contains(expected); found {
t.Errorf("Set should not have contain %v", expected)
}
})
}

func containsExactlyOnce[T comparable](values collections.Set[T], search T) bool {
Expand Down
32 changes: 16 additions & 16 deletions neo4j/internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/config"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt"
idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil"
Expand Down Expand Up @@ -453,35 +452,36 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) {
}

func (p *Pool) OnNeo4jError(ctx context.Context, connection idb.Connection, error *db.Neo4jError) error {
switch error.Code {
case "Neo.ClientError.Security.AuthorizationExpired":
if error.Code == "Neo.ClientError.Security.AuthorizationExpired" {
serverName := connection.ServerName()
p.serversMut.Lock()
defer p.serversMut.Unlock()
server := p.servers[serverName]
server.executeForAllConnections(func(c idb.Connection) {
c.ResetAuth()
})
case "Neo.ClientError.Security.TokenExpired":
}
if error.Code == "Neo.TransientError.General.DatabaseUnavailable" {
p.deactivate(ctx, connection.ServerName())
}
if error.IsRetriableCluster() {
var database string
if dbSelector, ok := connection.(idb.DatabaseSelector); ok {
database = dbSelector.Database()
}
p.deactivateWriter(connection.ServerName(), database)
}
if error.HasSecurityCode() {
manager, token := connection.GetCurrentAuth()
if manager != nil {
if err := manager.OnTokenExpired(ctx, token); err != nil {
handled, err := manager.HandleSecurityException(ctx, token, error)
if err != nil {
return err
}
if _, isStaticToken := manager.(auth.Token); !isStaticToken {
if handled {
error.MarkRetriable()
}
}
case "Neo.TransientError.General.DatabaseUnavailable":
p.deactivate(ctx, connection.ServerName())
default:
if error.IsRetriableCluster() {
var database string
if dbSelector, ok := connection.(idb.DatabaseSelector); ok {
database = dbSelector.Database()
}
p.deactivateWriter(connection.ServerName(), database)
}
}

return nil
Expand Down
Loading