22
22
import static io .lettuce .core .ClientOptions .DEFAULT_JSON_PARSER ;
23
23
import static io .lettuce .core .protocol .CommandType .*;
24
24
25
+ import java .nio .CharBuffer ;
25
26
import java .time .Duration ;
26
27
import java .util .ArrayList ;
27
28
import java .util .Collection ;
28
29
import java .util .List ;
30
+ import java .util .concurrent .atomic .AtomicBoolean ;
31
+ import java .util .concurrent .atomic .AtomicReference ;
32
+ import java .util .concurrent .locks .ReentrantLock ;
29
33
import java .util .function .Consumer ;
30
34
import java .util .stream .Collectors ;
31
35
37
41
import io .lettuce .core .cluster .api .sync .RedisClusterCommands ;
38
42
import io .lettuce .core .codec .RedisCodec ;
39
43
import io .lettuce .core .codec .StringCodec ;
44
+ import io .lettuce .core .event .connection .ReauthenticateEvent ;
45
+ import io .lettuce .core .event .connection .ReauthenticateFailedEvent ;
40
46
import io .lettuce .core .json .JsonParser ;
41
47
import io .lettuce .core .output .MultiOutput ;
42
48
import io .lettuce .core .output .StatusOutput ;
43
49
import io .lettuce .core .protocol .*;
50
+ import io .netty .util .internal .logging .InternalLogger ;
51
+ import io .netty .util .internal .logging .InternalLoggerFactory ;
44
52
import reactor .core .publisher .Mono ;
45
53
46
54
/**
55
63
*/
56
64
public class StatefulRedisConnectionImpl <K , V > extends RedisChannelHandler <K , V > implements StatefulRedisConnection <K , V > {
57
65
66
+ private static final InternalLogger logger = InternalLoggerFactory .getInstance (StatefulRedisConnectionImpl .class );
67
+
58
68
protected final RedisCodec <K , V > codec ;
59
69
60
70
protected final RedisCommands <K , V > sync ;
@@ -71,6 +81,14 @@ public class StatefulRedisConnectionImpl<K, V> extends RedisChannelHandler<K, V>
71
81
72
82
protected MultiOutput <K , V > multi ;
73
83
84
+ private RedisAuthenticationHandler authHandler ;
85
+
86
+ private AtomicReference <RedisCredentials > credentialsRef = new AtomicReference <>();
87
+
88
+ private final ReentrantLock reAuthSafety = new ReentrantLock ();
89
+
90
+ private AtomicBoolean inTransaction = new AtomicBoolean (false );
91
+
74
92
/**
75
93
* Initialize a new connection.
76
94
*
@@ -181,20 +199,38 @@ public boolean isMulti() {
181
199
public <T > RedisCommand <K , V , T > dispatch (RedisCommand <K , V , T > command ) {
182
200
183
201
RedisCommand <K , V , T > toSend = preProcessCommand (command );
184
- return super .dispatch (toSend );
202
+ RedisCommand <K , V , T > result = super .dispatch (toSend );
203
+ if (toSend .getType () == EXEC || toSend .getType () == DISCARD ) {
204
+ inTransaction .set (false );
205
+ setCredentials (credentialsRef .getAndSet (null ));
206
+ }
207
+
208
+ return result ;
185
209
}
186
210
187
211
@ Override
188
212
public Collection <RedisCommand <K , V , ?>> dispatch (Collection <? extends RedisCommand <K , V , ?>> commands ) {
189
213
190
214
List <RedisCommand <K , V , ?>> sentCommands = new ArrayList <>(commands .size ());
191
215
192
- commands .forEach (o -> {
216
+ boolean transactionComplete = false ;
217
+ for (RedisCommand <K , V , ?> o : commands ) {
193
218
RedisCommand <K , V , ?> command = preProcessCommand (o );
194
219
sentCommands .add (command );
195
- });
220
+ if (command .getType () == EXEC ) {
221
+ transactionComplete = true ;
222
+ }
223
+ if (command .getType () == MULTI || command .getType () == DISCARD ) {
224
+ transactionComplete = false ;
225
+ }
226
+ }
196
227
197
- return super .dispatch (sentCommands );
228
+ Collection <RedisCommand <K , V , ?>> result = super .dispatch (sentCommands );
229
+ if (transactionComplete ) {
230
+ inTransaction .set (false );
231
+ setCredentials (credentialsRef .getAndSet (null ));
232
+ }
233
+ return result ;
198
234
}
199
235
200
236
// TODO [tihomir.mateev] Refactor to include as part of the Command interface
@@ -273,12 +309,20 @@ protected <T> RedisCommand<K, V, T> preProcessCommand(RedisCommand<K, V, T> comm
273
309
274
310
if (commandType .equals (MULTI .name ())) {
275
311
312
+ reAuthSafety .lock ();
313
+ try {
314
+ inTransaction .set (true );
315
+ } finally {
316
+ reAuthSafety .unlock ();
317
+ }
276
318
multi = (multi == null ? new MultiOutput <>(codec ) : multi );
277
319
278
320
if (command instanceof CompleteableCommand ) {
279
321
((CompleteableCommand <?>) command ).onComplete ((ignored , e ) -> {
280
322
if (e != null ) {
281
323
multi = null ;
324
+ inTransaction .set (false );
325
+ setCredentials (credentialsRef .getAndSet (null ));
282
326
}
283
327
});
284
328
}
@@ -318,11 +362,78 @@ public ConnectionState getConnectionState() {
318
362
@ Override
319
363
public void activated () {
320
364
super .activated ();
365
+ if (authHandler != null ) {
366
+ authHandler .subscribe ();
367
+ }
321
368
}
322
369
323
370
@ Override
324
371
public void deactivated () {
372
+ if (authHandler != null ) {
373
+ authHandler .unsubscribe ();
374
+ }
325
375
super .deactivated ();
326
376
}
327
377
378
+ public void setAuthenticationHandler (RedisAuthenticationHandler handler ) {
379
+ if (authHandler != null ) {
380
+ authHandler .unsubscribe ();
381
+ }
382
+ authHandler = handler ;
383
+ if (isOpen ()) {
384
+ authHandler .subscribe ();
385
+ }
386
+ }
387
+
388
+ public void setCredentials (RedisCredentials credentials ) {
389
+ if (credentials == null ) {
390
+ return ;
391
+ }
392
+ reAuthSafety .lock ();
393
+ try {
394
+ credentialsRef .set (credentials );
395
+ if (!inTransaction .get ()) {
396
+ dispatchAuthCommand (credentialsRef .getAndSet (null ));
397
+ }
398
+ } finally {
399
+ reAuthSafety .unlock ();
400
+ }
401
+ }
402
+
403
+ private void dispatchAuthCommand (RedisCredentials credentials ) {
404
+ if (credentials == null ) {
405
+ return ;
406
+ }
407
+
408
+ RedisFuture <String > auth ;
409
+ if (credentials .getUsername () != null ) {
410
+ auth = async ().auth (credentials .getUsername (), CharBuffer .wrap (credentials .getPassword ()));
411
+ } else {
412
+ auth = async ().auth (CharBuffer .wrap (credentials .getPassword ()));
413
+ }
414
+ auth .thenRun (() -> {
415
+ publishReauthEvent ();
416
+ logger .info ("Re-authentication succeeded for endpoint {}." , getEpid ());
417
+ }).exceptionally (throwable -> {
418
+ publishReauthFailedEvent (throwable );
419
+ logger .error ("Re-authentication failed for endpoint {}." , getEpid (), throwable );
420
+ return null ;
421
+ });
422
+ }
423
+
424
+ private void publishReauthEvent () {
425
+ getResources ().eventBus ().publish (new ReauthenticateEvent (getEpid ()));
426
+ }
427
+
428
+ private void publishReauthFailedEvent (Throwable throwable ) {
429
+ getResources ().eventBus ().publish (new ReauthenticateFailedEvent (getEpid (), throwable ));
430
+ }
431
+
432
+ private String getEpid () {
433
+ if (getChannelWriter () instanceof Endpoint ) {
434
+ return ((Endpoint ) getChannelWriter ()).getId ();
435
+ }
436
+ return "" ;
437
+ }
438
+
328
439
}
0 commit comments