Skip to content

Commit 1a1dff5

Browse files
ztarvosTotktonada
authored andcommitted
fix race between close and reconnect
Fixed issue with concurrent close and reconnect. Closes #93
1 parent 9f72d0c commit 1a1dff5

File tree

5 files changed

+455
-55
lines changed

5 files changed

+455
-55
lines changed

src/main/java/org/tarantool/SocketChannelProvider.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ public interface SocketChannelProvider {
77
/**
88
* Provides socket channel to init restore connection.
99
* You could change hosts on fail and sleep between retries in this method
10-
* @param retryNumber number of current retry. -1 on initial connect.
10+
* @param retryNumber number of current retry. Reset after successful connect.
1111
* @param lastError the last error occurs when reconnecting
1212
* @return the result of SocketChannel open(SocketAddress remote) call
1313
*/

src/main/java/org/tarantool/TarantoolClientImpl.java

+159-39
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import java.util.concurrent.TimeUnit;
1616
import java.util.concurrent.TimeoutException;
1717
import java.util.concurrent.atomic.AtomicInteger;
18+
import java.util.concurrent.atomic.AtomicReference;
1819
import java.util.concurrent.locks.Condition;
1920
import java.util.concurrent.locks.LockSupport;
2021
import java.util.concurrent.locks.ReentrantLock;
@@ -29,7 +30,6 @@ public class TarantoolClientImpl extends TarantoolBase<Future<List<?>>> implemen
2930
*/
3031
protected SocketChannelProvider socketProvider;
3132
protected volatile Exception thumbstone;
32-
protected volatile CountDownLatch alive;
3333

3434
protected Map<Long, FutureImpl<List<?>>> futures;
3535
protected AtomicInteger wait = new AtomicInteger();
@@ -54,17 +54,18 @@ public class TarantoolClientImpl extends TarantoolBase<Future<List<?>>> implemen
5454
* Inner
5555
*/
5656
protected TarantoolClientStats stats;
57-
protected CountDownLatch stopIO;
57+
protected StateHelper state = new StateHelper(StateHelper.RECONNECT);
5858
protected Thread reader;
5959
protected Thread writer;
6060

61-
6261
protected Thread connector = new Thread(new Runnable() {
6362
@Override
6463
public void run() {
6564
while (!Thread.currentThread().isInterrupted()) {
66-
reconnect(0, thumbstone);
67-
LockSupport.park();
65+
if (state.compareAndSet(StateHelper.RECONNECT, 0)) {
66+
reconnect(0, thumbstone);
67+
}
68+
LockSupport.park(state);
6869
}
6970
}
7071
});
@@ -74,7 +75,6 @@ public TarantoolClientImpl(SocketChannelProvider socketProvider, TarantoolClient
7475
this.thumbstone = NOT_INIT_EXCEPTION;
7576
this.config = config;
7677
this.initialRequestSize = config.defaultRequestSize;
77-
this.alive = new CountDownLatch(1);
7878
this.socketProvider = socketProvider;
7979
this.stats = new TarantoolClientStats();
8080
this.futures = new ConcurrentHashMap<Long, FutureImpl<List<?>>>(config.predictedFutures);
@@ -92,25 +92,36 @@ public TarantoolClientImpl(SocketChannelProvider socketProvider, TarantoolClient
9292
connector.start();
9393
try {
9494
if (!waitAlive(config.initTimeoutMillis, TimeUnit.MILLISECONDS)) {
95-
close();
96-
throw new CommunicationException(config.initTimeoutMillis+"ms is exceeded when waiting for client initialization. You could configure init timeout in TarantoolConfig");
95+
CommunicationException e = new CommunicationException(config.initTimeoutMillis +
96+
"ms is exceeded when waiting for client initialization. " +
97+
"You could configure init timeout in TarantoolConfig");
98+
99+
close(e);
100+
throw e;
97101
}
98102
} catch (InterruptedException e) {
99-
close();
103+
close(e);
100104
throw new IllegalStateException(e);
101105
}
102106
}
103107

104108
protected void reconnect(int retry, Throwable lastError) {
105109
SocketChannel channel;
106-
while (!Thread.interrupted()) {
107-
channel = socketProvider.get(retry++, lastError == NOT_INIT_EXCEPTION ? null : lastError);
110+
while (!Thread.currentThread().isInterrupted()) {
111+
try {
112+
channel = socketProvider.get(retry++, lastError == NOT_INIT_EXCEPTION ? null : lastError);
113+
} catch (Exception e) {
114+
close(e);
115+
return;
116+
}
108117
try {
109118
connect(channel);
110119
return;
111120
} catch (Exception e) {
112121
closeChannel(channel);
113122
lastError = e;
123+
if (e instanceof InterruptedException)
124+
Thread.currentThread().interrupt();
114125
}
115126
}
116127
}
@@ -122,8 +133,11 @@ protected void connect(final SocketChannel channel) throws Exception {
122133
is.readFully(bytes);
123134
String firstLine = new String(bytes);
124135
if (!firstLine.startsWith("Tarantool")) {
125-
close();
126-
throw new CommunicationException("Welcome message should starts with tarantool but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet"));
136+
CommunicationException e = new CommunicationException("Welcome message should starts with tarantool " +
137+
"but starts with '" + firstLine + "'", new IllegalStateException("Invalid welcome packet"));
138+
139+
close(e);
140+
throw e;
127141
}
128142
is.readFully(bytes);
129143
this.salt = new String(bytes);
@@ -157,32 +171,43 @@ protected void connect(final SocketChannel channel) throws Exception {
157171
} finally {
158172
bufferLock.unlock();
159173
}
160-
startThreads(channel.socket().getRemoteSocketAddress().toString());
161174
this.thumbstone = null;
162-
alive.countDown();
175+
startThreads(channel.socket().getRemoteSocketAddress().toString());
163176
}
164177

165-
protected void startThreads(String threadName) throws IOException, InterruptedException {
178+
protected void startThreads(String threadName) throws InterruptedException {
166179
final CountDownLatch init = new CountDownLatch(2);
167-
stopIO = new CountDownLatch(2);
168180
reader = new Thread(new Runnable() {
169181
@Override
170182
public void run() {
171183
init.countDown();
172-
readThread();
173-
stopIO.countDown();
184+
if (state.acquire(StateHelper.READING)) {
185+
try {
186+
readThread();
187+
} finally {
188+
state.release(StateHelper.READING);
189+
if (state.compareAndSet(0, StateHelper.RECONNECT))
190+
LockSupport.unpark(connector);
191+
}
192+
}
174193
}
175194
});
176195
writer = new Thread(new Runnable() {
177196
@Override
178197
public void run() {
179198
init.countDown();
180-
writeThread();
181-
stopIO.countDown();
199+
if (state.acquire(StateHelper.WRITING)) {
200+
try {
201+
writeThread();
202+
} finally {
203+
state.release(StateHelper.WRITING);
204+
if (state.compareAndSet(0, StateHelper.RECONNECT))
205+
LockSupport.unpark(connector);
206+
}
207+
}
182208
}
183209
});
184210

185-
186211
configureThreads(threadName);
187212
reader.start();
188213
writer.start();
@@ -217,13 +242,11 @@ public Future<List<?>> exec(Code code, Object... args) {
217242
return q;
218243
}
219244

220-
221245
protected synchronized void die(String message, Exception cause) {
222246
if (thumbstone != null) {
223247
return;
224248
}
225249
this.thumbstone = new CommunicationException(message, cause);
226-
this.alive = new CountDownLatch(1);
227250
while (!futures.isEmpty()) {
228251
Iterator<Map.Entry<Long, FutureImpl<List<?>>>> iterator = futures.entrySet().iterator();
229252
while (iterator.hasNext()) {
@@ -244,9 +267,6 @@ protected synchronized void die(String message, Exception cause) {
244267
bufferLock.unlock();
245268
}
246269
stopIO();
247-
if (connector.getState() == Thread.State.WAITING) {
248-
LockSupport.unpark(connector);
249-
}
250270
}
251271

252272

@@ -426,10 +446,20 @@ protected void writeFully(SocketChannel channel, ByteBuffer buffer) throws IOExc
426446

427447
@Override
428448
public void close() {
429-
if (connector != null) {
449+
close(new Exception("Connection is closed."));
450+
try {
451+
state.awaitState(StateHelper.CLOSED);
452+
} catch (InterruptedException ignored) {
453+
Thread.currentThread().interrupt();
454+
}
455+
}
456+
457+
protected void close(Exception e) {
458+
if (state.close()) {
430459
connector.interrupt();
460+
461+
die(e.getMessage(), e);
431462
}
432-
stopIO();
433463
}
434464

435465
protected void stopIO() {
@@ -454,28 +484,21 @@ protected void stopIO() {
454484
}
455485
}
456486
closeChannel(channel);
457-
try {
458-
stopIO.await();
459-
} catch (InterruptedException ignored) {
460-
461-
}
462487
}
463488

464489
@Override
465490
public boolean isAlive() {
466-
return thumbstone == null;
491+
return state.getState() == StateHelper.ALIVE && thumbstone == null;
467492
}
468493

469494
@Override
470495
public void waitAlive() throws InterruptedException {
471-
while(!isAlive()) {
472-
alive.await();
473-
}
496+
state.awaitState(StateHelper.ALIVE);
474497
}
475498

476499
@Override
477500
public boolean waitAlive(long timeout, TimeUnit unit) throws InterruptedException {
478-
return alive.await(timeout, unit);
501+
return state.awaitState(StateHelper.ALIVE, timeout, unit);
479502
}
480503

481504
@Override
@@ -545,4 +568,101 @@ public TarantoolClientStats getStats() {
545568
return stats;
546569
}
547570

571+
/**
572+
* Manages state changes.
573+
*/
574+
protected final class StateHelper {
575+
static final int READING = 1;
576+
static final int WRITING = 2;
577+
static final int ALIVE = READING | WRITING;
578+
static final int RECONNECT = 4;
579+
static final int CLOSED = 8;
580+
581+
private final AtomicInteger state;
582+
583+
private final AtomicReference<CountDownLatch> nextAliveLatch =
584+
new AtomicReference<CountDownLatch>(new CountDownLatch(1));
585+
586+
private final CountDownLatch closedLatch = new CountDownLatch(1);
587+
588+
protected StateHelper(int state) {
589+
this.state = new AtomicInteger(state);
590+
}
591+
592+
protected int getState() {
593+
return state.get();
594+
}
595+
596+
protected boolean close() {
597+
for (;;) {
598+
int st = getState();
599+
if ((st & CLOSED) == CLOSED)
600+
return false;
601+
if (compareAndSet(st, (st & ~RECONNECT) | CLOSED))
602+
return true;
603+
}
604+
}
605+
606+
protected boolean acquire(int mask) {
607+
for (;;) {
608+
int st = getState();
609+
if ((st & CLOSED) == CLOSED)
610+
return false;
611+
612+
if ((st & mask) != 0)
613+
throw new IllegalStateException("State is already " + mask);
614+
615+
if (compareAndSet(st, st | mask))
616+
return true;
617+
}
618+
}
619+
620+
protected void release(int mask) {
621+
for (;;) {
622+
int st = getState();
623+
if (compareAndSet(st, st & ~mask))
624+
return;
625+
}
626+
}
627+
628+
protected boolean compareAndSet(int expect, int update) {
629+
if (!state.compareAndSet(expect, update)) {
630+
return false;
631+
}
632+
633+
if (update == ALIVE) {
634+
CountDownLatch latch = nextAliveLatch.getAndSet(new CountDownLatch(1));
635+
latch.countDown();
636+
} else if (update == CLOSED) {
637+
closedLatch.countDown();
638+
}
639+
return true;
640+
}
641+
642+
protected void awaitState(int state) throws InterruptedException {
643+
CountDownLatch latch = getStateLatch(state);
644+
if (latch != null) {
645+
latch.await();
646+
}
647+
}
648+
649+
protected boolean awaitState(int state, long timeout, TimeUnit timeUnit) throws InterruptedException {
650+
CountDownLatch latch = getStateLatch(state);
651+
return (latch == null) || latch.await(timeout, timeUnit);
652+
}
653+
654+
private CountDownLatch getStateLatch(int state) {
655+
if (state == CLOSED) {
656+
return closedLatch;
657+
}
658+
if (state == ALIVE) {
659+
if (getState() == CLOSED) {
660+
throw new IllegalStateException("State is CLOSED.");
661+
}
662+
CountDownLatch latch = nextAliveLatch.get();
663+
return (getState() == ALIVE) ? null : latch;
664+
}
665+
return null;
666+
}
667+
}
548668
}

0 commit comments

Comments
 (0)