From 89f81e0c61b958f4a5257f927530420b79747913 Mon Sep 17 00:00:00 2001 From: Sergei Kalashnikov Date: Wed, 21 Nov 2018 19:47:17 +0300 Subject: [PATCH] fix race between close and reconnect Fixed issue with concurrent close and reconnect. Closes #93 --- .../org/tarantool/SocketChannelProvider.java | 2 +- .../org/tarantool/TarantoolClientImpl.java | 198 ++++++++++--- .../AbstractTarantoolConnectorIT.java | 47 +++- .../java/org/tarantool/ClientReconnectIT.java | 260 +++++++++++++++++- .../tarantool/TestSocketChannelProvider.java | 3 +- 5 files changed, 455 insertions(+), 55 deletions(-) diff --git a/src/main/java/org/tarantool/SocketChannelProvider.java b/src/main/java/org/tarantool/SocketChannelProvider.java index 2c7b405b..09112dec 100644 --- a/src/main/java/org/tarantool/SocketChannelProvider.java +++ b/src/main/java/org/tarantool/SocketChannelProvider.java @@ -7,7 +7,7 @@ public interface SocketChannelProvider { /** * Provides socket channel to init restore connection. * You could change hosts on fail and sleep between retries in this method - * @param retryNumber number of current retry. -1 on initial connect. + * @param retryNumber number of current retry. Reset after successful connect. * @param lastError the last error occurs when reconnecting * @return the result of SocketChannel open(SocketAddress remote) call */ diff --git a/src/main/java/org/tarantool/TarantoolClientImpl.java b/src/main/java/org/tarantool/TarantoolClientImpl.java index 67c176c7..1c0e3a21 100644 --- a/src/main/java/org/tarantool/TarantoolClientImpl.java +++ b/src/main/java/org/tarantool/TarantoolClientImpl.java @@ -15,6 +15,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.ReentrantLock; @@ -29,7 +30,6 @@ public class TarantoolClientImpl extends TarantoolBase>> implemen */ protected SocketChannelProvider socketProvider; protected volatile Exception thumbstone; - protected volatile CountDownLatch alive; protected Map>> futures; protected AtomicInteger wait = new AtomicInteger(); @@ -54,17 +54,18 @@ public class TarantoolClientImpl extends TarantoolBase>> implemen * Inner */ protected TarantoolClientStats stats; - protected CountDownLatch stopIO; + protected StateHelper state = new StateHelper(StateHelper.RECONNECT); protected Thread reader; protected Thread writer; - protected Thread connector = new Thread(new Runnable() { @Override public void run() { while (!Thread.currentThread().isInterrupted()) { - reconnect(0, thumbstone); - LockSupport.park(); + if (state.compareAndSet(StateHelper.RECONNECT, 0)) { + reconnect(0, thumbstone); + } + LockSupport.park(state); } } }); @@ -74,7 +75,6 @@ public TarantoolClientImpl(SocketChannelProvider socketProvider, TarantoolClient this.thumbstone = NOT_INIT_EXCEPTION; this.config = config; this.initialRequestSize = config.defaultRequestSize; - this.alive = new CountDownLatch(1); this.socketProvider = socketProvider; this.stats = new TarantoolClientStats(); this.futures = new ConcurrentHashMap>>(config.predictedFutures); @@ -92,25 +92,36 @@ public TarantoolClientImpl(SocketChannelProvider socketProvider, TarantoolClient connector.start(); try { if (!waitAlive(config.initTimeoutMillis, TimeUnit.MILLISECONDS)) { - close(); - throw new CommunicationException(config.initTimeoutMillis+"ms is exceeded when waiting for client initialization. You could configure init timeout in TarantoolConfig"); + CommunicationException e = new CommunicationException(config.initTimeoutMillis + + "ms is exceeded when waiting for client initialization. " + + "You could configure init timeout in TarantoolConfig"); + + close(e); + throw e; } } catch (InterruptedException e) { - close(); + close(e); throw new IllegalStateException(e); } } protected void reconnect(int retry, Throwable lastError) { SocketChannel channel; - while (!Thread.interrupted()) { - channel = socketProvider.get(retry++, lastError == NOT_INIT_EXCEPTION ? null : lastError); + while (!Thread.currentThread().isInterrupted()) { + try { + channel = socketProvider.get(retry++, lastError == NOT_INIT_EXCEPTION ? null : lastError); + } catch (Exception e) { + close(e); + return; + } try { connect(channel); return; } catch (Exception e) { closeChannel(channel); lastError = e; + if (e instanceof InterruptedException) + Thread.currentThread().interrupt(); } } } @@ -122,8 +133,11 @@ protected void connect(final SocketChannel channel) throws Exception { is.readFully(bytes); String firstLine = new String(bytes); if (!firstLine.startsWith("Tarantool")) { - close(); - throw new CommunicationException("Welcome message should starts with tarantool but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet")); + CommunicationException e = new CommunicationException("Welcome message should starts with tarantool " + + "but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet")); + + close(e); + throw e; } is.readFully(bytes); this.salt = new String(bytes); @@ -157,32 +171,43 @@ protected void connect(final SocketChannel channel) throws Exception { } finally { bufferLock.unlock(); } - startThreads(channel.socket().getRemoteSocketAddress().toString()); this.thumbstone = null; - alive.countDown(); + startThreads(channel.socket().getRemoteSocketAddress().toString()); } - protected void startThreads(String threadName) throws IOException, InterruptedException { + protected void startThreads(String threadName) throws InterruptedException { final CountDownLatch init = new CountDownLatch(2); - stopIO = new CountDownLatch(2); reader = new Thread(new Runnable() { @Override public void run() { init.countDown(); - readThread(); - stopIO.countDown(); + if (state.acquire(StateHelper.READING)) { + try { + readThread(); + } finally { + state.release(StateHelper.READING); + if (state.compareAndSet(0, StateHelper.RECONNECT)) + LockSupport.unpark(connector); + } + } } }); writer = new Thread(new Runnable() { @Override public void run() { init.countDown(); - writeThread(); - stopIO.countDown(); + if (state.acquire(StateHelper.WRITING)) { + try { + writeThread(); + } finally { + state.release(StateHelper.WRITING); + if (state.compareAndSet(0, StateHelper.RECONNECT)) + LockSupport.unpark(connector); + } + } } }); - configureThreads(threadName); reader.start(); writer.start(); @@ -217,13 +242,11 @@ public Future> exec(Code code, Object... args) { return q; } - protected synchronized void die(String message, Exception cause) { if (thumbstone != null) { return; } this.thumbstone = new CommunicationException(message, cause); - this.alive = new CountDownLatch(1); while (!futures.isEmpty()) { Iterator>>> iterator = futures.entrySet().iterator(); while (iterator.hasNext()) { @@ -244,9 +267,6 @@ protected synchronized void die(String message, Exception cause) { bufferLock.unlock(); } stopIO(); - if (connector.getState() == Thread.State.WAITING) { - LockSupport.unpark(connector); - } } @@ -426,10 +446,20 @@ protected void writeFully(SocketChannel channel, ByteBuffer buffer) throws IOExc @Override public void close() { - if (connector != null) { + close(new Exception("Connection is closed.")); + try { + state.awaitState(StateHelper.CLOSED); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + } + } + + protected void close(Exception e) { + if (state.close()) { connector.interrupt(); + + die(e.getMessage(), e); } - stopIO(); } protected void stopIO() { @@ -454,28 +484,21 @@ protected void stopIO() { } } closeChannel(channel); - try { - stopIO.await(); - } catch (InterruptedException ignored) { - - } } @Override public boolean isAlive() { - return thumbstone == null; + return state.getState() == StateHelper.ALIVE && thumbstone == null; } @Override public void waitAlive() throws InterruptedException { - while(!isAlive()) { - alive.await(); - } + state.awaitState(StateHelper.ALIVE); } @Override public boolean waitAlive(long timeout, TimeUnit unit) throws InterruptedException { - return alive.await(timeout, unit); + return state.awaitState(StateHelper.ALIVE, timeout, unit); } @Override @@ -545,4 +568,101 @@ public TarantoolClientStats getStats() { return stats; } + /** + * Manages state changes. + */ + protected final class StateHelper { + static final int READING = 1; + static final int WRITING = 2; + static final int ALIVE = READING | WRITING; + static final int RECONNECT = 4; + static final int CLOSED = 8; + + private final AtomicInteger state; + + private final AtomicReference nextAliveLatch = + new AtomicReference(new CountDownLatch(1)); + + private final CountDownLatch closedLatch = new CountDownLatch(1); + + protected StateHelper(int state) { + this.state = new AtomicInteger(state); + } + + protected int getState() { + return state.get(); + } + + protected boolean close() { + for (;;) { + int st = getState(); + if ((st & CLOSED) == CLOSED) + return false; + if (compareAndSet(st, (st & ~RECONNECT) | CLOSED)) + return true; + } + } + + protected boolean acquire(int mask) { + for (;;) { + int st = getState(); + if ((st & CLOSED) == CLOSED) + return false; + + if ((st & mask) != 0) + throw new IllegalStateException("State is already " + mask); + + if (compareAndSet(st, st | mask)) + return true; + } + } + + protected void release(int mask) { + for (;;) { + int st = getState(); + if (compareAndSet(st, st & ~mask)) + return; + } + } + + protected boolean compareAndSet(int expect, int update) { + if (!state.compareAndSet(expect, update)) { + return false; + } + + if (update == ALIVE) { + CountDownLatch latch = nextAliveLatch.getAndSet(new CountDownLatch(1)); + latch.countDown(); + } else if (update == CLOSED) { + closedLatch.countDown(); + } + return true; + } + + protected void awaitState(int state) throws InterruptedException { + CountDownLatch latch = getStateLatch(state); + if (latch != null) { + latch.await(); + } + } + + protected boolean awaitState(int state, long timeout, TimeUnit timeUnit) throws InterruptedException { + CountDownLatch latch = getStateLatch(state); + return (latch == null) || latch.await(timeout, timeUnit); + } + + private CountDownLatch getStateLatch(int state) { + if (state == CLOSED) { + return closedLatch; + } + if (state == ALIVE) { + if (getState() == CLOSED) { + throw new IllegalStateException("State is CLOSED."); + } + CountDownLatch latch = nextAliveLatch.get(); + return (getState() == ALIVE) ? null : latch; + } + return null; + } + } } diff --git a/src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java b/src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java index 38772915..b608c843 100644 --- a/src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java +++ b/src/test/java/org/tarantool/AbstractTarantoolConnectorIT.java @@ -2,11 +2,17 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; +import org.opentest4j.AssertionFailedError; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Socket; import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -93,10 +99,13 @@ public static void setupEnv() { @AfterAll public static void cleanupEnv() { - executeLua(cleanScript); + try { + executeLua(cleanScript); - console.close(); - control.stop("jdk-testing"); + console.close(); + } finally { + control.stop("jdk-testing"); + } } private static void executeLua(String[] exprs) { @@ -116,13 +125,16 @@ protected void checkTupleResult(Object res, List tuple) { } protected TarantoolClient makeClient() { + return new TarantoolClientImpl(socketChannelProvider, makeClientConfig()); + } + + protected TarantoolClientConfig makeClientConfig() { TarantoolClientConfig config = new TarantoolClientConfig(); config.username = username; config.password = password; - config.initTimeoutMillis = 1000; + config.initTimeoutMillis = RESTART_TIMEOUT; config.sharedBufferSize = 128; - - return new TarantoolClientImpl(socketChannelProvider, config); + return config; } protected static TarantoolConsole openConsole() { @@ -184,4 +196,27 @@ protected void stopTarantool(String instance) { protected void startTarantool(String instance) { control.start(instance); } + + /** + * Asserts that execution of the Runnable completes before the given timeout is exceeded. + * + * @param timeout Timeout in ms. + * @param message Error message. + * @param r Runnable. + */ + protected void assertTimeoutPreemptively(int timeout, String message, Runnable r) { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + + Future future = executorService.submit(r); + + try { + future.get(timeout, TimeUnit.MILLISECONDS); + } catch (TimeoutException ex) { + throw new AssertionFailedError(message); + } catch (Exception ex) { + throw new RuntimeException(ex); + } finally { + executorService.shutdownNow(); + } + } } diff --git a/src/test/java/org/tarantool/ClientReconnectIT.java b/src/test/java/org/tarantool/ClientReconnectIT.java index 46b02b01..2472bf05 100644 --- a/src/test/java/org/tarantool/ClientReconnectIT.java +++ b/src/test/java/org/tarantool/ClientReconnectIT.java @@ -1,30 +1,48 @@ package org.tarantool; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import java.nio.channels.SocketChannel; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.concurrent.locks.LockSupport; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; public class ClientReconnectIT extends AbstractTarantoolConnectorIT { private static final String INSTANCE_NAME = "jdk-testing"; private TarantoolClient client; - @BeforeEach - public void setup() { - client = makeClient(); - } - @AfterEach public void tearDown() { - client.close(); + if (client != null) { + assertTimeoutPreemptively(RESTART_TIMEOUT, "Close is stuck.", new Runnable() { + @Override + public void run() { + client.close(); + } + }); + } + } + @AfterAll + public static void tearDownEnv() { // Re-open console for cleanupEnv() to work. console.close(); console = openConsole(); @@ -32,6 +50,8 @@ public void tearDown() { @Test public void testReconnect() throws Exception { + client = makeClient(); + client.syncOps().ping(); stopTarantool(INSTANCE_NAME); @@ -56,4 +76,230 @@ public void execute() { client.syncOps().ping(); } + + /** + * Spurious return from LockSupport.park() must not lead to reconnect. + * The implementation must check some invariant to tell a spurious + * return from the intended one. + */ + @Test + public void testSpuriousReturnFromPark() { + final CountDownLatch latch = new CountDownLatch(2); + SocketChannelProvider provider = new SocketChannelProvider() { + @Override + public SocketChannel get(int retryNumber, Throwable lastError) { + if (lastError == null) { + latch.countDown(); + } + return socketChannelProvider.get(retryNumber, lastError); + } + }; + + client = new TarantoolClientImpl(provider, makeClientConfig()); + client.syncOps().ping(); + + // The park() will return inside connector thread. + LockSupport.unpark(((TarantoolClientImpl)client).connector); + + // Wait on latch as a proof that reconnect did not happen. + // In case of a failure, latch will reach 0 before timeout occurs. + try { + assertFalse(latch.await(TIMEOUT, TimeUnit.MILLISECONDS)); + } catch (InterruptedException e) { + fail(); + } + } + + /** + * When the client is closed, all outstanding operations must fail. + * Otherwise, synchronous wait on such operations will block forever. + */ + @Test + public void testCloseWhileOperationsAreInProgress() { + client = new TarantoolClientImpl(socketChannelProvider, makeClientConfig()) { + @Override + protected void write(Code code, Long syncId, Long schemaId, Object... args) { + // Skip write. + } + }; + + final Future> res = client.asyncOps().select(SPACE_ID, PK_INDEX_ID, Collections.singletonList(1), + 0, 1, Iterator.EQ); + + client.close(); + + ExecutionException e = assertThrows(ExecutionException.class, new Executable() { + @Override + public void execute() throws Throwable { + res.get(); + } + }); + assertEquals("Connection is closed.", e.getCause().getMessage()); + } + + /** + * When the reconnection happen, the outstanding operations must fail. + * Otherwise, synchronous wait on such operations will block forever. + */ + @Test + public void testReconnectWhileOperationsAreInProgress() { + final AtomicBoolean writeEnabled = new AtomicBoolean(false); + client = new TarantoolClientImpl(socketChannelProvider, makeClientConfig()) { + @Override + protected void write(Code code, Long syncId, Long schemaId, Object... args) throws Exception { + if (writeEnabled.get()) { + super.write(code, syncId, schemaId, args); + } + } + }; + + final Future> mustFail = client.asyncOps().select(SPACE_ID, PK_INDEX_ID, Collections.singletonList(1), + 0, 1, Iterator.EQ); + + stopTarantool(INSTANCE_NAME); + + assertThrows(ExecutionException.class, new Executable() { + @Override + public void execute() throws Throwable { + mustFail.get(); + } + }); + + startTarantool(INSTANCE_NAME); + + writeEnabled.set(true); + + try { + client.waitAlive(RESTART_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + fail(); + } + + Future> res = client.asyncOps().select(SPACE_ID, PK_INDEX_ID, Collections.singletonList(1), + 0, 1, Iterator.EQ); + + try { + res.get(TIMEOUT, TimeUnit.MILLISECONDS); + } catch (Exception e) { + fail(e); + } + } + + @Test + public void testConcurrentCloseAndReconnect() { + final CountDownLatch latch = new CountDownLatch(2); + client = new TarantoolClientImpl(socketChannelProvider, makeClientConfig()) { + @Override + protected void connect(final SocketChannel channel) throws Exception { + latch.countDown(); + super.connect(channel); + } + }; + + stopTarantool(INSTANCE_NAME); + startTarantool(INSTANCE_NAME); + + try { + assertTrue(latch.await(RESTART_TIMEOUT, TimeUnit.MILLISECONDS)); + } catch (InterruptedException e) { + fail(e); + } + + assertTimeoutPreemptively(RESTART_TIMEOUT, "Close is stuck.", new Runnable() { + @Override + public void run() { + client.close(); + } + }); + } + + /** + * Test concurrent operations, reconnects and close. + * Expected situation is nothing gets stuck. + */ + @Test + public void testLongParallelCloseReconnects() { + int numThreads = 4; + int numClients = 4; + int timeBudget = 30*1000; + + final AtomicReferenceArray clients = + new AtomicReferenceArray(numClients); + + for (int idx = 0; idx < clients.length(); idx++) { + clients.set(idx, makeClient()); + } + + final Random rnd = new Random(); + final AtomicInteger cnt = new AtomicInteger(); + + // Start background threads that do operations. + final CountDownLatch latch = new CountDownLatch(numThreads); + final long deadline = System.currentTimeMillis() + timeBudget; + Thread[] threads = new Thread[numThreads]; + for (int idx = 0; idx < threads.length; idx++) { + threads[idx] = new Thread(new Runnable() { + @Override + public void run() { + while (!Thread.currentThread().isInterrupted() && + deadline > System.currentTimeMillis()) { + + int idx = rnd.nextInt(clients.length()); + + try { + TarantoolClient cli = clients.get(idx); + + int maxOps = rnd.nextInt(100); + for (int n = 0; n < maxOps; n++) { + cli.syncOps().ping(); + } + + cli.close(); + + TarantoolClient next = makeClient(); + if (!clients.compareAndSet(idx, cli, next)) { + next.close(); + } + cnt.incrementAndGet(); + } catch (Exception ignored) { + // No-op. + } + } + latch.countDown(); + } + }); + } + + for (int idx = 0; idx < threads.length; idx++) { + threads[idx].start(); + } + + // Restart tarantool several times in the foreground. + while (deadline > System.currentTimeMillis()) { + stopTarantool(INSTANCE_NAME); + startTarantool(INSTANCE_NAME); + try { + Thread.sleep(RESTART_TIMEOUT * 2); + } catch (InterruptedException e) { + fail(e); + } + if (deadline > System.currentTimeMillis()) { + System.out.println("" + (deadline - System.currentTimeMillis())/1000 + "s remains."); + } + } + + // Wait for all threads to finish. + try { + assertTrue(latch.await(RESTART_TIMEOUT, TimeUnit.MILLISECONDS)); + } catch (InterruptedException e) { + fail(e); + } + + // Close outstanding clients. + for (int idx = 0; idx < clients.length(); idx++) { + clients.get(idx).close(); + } + + assertTrue(cnt.get() > threads.length); + } } diff --git a/src/test/java/org/tarantool/TestSocketChannelProvider.java b/src/test/java/org/tarantool/TestSocketChannelProvider.java index 924097a6..469bc77c 100644 --- a/src/test/java/org/tarantool/TestSocketChannelProvider.java +++ b/src/test/java/org/tarantool/TestSocketChannelProvider.java @@ -34,7 +34,6 @@ public SocketChannel get(int retryNumber, Throwable lastError) { } } } - throw new RuntimeException("Test failure due to invalid environment. " + - "Timeout connecting to " + host + ":" + port); + throw new RuntimeException(new InterruptedException()); } }