diff --git a/.github/wordlist.txt b/.github/wordlist.txt index 578616b9d..a922d99ba 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -65,4 +65,12 @@ RedisGears RedisTimeseries RediSearch RawResult -RawVal \ No newline at end of file +RawVal +entra +EntraID +Entra +OAuth +Azure +StreamingCredentialsProvider +oauth +entraid \ No newline at end of file diff --git a/.gitignore b/.gitignore index e9c8f5264..0d99709e3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ testdata/* redis8tests.sh coverage.txt **/coverage.txt -.vscode \ No newline at end of file +.vscode +tmp/* diff --git a/README.md b/README.md index 4487c6e9a..265712472 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ key value NoSQL database that uses RocksDB as storage engine and is compatible w - Redis commands except QUIT and SYNC. - Automatic connection pooling. +- [StreamingCredentialsProvider (e.g. entra id, oauth)](#1-streaming-credentials-provider-highest-priority) - [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html). - [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html). - [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html). @@ -136,17 +137,121 @@ func ExampleClient() { } ``` -The above can be modified to specify the version of the RESP protocol by adding the `protocol` -option to the `Options` struct: +### Authentication + +The Redis client supports multiple ways to provide authentication credentials, with a clear priority order. Here are the available options: + +#### 1. Streaming Credentials Provider (Highest Priority) + +The streaming credentials provider allows for dynamic credential updates during the connection lifetime. This is particularly useful for managed identity services and token-based authentication. ```go - rdb := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", // no password set - DB: 0, // use default DB - Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3 - }) +type StreamingCredentialsProvider interface { + Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) +} + +type CredentialsListener interface { + OnNext(credentials Credentials) // Called when credentials are updated + OnError(err error) // Called when an error occurs +} + +type Credentials interface { + BasicAuth() (username string, password string) + RawCredentials() string +} +``` + +Example usage: +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + StreamingCredentialsProvider: &MyCredentialsProvider{}, +}) +``` + +**Note:** The streaming credentials provider can be used with [go-redis-entraid](https://github.com/redis-developer/go-redis-entraid) to enable Entra ID (formerly Azure AD) authentication. This allows for seamless integration with Azure's managed identity services and token-based authentication. + +Example with Entra ID: +```go +import ( + "github.com/redis/go-redis/v9" + "github.com/redis-developer/go-redis-entraid" +) + +// Create an Entra ID credentials provider +provider := entraid.NewDefaultAzureIdentityProvider() + +// Configure Redis client with Entra ID authentication +rdb := redis.NewClient(&redis.Options{ + Addr: "your-redis-server.redis.cache.windows.net:6380", + StreamingCredentialsProvider: provider, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, +}) +``` +#### 2. Context-based Credentials Provider + +The context-based provider allows credentials to be determined at the time of each operation, using the context. + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + CredentialsProviderContext: func(ctx context.Context) (string, string, error) { + // Return username, password, and any error + return "user", "pass", nil + }, +}) +``` + +#### 3. Regular Credentials Provider + +A simple function-based provider that returns static credentials. + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + CredentialsProvider: func() (string, string) { + // Return username and password + return "user", "pass" + }, +}) +``` + +#### 4. Username/Password Fields (Lowest Priority) + +The most basic way to provide credentials is through the `Username` and `Password` fields in the options. + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Username: "user", + Password: "pass", +}) +``` + +#### Priority Order + +The client will use credentials in the following priority order: +1. Streaming Credentials Provider (if set) +2. Context-based Credentials Provider (if set) +3. Regular Credentials Provider (if set) +4. Username/Password fields (if set) + +If none of these are set, the client will attempt to connect without authentication. + +### Protocol Version + +The client supports both RESP2 and RESP3 protocols. You can specify the protocol version in the options: + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", // no password set + DB: 0, // use default DB + Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3 +}) ``` ### Connecting via a redis url diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 000000000..1f5c80224 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,61 @@ +// Package auth package provides authentication-related interfaces and types. +// It also includes a basic implementation of credentials using username and password. +package auth + +// StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider. +// It is used to provide credentials for authentication. +// The CredentialsListener is used to receive updates when the credentials change. +type StreamingCredentialsProvider interface { + // Subscribe subscribes to the credentials provider for updates. + // It returns the current credentials, a cancel function to unsubscribe from the provider, + // and an error if any. + // TODO(ndyakov): Should we add context to the Subscribe method? + Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) +} + +// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider. +// It is used to unsubscribe from the provider when the credentials are no longer needed. +type UnsubscribeFunc func() error + +// CredentialsListener is an interface that defines the methods for a credentials listener. +// It is used to receive updates when the credentials change. +// The OnNext method is called when the credentials change. +// The OnError method is called when an error occurs while requesting the credentials. +type CredentialsListener interface { + OnNext(credentials Credentials) + OnError(err error) +} + +// Credentials is an interface that defines the methods for credentials. +// It is used to provide the credentials for authentication. +type Credentials interface { + // BasicAuth returns the username and password for basic authentication. + BasicAuth() (username string, password string) + // RawCredentials returns the raw credentials as a string. + // This can be used to extract the username and password from the raw credentials or + // additional information if present in the token. + RawCredentials() string +} + +type basicAuth struct { + username string + password string +} + +// RawCredentials returns the raw credentials as a string. +func (b *basicAuth) RawCredentials() string { + return b.username + ":" + b.password +} + +// BasicAuth returns the username and password for basic authentication. +func (b *basicAuth) BasicAuth() (username string, password string) { + return b.username, b.password +} + +// NewBasicCredentials creates a new Credentials object from the given username and password. +func NewBasicCredentials(username, password string) Credentials { + return &basicAuth{ + username: username, + password: password, + } +} diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 000000000..be762a854 --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,308 @@ +package auth + +import ( + "errors" + "sync" + "testing" + "time" +) + +type mockStreamingProvider struct { + credentials Credentials + err error + updates chan Credentials +} + +func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider { + return &mockStreamingProvider{ + credentials: initialCreds, + updates: make(chan Credentials, 10), + } +} + +func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) { + if m.err != nil { + return nil, nil, m.err + } + + // Send initial credentials + listener.OnNext(m.credentials) + + // Start goroutine to handle updates + go func() { + for creds := range m.updates { + listener.OnNext(creds) + } + }() + + return m.credentials, func() error { + close(m.updates) + return nil + }, nil +} + +func TestStreamingCredentialsProvider(t *testing.T) { + t.Run("successful subscription", func(t *testing.T) { + initialCreds := NewBasicCredentials("user1", "pass1") + provider := newMockStreamingProvider(initialCreds) + + var receivedCreds []Credentials + var receivedErrors []error + var mu sync.Mutex + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + mu.Lock() + receivedCreds = append(receivedCreds, creds) + mu.Unlock() + return nil + }, + func(err error) { + receivedErrors = append(receivedErrors, err) + }, + ) + + creds, cancel, err := provider.Subscribe(listener) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cancel == nil { + t.Fatal("expected cancel function to be non-nil") + } + if creds != initialCreds { + t.Fatalf("expected credentials %v, got %v", initialCreds, creds) + } + if len(receivedCreds) != 1 { + t.Fatalf("expected 1 received credential, got %d", len(receivedCreds)) + } + if receivedCreds[0] != initialCreds { + t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0]) + } + if len(receivedErrors) != 0 { + t.Fatalf("expected no errors, got %d", len(receivedErrors)) + } + + // Send an update + newCreds := NewBasicCredentials("user2", "pass2") + provider.updates <- newCreds + + // Wait for update to be processed + time.Sleep(100 * time.Millisecond) + mu.Lock() + if len(receivedCreds) != 2 { + t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds)) + } + if receivedCreds[1] != newCreds { + t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1]) + } + mu.Unlock() + + // Cancel subscription + if err := cancel(); err != nil { + t.Fatalf("unexpected error cancelling subscription: %v", err) + } + }) + + t.Run("subscription error", func(t *testing.T) { + provider := &mockStreamingProvider{ + err: errors.New("subscription failed"), + } + + var receivedCreds []Credentials + var receivedErrors []error + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + receivedCreds = append(receivedCreds, creds) + return nil + }, + func(err error) { + receivedErrors = append(receivedErrors, err) + }, + ) + + creds, cancel, err := provider.Subscribe(listener) + if err == nil { + t.Fatal("expected error, got nil") + } + if cancel != nil { + t.Fatal("expected cancel function to be nil") + } + if creds != nil { + t.Fatalf("expected nil credentials, got %v", creds) + } + if len(receivedCreds) != 0 { + t.Fatalf("expected no received credentials, got %d", len(receivedCreds)) + } + if len(receivedErrors) != 0 { + t.Fatalf("expected no errors, got %d", len(receivedErrors)) + } + }) + + t.Run("re-auth error", func(t *testing.T) { + initialCreds := NewBasicCredentials("user1", "pass1") + provider := newMockStreamingProvider(initialCreds) + + reauthErr := errors.New("re-auth failed") + var receivedErrors []error + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + return reauthErr + }, + func(err error) { + receivedErrors = append(receivedErrors, err) + }, + ) + + creds, cancel, err := provider.Subscribe(listener) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cancel == nil { + t.Fatal("expected cancel function to be non-nil") + } + if creds != initialCreds { + t.Fatalf("expected credentials %v, got %v", initialCreds, creds) + } + if len(receivedErrors) != 1 { + t.Fatalf("expected 1 error, got %d", len(receivedErrors)) + } + if receivedErrors[0] != reauthErr { + t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0]) + } + + if err := cancel(); err != nil { + t.Fatalf("unexpected error cancelling subscription: %v", err) + } + }) +} + +func TestBasicCredentials(t *testing.T) { + t.Run("basic auth", func(t *testing.T) { + creds := NewBasicCredentials("user1", "pass1") + username, password := creds.BasicAuth() + if username != "user1" { + t.Fatalf("expected username 'user1', got '%s'", username) + } + if password != "pass1" { + t.Fatalf("expected password 'pass1', got '%s'", password) + } + }) + + t.Run("raw credentials", func(t *testing.T) { + creds := NewBasicCredentials("user1", "pass1") + raw := creds.RawCredentials() + expected := "user1:pass1" + if raw != expected { + t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw) + } + }) + + t.Run("empty username", func(t *testing.T) { + creds := NewBasicCredentials("", "pass1") + username, password := creds.BasicAuth() + if username != "" { + t.Fatalf("expected empty username, got '%s'", username) + } + if password != "pass1" { + t.Fatalf("expected password 'pass1', got '%s'", password) + } + }) +} + +func TestReAuthCredentialsListener(t *testing.T) { + t.Run("successful re-auth", func(t *testing.T) { + var reAuthCalled bool + var onErrCalled bool + var receivedCreds Credentials + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + reAuthCalled = true + receivedCreds = creds + return nil + }, + func(err error) { + onErrCalled = true + }, + ) + + creds := NewBasicCredentials("user1", "pass1") + listener.OnNext(creds) + + if !reAuthCalled { + t.Fatal("expected reAuth to be called") + } + if onErrCalled { + t.Fatal("expected onErr not to be called") + } + if receivedCreds != creds { + t.Fatalf("expected credentials %v, got %v", creds, receivedCreds) + } + }) + + t.Run("re-auth error", func(t *testing.T) { + var reAuthCalled bool + var onErrCalled bool + var receivedErr error + expectedErr := errors.New("re-auth failed") + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + reAuthCalled = true + return expectedErr + }, + func(err error) { + onErrCalled = true + receivedErr = err + }, + ) + + creds := NewBasicCredentials("user1", "pass1") + listener.OnNext(creds) + + if !reAuthCalled { + t.Fatal("expected reAuth to be called") + } + if !onErrCalled { + t.Fatal("expected onErr to be called") + } + if receivedErr != expectedErr { + t.Fatalf("expected error %v, got %v", expectedErr, receivedErr) + } + }) + + t.Run("on error", func(t *testing.T) { + var onErrCalled bool + var receivedErr error + expectedErr := errors.New("provider error") + + listener := NewReAuthCredentialsListener( + func(creds Credentials) error { + return nil + }, + func(err error) { + onErrCalled = true + receivedErr = err + }, + ) + + listener.OnError(expectedErr) + + if !onErrCalled { + t.Fatal("expected onErr to be called") + } + if receivedErr != expectedErr { + t.Fatalf("expected error %v, got %v", expectedErr, receivedErr) + } + }) + + t.Run("nil callbacks", func(t *testing.T) { + listener := NewReAuthCredentialsListener(nil, nil) + + // Should not panic + listener.OnNext(NewBasicCredentials("user1", "pass1")) + listener.OnError(errors.New("test error")) + }) +} diff --git a/auth/reauth_credentials_listener.go b/auth/reauth_credentials_listener.go new file mode 100644 index 000000000..40076a0b1 --- /dev/null +++ b/auth/reauth_credentials_listener.go @@ -0,0 +1,47 @@ +package auth + +// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface. +// It is used to re-authenticate the credentials when they are updated. +// It contains: +// - reAuth: a function that takes the new credentials and returns an error if any. +// - onErr: a function that takes an error and handles it. +type ReAuthCredentialsListener struct { + reAuth func(credentials Credentials) error + onErr func(err error) +} + +// OnNext is called when the credentials are updated. +// It calls the reAuth function with the new credentials. +// If the reAuth function returns an error, it calls the onErr function with the error. +func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) { + if c.reAuth == nil { + return + } + + err := c.reAuth(credentials) + if err != nil { + c.OnError(err) + } +} + +// OnError is called when an error occurs. +// It can be called from both the credentials provider and the reAuth function. +func (c *ReAuthCredentialsListener) OnError(err error) { + if c.onErr == nil { + return + } + + c.onErr(err) +} + +// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener. +// Implements the auth.CredentialsListener interface. +func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener { + return &ReAuthCredentialsListener{ + reAuth: reAuth, + onErr: onErr, + } +} + +// Ensure ReAuthCredentialsListener implements the CredentialsListener interface. +var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) diff --git a/command_recorder_test.go b/command_recorder_test.go new file mode 100644 index 000000000..2251df5ef --- /dev/null +++ b/command_recorder_test.go @@ -0,0 +1,86 @@ +package redis_test + +import ( + "context" + "strings" + "sync" + + "github.com/redis/go-redis/v9" +) + +// commandRecorder records the last N commands executed by a Redis client. +type commandRecorder struct { + mu sync.Mutex + commands []string + maxSize int +} + +// newCommandRecorder creates a new command recorder with the specified maximum size. +func newCommandRecorder(maxSize int) *commandRecorder { + return &commandRecorder{ + commands: make([]string, 0, maxSize), + maxSize: maxSize, + } +} + +// Record adds a command to the recorder. +func (r *commandRecorder) Record(cmd string) { + cmd = strings.ToLower(cmd) + r.mu.Lock() + defer r.mu.Unlock() + + r.commands = append(r.commands, cmd) + if len(r.commands) > r.maxSize { + r.commands = r.commands[1:] + } +} + +// LastCommands returns a copy of the recorded commands. +func (r *commandRecorder) LastCommands() []string { + r.mu.Lock() + defer r.mu.Unlock() + return append([]string(nil), r.commands...) +} + +// Contains checks if the recorder contains a specific command. +func (r *commandRecorder) Contains(cmd string) bool { + cmd = strings.ToLower(cmd) + r.mu.Lock() + defer r.mu.Unlock() + for _, c := range r.commands { + if strings.Contains(c, cmd) { + return true + } + } + return false +} + +// Hook returns a Redis hook that records commands. +func (r *commandRecorder) Hook() redis.Hook { + return &commandHook{recorder: r} +} + +// commandHook implements the redis.Hook interface to record commands. +type commandHook struct { + recorder *commandRecorder +} + +func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook { + return next +} + +func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + h.recorder.Record(cmd.String()) + return next(ctx, cmd) + } +} + +func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + for _, cmd := range cmds { + h.recorder.Record(cmd.String()) + } + return next(ctx, cmds) + } +} diff --git a/doctests/lpush_lrange_test.go b/doctests/lpush_lrange_test.go index 1e69f4b0a..4c5a03a38 100644 --- a/doctests/lpush_lrange_test.go +++ b/doctests/lpush_lrange_test.go @@ -5,6 +5,7 @@ package example_commands_test import ( "context" "fmt" + "time" "github.com/redis/go-redis/v9" ) @@ -33,6 +34,7 @@ func ExampleClient_LPush_and_lrange() { } fmt.Println(listSize) + time.Sleep(10 * time.Millisecond) // Simulate some delay value, err := rdb.LRange(ctx, "my_bikes", 0, -1).Result() if err != nil { diff --git a/example_instrumentation_test.go b/example_instrumentation_test.go index a6069cf3f..36234ff09 100644 --- a/example_instrumentation_test.go +++ b/example_instrumentation_test.go @@ -23,38 +23,47 @@ func (redisHook) DialHook(hook redis.DialHook) redis.DialHook { func (redisHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { - fmt.Printf("starting processing: <%s>\n", cmd) + fmt.Printf("starting processing: <%v>\n", cmd.Args()) err := hook(ctx, cmd) - fmt.Printf("finished processing: <%s>\n", cmd) + fmt.Printf("finished processing: <%v>\n", cmd.Args()) return err } } func (redisHook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { - fmt.Printf("pipeline starting processing: %v\n", cmds) + names := make([]string, 0, len(cmds)) + for _, cmd := range cmds { + names = append(names, fmt.Sprintf("%v", cmd.Args())) + } + fmt.Printf("pipeline starting processing: %v\n", names) err := hook(ctx, cmds) - fmt.Printf("pipeline finished processing: %v\n", cmds) + fmt.Printf("pipeline finished processing: %v\n", names) return err } } func Example_instrumentation() { rdb := redis.NewClient(&redis.Options{ - Addr: ":6379", + Addr: ":6379", + DisableIdentity: true, }) rdb.AddHook(redisHook{}) rdb.Ping(ctx) - // Output: starting processing: <ping: > + // Output: + // starting processing: <[ping]> // dialing tcp :6379 // finished dialing tcp :6379 - // finished processing: <ping: PONG> + // starting processing: <[hello 3]> + // finished processing: <[hello 3]> + // finished processing: <[ping]> } func ExamplePipeline_instrumentation() { rdb := redis.NewClient(&redis.Options{ - Addr: ":6379", + Addr: ":6379", + DisableIdentity: true, }) rdb.AddHook(redisHook{}) @@ -63,15 +72,19 @@ func ExamplePipeline_instrumentation() { pipe.Ping(ctx) return nil }) - // Output: pipeline starting processing: [ping: ping: ] + // Output: + // pipeline starting processing: [[ping] [ping]] // dialing tcp :6379 // finished dialing tcp :6379 - // pipeline finished processing: [ping: PONG ping: PONG] + // starting processing: <[hello 3]> + // finished processing: <[hello 3]> + // pipeline finished processing: [[ping] [ping]] } func ExampleClient_Watch_instrumentation() { rdb := redis.NewClient(&redis.Options{ - Addr: ":6379", + Addr: ":6379", + DisableIdentity: true, }) rdb.AddHook(redisHook{}) @@ -81,14 +94,16 @@ func ExampleClient_Watch_instrumentation() { return nil }, "foo") // Output: - // starting processing: <watch foo: > + // starting processing: <[watch foo]> // dialing tcp :6379 // finished dialing tcp :6379 - // finished processing: <watch foo: OK> - // starting processing: <ping: > - // finished processing: <ping: PONG> - // starting processing: <ping: > - // finished processing: <ping: PONG> - // starting processing: <unwatch: > - // finished processing: <unwatch: OK> + // starting processing: <[hello 3]> + // finished processing: <[hello 3]> + // finished processing: <[watch foo]> + // starting processing: <[ping]> + // finished processing: <[ping]> + // starting processing: <[ping]> + // finished processing: <[ping]> + // starting processing: <[unwatch]> + // finished processing: <[unwatch]> } diff --git a/internal_test.go b/internal_test.go index 516ada823..a61b5c02d 100644 --- a/internal_test.go +++ b/internal_test.go @@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) { }, NewClient: func(opt *Options) *Client { c := NewClient(opt) - c.baseClient.onClose = func() error { + c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error { closeCounter.increment(opt.Addr) return nil - } + }) return c }, }) @@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) { } createCounter.increment(opt.Addr) c := NewClient(opt) - c.baseClient.onClose = func() error { + c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error { closeCounter.increment(opt.Addr) return nil - } + }) return c }, }) diff --git a/options.go b/options.go index 0ebeec342..eb35353da 100644 --- a/options.go +++ b/options.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal/pool" ) @@ -29,10 +30,13 @@ type Limiter interface { // Options keeps the settings to set up redis connection. type Options struct { - // The network type, either tcp or unix. - // Default is tcp. + + // Network type, either tcp or unix. + // + // default: is tcp. Network string - // host:port address. + + // Addr is the address formated as host:port Addr string // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. @@ -46,17 +50,21 @@ type Options struct { OnConnect func(ctx context.Context, cn *Conn) error // Protocol 2 or 3. Use the version to negotiate RESP version with redis-server. - // Default is 3. + // + // default: 3. Protocol int - // Use the specified Username to authenticate the current connection + + // Username is used to authenticate the current connection // with one of the connections defined in the ACL list when connecting // to a Redis 6.0 instance, or greater, that is using the Redis ACL system. Username string - // Optional password. Must match the password specified in the - // requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), + + // Password is an optional password. Must match the password specified in the + // `requirepass` server configuration option (if connecting to a Redis 5.0 instance, or lower), // or the User Password when connecting to a Redis 6.0 instance, or greater, // that is using the Redis ACL system. Password string + // CredentialsProvider allows the username and password to be updated // before reconnecting. It should return the current username and password. CredentialsProvider func() (username string, password string) @@ -67,85 +75,126 @@ type Options struct { // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider. CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) - // Database to be selected after connecting to the server. + // StreamingCredentialsProvider is used to retrieve the credentials + // for the connection from an external source. Those credentials may change + // during the connection lifetime. This is useful for managed identity + // scenarios where the credentials are retrieved from an external source. + // + // Currently, this is a placeholder for the future implementation. + StreamingCredentialsProvider auth.StreamingCredentialsProvider + + // DB is the database to be selected after connecting to the server. DB int - // Maximum number of retries before giving up. - // Default is 3 retries; -1 (not 0) disables retries. + // MaxRetries is the maximum number of retries before giving up. + // -1 (not 0) disables retries. + // + // default: 3 retries MaxRetries int - // Minimum backoff between each retry. - // Default is 8 milliseconds; -1 disables backoff. + + // MinRetryBackoff is the minimum backoff between each retry. + // -1 disables backoff. + // + // default: 8 milliseconds MinRetryBackoff time.Duration - // Maximum backoff between each retry. - // Default is 512 milliseconds; -1 disables backoff. + + // MaxRetryBackoff is the maximum backoff between each retry. + // -1 disables backoff. + // default: 512 milliseconds; MaxRetryBackoff time.Duration - // Dial timeout for establishing new connections. - // Default is 5 seconds. + // DialTimeout for establishing new connections. + // + // default: 5 seconds DialTimeout time.Duration - // Timeout for socket reads. If reached, commands will fail + + // ReadTimeout for socket reads. If reached, commands will fail // with a timeout instead of blocking. Supported values: - // - `0` - default timeout (3 seconds). - // - `-1` - no timeout (block indefinitely). - // - `-2` - disables SetReadDeadline calls completely. + // + // - `-1` - no timeout (block indefinitely). + // - `-2` - disables SetReadDeadline calls completely. + // + // default: 3 seconds ReadTimeout time.Duration - // Timeout for socket writes. If reached, commands will fail + + // WriteTimeout for socket writes. If reached, commands will fail // with a timeout instead of blocking. Supported values: - // - `0` - default timeout (3 seconds). - // - `-1` - no timeout (block indefinitely). - // - `-2` - disables SetWriteDeadline calls completely. + // + // - `-1` - no timeout (block indefinitely). + // - `-2` - disables SetWriteDeadline calls completely. + // + // default: 3 seconds WriteTimeout time.Duration + // ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines. // See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts ContextTimeoutEnabled bool - // Type of connection pool. - // true for FIFO pool, false for LIFO pool. + // PoolFIFO type of connection pool. + // + // - true for FIFO pool + // - false for LIFO pool. + // // Note that FIFO has slightly higher overhead compared to LIFO, // but it helps closing idle connections faster reducing the pool size. PoolFIFO bool - // Base number of socket connections. + + // PoolSize is the base number of socket connections. // Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS. // If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize, // you can limit it through MaxActiveConns + // + // default: 10 * runtime.GOMAXPROCS(0) PoolSize int - // Amount of time client waits for connection if all connections + + // PoolTimeout is the amount of time client waits for connection if all connections // are busy before returning an error. - // Default is ReadTimeout + 1 second. + // + // default: ReadTimeout + 1 second PoolTimeout time.Duration - // Minimum number of idle connections which is useful when establishing - // new connection is slow. - // Default is 0. the idle connections are not closed by default. + + // MinIdleConns is the minimum number of idle connections which is useful when establishing + // new connection is slow. The idle connections are not closed by default. + // + // default: 0 MinIdleConns int - // Maximum number of idle connections. - // Default is 0. the idle connections are not closed by default. + + // MaxIdleConns is the maximum number of idle connections. + // The idle connections are not closed by default. + // + // default: 0 MaxIdleConns int - // Maximum number of connections allocated by the pool at a given time. + + // MaxActiveConns is the maximum number of connections allocated by the pool at a given time. // When zero, there is no limit on the number of connections in the pool. + // If the pool is full, the next call to Get() will block until a connection is released. MaxActiveConns int + // ConnMaxIdleTime is the maximum amount of time a connection may be idle. // Should be less than server's timeout. // // Expired connections may be closed lazily before reuse. // If d <= 0, connections are not closed due to a connection's idle time. + // -1 disables idle timeout check. // - // Default is 30 minutes. -1 disables idle timeout check. + // default: 30 minutes ConnMaxIdleTime time.Duration + // ConnMaxLifetime is the maximum amount of time a connection may be reused. // // Expired connections may be closed lazily before reuse. // If <= 0, connections are not closed due to a connection's age. // - // Default is to not close idle connections. + // default: 0 ConnMaxLifetime time.Duration - // TLS Config to use. When set, TLS will be negotiated. + // TLSConfig to use. When set, TLS will be negotiated. TLSConfig *tls.Config // Limiter interface used to implement circuit breaker or rate limiter. Limiter Limiter - // Enables read only queries on slave/follower nodes. + // readOnly enables read only queries on slave/follower nodes. readOnly bool // DisableIndentity - Disable set-lib on connect. @@ -161,9 +210,11 @@ type Options struct { DisableIdentity bool // Add suffix to client name. Default is empty. + // IdentitySuffix - add suffix to client name. IdentitySuffix string // UnstableResp3 enables Unstable mode for Redis Search module with RESP3. + // When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult UnstableResp3 bool } diff --git a/osscluster.go b/osscluster.go index 20180464e..be06f3f90 100644 --- a/osscluster.go +++ b/osscluster.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hashtag" "github.com/redis/go-redis/v9/internal/pool" @@ -66,11 +67,12 @@ type ClusterOptions struct { OnConnect func(ctx context.Context, cn *Conn) error - Protocol int - Username string - Password string - CredentialsProvider func() (username string, password string) - CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + Protocol int + Username string + Password string + CredentialsProvider func() (username string, password string) + CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + StreamingCredentialsProvider auth.StreamingCredentialsProvider MaxRetries int MinRetryBackoff time.Duration @@ -291,11 +293,12 @@ func (opt *ClusterOptions) clientOptions() *Options { Dialer: opt.Dialer, OnConnect: opt.OnConnect, - Protocol: opt.Protocol, - Username: opt.Username, - Password: opt.Password, - CredentialsProvider: opt.CredentialsProvider, - CredentialsProviderContext: opt.CredentialsProviderContext, + Protocol: opt.Protocol, + Username: opt.Username, + Password: opt.Password, + CredentialsProvider: opt.CredentialsProvider, + CredentialsProviderContext: opt.CredentialsProviderContext, + StreamingCredentialsProvider: opt.StreamingCredentialsProvider, MaxRetries: opt.MaxRetries, MinRetryBackoff: opt.MinRetryBackoff, diff --git a/osscluster_test.go b/osscluster_test.go index ccf6daad8..6e214a719 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -89,6 +89,9 @@ func (s *clusterScenario) newClusterClient( func (s *clusterScenario) Close() error { ctx := context.TODO() for _, master := range s.masters() { + if master == nil { + continue + } err := master.FlushAll(ctx).Err() if err != nil { return err diff --git a/probabilistic_test.go b/probabilistic_test.go index a0a050e23..0a3f1a15c 100644 --- a/probabilistic_test.go +++ b/probabilistic_test.go @@ -298,7 +298,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() { }) It("should CFCount", Label("cuckoo", "cfcount"), func() { - err := client.CFAdd(ctx, "testcf1", "item1").Err() + client.CFAdd(ctx, "testcf1", "item1") cnt, err := client.CFCount(ctx, "testcf1", "item1").Result() Expect(err).NotTo(HaveOccurred()) Expect(cnt).To(BeEquivalentTo(int64(1))) @@ -394,7 +394,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() { NoCreate: true, } - result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() + _, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() Expect(err).To(HaveOccurred()) args = &redis.CFInsertOptions{ @@ -402,7 +402,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() { NoCreate: false, } - result, err = client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() + result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result() Expect(err).NotTo(HaveOccurred()) Expect(len(result)).To(BeEquivalentTo(3)) }) diff --git a/redis.go b/redis.go index e0159294d..03f939ad2 100644 --- a/redis.go +++ b/redis.go @@ -4,11 +4,13 @@ import ( "context" "errors" "fmt" + "log" "net" "sync" "sync/atomic" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/pool" @@ -203,6 +205,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e type baseClient struct { opt *Options connPool pool.Pooler + hooksMixin onClose func() error // hook called when client is closed } @@ -282,36 +285,107 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } +func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, conn *Conn) auth.CredentialsListener { + return auth.NewReAuthCredentialsListener( + c.reAuthConnection(c.context(ctx), conn), + c.onAuthenticationErr(c.context(ctx), conn), + ) +} + +func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error { + return func(credentials auth.Credentials) error { + var err error + username, password := credentials.BasicAuth() + if username != "" { + err = cn.AuthACL(ctx, username, password).Err() + } else { + err = cn.Auth(ctx, password).Err() + } + return err + } +} +func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) { + return func(err error) { + // since the connection pool of the *Conn will actually return us the underlying pool.Conn, + // we can get it from the *Conn and remove it from the clients pool. + if err != nil { + if isBadConn(err, false, c.opt.Addr) { + poolCn, getErr := cn.connPool.Get(ctx) + if getErr == nil { + c.connPool.Remove(ctx, poolCn, err) + } else { + // if we can't get the pool connection, we can only close the connection + if err := cn.Close(); err != nil { + log.Printf("failed to close connection: %v", err) + } + } + } + } + } +} + +func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error { + onClose := c.onClose + return func() error { + var firstErr error + err := newOnClose() + // Even if we have an error we would like to execute the onClose hook + // if it exists. We will return the first error that occurred. + // This is to keep error handling consistent with the rest of the code. + if err != nil { + firstErr = err + } + if onClose != nil { + err = onClose() + if err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr + } +} + func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { if cn.Inited { return nil } - cn.Inited = true var err error - username, password := c.opt.Username, c.opt.Password - if c.opt.CredentialsProviderContext != nil { - if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil { - return err - } - } else if c.opt.CredentialsProvider != nil { - username, password = c.opt.CredentialsProvider() - } - + cn.Inited = true connPool := pool.NewSingleConnPool(c.connPool, cn) - conn := newConn(c.opt, connPool) - var auth bool + conn := newConn(c.opt, connPool, c.hooksMixin) + protocol := c.opt.Protocol // By default, use RESP3 in current version. if protocol < 2 { protocol = 3 } + username, password := "", "" + if c.opt.StreamingCredentialsProvider != nil { + credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider. + Subscribe(c.newReAuthCredentialsListener(ctx, conn)) + if err != nil { + return fmt.Errorf("failed to subscribe to streaming credentials: %w", err) + } + c.onClose = c.wrappedOnClose(cancelCredentialsProvider) + username, password = credentials.BasicAuth() + } else if c.opt.CredentialsProviderContext != nil { + username, password, err = c.opt.CredentialsProviderContext(ctx) + if err != nil { + return fmt.Errorf("failed to get credentials from context provider: %w", err) + } + } else if c.opt.CredentialsProvider != nil { + username, password = c.opt.CredentialsProvider() + } else if c.opt.Username != "" || c.opt.Password != "" { + username, password = c.opt.Username, c.opt.Password + } + // for redis-server versions that do not support the HELLO command, // RESP2 will continue to be used. if err = conn.Hello(ctx, protocol, username, password, c.opt.ClientName).Err(); err == nil { - auth = true + // Authentication successful with HELLO command } else if !isRedisError(err) { // When the server responds with the RESP protocol and the result is not a normal // execution result of the HELLO command, we consider it to be an indication that @@ -321,17 +395,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { // with different error string results for unsupported commands, making it // difficult to rely on error strings to determine all results. return err + } else if password != "" { + // Try legacy AUTH command if HELLO failed + err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password)) + if err != nil { + return fmt.Errorf("failed to authenticate: %w", err) + } } _, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { - if !auth && password != "" { - if username != "" { - pipe.AuthACL(ctx, username, password) - } else { - pipe.Auth(ctx, password) - } - } - if c.opt.DB > 0 { pipe.Select(ctx, c.opt.DB) } @@ -347,7 +419,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return nil }) if err != nil { - return err + return fmt.Errorf("failed to initialize connection options: %w", err) } if !c.opt.DisableIdentity && !c.opt.DisableIndentity { @@ -369,6 +441,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { if c.opt.OnConnect != nil { return c.opt.OnConnect(ctx, conn) } + return nil } @@ -487,6 +560,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { return c.opt.ReadTimeout } +// context returns the context for the current connection. +// If the context timeout is enabled, it returns the original context. +// Otherwise, it returns a new background context. +func (c *baseClient) context(ctx context.Context) context.Context { + if c.opt.ContextTimeoutEnabled { + return ctx + } + return context.Background() +} + // Close closes the client, releasing any open resources. // // It is rare to Close a Client, as the Client is meant to be @@ -639,13 +722,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) return nil } -func (c *baseClient) context(ctx context.Context) context.Context { - if c.opt.ContextTimeoutEnabled { - return ctx - } - return context.Background() -} - //------------------------------------------------------------------------------ // Client is a Redis client representing a pool of zero or more underlying connections. @@ -656,7 +732,6 @@ func (c *baseClient) context(ctx context.Context) context.Context { type Client struct { *baseClient cmdable - hooksMixin } // NewClient returns a client to the Redis Server specified by Options. @@ -692,7 +767,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client { } func (c *Client) Conn() *Conn { - return newConn(c.opt, pool.NewStickyConnPool(c.connPool)) + return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.hooksMixin) } // Do create a Cmd from the args and processes the cmd. @@ -825,14 +900,17 @@ type Conn struct { baseClient cmdable statefulCmdable - hooksMixin } -func newConn(opt *Options, connPool pool.Pooler) *Conn { +// newConn is a helper func to create a new Conn instance. +// the Conn instance is not thread-safe and should not be shared between goroutines. +// the parentHooks will be cloned, no need to clone before passing it. +func newConn(opt *Options, connPool pool.Pooler, parentHooks hooksMixin) *Conn { c := Conn{ baseClient: baseClient{ - opt: opt, - connPool: connPool, + opt: opt, + connPool: connPool, + hooksMixin: parentHooks.clone(), }, } diff --git a/redis_test.go b/redis_test.go index 7d9bf1cef..089973e01 100644 --- a/redis_test.go +++ b/redis_test.go @@ -14,6 +14,7 @@ import ( . "github.com/bsm/gomega" "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/auth" ) type redisHookError struct{} @@ -727,3 +728,171 @@ var _ = Describe("Dialer connection timeouts", func() { Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay)) }) }) + +var _ = Describe("Credentials Provider Priority", func() { + var client *redis.Client + var opt *redis.Options + var recorder *commandRecorder + + BeforeEach(func() { + recorder = newCommandRecorder(10) + }) + + AfterEach(func() { + if client != nil { + Expect(client.Close()).NotTo(HaveOccurred()) + } + }) + + It("should use streaming provider when available", func() { + streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass") + ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass") + providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass") + + opt = &redis.Options{ + Username: "field_user", + Password: "field_pass", + CredentialsProvider: func() (string, string) { + username, password := providerCreds.BasicAuth() + return username, password + }, + CredentialsProviderContext: func(ctx context.Context) (string, string, error) { + username, password := ctxCreds.BasicAuth() + return username, password, nil + }, + StreamingCredentialsProvider: &mockStreamingProvider{ + credentials: streamingCreds, + updates: make(chan auth.Credentials, 1), + }, + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue()) + }) + + It("should use context provider when streaming provider is not available", func() { + ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass") + providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass") + + opt = &redis.Options{ + Username: "field_user", + Password: "field_pass", + CredentialsProvider: func() (string, string) { + username, password := providerCreds.BasicAuth() + return username, password + }, + CredentialsProviderContext: func(ctx context.Context) (string, string, error) { + username, password := ctxCreds.BasicAuth() + return username, password, nil + }, + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue()) + }) + + It("should use regular provider when streaming and context providers are not available", func() { + providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass") + + opt = &redis.Options{ + Username: "field_user", + Password: "field_pass", + CredentialsProvider: func() (string, string) { + username, password := providerCreds.BasicAuth() + return username, password + }, + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH provider_user")).To(BeTrue()) + }) + + It("should use username/password fields when no providers are set", func() { + opt = &redis.Options{ + Username: "field_user", + Password: "field_pass", + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH field_user")).To(BeTrue()) + }) + + It("should use empty credentials when nothing is set", func() { + opt = &redis.Options{} + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // no pass, ok + Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred()) + Expect(recorder.Contains("AUTH")).To(BeFalse()) + }) + + It("should handle credential updates from streaming provider", func() { + initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass") + updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass") + + opt = &redis.Options{ + StreamingCredentialsProvider: &mockStreamingProvider{ + credentials: initialCreds, + updates: make(chan auth.Credentials, 1), + }, + } + + client = redis.NewClient(opt) + client.AddHook(recorder.Hook()) + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH initial_user")).To(BeTrue()) + + // Update credentials + opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds + // wrongpass + Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + Expect(recorder.Contains("AUTH updated_user")).To(BeTrue()) + }) +}) + +type mockStreamingProvider struct { + credentials auth.Credentials + err error + updates chan auth.Credentials +} + +func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) { + if m.err != nil { + return nil, nil, m.err + } + + // Send initial credentials + listener.OnNext(m.credentials) + + // Start goroutine to handle updates + go func() { + for creds := range m.updates { + listener.OnNext(creds) + } + }() + + return m.credentials, func() (err error) { + defer func() { + if r := recover(); r != nil { + // this is just a mock: + // allow multiple closes from multiple listeners + } + }() + close(m.updates) + return + }, nil +} diff --git a/ring_test.go b/ring_test.go index cfd545c17..599f6888a 100644 --- a/ring_test.go +++ b/ring_test.go @@ -357,13 +357,17 @@ var _ = Describe("Redis Ring", func() { ring.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) + // skip the connection initialization + if cmds[0].Name() == "hello" || cmds[0].Name() == "client" { + return nil + } + Expect(len(cmds)).To(BeNumerically(">", 0)) Expect(cmds[0].String()).To(Equal("ping: ")) stack = append(stack, "ring.BeforeProcessPipeline") err := hook(ctx, cmds) - Expect(cmds).To(HaveLen(1)) + Expect(len(cmds)).To(BeNumerically(">", 0)) Expect(cmds[0].String()).To(Equal("ping: PONG")) stack = append(stack, "ring.AfterProcessPipeline") @@ -376,13 +380,17 @@ var _ = Describe("Redis Ring", func() { shard.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) + // skip the connection initialization + if cmds[0].Name() == "hello" || cmds[0].Name() == "client" { + return nil + } + Expect(len(cmds)).To(BeNumerically(">", 0)) Expect(cmds[0].String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcessPipeline") err := hook(ctx, cmds) - Expect(cmds).To(HaveLen(1)) + Expect(len(cmds)).To(BeNumerically(">", 0)) Expect(cmds[0].String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcessPipeline") @@ -416,14 +424,18 @@ var _ = Describe("Redis Ring", func() { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { defer GinkgoRecover() + // skip the connection initialization + if cmds[0].Name() == "hello" || cmds[0].Name() == "client" { + return nil + } - Expect(cmds).To(HaveLen(3)) + Expect(len(cmds)).To(BeNumerically(">=", 3)) Expect(cmds[1].String()).To(Equal("ping: ")) stack = append(stack, "ring.BeforeProcessPipeline") err := hook(ctx, cmds) - Expect(cmds).To(HaveLen(3)) + Expect(len(cmds)).To(BeNumerically(">=", 3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) stack = append(stack, "ring.AfterProcessPipeline") @@ -437,14 +449,18 @@ var _ = Describe("Redis Ring", func() { processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { defer GinkgoRecover() + // skip the connection initialization + if cmds[0].Name() == "hello" || cmds[0].Name() == "client" { + return nil + } - Expect(cmds).To(HaveLen(3)) + Expect(len(cmds)).To(BeNumerically(">=", 3)) Expect(cmds[1].String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcessPipeline") err := hook(ctx, cmds) - Expect(cmds).To(HaveLen(3)) + Expect(len(cmds)).To(BeNumerically(">=", 3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcessPipeline") diff --git a/sentinel.go b/sentinel.go index f5b9a52d1..a708dc982 100644 --- a/sentinel.go +++ b/sentinel.go @@ -258,7 +258,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { connPool = newConnPool(opt, rdb.dialHook) rdb.connPool = connPool - rdb.onClose = failover.Close + rdb.onClose = rdb.wrappedOnClose(failover.Close) failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) { @@ -309,7 +309,6 @@ func masterReplicaDialer( // SentinelClient is a client for a Redis Sentinel. type SentinelClient struct { *baseClient - hooksMixin } func NewSentinelClient(opt *Options) *SentinelClient { diff --git a/tx.go b/tx.go index 039eaf351..0daa222e3 100644 --- a/tx.go +++ b/tx.go @@ -19,16 +19,15 @@ type Tx struct { baseClient cmdable statefulCmdable - hooksMixin } func (c *Client) newTx() *Tx { tx := Tx{ baseClient: baseClient{ - opt: c.opt, - connPool: pool.NewStickyConnPool(c.connPool), + opt: c.opt, + connPool: pool.NewStickyConnPool(c.connPool), + hooksMixin: c.hooksMixin.clone(), }, - hooksMixin: c.hooksMixin.clone(), } tx.init() return &tx