Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UNDERTOW-1701: Fix race condition in GracefulShutdownHandler #72

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.function.LongUnaryOperator;

import io.undertow.UndertowMessages;
import io.undertow.server.ExchangeCompletionListener;
Expand All @@ -41,25 +42,49 @@
*/
public class GracefulShutdownHandler implements HttpHandler {

private volatile boolean shutdown = false;
private static final long SHUTDOWN_MASK = 1L << 63;
private static final long ACTIVE_COUNT_MASK = (1L << 63) - 1;

private static final LongUnaryOperator incrementActive = current -> {
long incrementedActiveCount = activeCount(current) + 1;
return incrementedActiveCount | (current & ~ACTIVE_COUNT_MASK);
};

private static final LongUnaryOperator incrementActiveAndShutdown =
incrementActive.andThen(current -> current | SHUTDOWN_MASK);

private static final LongUnaryOperator decrementActive = current -> {
long decrementedActiveCount = activeCount(current) - 1;
return decrementedActiveCount | (current & ~ACTIVE_COUNT_MASK);
};

private final GracefulShutdownListener listener = new GracefulShutdownListener();
private final List<ShutdownListener> shutdownListeners = new ArrayList<>();

private final Object lock = new Object();

private volatile long activeRequests = 0;
private static final AtomicLongFieldUpdater<GracefulShutdownHandler> activeRequestsUpdater = AtomicLongFieldUpdater.newUpdater(GracefulShutdownHandler.class, "activeRequests");
private volatile long state = 0;
private static final AtomicLongFieldUpdater<GracefulShutdownHandler> stateUpdater =
AtomicLongFieldUpdater.newUpdater(GracefulShutdownHandler.class, "state");

private final HttpHandler next;

public GracefulShutdownHandler(HttpHandler next) {
this.next = next;
}

private static boolean isShutdown(long state) {
return (state & SHUTDOWN_MASK) != 0;
}

private static long activeCount(long state) {
return state & ACTIVE_COUNT_MASK;
}

@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
activeRequestsUpdater.incrementAndGet(this);
if (shutdown) {
long snapshot = stateUpdater.updateAndGet(this, incrementActive);
if (isShutdown(snapshot)) {
decrementRequests();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case could be optimized to avoid CAS increment+decrement after the shutdown flag has been set by extracting the LongUnaryOperator into a loop until CAS succeeds similarly to the jboss-threads view executor pattern: https://github.com/jbossas/jboss-threads/blob/a203e13fd739902f850b55a14dc9c77e46a361a6/src/main/java/org/jboss/threads/EnhancedViewExecutor.java#L250-L280

Unclear if it's helpful to optimize this case though, I'm most interested in the hot common case in which the server is actively accepting requests.

e.g.

for (;;) {
    long snapshot = state;
    if (isShutdown(snapshot)) {
        exchange.setStatusCode(StatusCodes.SERVICE_UNAVAILABLE);
        exchange.endExchange();
        return;
    }
    if (stateUpdater.compareAndSet(this, snapshot, snapshot + 1)) {
        exchange.addExchangeCompleteListener(listener);
        next.handleRequest(exchange);
        break;
   }
}

exchange.setStatusCode(StatusCodes.SERVICE_UNAVAILABLE);
exchange.endExchange();
Expand All @@ -71,15 +96,14 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {


public void shutdown() {
activeRequestsUpdater.incrementAndGet(this);
//the request count is never zero when shutdown is set to true
shutdown = true;
stateUpdater.updateAndGet(this, incrementActiveAndShutdown);
decrementRequests();
}

public void start() {
synchronized (lock) {
shutdown = false;
stateUpdater.updateAndGet(this, current -> current & ACTIVE_COUNT_MASK);
for (ShutdownListener listener : shutdownListeners) {
listener.shutdown(false);
}
Expand All @@ -88,23 +112,24 @@ public void start() {
}

private void shutdownComplete() {
assert Thread.holdsLock(lock);
lock.notifyAll();
for (ShutdownListener listener : shutdownListeners) {
listener.shutdown(true);
synchronized (lock) {
lock.notifyAll();
for (ShutdownListener listener : shutdownListeners) {
listener.shutdown(true);
}
shutdownListeners.clear();
}
shutdownListeners.clear();
}

/**
* Waits for the handler to shutdown.
*/
public void awaitShutdown() throws InterruptedException {
synchronized (lock) {
if (!shutdown) {
if (!isShutdown(stateUpdater.get(this))) {
throw UndertowMessages.MESSAGES.handlerNotShutdown();
}
while (activeRequestsUpdater.get(this) > 0) {
while (activeCount(stateUpdater.get(this)) > 0) {
lock.wait();
}
}
Expand All @@ -118,18 +143,16 @@ public void awaitShutdown() throws InterruptedException {
*/
public boolean awaitShutdown(long millis) throws InterruptedException {
synchronized (lock) {
if (!shutdown) {
if (!isShutdown(stateUpdater.get(this))) {
throw UndertowMessages.MESSAGES.handlerNotShutdown();
}
long end = System.currentTimeMillis() + millis;
int count = (int) activeRequestsUpdater.get(this);
while (count != 0) {
while (activeCount(stateUpdater.get(this)) != 0) {
long left = end - System.currentTimeMillis();
if (left <= 0) {
return false;
}
lock.wait(left);
count = (int) activeRequestsUpdater.get(this);
}
return true;
}
Expand All @@ -143,10 +166,10 @@ public boolean awaitShutdown(long millis) throws InterruptedException {
*/
public void addShutdownListener(final ShutdownListener shutdownListener) {
synchronized (lock) {
if (!shutdown) {
if (!isShutdown(stateUpdater.get(this))) {
throw UndertowMessages.MESSAGES.handlerNotShutdown();
}
long count = activeRequestsUpdater.get(this);
long count = activeCount(stateUpdater.get(this));
if (count == 0) {
shutdownListener.shutdown(true);
} else {
Expand All @@ -155,20 +178,11 @@ public void addShutdownListener(final ShutdownListener shutdownListener) {
}
}


private void decrementRequests() {
if (shutdown) {
//we don't read the request count until after checking the shutdown variable
//otherwise we could read the request count as zero, a new request could state, and then we shutdown
//see https://issues.jboss.org/browse/UNDERTOW-1099
long active = activeRequestsUpdater.decrementAndGet(this);
synchronized (lock) {
if (active == 0) {
shutdownComplete();
}
}
} else {
activeRequestsUpdater.decrementAndGet(this);
long snapshot = stateUpdater.updateAndGet(this, decrementActive);
// Shutdown has completed when the activeCount portion is zero, and shutdown is set.
if (snapshot == SHUTDOWN_MASK) {
shutdownComplete();
}
}

Expand Down
Loading