From d761e593592ea11b4e78e8ef3242cb74f1ceca21 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Mon, 13 Apr 2020 14:47:27 +0300 Subject: [PATCH] provides leaks tracking tests and tooling Signed-off-by: Oleh Dokuka --- .../buffer/LeaksTrackingByteBufAllocator.java | 5 +- .../io/rsocket/core/RSocketRequesterTest.java | 185 +++++++++++++++--- .../rsocket/frame/ByteBufRepresentation.java | 7 +- 3 files changed, 166 insertions(+), 31 deletions(-) diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java index de7f093c4..3b1f97111 100644 --- a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -64,10 +64,9 @@ private LeaksTrackingByteBufAllocator(ByteBufAllocator delegate) { public LeaksTrackingByteBufAllocator assertHasNoLeaks() { Assertions.assertThat(tracker) .allSatisfy(buf -> { - ByteBuf unwrap = buf.unwrap(); - if (unwrap instanceof CompositeByteBuf) { + if (buf instanceof CompositeByteBuf) { if (buf.refCnt() > 0) { - List decomposed = ((CompositeByteBuf) unwrap).decompose(0, unwrap.readableBytes()); + List decomposed = ((CompositeByteBuf) buf).decompose(0, buf.readableBytes()); for (int i = 0; i < decomposed.size(); i++) { Assertions.assertThat(decomposed.get(i)) .matches(bb -> bb.refCnt() == 0, "Got unreleased CompositeByteBuf"); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 101500da7..01de4c28c 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -16,29 +16,14 @@ package io.rsocket.core; -import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; -import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static io.rsocket.frame.FrameType.CANCEL; -import static io.rsocket.frame.FrameType.KEEPALIVE; -import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; -import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; -import static io.rsocket.frame.FrameType.REQUEST_STREAM; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.frame.CancelFrameFlyweight; @@ -50,31 +35,59 @@ import io.rsocket.frame.RequestChannelFrameFlyweight; import io.rsocket.frame.RequestNFrameFlyweight; import io.rsocket.frame.RequestStreamFrameFlyweight; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; import io.rsocket.util.MultiSubscriberRSocket; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.concurrent.ThreadLocalRandom; -import java.util.function.BiFunction; -import java.util.stream.Collectors; -import java.util.stream.Stream; import org.assertj.core.api.Assertions; import org.junit.Rule; import org.junit.Test; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.runners.model.Statement; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.UnicastProcessor; import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameHeaderFlyweight.frameType; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.KEEPALIVE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; public class RSocketRequesterTest { @@ -333,6 +346,124 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen .verify(); } + + private static Stream racingCases() { + return Stream.of( + Arguments.of( + (Runnable) () -> System.out.println("RequestChannel downstream cancellation case"), + (Function>) (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), + (BiConsumer, ClientSocketRule>) (as, rule) -> { + LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + } + )/*,*/ +// Arguments.of( +// (Runnable) () -> System.out.println("RequestChannel upstream cancellation 1"), +// (Function>) (rule) -> { +// LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); +// ByteBuf metadata = allocator.buffer(); +// metadata.writeCharSequence("abc", CharsetUtil.UTF_8); +// ByteBuf data = allocator.buffer(); +// data.writeCharSequence("def", CharsetUtil.UTF_8); +// return rule.socket.requestChannel(Flux.just(ByteBufPayload.create(data, metadata))); +// }, +// (BiConsumer, ClientSocketRule>) (as, rule) -> { +// LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); +// int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); +// ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); +// +// RaceTestUtils.race(() -> as.request(1), () -> rule.connection.addToReceivedBuffer(frame)); +// } +// ), +// Arguments.of( +// (Runnable) () -> System.out.println("RequestChannel upstream cancellation 2"), +// (Function>) (rule) -> { +// return rule.socket.requestChannel(Flux.just(ByteBufPayload.create("a", "b"), ByteBufPayload.create("c", "d"), ByteBufPayload.create("e", "f"))); +// }, +// (BiConsumer, ClientSocketRule>) (as, rule) -> { +// LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); +// int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); +// ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); +// +// as.request(1); +// +// RaceTestUtils.race(() -> as.request(Long.MAX_VALUE), () -> rule.connection.addToReceivedBuffer(frame)); +// } +// ), +// Arguments.of( +// (Runnable) () -> System.out.println("RequestResponse downstream cancellation"), +// (Function>) (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE), +// (BiConsumer, ClientSocketRule>) (as, rule) -> { +// LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); +// ByteBuf metadata = allocator.buffer(); +// metadata.writeCharSequence("abc", CharsetUtil.UTF_8); +// ByteBuf data = allocator.buffer(); +// data.writeCharSequence("def", CharsetUtil.UTF_8); +// int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); +// ByteBuf frame = PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, metadata, data); +// +// RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); +// } +// ) + ); + } + + @Test + @SuppressWarnings("unchecked") + public void checkNoLeaksOnRacingTest() { + + racingCases() + .forEach(a -> { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); + ((Runnable)a.get()[0]).run(); + checkNoLeaksOnRacing(allocator, (Function>) a.get()[1], (BiConsumer, ClientSocketRule>) a.get()[2]); + + Hooks.resetOnNextDropped(); + LeaksTrackingByteBufAllocator.deinstrumentDefault(); + }); + } + + public void checkNoLeaksOnRacing(LeaksTrackingByteBufAllocator allocator, Function> initiator, BiConsumer, ClientSocketRule> runner) { + for (int i = 0; i < 100000; i++) { + System.out.println(i); + ClientSocketRule clientSocketRule = new ClientSocketRule(); + try { + clientSocketRule.apply(new Statement() { + @Override + public void evaluate() throws Throwable { + + } + }, null).evaluate(); + } catch (Throwable throwable) { + throwable.printStackTrace(); + } + + Publisher payloadP = initiator.apply(clientSocketRule); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + if (payloadP instanceof Flux) { + ((Flux)payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } else { + ((Mono)payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } + + runner.accept(assertSubscriber, clientSocketRule); + + Assertions.assertThat(clientSocketRule.connection.getSent()) + .allMatch(ReferenceCounted::release); + + allocator.assertHasNoLeaks(); + } + } + static Stream>> prepareCalls() { return Stream.of( RSocket::fireAndForget, @@ -360,7 +491,7 @@ protected RSocketRequester newRSocket() { return new RSocketRequester( ByteBufAllocator.DEFAULT, connection, - DefaultPayload::create, + PayloadDecoder.ZERO_COPY, throwable -> errors.add(throwable), StreamIdSupplier.clientSupplier(), 0, diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index b22a95c0b..d2f51e6da 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -17,6 +17,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; import org.assertj.core.presentation.StandardRepresentation; public final class ByteBufRepresentation extends StandardRepresentation { @@ -24,7 +25,11 @@ public final class ByteBufRepresentation extends StandardRepresentation { @Override protected String fallbackToStringOf(Object object) { if (object instanceof ByteBuf) { - return ByteBufUtil.prettyHexDump((ByteBuf) object); + try { + return ByteBufUtil.prettyHexDump((ByteBuf) object); + } catch (IllegalReferenceCountException e) { + //noops + } } return super.fallbackToStringOf(object);