Skip to content

fix race between close and reconnect #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/org/tarantool/SocketChannelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public interface SocketChannelProvider {
/**
* Provides socket channel to init restore connection.
* You could change hosts on fail and sleep between retries in this method
* @param retryNumber number of current retry. -1 on initial connect.
* @param retryNumber number of current retry. Reset after successful connect.
* @param lastError the last error occurs when reconnecting
* @return the result of SocketChannel open(SocketAddress remote) call
*/
Expand Down
198 changes: 159 additions & 39 deletions src/main/java/org/tarantool/TarantoolClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
Expand All @@ -29,7 +30,6 @@ public class TarantoolClientImpl extends TarantoolBase<Future<List<?>>> implemen
*/
protected SocketChannelProvider socketProvider;
protected volatile Exception thumbstone;
protected volatile CountDownLatch alive;

protected Map<Long, FutureImpl<List<?>>> futures;
protected AtomicInteger wait = new AtomicInteger();
Expand All @@ -54,17 +54,18 @@ public class TarantoolClientImpl extends TarantoolBase<Future<List<?>>> implemen
* Inner
*/
protected TarantoolClientStats stats;
protected CountDownLatch stopIO;
protected StateHelper state = new StateHelper(StateHelper.RECONNECT);
protected Thread reader;
protected Thread writer;


protected Thread connector = new Thread(new Runnable() {
@Override
public void run() {
while (!Thread.currentThread().isInterrupted()) {
reconnect(0, thumbstone);
LockSupport.park();
if (state.compareAndSet(StateHelper.RECONNECT, 0)) {
reconnect(0, thumbstone);
}
LockSupport.park(state);
}
}
});
Expand All @@ -74,7 +75,6 @@ public TarantoolClientImpl(SocketChannelProvider socketProvider, TarantoolClient
this.thumbstone = NOT_INIT_EXCEPTION;
this.config = config;
this.initialRequestSize = config.defaultRequestSize;
this.alive = new CountDownLatch(1);
this.socketProvider = socketProvider;
this.stats = new TarantoolClientStats();
this.futures = new ConcurrentHashMap<Long, FutureImpl<List<?>>>(config.predictedFutures);
Expand All @@ -92,25 +92,36 @@ public TarantoolClientImpl(SocketChannelProvider socketProvider, TarantoolClient
connector.start();
try {
if (!waitAlive(config.initTimeoutMillis, TimeUnit.MILLISECONDS)) {
close();
throw new CommunicationException(config.initTimeoutMillis+"ms is exceeded when waiting for client initialization. You could configure init timeout in TarantoolConfig");
CommunicationException e = new CommunicationException(config.initTimeoutMillis +
"ms is exceeded when waiting for client initialization. " +
"You could configure init timeout in TarantoolConfig");

close(e);
throw e;
}
} catch (InterruptedException e) {
close();
close(e);
throw new IllegalStateException(e);
}
}

protected void reconnect(int retry, Throwable lastError) {
SocketChannel channel;
while (!Thread.interrupted()) {
channel = socketProvider.get(retry++, lastError == NOT_INIT_EXCEPTION ? null : lastError);
while (!Thread.currentThread().isInterrupted()) {
try {
channel = socketProvider.get(retry++, lastError == NOT_INIT_EXCEPTION ? null : lastError);
} catch (Exception e) {
close(e);
return;
}
try {
connect(channel);
return;
} catch (Exception e) {
closeChannel(channel);
lastError = e;
if (e instanceof InterruptedException)
Thread.currentThread().interrupt();
}
}
}
Expand All @@ -122,8 +133,11 @@ protected void connect(final SocketChannel channel) throws Exception {
is.readFully(bytes);
String firstLine = new String(bytes);
if (!firstLine.startsWith("Tarantool")) {
close();
throw new CommunicationException("Welcome message should starts with tarantool but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet"));
CommunicationException e = new CommunicationException("Welcome message should starts with tarantool " +
"but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet"));

close(e);
throw e;
}
is.readFully(bytes);
this.salt = new String(bytes);
Expand Down Expand Up @@ -157,32 +171,43 @@ protected void connect(final SocketChannel channel) throws Exception {
} finally {
bufferLock.unlock();
}
startThreads(channel.socket().getRemoteSocketAddress().toString());
this.thumbstone = null;
alive.countDown();
startThreads(channel.socket().getRemoteSocketAddress().toString());
}

protected void startThreads(String threadName) throws IOException, InterruptedException {
protected void startThreads(String threadName) throws InterruptedException {
final CountDownLatch init = new CountDownLatch(2);
stopIO = new CountDownLatch(2);
reader = new Thread(new Runnable() {
@Override
public void run() {
init.countDown();
readThread();
stopIO.countDown();
if (state.acquire(StateHelper.READING)) {
try {
readThread();
} finally {
state.release(StateHelper.READING);
if (state.compareAndSet(0, StateHelper.RECONNECT))
LockSupport.unpark(connector);
}
}
}
});
writer = new Thread(new Runnable() {
@Override
public void run() {
init.countDown();
writeThread();
stopIO.countDown();
if (state.acquire(StateHelper.WRITING)) {
try {
writeThread();
} finally {
state.release(StateHelper.WRITING);
if (state.compareAndSet(0, StateHelper.RECONNECT))
LockSupport.unpark(connector);
}
}
}
});


configureThreads(threadName);
reader.start();
writer.start();
Expand Down Expand Up @@ -217,13 +242,11 @@ public Future<List<?>> exec(Code code, Object... args) {
return q;
}


protected synchronized void die(String message, Exception cause) {
if (thumbstone != null) {
return;
}
this.thumbstone = new CommunicationException(message, cause);
this.alive = new CountDownLatch(1);
while (!futures.isEmpty()) {
Iterator<Map.Entry<Long, FutureImpl<List<?>>>> iterator = futures.entrySet().iterator();
while (iterator.hasNext()) {
Expand All @@ -244,9 +267,6 @@ protected synchronized void die(String message, Exception cause) {
bufferLock.unlock();
}
stopIO();
if (connector.getState() == Thread.State.WAITING) {
LockSupport.unpark(connector);
}
}


Expand Down Expand Up @@ -426,10 +446,20 @@ protected void writeFully(SocketChannel channel, ByteBuffer buffer) throws IOExc

@Override
public void close() {
if (connector != null) {
close(new Exception("Connection is closed."));
try {
state.awaitState(StateHelper.CLOSED);
} catch (InterruptedException ignored) {
Thread.currentThread().interrupt();
}
}

protected void close(Exception e) {
if (state.close()) {
connector.interrupt();

die(e.getMessage(), e);
}
stopIO();
}

protected void stopIO() {
Expand All @@ -454,28 +484,21 @@ protected void stopIO() {
}
}
closeChannel(channel);
try {
stopIO.await();
} catch (InterruptedException ignored) {

}
}

@Override
public boolean isAlive() {
return thumbstone == null;
return state.getState() == StateHelper.ALIVE && thumbstone == null;
}

@Override
public void waitAlive() throws InterruptedException {
while(!isAlive()) {
alive.await();
}
state.awaitState(StateHelper.ALIVE);
}

@Override
public boolean waitAlive(long timeout, TimeUnit unit) throws InterruptedException {
return alive.await(timeout, unit);
return state.awaitState(StateHelper.ALIVE, timeout, unit);
}

@Override
Expand Down Expand Up @@ -545,4 +568,101 @@ public TarantoolClientStats getStats() {
return stats;
}

/**
* Manages state changes.
*/
protected final class StateHelper {
static final int READING = 1;
static final int WRITING = 2;
static final int ALIVE = READING | WRITING;
static final int RECONNECT = 4;
static final int CLOSED = 8;

private final AtomicInteger state;

private final AtomicReference<CountDownLatch> nextAliveLatch =
new AtomicReference<CountDownLatch>(new CountDownLatch(1));

private final CountDownLatch closedLatch = new CountDownLatch(1);

protected StateHelper(int state) {
this.state = new AtomicInteger(state);
}

protected int getState() {
return state.get();
}

protected boolean close() {
for (;;) {
int st = getState();
if ((st & CLOSED) == CLOSED)
return false;
if (compareAndSet(st, (st & ~RECONNECT) | CLOSED))
return true;
}
}

protected boolean acquire(int mask) {
for (;;) {
int st = getState();
if ((st & CLOSED) == CLOSED)
return false;

if ((st & mask) != 0)
throw new IllegalStateException("State is already " + mask);

if (compareAndSet(st, st | mask))
return true;
}
}

protected void release(int mask) {
for (;;) {
int st = getState();
if (compareAndSet(st, st & ~mask))
return;
}
}

protected boolean compareAndSet(int expect, int update) {
if (!state.compareAndSet(expect, update)) {
return false;
}

if (update == ALIVE) {
CountDownLatch latch = nextAliveLatch.getAndSet(new CountDownLatch(1));
latch.countDown();
} else if (update == CLOSED) {
closedLatch.countDown();
}
return true;
}

protected void awaitState(int state) throws InterruptedException {
CountDownLatch latch = getStateLatch(state);
if (latch != null) {
latch.await();
}
}

protected boolean awaitState(int state, long timeout, TimeUnit timeUnit) throws InterruptedException {
CountDownLatch latch = getStateLatch(state);
return (latch == null) || latch.await(timeout, timeUnit);
}

private CountDownLatch getStateLatch(int state) {
if (state == CLOSED) {
return closedLatch;
}
if (state == ALIVE) {
if (getState() == CLOSED) {
throw new IllegalStateException("State is CLOSED.");
}
CountDownLatch latch = nextAliveLatch.get();
return (getState() == ALIVE) ? null : latch;
}
return null;
}
}
}
Loading