diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index fc3175b15..70b2a1889 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -22,7 +22,9 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.Payload; @@ -77,6 +79,16 @@ class RSocketRequester implements RSocket { AtomicReferenceFieldUpdater.newUpdater( RSocketRequester.class, Throwable.class, "terminationError"); private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + referenceCounted -> { + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + }; static { CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); @@ -259,7 +271,7 @@ public void doOnTerminal( }); receivers.put(streamId, receiver); - return receiver; + return receiver.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); } private Flux handleRequestStream(final Payload payload) { @@ -323,7 +335,8 @@ public void accept(long n) { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); } }) - .doFinally(s -> removeStreamReceiver(streamId)); + .doFinally(s -> removeStreamReceiver(streamId)) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); } private Flux handleChannel(Flux request) { @@ -424,7 +437,10 @@ public void accept(long n) { senders.put(streamId, upstreamSubscriber); receivers.put(streamId, receiver); - inboundFlux.limitRate(Queues.SMALL_BUFFER_SIZE).subscribe(upstreamSubscriber); + inboundFlux + .limitRate(Queues.SMALL_BUFFER_SIZE) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(upstreamSubscriber); if (!payloadReleasedFlag.getAndSet(true)) { ByteBuf frame = RequestChannelFrameFlyweight.encode( @@ -461,7 +477,8 @@ public void accept(long n) { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); upstreamSubscriber.cancel(); } - }); + }) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); } private Mono handleMetadataPush(Payload payload) { diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 6f235587a..e01000e49 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -20,7 +20,9 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.Payload; @@ -45,6 +47,16 @@ /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ class RSocketResponder implements ResponderRSocket { + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + referenceCounted -> { + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + }; private final DuplexConnection connection; private final RSocket requestHandler; @@ -418,7 +430,7 @@ protected void hookFinally(SignalType type) { }; sendingSubscriptions.put(streamId, subscriber); - response.subscribe(subscriber); + response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); } private void handleStream(int streamId, Flux response, int initialRequestN) { @@ -471,7 +483,10 @@ protected void hookFinally(SignalType type) { }; sendingSubscriptions.put(streamId, subscriber); - response.limitRate(Queues.SMALL_BUFFER_SIZE).subscribe(subscriber); + response + .limitRate(Queues.SMALL_BUFFER_SIZE) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(subscriber); } private void handleChannel(int streamId, Payload payload, int initialRequestN) { @@ -499,7 +514,8 @@ public void accept(long l) { sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); } }) - .doFinally(signalType -> channelProcessors.remove(streamId)); + .doFinally(signalType -> channelProcessors.remove(streamId)) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); // not chained, as the payload should be enqueued in the Unicast processor before this method // returns diff --git a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java index 6f2aa7150..328fb8435 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java @@ -99,7 +99,7 @@ private static char[] checkCharSequenceBounds(char[] seq, int start, int end) { } /** - * Encode a {@link char[]} in UTF-8 and write it + * Encode a {@code char[]} in UTF-8 and write it * into {@link ByteBuf}. * *

This method returns the actual number of bytes written. @@ -109,9 +109,8 @@ public static int writeUtf8(ByteBuf buf, char[] seq) { } /** - * Equivalent to {@link #writeUtf8(ByteBuf, char[]) - * writeUtf8(buf, seq.subSequence(start, end), reserveBytes)} but avoids subsequence object - * allocation if possible. + * Equivalent to {@link #writeUtf8(ByteBuf, char[]) writeUtf8(buf, seq.subSequence(start, end), + * reserveBytes)} but avoids subsequence object allocation if possible. * * @return actual number of bytes written */ diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..2044779ef --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,167 @@ +package io.rsocket.buffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + private LeaksTrackingByteBufAllocator(ByteBufAllocator delegate) { + this.delegate = delegate; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + Assertions.assertThat(tracker) + .allSatisfy( + buf -> { + if (buf instanceof CompositeByteBuf) { + if (buf.refCnt() > 0) { + List decomposed = + ((CompositeByteBuf) buf).decompose(0, buf.readableBytes()); + for (int i = 0; i < decomposed.size(); i++) { + Assertions.assertThat(decomposed.get(i)) + .matches(bb -> bb.refCnt() == 0, "Got unreleased CompositeByteBuf"); + } + } + + } else { + Assertions.assertThat(buf) + .matches(bb -> bb.refCnt() == 0, "buffer should be released"); + } + }); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java index dc01e7911..5a43838c7 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -16,7 +16,9 @@ package io.rsocket.core; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; import java.util.concurrent.ConcurrentLinkedQueue; @@ -32,6 +34,7 @@ public abstract class AbstractSocketRule extends ExternalReso protected Subscriber connectSub; protected T socket; protected ConcurrentLinkedQueue errors; + protected LeaksTrackingByteBufAllocator allocator; @Override public Statement apply(final Statement base, Description description) { @@ -41,6 +44,7 @@ public void evaluate() throws Throwable { connection = new TestDuplexConnection(); connectSub = TestSubscriber.create(); errors = new ConcurrentLinkedQueue<>(); + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); init(); base.evaluate(); } @@ -48,14 +52,22 @@ public void evaluate() throws Throwable { } protected void init() { - socket = newRSocket(); + socket = newRSocket(allocator); } - protected abstract T newRSocket(); + protected abstract T newRSocket(LeaksTrackingByteBufAllocator allocator); public void assertNoConnectionErrors() { if (errors.size() > 1) { Assert.fail("No connection errors expected: " + errors.peek().toString()); } } + + public ByteBufAllocator alloc() { + return allocator; + } + + public void assertHasNoLeaks() { + allocator.assertHasNoLeaks(); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 101500da7..586c9cfd3 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -19,7 +19,6 @@ import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.frame.FrameHeaderFlyweight.frameType; import static io.rsocket.frame.FrameType.CANCEL; -import static io.rsocket.frame.FrameType.KEEPALIVE; import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; import static io.rsocket.frame.FrameType.REQUEST_STREAM; @@ -37,9 +36,12 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.frame.CancelFrameFlyweight; import io.rsocket.frame.ErrorFrameFlyweight; @@ -50,8 +52,11 @@ import io.rsocket.frame.RequestChannelFrameFlyweight; import io.rsocket.frame.RequestNFrameFlyweight; import io.rsocket.frame.RequestStreamFrameFlyweight; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; import io.rsocket.util.MultiSubscriberRSocket; @@ -60,12 +65,17 @@ import java.util.Iterator; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiConsumer; import java.util.function.BiFunction; -import java.util.stream.Collectors; +import java.util.function.Function; import java.util.stream.Stream; import org.assertj.core.api.Assertions; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.runners.model.Statement; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -75,23 +85,39 @@ import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.UnicastProcessor; import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; public class RSocketRequesterTest { - @Rule public final ClientSocketRule rule = new ClientSocketRule(); + ClientSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + rule = new ClientSocketRule(); + rule.apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testInvalidFrameOnStream0() { - rule.connection.addToReceivedBuffer( - RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 0, 10)); + rule.connection.addToReceivedBuffer(RequestNFrameFlyweight.encode(rule.alloc(), 0, 10)); assertThat("Unexpected errors.", rule.errors, hasSize(1)); assertThat( "Unexpected error received.", rule.errors, contains(instanceOf(IllegalStateException.class))); + rule.assertHasNoLeaks(); } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testStreamInitialN() { Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); @@ -100,19 +126,15 @@ public void testStreamInitialN() { @Override protected void hookOnSubscribe(Subscription subscription) { // don't request here - // subscription.request(3); } }; stream.subscribe(subscriber); + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + subscriber.request(5); - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> frameType(f) != KEEPALIVE) - .collect(Collectors.toList()); + List sent = new ArrayList<>(rule.connection.getSent()); assertThat("sent frame count", sent.size(), is(1)); @@ -120,21 +142,25 @@ protected void hookOnSubscribe(Subscription subscription) { assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); assertThat("initial request n", RequestStreamFrameFlyweight.initialRequestN(f), is(5)); + assertThat("should be released", f.release(), is(true)); + rule.assertHasNoLeaks(); } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testHandleSetupException() { rule.connection.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("boom"))); + ErrorFrameFlyweight.encode(rule.alloc(), 0, new RejectedSetupException("boom"))); assertThat("Unexpected errors.", rule.errors, hasSize(1)); assertThat( "Unexpected error received.", rule.errors, contains(instanceOf(RejectedSetupException.class))); + rule.assertHasNoLeaks(); } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testHandleApplicationException() { rule.connection.clearSendReceiveBuffers(); Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); @@ -143,13 +169,20 @@ public void testHandleApplicationException() { int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); rule.connection.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, new ApplicationErrorException("error"))); + ErrorFrameFlyweight.encode(rule.alloc(), streamId, new ApplicationErrorException("error"))); verify(responseSub).onError(any(ApplicationErrorException.class)); + + Assertions.assertThat(rule.connection.getSent()) + // requestResponseFrame FIXME + // .hasSize(1) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testHandleValidFrame() { Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); Subscriber sub = TestSubscriber.create(); @@ -157,13 +190,15 @@ public void testHandleValidFrame() { int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeNext( - ByteBufAllocator.DEFAULT, streamId, EmptyPayload.INSTANCE)); + PayloadFrameFlyweight.encodeNext(rule.alloc(), streamId, EmptyPayload.INSTANCE)); verify(sub).onComplete(); + Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testRequestReplyWithCancel() { Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); @@ -172,19 +207,18 @@ public void testRequestReplyWithCancel() { } catch (IllegalStateException ise) { } - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> frameType(f) != KEEPALIVE) - .collect(Collectors.toList()); + List sent = new ArrayList<>(rule.connection.getSent()); assertThat( "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); + Assertions.assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); } - @Test(timeout = 2_000) + @Test + @Disabled("invalid") + @Timeout(2_000) public void testRequestReplyErrorOnSend() { rule.connection.setAvailability(0); // Fails send Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); @@ -195,21 +229,28 @@ public void testRequestReplyErrorOnSend() { verify(responseSub).onSubscribe(any(Subscription.class)); + rule.assertHasNoLeaks(); // TODO this should get the error reported through the response subscription // verify(responseSub).onError(any(RuntimeException.class)); } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testLazyRequestResponse() { Publisher response = new MultiSubscriberRSocket(rule.socket).requestResponse(EmptyPayload.INSTANCE); int streamId = sendRequestResponse(response); + Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); rule.connection.clearSendReceiveBuffers(); int streamId2 = sendRequestResponse(response); assertThat("Stream ID reused.", streamId2, not(equalTo(streamId))); + Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); } @Test + @Timeout(2_000) public void testChannelRequestCancellation() { MonoProcessor cancelled = MonoProcessor.create(); Flux request = Flux.never().doOnCancel(cancelled::onComplete); @@ -219,9 +260,11 @@ public void testChannelRequestCancellation() { Flux.error(new IllegalStateException("Channel request not cancelled")) .delaySubscription(Duration.ofSeconds(1))) .blockFirst(); + rule.assertHasNoLeaks(); } @Test + @Timeout(2_000) public void testChannelRequestCancellation2() { MonoProcessor cancelled = MonoProcessor.create(); Flux request = @@ -232,6 +275,8 @@ public void testChannelRequestCancellation2() { Flux.error(new IllegalStateException("Channel request not cancelled")) .delaySubscription(Duration.ofSeconds(1))) .blockFirst(); + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); } @Test @@ -241,10 +286,9 @@ public void testChannelRequestServerSideCancellation() { request.onNext(EmptyPayload.INSTANCE); rule.socket.requestChannel(request).subscribe(cancelled); int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + rule.connection.addToReceivedBuffer(CancelFrameFlyweight.encode(rule.alloc(), streamId)); rule.connection.addToReceivedBuffer( - CancelFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId)); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeComplete(ByteBufAllocator.DEFAULT, streamId)); + PayloadFrameFlyweight.encodeComplete(rule.alloc(), streamId)); Flux.first( cancelled, Flux.error(new IllegalStateException("Channel request not cancelled")) @@ -252,6 +296,12 @@ public void testChannelRequestServerSideCancellation() { .blockFirst(); Assertions.assertThat(request.isDisposed()).isTrue(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_CHANNEL) + .matches(ReferenceCounted::release); + rule.assertHasNoLeaks(); } @Test @@ -282,8 +332,10 @@ protected void hookOnSubscribe(Subscription subscription) {} Assertions.assertThat( RequestChannelFrameFlyweight.data(initialFrame).toString(CharsetUtil.UTF_8)) .isEqualTo("0"); + Assertions.assertThat(initialFrame.release()).isTrue(); Assertions.assertThat(iterator.hasNext()).isFalse(); + rule.assertHasNoLeaks(); } @Test @@ -304,9 +356,21 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen .isInstanceOf(IllegalArgumentException.class) .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) .verify(); + // FIXME: should be removed + Assertions.assertThat(rule.connection.getSent()).allMatch(bb -> bb.release()); + rule.assertHasNoLeaks(); }); } + static Stream>> prepareCalls() { + return Stream.of( + RSocket::fireAndForget, + RSocket::requestResponse, + RSocket::requestStream, + (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), + RSocket::metadataPush); + } + @Test public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() { @@ -322,24 +386,245 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen () -> rule.connection.addToReceivedBuffer( RequestNFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - rule.getStreamIdForRequestType(REQUEST_CHANNEL), - 2))) + rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2))) .expectErrorSatisfies( t -> Assertions.assertThat(t) .isInstanceOf(IllegalArgumentException.class) .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) .verify(); + Assertions.assertThat(rule.connection.getSent()) + // expect to be sent RequestChannelFrame + // expect to be sent CancelFrame + .hasSize(2) + .allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); } - static Stream>> prepareCalls() { + @Test + @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") + @SuppressWarnings("unchecked") + public void checkNoLeaksOnRacingTest() { + + racingCases() + .forEach( + a -> { + ((Runnable) a.get()[0]).run(); + checkNoLeaksOnRacing( + (Function>) a.get()[1], + (BiConsumer, ClientSocketRule>) a.get()[2]); + }); + } + + public void checkNoLeaksOnRacing( + Function> initiator, + BiConsumer, ClientSocketRule> runner) { + for (int i = 0; i < 10000; i++) { + ClientSocketRule clientSocketRule = new ClientSocketRule(); + try { + clientSocketRule + .apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } catch (Throwable throwable) { + throwable.printStackTrace(); + } + + Publisher payloadP = initiator.apply(clientSocketRule); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + if (payloadP instanceof Flux) { + ((Flux) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } else { + ((Mono) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } + + runner.accept(assertSubscriber, clientSocketRule); + + Assertions.assertThat(clientSocketRule.connection.getSent()) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + private static Stream racingCases() { return Stream.of( - RSocket::fireAndForget, - RSocket::requestResponse, - RSocket::requestStream, - (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), - RSocket::metadataPush); + Arguments.of( + (Runnable) () -> System.out.println("RequestStream downstream cancellation case"), + (Function>) + (rule) -> rule.socket.requestStream(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameFlyweight.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Runnable) () -> System.out.println("RequestChannel downstream cancellation case"), + (Function>) + (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + PayloadFrameFlyweight.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Runnable) () -> System.out.println("RequestChannel upstream cancellation 1"), + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + return rule.socket.requestChannel( + Flux.just(ByteBufPayload.create(data, metadata))); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); + + RaceTestUtils.race( + () -> as.request(1), () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Runnable) () -> System.out.println("RequestChannel upstream cancellation 2"), + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + final Payload payload = + ByteBufPayload.create("d" + index, "m" + index); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); + + as.request(1); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Runnable) () -> System.out.println("RequestChannel remote error"), + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + final Payload payload = + ByteBufPayload.create("d" + index, "m" + index); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + ErrorFrameFlyweight.encode(allocator, streamId, new RuntimeException("test")); + + as.request(1); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Runnable) () -> System.out.println("RequestResponse downstream cancellation"), + (Function>) + (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + ByteBuf frame = + PayloadFrameFlyweight.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + })); + } + + @Test + public void simpleOnDiscardRequestChannelTest() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + TestPublisher testPublisher = TestPublisher.create(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.next( + ByteBufPayload.create("d", "m"), + ByteBufPayload.create("d1", "m1"), + ByteBufPayload.create("d2", "m2")); + + assertSubscriber.cancel(); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleOnDiscardRequestChannelTest2() { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + TestPublisher testPublisher = TestPublisher.create(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.next(ByteBufPayload.create("d", "m")); + + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + testPublisher.next(ByteBufPayload.create("d1", "m1"), ByteBufPayload.create("d2", "m2")); + + rule.connection.addToReceivedBuffer( + ErrorFrameFlyweight.encode( + allocator, streamId, new CustomRSocketException(0x00000404, "test"))); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); } public int sendRequestResponse(Publisher response) { @@ -347,8 +632,7 @@ public int sendRequestResponse(Publisher response) { response.subscribe(sub); int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeNextComplete( - ByteBufAllocator.DEFAULT, streamId, EmptyPayload.INSTANCE)); + PayloadFrameFlyweight.encodeNextComplete(rule.alloc(), streamId, EmptyPayload.INSTANCE)); verify(sub).onNext(any(Payload.class)); verify(sub).onComplete(); return streamId; @@ -356,11 +640,11 @@ public int sendRequestResponse(Publisher response) { public static class ClientSocketRule extends AbstractSocketRule { @Override - protected RSocketRequester newRSocket() { + protected RSocketRequester newRSocket(LeaksTrackingByteBufAllocator allocator) { return new RSocketRequester( - ByteBufAllocator.DEFAULT, + allocator, connection, - DefaultPayload::create, + PayloadDecoder.ZERO_COPY, throwable -> errors.add(throwable), StreamIdSupplier.clientSupplier(), 0, diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 5c147f46f..d31fc3bf7 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -18,43 +18,93 @@ import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.frame.FrameHeaderFlyweight.frameType; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.frame.*; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.CancelFrameFlyweight; +import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameFlyweight; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestNFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; import io.rsocket.lease.ResponderLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; import java.util.Collection; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; import org.assertj.core.api.Assertions; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.runners.model.Statement; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.util.RaceTestUtils; public class RSocketResponderTest { - @Rule public final ServerSocketRule rule = new ServerSocketRule(); + ServerSocketRule rule; - @Test(timeout = 2000) - @Ignore + @BeforeEach + public void setUp() throws Throwable { + rule = new ServerSocketRule(); + rule.apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + } + + @Test + @Timeout(2_000) + @Disabled public void testHandleKeepAlive() throws Exception { rule.connection.addToReceivedBuffer( - KeepAliveFrameFlyweight.encode(ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER)); + KeepAliveFrameFlyweight.encode(rule.alloc(), true, 0, Unpooled.EMPTY_BUFFER)); ByteBuf sent = rule.connection.awaitSend(); assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ @@ -64,8 +114,9 @@ public void testHandleKeepAlive() throws Exception { is(false)); } - @Test(timeout = 2000) - @Ignore + @Test + @Timeout(2_000) + @Disabled public void testHandleResponseFrameNoError() throws Exception { final int streamId = 4; rule.connection.clearSendReceiveBuffers(); @@ -82,8 +133,9 @@ public void testHandleResponseFrameNoError() throws Exception { anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); } - @Test(timeout = 2000) - @Ignore + @Test + @Timeout(2_000) + @Disabled public void testHandlerEmitsError() throws Exception { final int streamId = 4; rule.sendRequest(streamId, FrameType.REQUEST_STREAM); @@ -92,14 +144,17 @@ public void testHandlerEmitsError() throws Exception { "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); } - @Test(timeout = 2_0000) + @Test + @Timeout(20_000) public void testCancel() { + ByteBufAllocator allocator = rule.alloc(); final int streamId = 4; final AtomicBoolean cancelled = new AtomicBoolean(); rule.setAcceptingSocket( new AbstractRSocket() { @Override public Mono requestResponse(Payload payload) { + payload.release(); return Mono.never().doOnCancel(() -> cancelled.set(true)); } }); @@ -108,14 +163,15 @@ public Mono requestResponse(Payload payload) { assertThat("Unexpected error.", rule.errors, is(empty())); assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - rule.connection.addToReceivedBuffer( - CancelFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId)); + rule.connection.addToReceivedBuffer(CancelFrameFlyweight.encode(allocator, streamId)); assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + rule.assertHasNoLeaks(); } @Test + @Timeout(2_000) public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { final int streamId = 4; final AtomicBoolean cancelled = new AtomicBoolean(); @@ -128,48 +184,429 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen new AbstractRSocket() { @Override public Mono requestResponse(Payload p) { + p.release(); return Mono.just(payload).doOnCancel(() -> cancelled.set(true)); } @Override public Flux requestStream(Payload p) { + p.release(); return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); } - - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); - } + // FIXME + // @Override + // public Flux requestChannel(Publisher payloads) { + // Flux.from(payloads) + // .doOnNext(Payload::release) + // .subscribe( + // new BaseSubscriber() { + // @Override + // protected void hookOnSubscribe(Subscription subscription) { + // subscription.request(1); + // } + // }); + // return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + // } }; rule.setAcceptingSocket(acceptingSocket); final Runnable[] runnables = { () -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE), - () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM), - () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL) + () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM) /* FIXME, + () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL)*/ }; for (Runnable runnable : runnables) { + rule.connection.clearSendReceiveBuffers(); runnable.run(); Assertions.assertThat(rule.errors) .first() .isInstanceOf(IllegalArgumentException.class) .hasToString("java.lang.IllegalArgumentException: " + INVALID_PAYLOAD_ERROR_MESSAGE); Assertions.assertThat(rule.connection.getSent()) + .filteredOn(bb -> FrameHeaderFlyweight.frameType(bb) == FrameType.ERROR) .hasSize(1) .first() - .matches(bb -> FrameHeaderFlyweight.frameType(bb) == FrameType.ERROR) - .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)); + .matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)) + .matches(ReferenceCounted::release); assertThat("Subscription not cancelled.", cancelled.get(), is(true)); - rule.init(); - rule.setAcceptingSocket(acceptingSocket); } + + rule.assertHasNoLeaks(); + } + + @Test + @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") + public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { + + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.never(); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_CHANNEL); + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); + + RaceTestUtils.race( + () -> { + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + }, + assertSubscriber::cancel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + ByteBuf requestNFrame = RequestNFrameFlyweight.encode(allocator, 1, Integer.MAX_VALUE); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(cancelFrame), + parallel), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") + public void + checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() + throws InterruptedException { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + + return Flux.create( + sink -> { + sinks[0] = sink; + }, + FluxSink.OverflowStrategy.IGNORE) + .mergeWith(payloads); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); + + ByteBuf requestNFrame = RequestNFrameFlyweight.encode(allocator, 1, Integer.MAX_VALUE); + + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), + parallel), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + sink.error(new RuntimeException()); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + @Disabled("Due to https://github.com/reactor/reactor-core/pull/2114") + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestResponseTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + sources[0] = new Operators.MonoSubscriber<>(actual); + actual.onSubscribe(sources[0]); + } + }; + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_RESPONSE); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sources[0].complete(ByteBufPayload.create("d1", "m1")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void simpleDiscardRequestStreamTest() { + ByteBufAllocator allocator = rule.alloc(); + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + FluxSink sink = sinks[0]; + + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + rule.connection.addToReceivedBuffer(cancelFrame); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleDiscardRequestChannelTest() { + ByteBufAllocator allocator = rule.alloc(); + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return (Flux) payloads; + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameFlyweight.encode(allocator, 1); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameFlyweight.encode(allocator, 1, false, false, true, metadata3, data3); + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + + rule.connection.addToReceivedBuffer(cancelFrame); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); } public static class ServerSocketRule extends AbstractSocketRule { private RSocket acceptingSocket; + private volatile int prefetch; @Override protected void init() { @@ -188,25 +625,26 @@ public void setAcceptingSocket(RSocket acceptingSocket) { connection = new TestDuplexConnection(); connectSub = TestSubscriber.create(); errors = new ConcurrentLinkedQueue<>(); + this.prefetch = Integer.MAX_VALUE; super.init(); } public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { this.acceptingSocket = acceptingSocket; connection = new TestDuplexConnection(); - connection.setInitialSendRequestN(prefetch); connectSub = TestSubscriber.create(); errors = new ConcurrentLinkedQueue<>(); + this.prefetch = prefetch; super.init(); } @Override - protected RSocketResponder newRSocket() { + protected RSocketResponder newRSocket(LeaksTrackingByteBufAllocator allocator) { return new RSocketResponder( - ByteBufAllocator.DEFAULT, + allocator, connection, acceptingSocket, - DefaultPayload::create, + PayloadDecoder.ZERO_COPY, throwable -> errors.add(throwable), ResponderLeaseHandler.None, 0); @@ -219,25 +657,34 @@ private void sendRequest(int streamId, FrameType frameType) { case REQUEST_CHANNEL: request = RequestChannelFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, false, 1, EmptyPayload.INSTANCE); + allocator, + streamId, + false, + false, + prefetch, + Unpooled.EMPTY_BUFFER, + Unpooled.EMPTY_BUFFER); break; case REQUEST_STREAM: request = RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, 1, EmptyPayload.INSTANCE); + allocator, + streamId, + false, + prefetch, + Unpooled.EMPTY_BUFFER, + Unpooled.EMPTY_BUFFER); break; case REQUEST_RESPONSE: request = RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, EmptyPayload.INSTANCE); + allocator, streamId, false, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); break; default: throw new IllegalArgumentException("unsupported type: " + frameType); } connection.addToReceivedBuffer(request); - connection.addToReceivedBuffer( - RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, 2)); } } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index b22a95c0b..5e94935c5 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -17,6 +17,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; import org.assertj.core.presentation.StandardRepresentation; public final class ByteBufRepresentation extends StandardRepresentation { @@ -24,7 +25,11 @@ public final class ByteBufRepresentation extends StandardRepresentation { @Override protected String fallbackToStringOf(Object object) { if (object instanceof ByteBuf) { - return ByteBufUtil.prettyHexDump((ByteBuf) object); + try { + return ByteBufUtil.prettyHexDump((ByteBuf) object); + } catch (IllegalReferenceCountException e) { + // noops + } } return super.fallbackToStringOf(object);