diff --git a/src/main/java/io/reactivesocket/DefaultReactiveSocket.java b/src/main/java/io/reactivesocket/DefaultReactiveSocket.java index be2d4dfe6..40b8747de 100644 --- a/src/main/java/io/reactivesocket/DefaultReactiveSocket.java +++ b/src/main/java/io/reactivesocket/DefaultReactiveSocket.java @@ -29,6 +29,7 @@ import org.reactivestreams.Subscription; import java.io.IOException; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @@ -55,6 +56,7 @@ public class DefaultReactiveSocket implements ReactiveSocket { private final RequestHandler clientRequestHandler; private final ConnectionSetupHandler responderConnectionHandler; private final LeaseGovernor leaseGovernor; + private final CopyOnWriteArrayList shutdownListeners; private DefaultReactiveSocket( DuplexConnection connection, @@ -72,6 +74,7 @@ private DefaultReactiveSocket( this.responderConnectionHandler = responderConnectionHandler; this.leaseGovernor = leaseGovernor; this.errorStream = errorStream; + this.shutdownListeners = new CopyOnWriteArrayList<>(); } /** @@ -439,15 +442,28 @@ public void addOutput(Frame f, Completable callback) { }; + @Override + public void onShutdown(Completable c) { + shutdownListeners.add(c); + } + @Override public void close() throws Exception { - connection.close(); - leaseGovernor.unregister(responder); - if (requester != null) { - requester.shutdown(); - } - if (responder != null) { - responder.shutdown(); + try { + connection.close(); + leaseGovernor.unregister(responder); + if (requester != null) { + requester.shutdown(); + } + if (responder != null) { + responder.shutdown(); + } + + shutdownListeners.forEach(Completable::success); + + } catch (Throwable t) { + shutdownListeners.forEach(c -> c.error(t)); + throw t; } } diff --git a/src/main/java/io/reactivesocket/DuplexConnection.java b/src/main/java/io/reactivesocket/DuplexConnection.java index 772205404..2a6c3f888 100644 --- a/src/main/java/io/reactivesocket/DuplexConnection.java +++ b/src/main/java/io/reactivesocket/DuplexConnection.java @@ -15,6 +15,7 @@ */ package io.reactivesocket; +import io.reactivesocket.internal.rx.EmptySubscription; import io.reactivesocket.rx.Completable; import io.reactivesocket.rx.Observable; import org.reactivestreams.Publisher; @@ -32,6 +33,7 @@ public interface DuplexConnection extends Closeable { default void addOutput(Frame frame, Completable callback) { addOutput(s -> { + s.onSubscribe(EmptySubscription.INSTANCE); s.onNext(frame); s.onComplete(); }, callback); diff --git a/src/main/java/io/reactivesocket/ReactiveSocket.java b/src/main/java/io/reactivesocket/ReactiveSocket.java index 8f8c2a10c..c862b0277 100644 --- a/src/main/java/io/reactivesocket/ReactiveSocket.java +++ b/src/main/java/io/reactivesocket/ReactiveSocket.java @@ -94,6 +94,11 @@ public void error(Throwable e) { */ void onRequestReady(Completable c); + /** + * Registers a completable to be run when an ReactiveSocket is closed + */ + void onShutdown(Completable c); + /** * Server granting new lease information to client * diff --git a/src/test/java/io/reactivesocket/ReactiveSocketTest.java b/src/test/java/io/reactivesocket/ReactiveSocketTest.java index 28229a777..4c66e99e7 100644 --- a/src/test/java/io/reactivesocket/ReactiveSocketTest.java +++ b/src/test/java/io/reactivesocket/ReactiveSocketTest.java @@ -16,6 +16,7 @@ package io.reactivesocket; import io.reactivesocket.lease.FairLeaseGovernor; +import io.reactivesocket.rx.Completable; import io.reactivex.disposables.Disposable; import io.reactivex.observables.ConnectableObservable; import io.reactivex.subscribers.TestSubscriber; @@ -256,6 +257,74 @@ private void awaitSocketAvailability(ReactiveSocket socket, long timeout, TimeUn assertTrue("client socket has positive avaibility", socket.availability() > 0.0); } + @Test(timeout = 2000) + public void testShutdownListener() throws Exception { + socketClient = DefaultReactiveSocket.fromClientConnection( + clientConnection, + ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), + err -> err.printStackTrace() + ); + + CountDownLatch latch = new CountDownLatch(1); + + socketClient.onShutdown(new Completable() { + @Override + public void success() { + latch.countDown(); + } + + @Override + public void error(Throwable e) { + + } + }); + + socketClient.close(); + + latch.await(); + } + + @Test(timeout = 2000) + public void testMultipleShutdownListeners() throws Exception { + socketClient = DefaultReactiveSocket.fromClientConnection( + clientConnection, + ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), + err -> err.printStackTrace() + ); + + CountDownLatch latch = new CountDownLatch(2); + + socketClient + .onShutdown(new Completable() { + @Override + public void success() { + latch.countDown(); + } + + @Override + public void error(Throwable e) { + + } + }); + + socketClient + .onShutdown(new Completable() { + @Override + public void success() { + latch.countDown(); + } + + @Override + public void error(Throwable e) { + + } + }); + + socketClient.close(); + + latch.await(); + } + @Test(timeout=2000) @Theory public void testRequestResponse(int setupFlag) throws InterruptedException { @@ -269,7 +338,7 @@ public void testRequestResponse(int setupFlag) throws InterruptedException { ts.assertNoErrors(); ts.assertValue(TestUtil.utf8EncodedPayload("hello world", null)); } - + @Test(timeout=2000, expected=IllegalStateException.class) public void testRequestResponsePremature() throws InterruptedException { socketClient = DefaultReactiveSocket.fromClientConnection(