Skip to content

Commit d483898

Browse files
committed
Support multi with re-auth
Defer the re-auth operation in case there is ongoing Tx Tx in lettuce need to be externaly syncronised when used in multi threaded env. Since re-auth happens from different thread we need to make sure it does not happen while there is ongoing transaction.
1 parent 6f46022 commit d483898

9 files changed

+183
-120
lines changed

src/main/java/io/lettuce/core/ConnectionBuilder.java

-11
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,6 @@ public void apply(RedisURI redisURI) {
113113
bootstrap.attr(REDIS_URI, redisURI.toString());
114114
}
115115

116-
public void registerAuthenticationHandler(RedisCredentialsProvider credentialsProvider, ConnectionState state,
117-
Boolean isPubSubConnection) {
118-
LettuceAssert.assertState(endpoint != null, "Endpoint must be set");
119-
LettuceAssert.assertState(connection != null, "Connection must be set");
120-
LettuceAssert.assertState(clientResources != null, "ClientResources must be set");
121-
122-
RedisAuthenticationHandler authenticationHandler = new RedisAuthenticationHandler(connection, credentialsProvider,
123-
state, clientResources.eventBus(), isPubSubConnection);
124-
endpoint.registerAuthenticationHandler(authenticationHandler);
125-
}
126-
127116
protected List<ChannelHandler> buildHandlers() {
128117

129118
LettuceAssert.assertState(channelGroup != null, "ChannelGroup must be set");

src/main/java/io/lettuce/core/RedisAuthenticationHandler.java

+7-51
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,18 @@ public class RedisAuthenticationHandler {
4646

4747
private static final InternalLogger log = InternalLoggerFactory.getInstance(RedisAuthenticationHandler.class);
4848

49-
private final RedisChannelHandler<?, ?> connection;
50-
51-
private final ConnectionState state;
52-
53-
private final RedisCommandBuilder<String, String> commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8);
49+
private final StatefulRedisConnectionImpl<?, ?> connection;
5450

5551
private final RedisCredentialsProvider credentialsProvider;
5652

5753
private final AtomicReference<Disposable> credentialsSubscription = new AtomicReference<>();
5854

59-
private final EventBus eventBus;
60-
6155
private final Boolean isPubSubConnection;
6256

63-
public RedisAuthenticationHandler(RedisChannelHandler<?, ?> connection, RedisCredentialsProvider credentialsProvider,
64-
ConnectionState state, EventBus eventBus, Boolean isPubSubConnection) {
57+
public RedisAuthenticationHandler(StatefulRedisConnectionImpl<?, ?> connection,
58+
RedisCredentialsProvider credentialsProvider, Boolean isPubSubConnection) {
6559
this.connection = connection;
66-
this.state = state;
6760
this.credentialsProvider = credentialsProvider;
68-
this.eventBus = eventBus;
6961
this.isPubSubConnection = isPubSubConnection;
7062
}
7163

@@ -125,55 +117,19 @@ protected void onError(Throwable e) {
125117
* @param credentials the new credentials
126118
*/
127119
protected void reauthenticate(RedisCredentials credentials) {
128-
CharSequence password = CharBuffer.wrap(credentials.getPassword());
129-
130-
AsyncCommand<String, String, String> authCmd;
131-
if (credentials.hasUsername()) {
132-
authCmd = new AsyncCommand<>(commandBuilder.auth(credentials.getUsername(), password));
133-
} else {
134-
authCmd = new AsyncCommand<>(commandBuilder.auth(password));
135-
}
136-
137-
dispatchAuth(authCmd).thenRun(() -> {
138-
publishReauthEvent();
139-
log.info("Re-authentication succeeded for endpoint {}.", getEpid());
140-
}).exceptionally(throwable -> {
141-
publishReauthFailedEvent(throwable);
142-
log.error("Re-authentication failed for endpoint {}.", getEpid(), throwable);
143-
return null;
144-
});
145-
}
146-
147-
private AsyncCommand<?, ?, ?> dispatchAuth(RedisCommand<?, ?, ?> authCommand) {
148-
AsyncCommand asyncCommand = new AsyncCommand<>(authCommand);
149-
RedisCommand<?, ?, ?> dispatched = connection.dispatch(asyncCommand);
150-
if (dispatched instanceof AsyncCommand) {
151-
return (AsyncCommand<?, ?, ?>) dispatched;
152-
}
153-
return asyncCommand;
154-
}
155-
156-
private void publishReauthEvent() {
157-
eventBus.publish(new ReauthenticateEvent(getEpid()));
158-
}
159-
160-
private void publishReauthFailedEvent(Throwable throwable) {
161-
eventBus.publish(new ReauthenticateFailedEvent(getEpid(), throwable));
120+
connection.setCredentials(credentials);
162121
}
163122

164123
protected boolean isSupportedConnection() {
165-
if (isPubSubConnection && ProtocolVersion.RESP2 == state.getNegotiatedProtocolVersion()) {
124+
if (isPubSubConnection && ProtocolVersion.RESP2 == connection.getConnectionState().getNegotiatedProtocolVersion()) {
166125
log.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection.");
167126
return false;
168127
}
169128
return true;
170129
}
171130

172-
private String getEpid() {
173-
if (connection.getChannelWriter() instanceof Endpoint) {
174-
return ((Endpoint) connection.getChannelWriter()).getId();
175-
}
176-
return "unknown";
131+
private void publishReauthFailedEvent(Throwable throwable) {
132+
connection.getResources().eventBus().publish(new ReauthenticateFailedEvent(throwable));
177133
}
178134

179135
public static boolean isSupported(ClientOptions clientOptions) {

src/main/java/io/lettuce/core/RedisClient.java

+4-6
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,10 @@ private <K, V, S> ConnectionFuture<S> connectStatefulAsync(StatefulRedisConnecti
317317
ConnectionState state = connection.getConnectionState();
318318
state.apply(redisURI);
319319
state.setDb(redisURI.getDatabase());
320-
320+
if (RedisAuthenticationHandler.isSupported(getOptions())) {
321+
connection.setAuthenticationHandler(
322+
new RedisAuthenticationHandler(connection, redisURI.getCredentialsProvider(), isPubSub));
323+
}
321324
connectionBuilder.connection(connection);
322325
connectionBuilder.clientOptions(getOptions());
323326
connectionBuilder.clientResources(getResources());
@@ -326,11 +329,6 @@ private <K, V, S> ConnectionFuture<S> connectStatefulAsync(StatefulRedisConnecti
326329
connectionBuilder(getSocketAddressSupplier(redisURI), connectionBuilder, connection.getConnectionEvents(), redisURI);
327330
connectionBuilder.connectionInitializer(createHandshake(state));
328331

329-
if (RedisAuthenticationHandler.isSupported(getOptions())) {
330-
connectionBuilder.registerAuthenticationHandler(redisURI.getCredentialsProvider(), connection.getConnectionState(),
331-
isPubSub);
332-
}
333-
334332
ConnectionFuture<RedisChannelHandler<K, V>> future = initializeChannelAsync(connectionBuilder);
335333

336334
return future.thenApply(channelHandler -> (S) connection);

src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java

+115-4
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222
import static io.lettuce.core.ClientOptions.DEFAULT_JSON_PARSER;
2323
import static io.lettuce.core.protocol.CommandType.*;
2424

25+
import java.nio.CharBuffer;
2526
import java.time.Duration;
2627
import java.util.ArrayList;
2728
import java.util.Collection;
2829
import java.util.List;
30+
import java.util.concurrent.atomic.AtomicBoolean;
31+
import java.util.concurrent.atomic.AtomicReference;
32+
import java.util.concurrent.locks.ReentrantLock;
2933
import java.util.function.Consumer;
3034
import java.util.stream.Collectors;
3135

@@ -37,10 +41,14 @@
3741
import io.lettuce.core.cluster.api.sync.RedisClusterCommands;
3842
import io.lettuce.core.codec.RedisCodec;
3943
import io.lettuce.core.codec.StringCodec;
44+
import io.lettuce.core.event.connection.ReauthenticateEvent;
45+
import io.lettuce.core.event.connection.ReauthenticateFailedEvent;
4046
import io.lettuce.core.json.JsonParser;
4147
import io.lettuce.core.output.MultiOutput;
4248
import io.lettuce.core.output.StatusOutput;
4349
import io.lettuce.core.protocol.*;
50+
import io.netty.util.internal.logging.InternalLogger;
51+
import io.netty.util.internal.logging.InternalLoggerFactory;
4452
import reactor.core.publisher.Mono;
4553

4654
/**
@@ -55,6 +63,8 @@
5563
*/
5664
public class StatefulRedisConnectionImpl<K, V> extends RedisChannelHandler<K, V> implements StatefulRedisConnection<K, V> {
5765

66+
private static final InternalLogger logger = InternalLoggerFactory.getInstance(StatefulRedisConnectionImpl.class);
67+
5868
protected final RedisCodec<K, V> codec;
5969

6070
protected final RedisCommands<K, V> sync;
@@ -71,6 +81,14 @@ public class StatefulRedisConnectionImpl<K, V> extends RedisChannelHandler<K, V>
7181

7282
protected MultiOutput<K, V> multi;
7383

84+
private RedisAuthenticationHandler authHandler;
85+
86+
private AtomicReference<RedisCredentials> credentialsRef = new AtomicReference<>();
87+
88+
private final ReentrantLock reAuthSafety = new ReentrantLock();
89+
90+
private AtomicBoolean inTransaction = new AtomicBoolean(false);
91+
7492
/**
7593
* Initialize a new connection.
7694
*
@@ -181,20 +199,38 @@ public boolean isMulti() {
181199
public <T> RedisCommand<K, V, T> dispatch(RedisCommand<K, V, T> command) {
182200

183201
RedisCommand<K, V, T> toSend = preProcessCommand(command);
184-
return super.dispatch(toSend);
202+
RedisCommand<K, V, T> result = super.dispatch(toSend);
203+
if (toSend.getType() == EXEC || toSend.getType() == DISCARD) {
204+
inTransaction.set(false);
205+
setCredentials(credentialsRef.getAndSet(null));
206+
}
207+
208+
return result;
185209
}
186210

187211
@Override
188212
public Collection<RedisCommand<K, V, ?>> dispatch(Collection<? extends RedisCommand<K, V, ?>> commands) {
189213

190214
List<RedisCommand<K, V, ?>> sentCommands = new ArrayList<>(commands.size());
191215

192-
commands.forEach(o -> {
216+
boolean transactionComplete = false;
217+
for (RedisCommand<K, V, ?> o : commands) {
193218
RedisCommand<K, V, ?> command = preProcessCommand(o);
194219
sentCommands.add(command);
195-
});
220+
if (command.getType() == EXEC) {
221+
transactionComplete = true;
222+
}
223+
if (command.getType() == MULTI || command.getType() == DISCARD) {
224+
transactionComplete = false;
225+
}
226+
}
196227

197-
return super.dispatch(sentCommands);
228+
Collection<RedisCommand<K, V, ?>> result = super.dispatch(sentCommands);
229+
if (transactionComplete) {
230+
inTransaction.set(false);
231+
setCredentials(credentialsRef.getAndSet(null));
232+
}
233+
return result;
198234
}
199235

200236
// TODO [tihomir.mateev] Refactor to include as part of the Command interface
@@ -273,12 +309,20 @@ protected <T> RedisCommand<K, V, T> preProcessCommand(RedisCommand<K, V, T> comm
273309

274310
if (commandType.equals(MULTI.name())) {
275311

312+
reAuthSafety.lock();
313+
try {
314+
inTransaction.set(true);
315+
} finally {
316+
reAuthSafety.unlock();
317+
}
276318
multi = (multi == null ? new MultiOutput<>(codec) : multi);
277319

278320
if (command instanceof CompleteableCommand) {
279321
((CompleteableCommand<?>) command).onComplete((ignored, e) -> {
280322
if (e != null) {
281323
multi = null;
324+
inTransaction.set(false);
325+
setCredentials(credentialsRef.getAndSet(null));
282326
}
283327
});
284328
}
@@ -318,11 +362,78 @@ public ConnectionState getConnectionState() {
318362
@Override
319363
public void activated() {
320364
super.activated();
365+
if (authHandler != null) {
366+
authHandler.subscribe();
367+
}
321368
}
322369

323370
@Override
324371
public void deactivated() {
372+
if (authHandler != null) {
373+
authHandler.unsubscribe();
374+
}
325375
super.deactivated();
326376
}
327377

378+
public void setAuthenticationHandler(RedisAuthenticationHandler handler) {
379+
if (authHandler != null) {
380+
authHandler.unsubscribe();
381+
}
382+
authHandler = handler;
383+
if (isOpen()) {
384+
authHandler.subscribe();
385+
}
386+
}
387+
388+
public void setCredentials(RedisCredentials credentials) {
389+
if (credentials == null) {
390+
return;
391+
}
392+
reAuthSafety.lock();
393+
try {
394+
credentialsRef.set(credentials);
395+
if (!inTransaction.get()) {
396+
dispatchAuthCommand(credentialsRef.getAndSet(null));
397+
}
398+
} finally {
399+
reAuthSafety.unlock();
400+
}
401+
}
402+
403+
private void dispatchAuthCommand(RedisCredentials credentials) {
404+
if (credentials == null) {
405+
return;
406+
}
407+
408+
RedisFuture<String> auth;
409+
if (credentials.getUsername() != null) {
410+
auth = async().auth(credentials.getUsername(), CharBuffer.wrap(credentials.getPassword()));
411+
} else {
412+
auth = async().auth(CharBuffer.wrap(credentials.getPassword()));
413+
}
414+
auth.thenRun(() -> {
415+
publishReauthEvent();
416+
logger.info("Re-authentication succeeded for endpoint {}.", getEpid());
417+
}).exceptionally(throwable -> {
418+
publishReauthFailedEvent(throwable);
419+
logger.error("Re-authentication failed for endpoint {}.", getEpid(), throwable);
420+
return null;
421+
});
422+
}
423+
424+
private void publishReauthEvent() {
425+
getResources().eventBus().publish(new ReauthenticateEvent(getEpid()));
426+
}
427+
428+
private void publishReauthFailedEvent(Throwable throwable) {
429+
getResources().eventBus().publish(new ReauthenticateFailedEvent(getEpid(), throwable));
430+
}
431+
432+
private String getEpid() {
433+
if (getChannelWriter() instanceof Endpoint) {
434+
return ((Endpoint) getChannelWriter()).getId();
435+
}
436+
return "";
437+
}
438+
328439
}

src/main/java/io/lettuce/core/api/StatefulRedisConnection.java

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.lettuce.core.api;
22

3+
import io.lettuce.core.RedisCredentials;
34
import io.lettuce.core.api.async.RedisAsyncCommands;
45
import io.lettuce.core.api.push.PushListener;
56
import io.lettuce.core.api.reactive.RedisReactiveCommands;

0 commit comments

Comments
 (0)