Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use doOnDiscard to release cached Payloads #777

Merged
merged 18 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.collection.IntObjectMap;
import io.rsocket.DuplexConnection;
import io.rsocket.Payload;
Expand Down Expand Up @@ -77,6 +79,16 @@ class RSocketRequester implements RSocket {
AtomicReferenceFieldUpdater.newUpdater(
RSocketRequester.class, Throwable.class, "terminationError");
private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
private static final Consumer<ReferenceCounted> DROPPED_ELEMENTS_CONSUMER =
referenceCounted -> {
if (referenceCounted.refCnt() > 0) {
try {
referenceCounted.release();
} catch (IllegalReferenceCountException e) {
// ignored
}
}
};

static {
CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]);
Expand Down Expand Up @@ -259,7 +271,7 @@ public void doOnTerminal(
});
receivers.put(streamId, receiver);

return receiver;
return receiver.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);
}

private Flux<Payload> handleRequestStream(final Payload payload) {
Expand Down Expand Up @@ -323,7 +335,8 @@ public void accept(long n) {
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
}
})
.doFinally(s -> removeStreamReceiver(streamId));
.doFinally(s -> removeStreamReceiver(streamId))
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);
}

private Flux<Payload> handleChannel(Flux<Payload> request) {
Expand Down Expand Up @@ -424,7 +437,10 @@ public void accept(long n) {
senders.put(streamId, upstreamSubscriber);
receivers.put(streamId, receiver);

inboundFlux.limitRate(Queues.SMALL_BUFFER_SIZE).subscribe(upstreamSubscriber);
inboundFlux
.limitRate(Queues.SMALL_BUFFER_SIZE)
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER)
.subscribe(upstreamSubscriber);
if (!payloadReleasedFlag.getAndSet(true)) {
ByteBuf frame =
RequestChannelFrameFlyweight.encode(
Expand Down Expand Up @@ -461,7 +477,8 @@ public void accept(long n) {
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
upstreamSubscriber.cancel();
}
});
})
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);
}

private Mono<Void> handleMetadataPush(Payload payload) {
Expand Down
22 changes: 19 additions & 3 deletions rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.collection.IntObjectMap;
import io.rsocket.DuplexConnection;
import io.rsocket.Payload;
Expand All @@ -45,6 +47,16 @@

/** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */
class RSocketResponder implements ResponderRSocket {
private static final Consumer<ReferenceCounted> DROPPED_ELEMENTS_CONSUMER =
referenceCounted -> {
if (referenceCounted.refCnt() > 0) {
try {
referenceCounted.release();
} catch (IllegalReferenceCountException e) {
// ignored
}
}
};

private final DuplexConnection connection;
private final RSocket requestHandler;
Expand Down Expand Up @@ -418,7 +430,7 @@ protected void hookFinally(SignalType type) {
};

sendingSubscriptions.put(streamId, subscriber);
response.subscribe(subscriber);
response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber);
}

private void handleStream(int streamId, Flux<Payload> response, int initialRequestN) {
Expand Down Expand Up @@ -471,7 +483,10 @@ protected void hookFinally(SignalType type) {
};

sendingSubscriptions.put(streamId, subscriber);
response.limitRate(Queues.SMALL_BUFFER_SIZE).subscribe(subscriber);
response
.limitRate(Queues.SMALL_BUFFER_SIZE)
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER)
.subscribe(subscriber);
}

private void handleChannel(int streamId, Payload payload, int initialRequestN) {
Expand Down Expand Up @@ -499,7 +514,8 @@ public void accept(long l) {
sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n));
}
})
.doFinally(signalType -> channelProcessors.remove(streamId));
.doFinally(signalType -> channelProcessors.remove(streamId))
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);

// not chained, as the payload should be enqueued in the Unicast processor before this method
// returns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private static char[] checkCharSequenceBounds(char[] seq, int start, int end) {
}

/**
* Encode a {@link char[]} in <a href="http://en.wikipedia.org/wiki/UTF-8">UTF-8</a> and write it
* Encode a {@code char[]} in <a href="http://en.wikipedia.org/wiki/UTF-8">UTF-8</a> and write it
* into {@link ByteBuf}.
*
* <p>This method returns the actual number of bytes written.
Expand All @@ -109,9 +109,8 @@ public static int writeUtf8(ByteBuf buf, char[] seq) {
}

/**
* Equivalent to <code>{@link #writeUtf8(ByteBuf, char[])
* writeUtf8(buf, seq.subSequence(start, end), reserveBytes)}</code> but avoids subsequence object
* allocation if possible.
* Equivalent to {@link #writeUtf8(ByteBuf, char[]) writeUtf8(buf, seq.subSequence(start, end),
* reserveBytes)} but avoids subsequence object allocation if possible.
*
* @return actual number of bytes written
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package io.rsocket.buffer;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.assertj.core.api.Assertions;

/**
* Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created
* ByteBuffs
*/
public class LeaksTrackingByteBufAllocator implements ByteBufAllocator {

/**
* Allows to instrument any given the instance of ByteBufAllocator
*
* @param allocator
* @return
*/
public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) {
return new LeaksTrackingByteBufAllocator(allocator);
}

final ConcurrentLinkedQueue<ByteBuf> tracker = new ConcurrentLinkedQueue<>();

final ByteBufAllocator delegate;

private LeaksTrackingByteBufAllocator(ByteBufAllocator delegate) {
this.delegate = delegate;
}

public LeaksTrackingByteBufAllocator assertHasNoLeaks() {
try {
Assertions.assertThat(tracker)
.allSatisfy(
buf -> {
if (buf instanceof CompositeByteBuf) {
if (buf.refCnt() > 0) {
List<ByteBuf> decomposed =
((CompositeByteBuf) buf).decompose(0, buf.readableBytes());
for (int i = 0; i < decomposed.size(); i++) {
Assertions.assertThat(decomposed.get(i))
.matches(bb -> bb.refCnt() == 0, "Got unreleased CompositeByteBuf");
}
}

} else {
Assertions.assertThat(buf)
.matches(bb -> bb.refCnt() == 0, "buffer should be released");
}
});
} finally {
tracker.clear();
}
return this;
}

// Delegating logic with tracking of buffers

@Override
public ByteBuf buffer() {
return track(delegate.buffer());
}

@Override
public ByteBuf buffer(int initialCapacity) {
return track(delegate.buffer(initialCapacity));
}

@Override
public ByteBuf buffer(int initialCapacity, int maxCapacity) {
return track(delegate.buffer(initialCapacity, maxCapacity));
}

@Override
public ByteBuf ioBuffer() {
return track(delegate.ioBuffer());
}

@Override
public ByteBuf ioBuffer(int initialCapacity) {
return track(delegate.ioBuffer(initialCapacity));
}

@Override
public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) {
return track(delegate.ioBuffer(initialCapacity, maxCapacity));
}

@Override
public ByteBuf heapBuffer() {
return track(delegate.heapBuffer());
}

@Override
public ByteBuf heapBuffer(int initialCapacity) {
return track(delegate.heapBuffer(initialCapacity));
}

@Override
public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) {
return track(delegate.heapBuffer(initialCapacity, maxCapacity));
}

@Override
public ByteBuf directBuffer() {
return track(delegate.directBuffer());
}

@Override
public ByteBuf directBuffer(int initialCapacity) {
return track(delegate.directBuffer(initialCapacity));
}

@Override
public ByteBuf directBuffer(int initialCapacity, int maxCapacity) {
return track(delegate.directBuffer(initialCapacity, maxCapacity));
}

@Override
public CompositeByteBuf compositeBuffer() {
return track(delegate.compositeBuffer());
}

@Override
public CompositeByteBuf compositeBuffer(int maxNumComponents) {
return track(delegate.compositeBuffer(maxNumComponents));
}

@Override
public CompositeByteBuf compositeHeapBuffer() {
return track(delegate.compositeHeapBuffer());
}

@Override
public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) {
return track(delegate.compositeHeapBuffer(maxNumComponents));
}

@Override
public CompositeByteBuf compositeDirectBuffer() {
return track(delegate.compositeDirectBuffer());
}

@Override
public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) {
return track(delegate.compositeDirectBuffer(maxNumComponents));
}

@Override
public boolean isDirectBufferPooled() {
return delegate.isDirectBufferPooled();
}

@Override
public int calculateNewCapacity(int minNewCapacity, int maxCapacity) {
return delegate.calculateNewCapacity(minNewCapacity, maxCapacity);
}

<T extends ByteBuf> T track(T buffer) {
tracker.offer(buffer);

return buffer;
}
}
16 changes: 14 additions & 2 deletions rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package io.rsocket.core;

import io.netty.buffer.ByteBufAllocator;
import io.rsocket.RSocket;
import io.rsocket.buffer.LeaksTrackingByteBufAllocator;
import io.rsocket.test.util.TestDuplexConnection;
import io.rsocket.test.util.TestSubscriber;
import java.util.concurrent.ConcurrentLinkedQueue;
Expand All @@ -32,6 +34,7 @@ public abstract class AbstractSocketRule<T extends RSocket> extends ExternalReso
protected Subscriber<Void> connectSub;
protected T socket;
protected ConcurrentLinkedQueue<Throwable> errors;
protected LeaksTrackingByteBufAllocator allocator;

@Override
public Statement apply(final Statement base, Description description) {
Expand All @@ -41,21 +44,30 @@ public void evaluate() throws Throwable {
connection = new TestDuplexConnection();
connectSub = TestSubscriber.create();
errors = new ConcurrentLinkedQueue<>();
allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT);
init();
base.evaluate();
}
};
}

protected void init() {
socket = newRSocket();
socket = newRSocket(allocator);
}

protected abstract T newRSocket();
protected abstract T newRSocket(LeaksTrackingByteBufAllocator allocator);

public void assertNoConnectionErrors() {
if (errors.size() > 1) {
Assert.fail("No connection errors expected: " + errors.peek().toString());
}
}

public ByteBufAllocator alloc() {
return allocator;
}

public void assertHasNoLeaks() {
allocator.assertHasNoLeaks();
}
}
Loading