Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 44628c5

Browse files
committedApr 22, 2025··
add tests
1 parent e0c224d commit 44628c5

File tree

7 files changed

+619
-28
lines changed

7 files changed

+619
-28
lines changed
 

‎auth/auth_test.go

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
package auth
2+
3+
import (
4+
"errors"
5+
"testing"
6+
"time"
7+
)
8+
9+
type mockStreamingProvider struct {
10+
credentials Credentials
11+
err error
12+
updates chan Credentials
13+
}
14+
15+
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
16+
return &mockStreamingProvider{
17+
credentials: initialCreds,
18+
updates: make(chan Credentials, 10),
19+
}
20+
}
21+
22+
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
23+
if m.err != nil {
24+
return nil, nil, m.err
25+
}
26+
27+
// Send initial credentials
28+
listener.OnNext(m.credentials)
29+
30+
// Start goroutine to handle updates
31+
go func() {
32+
for creds := range m.updates {
33+
listener.OnNext(creds)
34+
}
35+
}()
36+
37+
return m.credentials, func() error {
38+
close(m.updates)
39+
return nil
40+
}, nil
41+
}
42+
43+
func TestStreamingCredentialsProvider(t *testing.T) {
44+
t.Run("successful subscription", func(t *testing.T) {
45+
initialCreds := NewBasicCredentials("user1", "pass1")
46+
provider := newMockStreamingProvider(initialCreds)
47+
48+
var receivedCreds []Credentials
49+
var receivedErrors []error
50+
51+
listener := NewReAuthCredentialsListener(
52+
func(creds Credentials) error {
53+
receivedCreds = append(receivedCreds, creds)
54+
return nil
55+
},
56+
func(err error) {
57+
receivedErrors = append(receivedErrors, err)
58+
},
59+
)
60+
61+
creds, cancel, err := provider.Subscribe(listener)
62+
if err != nil {
63+
t.Fatalf("unexpected error: %v", err)
64+
}
65+
if cancel == nil {
66+
t.Fatal("expected cancel function to be non-nil")
67+
}
68+
if creds != initialCreds {
69+
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
70+
}
71+
if len(receivedCreds) != 1 {
72+
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
73+
}
74+
if receivedCreds[0] != initialCreds {
75+
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
76+
}
77+
if len(receivedErrors) != 0 {
78+
t.Fatalf("expected no errors, got %d", len(receivedErrors))
79+
}
80+
81+
// Send an update
82+
newCreds := NewBasicCredentials("user2", "pass2")
83+
provider.updates <- newCreds
84+
85+
// Wait for update to be processed
86+
time.Sleep(100 * time.Millisecond)
87+
if len(receivedCreds) != 2 {
88+
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
89+
}
90+
if receivedCreds[1] != newCreds {
91+
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
92+
}
93+
94+
// Cancel subscription
95+
if err := cancel(); err != nil {
96+
t.Fatalf("unexpected error cancelling subscription: %v", err)
97+
}
98+
})
99+
100+
t.Run("subscription error", func(t *testing.T) {
101+
provider := &mockStreamingProvider{
102+
err: errors.New("subscription failed"),
103+
}
104+
105+
var receivedCreds []Credentials
106+
var receivedErrors []error
107+
108+
listener := NewReAuthCredentialsListener(
109+
func(creds Credentials) error {
110+
receivedCreds = append(receivedCreds, creds)
111+
return nil
112+
},
113+
func(err error) {
114+
receivedErrors = append(receivedErrors, err)
115+
},
116+
)
117+
118+
creds, cancel, err := provider.Subscribe(listener)
119+
if err == nil {
120+
t.Fatal("expected error, got nil")
121+
}
122+
if cancel != nil {
123+
t.Fatal("expected cancel function to be nil")
124+
}
125+
if creds != nil {
126+
t.Fatalf("expected nil credentials, got %v", creds)
127+
}
128+
if len(receivedCreds) != 0 {
129+
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
130+
}
131+
if len(receivedErrors) != 0 {
132+
t.Fatalf("expected no errors, got %d", len(receivedErrors))
133+
}
134+
})
135+
136+
t.Run("re-auth error", func(t *testing.T) {
137+
initialCreds := NewBasicCredentials("user1", "pass1")
138+
provider := newMockStreamingProvider(initialCreds)
139+
140+
reauthErr := errors.New("re-auth failed")
141+
var receivedErrors []error
142+
143+
listener := NewReAuthCredentialsListener(
144+
func(creds Credentials) error {
145+
return reauthErr
146+
},
147+
func(err error) {
148+
receivedErrors = append(receivedErrors, err)
149+
},
150+
)
151+
152+
creds, cancel, err := provider.Subscribe(listener)
153+
if err != nil {
154+
t.Fatalf("unexpected error: %v", err)
155+
}
156+
if cancel == nil {
157+
t.Fatal("expected cancel function to be non-nil")
158+
}
159+
if creds != initialCreds {
160+
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
161+
}
162+
if len(receivedErrors) != 1 {
163+
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
164+
}
165+
if receivedErrors[0] != reauthErr {
166+
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
167+
}
168+
169+
if err := cancel(); err != nil {
170+
t.Fatalf("unexpected error cancelling subscription: %v", err)
171+
}
172+
})
173+
}
174+
175+
func TestBasicCredentials(t *testing.T) {
176+
t.Run("basic auth", func(t *testing.T) {
177+
creds := NewBasicCredentials("user1", "pass1")
178+
username, password := creds.BasicAuth()
179+
if username != "user1" {
180+
t.Fatalf("expected username 'user1', got '%s'", username)
181+
}
182+
if password != "pass1" {
183+
t.Fatalf("expected password 'pass1', got '%s'", password)
184+
}
185+
})
186+
187+
t.Run("raw credentials", func(t *testing.T) {
188+
creds := NewBasicCredentials("user1", "pass1")
189+
raw := creds.RawCredentials()
190+
expected := "user1:pass1"
191+
if raw != expected {
192+
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
193+
}
194+
})
195+
196+
t.Run("empty username", func(t *testing.T) {
197+
creds := NewBasicCredentials("", "pass1")
198+
username, password := creds.BasicAuth()
199+
if username != "" {
200+
t.Fatalf("expected empty username, got '%s'", username)
201+
}
202+
if password != "pass1" {
203+
t.Fatalf("expected password 'pass1', got '%s'", password)
204+
}
205+
})
206+
}
207+
208+
func TestReAuthCredentialsListener(t *testing.T) {
209+
t.Run("successful re-auth", func(t *testing.T) {
210+
var reAuthCalled bool
211+
var onErrCalled bool
212+
var receivedCreds Credentials
213+
214+
listener := NewReAuthCredentialsListener(
215+
func(creds Credentials) error {
216+
reAuthCalled = true
217+
receivedCreds = creds
218+
return nil
219+
},
220+
func(err error) {
221+
onErrCalled = true
222+
},
223+
)
224+
225+
creds := NewBasicCredentials("user1", "pass1")
226+
listener.OnNext(creds)
227+
228+
if !reAuthCalled {
229+
t.Fatal("expected reAuth to be called")
230+
}
231+
if onErrCalled {
232+
t.Fatal("expected onErr not to be called")
233+
}
234+
if receivedCreds != creds {
235+
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
236+
}
237+
})
238+
239+
t.Run("re-auth error", func(t *testing.T) {
240+
var reAuthCalled bool
241+
var onErrCalled bool
242+
var receivedErr error
243+
expectedErr := errors.New("re-auth failed")
244+
245+
listener := NewReAuthCredentialsListener(
246+
func(creds Credentials) error {
247+
reAuthCalled = true
248+
return expectedErr
249+
},
250+
func(err error) {
251+
onErrCalled = true
252+
receivedErr = err
253+
},
254+
)
255+
256+
creds := NewBasicCredentials("user1", "pass1")
257+
listener.OnNext(creds)
258+
259+
if !reAuthCalled {
260+
t.Fatal("expected reAuth to be called")
261+
}
262+
if !onErrCalled {
263+
t.Fatal("expected onErr to be called")
264+
}
265+
if receivedErr != expectedErr {
266+
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
267+
}
268+
})
269+
270+
t.Run("on error", func(t *testing.T) {
271+
var onErrCalled bool
272+
var receivedErr error
273+
expectedErr := errors.New("provider error")
274+
275+
listener := NewReAuthCredentialsListener(
276+
func(creds Credentials) error {
277+
return nil
278+
},
279+
func(err error) {
280+
onErrCalled = true
281+
receivedErr = err
282+
},
283+
)
284+
285+
listener.OnError(expectedErr)
286+
287+
if !onErrCalled {
288+
t.Fatal("expected onErr to be called")
289+
}
290+
if receivedErr != expectedErr {
291+
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
292+
}
293+
})
294+
295+
t.Run("nil callbacks", func(t *testing.T) {
296+
listener := NewReAuthCredentialsListener(nil, nil)
297+
298+
// Should not panic
299+
listener.OnNext(NewBasicCredentials("user1", "pass1"))
300+
listener.OnError(errors.New("test error"))
301+
})
302+
}

‎command_recorder_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package redis_test
2+
3+
import (
4+
"context"
5+
"strings"
6+
"sync"
7+
8+
"github.com/redis/go-redis/v9"
9+
)
10+
11+
// commandRecorder records the last N commands executed by a Redis client.
12+
type commandRecorder struct {
13+
mu sync.Mutex
14+
commands []string
15+
maxSize int
16+
}
17+
18+
// newCommandRecorder creates a new command recorder with the specified maximum size.
19+
func newCommandRecorder(maxSize int) *commandRecorder {
20+
return &commandRecorder{
21+
commands: make([]string, 0, maxSize),
22+
maxSize: maxSize,
23+
}
24+
}
25+
26+
// Record adds a command to the recorder.
27+
func (r *commandRecorder) Record(cmd string) {
28+
cmd = strings.ToLower(cmd)
29+
r.mu.Lock()
30+
defer r.mu.Unlock()
31+
32+
r.commands = append(r.commands, cmd)
33+
if len(r.commands) > r.maxSize {
34+
r.commands = r.commands[1:]
35+
}
36+
}
37+
38+
// LastCommands returns a copy of the recorded commands.
39+
func (r *commandRecorder) LastCommands() []string {
40+
r.mu.Lock()
41+
defer r.mu.Unlock()
42+
return append([]string(nil), r.commands...)
43+
}
44+
45+
// Contains checks if the recorder contains a specific command.
46+
func (r *commandRecorder) Contains(cmd string) bool {
47+
cmd = strings.ToLower(cmd)
48+
r.mu.Lock()
49+
defer r.mu.Unlock()
50+
for _, c := range r.commands {
51+
if strings.Contains(c, cmd) {
52+
return true
53+
}
54+
}
55+
return false
56+
}
57+
58+
// Hook returns a Redis hook that records commands.
59+
func (r *commandRecorder) Hook() redis.Hook {
60+
return &commandHook{recorder: r}
61+
}
62+
63+
// commandHook implements the redis.Hook interface to record commands.
64+
type commandHook struct {
65+
recorder *commandRecorder
66+
}
67+
68+
func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
69+
return next
70+
}
71+
72+
func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
73+
return func(ctx context.Context, cmd redis.Cmder) error {
74+
h.recorder.Record(cmd.String())
75+
return next(ctx, cmd)
76+
}
77+
}
78+
79+
func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
80+
return func(ctx context.Context, cmds []redis.Cmder) error {
81+
for _, cmd := range cmds {
82+
h.recorder.Record(cmd.String())
83+
}
84+
return next(ctx, cmds)
85+
}
86+
}

‎internal/internal.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"github.com/redis/go-redis/v9/internal/rand"
77
)
88

9+
type ParentHooksMixinKey struct{}
10+
911
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
1012
if retry < 0 {
1113
panic("not reached")

‎osscluster_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ func (s *clusterScenario) newClusterClient(
8989
func (s *clusterScenario) Close() error {
9090
ctx := context.TODO()
9191
for _, master := range s.masters() {
92+
if master == nil {
93+
continue
94+
}
9295
err := master.FlushAll(ctx).Err()
9396
if err != nil {
9497
return err

‎probabilistic_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
298298
})
299299

300300
It("should CFCount", Label("cuckoo", "cfcount"), func() {
301-
err := client.CFAdd(ctx, "testcf1", "item1").Err()
301+
client.CFAdd(ctx, "testcf1", "item1")
302302
cnt, err := client.CFCount(ctx, "testcf1", "item1").Result()
303303
Expect(err).NotTo(HaveOccurred())
304304
Expect(cnt).To(BeEquivalentTo(int64(1)))
@@ -394,15 +394,15 @@ var _ = Describe("Probabilistic commands", Label("probabilistic"), func() {
394394
NoCreate: true,
395395
}
396396

397-
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
397+
_, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
398398
Expect(err).To(HaveOccurred())
399399

400400
args = &redis.CFInsertOptions{
401401
Capacity: 3000,
402402
NoCreate: false,
403403
}
404404

405-
result, err = client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
405+
result, err := client.CFInsert(ctx, "testcf1", args, "item1", "item2", "item3").Result()
406406
Expect(err).NotTo(HaveOccurred())
407407
Expect(len(result)).To(BeEquivalentTo(3))
408408
})

‎redis.go

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"log"
78
"net"
89
"sync"
910
"sync/atomic"
@@ -308,8 +309,15 @@ func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err
308309
// we can get it from the *Conn and remove it from the clients pool.
309310
if err != nil {
310311
if isBadConn(err, false, c.opt.Addr) {
311-
poolCn, _ := cn.connPool.Get(ctx)
312-
c.connPool.Remove(ctx, poolCn, err)
312+
poolCn, getErr := cn.connPool.Get(ctx)
313+
if getErr == nil {
314+
c.connPool.Remove(ctx, poolCn, err)
315+
} else {
316+
// if we can't get the pool connection, we can only close the connection
317+
if err := cn.Close(); err != nil {
318+
log.Printf("failed to close connection: %v", err)
319+
}
320+
}
313321
}
314322
}
315323
}
@@ -344,36 +352,51 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
344352
var err error
345353
cn.Inited = true
346354
connPool := pool.NewSingleConnPool(c.connPool, cn)
347-
conn := newConn(c.opt, connPool)
355+
var parentHooks hooksMixin
356+
pH := ctx.Value(internal.ParentHooksMixinKey{})
357+
switch pH := pH.(type) {
358+
case nil:
359+
parentHooks = hooksMixin{}
360+
case hooksMixin:
361+
parentHooks = pH.clone()
362+
case *hooksMixin:
363+
parentHooks = (*pH).clone()
364+
default:
365+
parentHooks = hooksMixin{}
366+
}
367+
368+
conn := newConn(c.opt, connPool, parentHooks)
348369

349370
protocol := c.opt.Protocol
350371
// By default, use RESP3 in current version.
351372
if protocol < 2 {
352373
protocol = 3
353374
}
354375

355-
var authenticated bool
356-
username, password := c.opt.Username, c.opt.Password
376+
username, password := "", ""
357377
if c.opt.StreamingCredentialsProvider != nil {
358378
credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
359379
Subscribe(c.newReAuthCredentialsListener(ctx, conn))
360380
if err != nil {
361-
return err
381+
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
362382
}
363383
c.onClose = c.wrappedOnClose(cancelCredentialsProvider)
364384
username, password = credentials.BasicAuth()
365385
} else if c.opt.CredentialsProviderContext != nil {
366-
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
367-
return err
386+
username, password, err = c.opt.CredentialsProviderContext(ctx)
387+
if err != nil {
388+
return fmt.Errorf("failed to get credentials from context provider: %w", err)
368389
}
369390
} else if c.opt.CredentialsProvider != nil {
370391
username, password = c.opt.CredentialsProvider()
392+
} else if c.opt.Username != "" || c.opt.Password != "" {
393+
username, password = c.opt.Username, c.opt.Password
371394
}
372395

373396
// for redis-server versions that do not support the HELLO command,
374397
// RESP2 will continue to be used.
375398
if err = conn.Hello(ctx, protocol, username, password, c.opt.ClientName).Err(); err == nil {
376-
authenticated = true
399+
// Authentication successful with HELLO command
377400
} else if !isRedisError(err) {
378401
// When the server responds with the RESP protocol and the result is not a normal
379402
// execution result of the HELLO command, we consider it to be an indication that
@@ -382,15 +405,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
382405
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
383406
// with different error string results for unsupported commands, making it
384407
// difficult to rely on error strings to determine all results.
385-
return err
386-
}
387-
388-
if !authenticated && password != "" {
408+
return fmt.Errorf("failed to initialize connection: %w", err)
409+
} else if password != "" {
410+
// Try legacy AUTH command if HELLO failed
389411
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
390412
if err != nil {
391-
return err
413+
return fmt.Errorf("failed to authenticate: %w", err)
392414
}
393415
}
416+
394417
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
395418
if c.opt.DB > 0 {
396419
pipe.Select(ctx, c.opt.DB)
@@ -407,7 +430,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
407430
return nil
408431
})
409432
if err != nil {
410-
return err
433+
return fmt.Errorf("failed to initialize connection options: %w", err)
411434
}
412435

413436
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
@@ -422,13 +445,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
422445
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
423446
// out of order responses later on.
424447
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
425-
return err
448+
return fmt.Errorf("failed to set client identity: %w", err)
426449
}
427450
}
428451

429452
if c.opt.OnConnect != nil {
430453
return c.opt.OnConnect(ctx, conn)
431454
}
455+
432456
return nil
433457
}
434458

@@ -547,6 +571,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
547571
return c.opt.ReadTimeout
548572
}
549573

574+
// context returns the context for the current connection.
575+
// If the context timeout is enabled, it returns the original context.
576+
// Otherwise, it returns a new background context.
577+
func (c *baseClient) context(ctx context.Context) context.Context {
578+
if c.opt.ContextTimeoutEnabled {
579+
return ctx
580+
}
581+
return context.Background()
582+
}
583+
550584
// Close closes the client, releasing any open resources.
551585
//
552586
// It is rare to Close a Client, as the Client is meant to be
@@ -699,13 +733,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
699733
return nil
700734
}
701735

702-
func (c *baseClient) context(ctx context.Context) context.Context {
703-
if c.opt.ContextTimeoutEnabled {
704-
return ctx
705-
}
706-
return context.Background()
707-
}
708-
709736
//------------------------------------------------------------------------------
710737

711738
// Client is a Redis client representing a pool of zero or more underlying connections.
@@ -752,7 +779,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
752779
}
753780

754781
func (c *Client) Conn() *Conn {
755-
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
782+
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.hooksMixin.clone())
756783
}
757784

758785
// Do create a Cmd from the args and processes the cmd.
@@ -763,6 +790,7 @@ func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
763790
}
764791

765792
func (c *Client) Process(ctx context.Context, cmd Cmder) error {
793+
ctx = context.WithValue(ctx, internal.ParentHooksMixinKey{}, c.hooksMixin)
766794
err := c.processHook(ctx, cmd)
767795
cmd.SetErr(err)
768796
return err
@@ -888,7 +916,7 @@ type Conn struct {
888916
hooksMixin
889917
}
890918

891-
func newConn(opt *Options, connPool pool.Pooler) *Conn {
919+
func newConn(opt *Options, connPool pool.Pooler, parentHooks hooksMixin) *Conn {
892920
c := Conn{
893921
baseClient: baseClient{
894922
opt: opt,
@@ -898,6 +926,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
898926

899927
c.cmdable = c.Process
900928
c.statefulCmdable = c.Process
929+
c.hooksMixin = parentHooks
901930
c.initHooks(hooks{
902931
dial: c.baseClient.dial,
903932
process: c.baseClient.process,

‎redis_test.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
. "github.com/bsm/gomega"
1515

1616
"github.com/redis/go-redis/v9"
17+
"github.com/redis/go-redis/v9/auth"
1718
)
1819

1920
type redisHookError struct{}
@@ -727,3 +728,171 @@ var _ = Describe("Dialer connection timeouts", func() {
727728
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
728729
})
729730
})
731+
732+
var _ = Describe("Credentials Provider Priority", func() {
733+
var client *redis.Client
734+
var opt *redis.Options
735+
var recorder *commandRecorder
736+
737+
BeforeEach(func() {
738+
recorder = newCommandRecorder(10)
739+
})
740+
741+
AfterEach(func() {
742+
if client != nil {
743+
Expect(client.Close()).NotTo(HaveOccurred())
744+
}
745+
})
746+
747+
It("should use streaming provider when available", func() {
748+
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
749+
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
750+
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
751+
752+
opt = &redis.Options{
753+
Username: "field_user",
754+
Password: "field_pass",
755+
CredentialsProvider: func() (string, string) {
756+
username, password := providerCreds.BasicAuth()
757+
return username, password
758+
},
759+
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
760+
username, password := ctxCreds.BasicAuth()
761+
return username, password, nil
762+
},
763+
StreamingCredentialsProvider: &mockStreamingProvider{
764+
credentials: streamingCreds,
765+
updates: make(chan auth.Credentials, 1),
766+
},
767+
}
768+
769+
client = redis.NewClient(opt)
770+
client.AddHook(recorder.Hook())
771+
// wrongpass
772+
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
773+
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
774+
})
775+
776+
It("should use context provider when streaming provider is not available", func() {
777+
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
778+
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
779+
780+
opt = &redis.Options{
781+
Username: "field_user",
782+
Password: "field_pass",
783+
CredentialsProvider: func() (string, string) {
784+
username, password := providerCreds.BasicAuth()
785+
return username, password
786+
},
787+
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
788+
username, password := ctxCreds.BasicAuth()
789+
return username, password, nil
790+
},
791+
}
792+
793+
client = redis.NewClient(opt)
794+
client.AddHook(recorder.Hook())
795+
// wrongpass
796+
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
797+
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
798+
})
799+
800+
It("should use regular provider when streaming and context providers are not available", func() {
801+
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
802+
803+
opt = &redis.Options{
804+
Username: "field_user",
805+
Password: "field_pass",
806+
CredentialsProvider: func() (string, string) {
807+
username, password := providerCreds.BasicAuth()
808+
return username, password
809+
},
810+
}
811+
812+
client = redis.NewClient(opt)
813+
client.AddHook(recorder.Hook())
814+
// wrongpass
815+
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
816+
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
817+
})
818+
819+
It("should use username/password fields when no providers are set", func() {
820+
opt = &redis.Options{
821+
Username: "field_user",
822+
Password: "field_pass",
823+
}
824+
825+
client = redis.NewClient(opt)
826+
client.AddHook(recorder.Hook())
827+
// wrongpass
828+
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
829+
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
830+
})
831+
832+
It("should use empty credentials when nothing is set", func() {
833+
opt = &redis.Options{}
834+
835+
client = redis.NewClient(opt)
836+
client.AddHook(recorder.Hook())
837+
// no pass, ok
838+
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
839+
Expect(recorder.Contains("AUTH")).To(BeFalse())
840+
})
841+
842+
It("should handle credential updates from streaming provider", func() {
843+
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
844+
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
845+
846+
opt = &redis.Options{
847+
StreamingCredentialsProvider: &mockStreamingProvider{
848+
credentials: initialCreds,
849+
updates: make(chan auth.Credentials, 1),
850+
},
851+
}
852+
853+
client = redis.NewClient(opt)
854+
client.AddHook(recorder.Hook())
855+
// wrongpass
856+
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
857+
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
858+
859+
// Update credentials
860+
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
861+
// wrongpass
862+
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
863+
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
864+
})
865+
})
866+
867+
type mockStreamingProvider struct {
868+
credentials auth.Credentials
869+
err error
870+
updates chan auth.Credentials
871+
}
872+
873+
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
874+
if m.err != nil {
875+
return nil, nil, m.err
876+
}
877+
878+
// Send initial credentials
879+
listener.OnNext(m.credentials)
880+
881+
// Start goroutine to handle updates
882+
go func() {
883+
for creds := range m.updates {
884+
listener.OnNext(creds)
885+
}
886+
}()
887+
888+
return m.credentials, func() (err error) {
889+
defer func() {
890+
if r := recover(); r != nil {
891+
// this is just a mock:
892+
// allow multiple closes from multiple listeners
893+
}
894+
}()
895+
close(m.updates)
896+
return
897+
}, nil
898+
}

0 commit comments

Comments
 (0)
Please sign in to comment.