Skip to content

Commit

Permalink
feat: add OwnershipSynchronizer to abstract consumer migration
Browse files Browse the repository at this point in the history
  • Loading branch information
okg-cxf committed Aug 31, 2024
1 parent 82ca33f commit e9cfc31
Showing 1 changed file with 153 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,25 +144,25 @@ protected static void cancelCommandOnEndpointClose(RedisCommand<?, ?, ?> cmd) {

private final boolean debugEnabled = logger.isDebugEnabled();

protected final CompletableFuture<Void> closeFuture = new CompletableFuture<>();
private final CompletableFuture<Void> closeFuture = new CompletableFuture<>();

private String logPrefix;

private boolean autoFlushCommands = true;

private boolean inActivation = false;

protected @Nullable ConnectionWatchdog connectionWatchdog;
private @Nullable ConnectionWatchdog connectionWatchdog;

private ConnectionFacade connectionFacade;

private final String cachedEndpointId;

protected final UnboundedOfferFirstQueue<Object> taskQueue;
private final UnboundedOfferFirstQueue<Object> taskQueue;

private final boolean canFire;
private final OwnershipSynchronizer taskQueueOwnerSync;

private volatile EventExecutor lastEventExecutor;
private final boolean canFire;

private volatile Throwable connectionError;

Expand All @@ -172,8 +172,6 @@ protected static void cancelCommandOnEndpointClose(RedisCommand<?, ?, ?> cmd) {

private final int batchSize;

private final boolean usesMpscQueue;

/**
* Create a new {@link AutoBatchFlushEndpoint}.
*
Expand All @@ -197,13 +195,14 @@ protected DefaultAutoBatchFlushEndpoint(ClientOptions clientOptions, ClientResou
this.rejectCommandsWhileDisconnected = isRejectCommand(clientOptions);
long endpointId = ENDPOINT_COUNTER.incrementAndGet();
this.cachedEndpointId = "0x" + Long.toHexString(endpointId);
this.usesMpscQueue = clientOptions.getAutoBatchFlushOptions().usesMpscQueue();
this.taskQueue = usesMpscQueue ? new JcToolsUnboundedMpscOfferFirstQueue<>() : new ConcurrentLinkedOfferFirstQueue<>();
this.taskQueue = clientOptions.getAutoBatchFlushOptions().usesMpscQueue() ? new JcToolsUnboundedMpscOfferFirstQueue<>()
: new ConcurrentLinkedOfferFirstQueue<>();
this.canFire = false;
this.callbackOnClose = callbackOnClose;
this.writeSpinCount = clientOptions.getAutoBatchFlushOptions().getWriteSpinCount();
this.batchSize = clientOptions.getAutoBatchFlushOptions().getBatchSize();
this.lastEventExecutor = clientResources.eventExecutorGroup().next();
this.taskQueueOwnerSync = new OwnershipSynchronizer(clientResources.eventExecutorGroup().next(),
Thread.currentThread().getName(), true/* allows to be preempted by first event loop thread */);
}

@Override
Expand Down Expand Up @@ -324,7 +323,8 @@ public void notifyChannelActive(Channel channel) {
return;
}

this.lastEventExecutor = channel.eventLoop();
this.taskQueueOwnerSync.preempt(channel.eventLoop(), Thread.currentThread().getName(),
false /* disallow preempt until reached quiescent point, see onEndpointQuiescence() */);
this.connectionError = null;
this.inProtectMode = false;
this.logPrefix = null;
Expand Down Expand Up @@ -379,7 +379,7 @@ public void notifyReconnectFailed(Throwable t) {
return;
}

syncAfterTerminated(() -> {
taskQueueOwnerSync.execute(() -> {
if (isClosed()) {
onEndpointClosed();
} else {
Expand Down Expand Up @@ -474,10 +474,10 @@ public void flushCommands() {
final ContextualChannel chan = this.channel;
switch (chan.context.initialState) {
case ENDPOINT_CLOSED:
syncAfterTerminated(this::onEndpointClosed);
taskQueueOwnerSync.execute(this::onEndpointClosed);
return;
case RECONNECT_FAILED:
syncAfterTerminated(() -> {
taskQueueOwnerSync.execute(() -> {
if (isClosed()) {
onEndpointClosed();
} else {
Expand Down Expand Up @@ -563,7 +563,6 @@ public void disconnect() {
*/
@Override
public void reset() {

if (debugEnabled) {
logger.debug("{} reset()", logPrefix());
}
Expand All @@ -572,10 +571,7 @@ public void reset() {
if (chan.context.initialState.isConnected()) {
chan.pipeline().fireUserEventTriggered(new ConnectionEvents.Reset());
}
if (!usesMpscQueue) {
cancelCommands("reset");
}
// Otherwise, unsafe to call cancelBufferedCommands() here.
taskQueueOwnerSync.execute(() -> cancelCommands("reset"));
}

private void resetInternal() {
Expand All @@ -587,7 +583,6 @@ private void resetInternal() {
if (chan.context.initialState.isConnected()) {
chan.pipeline().fireUserEventTriggered(new ConnectionEvents.Reset());
}
LettuceAssert.assertState(lastEventExecutor.inEventLoop(), "must be called in lastEventLoop thread");
cancelCommands("resetInternal");
}

Expand All @@ -596,10 +591,8 @@ private void resetInternal() {
*/
@Override
public void initialState() {
if (!usesMpscQueue) {
cancelCommands("initialState");
}
// Otherwise, unsafe to call cancelBufferedCommands() here.
taskQueueOwnerSync.execute(() -> cancelCommands("initialState"));

ContextualChannel currentChannel = this.channel;
if (currentChannel.context.initialState.isConnected()) {
ChannelFuture close = currentChannel.close();
Expand Down Expand Up @@ -637,8 +630,6 @@ public String getId() {
}

private void scheduleSendJobOnConnected(final ContextualChannel chan) {
LettuceAssert.assertState(chan.eventLoop().inEventLoop(), "must be called in event loop thread");

// Schedule directly
loopSend(chan, false);
}
Expand Down Expand Up @@ -758,7 +749,6 @@ private int pollBatch(final AutoBatchFlushEndPointContext autoBatchFlushEndPoint
private void trySetEndpointQuiescence(ContextualChannel chan) {
final EventLoop eventLoop = chan.eventLoop();
LettuceAssert.isTrue(eventLoop.inEventLoop(), "unexpected: not in event loop");
LettuceAssert.isTrue(eventLoop == lastEventExecutor, "unexpected: lastEventLoop not match");

final ConnectionContext connectionContext = chan.context;
final @Nullable ConnectionContext.CloseStatus closeStatus = connectionContext.getCloseStatus();
Expand Down Expand Up @@ -827,6 +817,8 @@ private void onWontReconnect(@Nonnull final ConnectionContext.CloseStatus closeS
}

private void onEndpointQuiescence() {
taskQueueOwnerSync.done(1); // allows preemption

if (channel.context.initialState == ConnectionContext.State.ENDPOINT_CLOSED) {
return;
}
Expand Down Expand Up @@ -864,7 +856,7 @@ private final void onEndpointClosed(Queue<RedisCommand<?, ?, ?>>... queues) {
fulfillCommands("endpoint closed", callbackOnClose, queues);
}

private final void onReconnectFailed() {
private void onReconnectFailed() {
fulfillCommands("reconnect failed", cmd -> cmd.completeExceptionally(getFailedToReconnectReason()));
}

Expand Down Expand Up @@ -996,7 +988,7 @@ private Throwable validateWrite(ContextualChannel chan, int commands, boolean is
private void onUnexpectedState(String caller, ConnectionContext.State exp) {
final ConnectionContext.State actual = this.channel.context.initialState;
logger.error("{}[{}][unexpected] : unexpected state: exp '{}' got '{}'", logPrefix(), caller, exp, actual);
syncAfterTerminated(
taskQueueOwnerSync.execute(
() -> cancelCommands(String.format("%s: state not match: expect '%s', got '%s'", caller, exp, actual)));
}

Expand All @@ -1017,23 +1009,6 @@ private ChannelFuture channelWrite(Channel channel, RedisCommand<?, ?, ?> comman
return channel.write(command);
}

/*
* Synchronize after the endpoint is terminated. This is to ensure only one thread can access the task queue after endpoint
* is terminated (state is RECONNECT_FAILED/ENDPOINT_CLOSED)
*/
private void syncAfterTerminated(Runnable runnable) {
final EventExecutor localLastEventExecutor = lastEventExecutor;
if (localLastEventExecutor.inEventLoop()) {
runnable.run();
} else {
localLastEventExecutor.execute(() -> {
runnable.run();
LettuceAssert.isTrue(lastEventExecutor == localLastEventExecutor,
"lastEventLoop must not be changed after terminated");
});
}
}

private enum Reliability {
AT_MOST_ONCE, AT_LEAST_ONCE
}
Expand Down Expand Up @@ -1103,7 +1078,7 @@ public void operationComplete(Future<Void> future) {

final Throwable retryableErr = checkSendResult(future);
if (retryableErr != null && autoBatchFlushEndPointContext.addRetryableFailedToSendCommand(cmd, retryableErr)) {
// Close connection on first transient write failure
// Close connection on first transient write failure.
internalCloseConnectionIfNeeded(retryableErr);
}

Expand Down Expand Up @@ -1163,6 +1138,7 @@ private void internalCloseConnectionIfNeeded(Throwable reason) {
return;
}

// It is really rare (maybe impossible?) that the connection is still active.
logger.error(
"[internalCloseConnectionIfNeeded][interesting][{}] close the connection due to write error, reason: '{}'",
endpoint.logPrefix(), reason.getMessage(), reason);
Expand All @@ -1184,4 +1160,134 @@ private void recycle() {

}

public static class OwnershipSynchronizer {

private static class Owner {

private final EventExecutor thread;

private final String threadName;

// if positive, no other thread can preempt the ownership.
private final int semaphore;

public Owner(EventExecutor thread, String threadName, int semaphore) {
LettuceAssert.assertState(semaphore >= 0, () -> String.format("negative semaphore: %d", semaphore));
this.thread = thread;
this.threadName = threadName;
this.semaphore = semaphore;
}

public boolean isCurrentThread() {
return thread.inEventLoop();
}

public Owner toAdd(int n) {
return new Owner(thread, threadName, semaphore + n);
}

public Owner toDone(int n) {
return new Owner(thread, threadName, semaphore - n);
}

public boolean isDone() {
return semaphore == 0;
}

}

private static final AtomicReferenceFieldUpdater<OwnershipSynchronizer, Owner> OWNER = AtomicReferenceFieldUpdater
.newUpdater(OwnershipSynchronizer.class, Owner.class, "owner");

private volatile Owner owner;

public OwnershipSynchronizer(EventExecutor thread, String threadName, boolean allowsPreemptByOtherThreads) {
this.owner = new Owner(thread, threadName, allowsPreemptByOtherThreads ? 0 : 1);
}

private void assertIsOwnerThreadAndPreemptPrevented(Owner cur) {
LettuceAssert.assertState(isOwnerCurrentThreadAndPreemptPrevented(cur),
() -> "[executeInOwnerWithPreemptPrevention] unexpected: "
+ (cur.isCurrentThread() ? "preemption not prevented" : "owner is not this thread"));
}

private boolean isOwnerCurrentThreadAndPreemptPrevented(Owner owner) {
return owner.isCurrentThread() && !owner.isDone();
}

public void preempt(EventExecutor thread, String threadName, boolean allowsPreemptByOtherThreads) {
Owner cur;
Owner newOwner = null;
while (true) {
cur = this.owner;
if (cur.thread == thread) {
if (allowsPreemptByOtherThreads) {
return;
}
if (OWNER.compareAndSet(this, cur, cur.toAdd(1))) { // prevent preempt
return;
}
continue;
}

if (!cur.isDone()) {
// unsafe to preempt
continue;
}

if (newOwner == null) {
newOwner = new Owner(thread, threadName, allowsPreemptByOtherThreads ? 0 : 1);
}
if (OWNER.compareAndSet(this, cur, newOwner)) {
logger.debug("ownership preempted by a new thread [{}]", threadName);
// established happens-before with done()
return;
}
}
}

public void done(int n) {
Owner cur;
do {
cur = this.owner;
assertIsOwnerThreadAndPreemptPrevented(cur);
} while (!OWNER.compareAndSet(this, cur, cur.toDone(n)));
// create happens-before with preempt()
}

/**
* Safely run a task and release its memory effect to next owner thread.
*
* @param task task to run
*/
public void execute(Runnable task) {
Owner cur;
do {
cur = this.owner;
if (isOwnerCurrentThreadAndPreemptPrevented(cur)) {
// already prevented preemption, safe to skip expensive add/done calls
executeInOwnerWithPreemptPrevention(task, false);
return;
}
} while (!OWNER.compareAndSet(this, cur, cur.toAdd(1)));

if (cur.isCurrentThread()) {
executeInOwnerWithPreemptPrevention(task, true);
} else {
cur.thread.execute(() -> executeInOwnerWithPreemptPrevention(task, true));
}
}

private void executeInOwnerWithPreemptPrevention(Runnable task, boolean added) {
try {
task.run();
} finally {
if (added) {
done(1);
}
}
}

}

}

0 comments on commit e9cfc31

Please sign in to comment.