Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds message counting to protect against malicious overflow #1067

Merged
merged 6 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ final class RequestChannelRequesterFlux extends Flux<Payload>
Context cachedContext;
CoreSubscriber<? super Payload> inboundSubscriber;
boolean inboundDone;
long requested;
long produced;

CompositeByteBuf frames;

Expand Down Expand Up @@ -138,6 +140,8 @@ public final void request(long n) {
return;
}

this.requested = Operators.addCap(this.requested, n);

long previousState = addRequestN(STATE, this, n, this.requesterLeaseTracker == null);
if (isTerminated(previousState)) {
return;
Expand Down Expand Up @@ -706,6 +710,27 @@ public final void handlePayload(Payload value) {
return;
}

final long produced = this.produced;
if (this.requested == produced) {
value.release();
if (!tryCancel()) {
return;
}

final Throwable cause =
Exceptions.failWithOverflow(
"The number of messages received exceeds the number requested");
final RequestInterceptor requestInterceptor = this.requestInterceptor;
if (requestInterceptor != null) {
requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause);
}

this.inboundSubscriber.onError(cause);
return;
}

this.produced = produced + 1;

this.inboundSubscriber.onNext(value);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ final class RequestChannelResponderSubscriber extends Flux<Payload>

boolean inboundDone;
boolean outboundDone;
long requested;
long produced;

public RequestChannelResponderSubscriber(
int streamId,
Expand Down Expand Up @@ -179,6 +181,8 @@ public void request(long n) {
return;
}

this.requested = Operators.addCap(this.requested, n);

long previousState = StateUtils.addRequestN(STATE, this, n);
if (isTerminated(previousState)) {
// full termination can be the result of both sides completion / cancelFrame / remote or local
Expand All @@ -196,6 +200,9 @@ public void request(long n) {
Payload firstPayload = this.firstPayload;
if (firstPayload != null) {
this.firstPayload = null;

this.produced++;

inboundSubscriber.onNext(firstPayload);
}

Expand All @@ -216,6 +223,8 @@ public void request(long n) {
final Payload firstPayload = this.firstPayload;
this.firstPayload = null;

this.produced++;

inboundSubscriber.onNext(firstPayload);
inboundSubscriber.onComplete();

Expand All @@ -238,6 +247,9 @@ public void request(long n) {

final Payload firstPayload = this.firstPayload;
this.firstPayload = null;

this.produced++;

inboundSubscriber.onNext(firstPayload);

previousState = markFirstFrameSent(STATE, this);
Expand Down Expand Up @@ -416,6 +428,58 @@ final void handlePayload(Payload p) {
return;
}

final long produced = this.produced;
if (this.requested == produced) {
p.release();

this.inboundDone = true;

final Throwable cause =
Exceptions.failWithOverflow(
"The number of messages received exceeds the number requested");
boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, cause);

long previousState = markTerminated(STATE, this);
if (isTerminated(previousState)) {
if (!wasThrowableAdded) {
Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext());
}
return;
}

this.requesterResponderSupport.remove(this.streamId, this);

this.connection.sendFrame(
streamId,
ErrorFrameCodec.encode(
this.allocator, streamId, new CanceledException(cause.getMessage())));

if (!isSubscribed(previousState)) {
final Payload firstPayload = this.firstPayload;
this.firstPayload = null;
firstPayload.release();
} else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) {
Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this);
if (inboundError != TERMINATED) {
//noinspection ConstantConditions
this.inboundSubscriber.onError(inboundError);
}
}

// this is downstream subscription so need to cancel it just in case error signal has not
// reached it
// needs for disconnected upstream and downstream case
this.outboundSubscription.cancel();

final RequestInterceptor interceptor = requestInterceptor;
if (interceptor != null) {
interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause);
}
return;
}

this.produced = produced + 1;

this.inboundSubscriber.onNext(p);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ final class RequestStreamRequesterFlux extends Flux<Payload>
CoreSubscriber<? super Payload> inboundSubscriber;
CompositeByteBuf frames;
boolean done;
long requested;
long produced;

RequestStreamRequesterFlux(Payload payload, RequesterResponderSupport requesterResponderSupport) {
this.allocator = requesterResponderSupport.getAllocator();
Expand Down Expand Up @@ -134,6 +136,8 @@ public final void request(long n) {
return;
}

this.requested = Operators.addCap(this.requested, n);

final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker;
final boolean leaseEnabled = requesterLeaseTracker != null;
final long previousState = addRequestN(STATE, this, n, !leaseEnabled);
Expand Down Expand Up @@ -295,6 +299,34 @@ public final void handlePayload(Payload p) {
return;
}

final long produced = this.produced;
if (this.requested == produced) {
p.release();

long previousState = markTerminated(STATE, this);
if (isTerminated(previousState)) {
return;
}

final int streamId = this.streamId;
this.requesterResponderSupport.remove(streamId, this);

final IllegalStateException cause =
Exceptions.failWithOverflow(
"The number of messages received exceeds the number requested");
this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId));

final RequestInterceptor requestInterceptor = this.requestInterceptor;
if (requestInterceptor != null) {
requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause);
}

this.inboundSubscriber.onError(cause);
return;
}

this.produced = produced + 1;

this.inboundSubscriber.onNext(p);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.rsocket.core;

import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK;
import static io.rsocket.frame.FrameType.CANCEL;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
Expand All @@ -40,6 +41,7 @@
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand Down Expand Up @@ -513,6 +515,77 @@ public void errorShouldTerminateExecution(String terminationMode) {
stateAssert.isTerminated();
}

@Test
public void failOnOverflow() {
final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client();
final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator();
final TestDuplexConnection sender = activeStreams.getDuplexConnection();
final TestPublisher<Payload> publisher = TestPublisher.create();

final RequestChannelRequesterFlux requestChannelRequesterFlux =
new RequestChannelRequesterFlux(publisher, activeStreams);
final StateAssert<RequestChannelRequesterFlux> stateAssert =
StateAssert.assertThat(requestChannelRequesterFlux);

// state machine check

stateAssert.isUnsubscribed();
activeStreams.assertNoActiveStreams();

final AssertSubscriber<Payload> assertSubscriber =
requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0));
activeStreams.assertNoActiveStreams();

// state machine check
stateAssert.hasSubscribedFlagOnly();

assertSubscriber.request(1);
stateAssert.hasSubscribedFlag().hasRequestN(1).hasNoFirstFrameSentFlag();
activeStreams.assertNoActiveStreams();

Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator);

publisher.next(payload1.retain());

FrameAssert.assertThat(sender.awaitFrame())
.typeOf(FrameType.REQUEST_CHANNEL)
.hasPayload(payload1)
.hasRequestN(1)
.hasNoLeaks();
payload1.release();

stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag();
activeStreams.assertHasStream(1, requestChannelRequesterFlux);

publisher.assertMaxRequested(1);

Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator);
requestChannelRequesterFlux.handlePayload(nextPayload);

Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator);
requestChannelRequesterFlux.handlePayload(unrequestedPayload);

final ByteBuf cancelFrame = sender.awaitFrame();
FrameAssert.assertThat(cancelFrame)
.isNotNull()
.typeOf(CANCEL)
.hasClientSideStreamId()
.hasStreamId(1)
.hasNoLeaks();

assertSubscriber
.assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks())
.assertError()
.assertErrorMessage("The number of messages received exceeds the number requested");

publisher.assertWasCancelled();

activeStreams.assertNoActiveStreams();
// state machine check
stateAssert.isTerminated();
Assertions.assertThat(sender.isEmpty()).isTrue();
}

/*
* +--------------------------------+
* | Racing Test Cases |
Expand Down
Loading