Skip to content

Commit

Permalink
Fix a leak in HttpEncodedResponse (#5858)
Browse files Browse the repository at this point in the history
Motivation:

An `HttpData` produced in `HttpEncodedResponse.beforeComplete()` is not
collected by `CollectingSubscriberAndSubscription` but is leaked.

```java
  Hint: {10B, pooled, <unknown>}
  com.linecorp.armeria.common.HttpData.wrap(HttpData.java:110)
  com.linecorp.armeria.server.encoding.HttpEncodedResponse.beforeComplete(HttpEncodedResponse.java:163)
  com.linecorp.armeria.common.stream.FilteredStreamMessage.lambda$collect$0(FilteredStreamMessage.java:201)
  java.base/java.util.concurrent.CompletableFuture.uniHandle(CompletableFuture.java:934)
  java.base/java.util.concurrent.CompletableFuture.uniHandleStage(CompletableFuture.java:950)
  java.base/java.util.concurrent.CompletableFuture.handle(CompletableFuture.java:2340)
  com.linecorp.armeria.common.stream.FilteredStreamMessage.collect(FilteredStreamMessage.java:142)
```

`CollectingSubscriberAndSubscription` was designed to only apply
`filter()` to the `upstream.collect()`. I didn't consider that an object
could be published via `onNext()` in `beforeComplete()`. The purpose of
`CollectingSubscriberAndSubscription` was to provide an optimized code
path for unary calls. it didn't seem the code provides a trivial
performance improvement but the implementation was complex and
error-prone.

I was able to fix the code not to leak the data but I didn't want to
additional complexity to it. It might be cleaner to use the Reactive
Streams API instead of keeping the custom `collect()` implementation.
There will be no change in performance for normal message sizes.

Modifications:

- Remove the custom `collect()` implemtation in `FilteredStreamMessage`.

Result:

Fix a potential leak when sending compressed responses.
  • Loading branch information
ikhoon authored Aug 8, 2024
1 parent c7aca10 commit c1d5475
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,12 @@ public void onNext(T item) {
} catch (Throwable ex) {
StreamMessageUtil.closeOrAbort(item, ex);

// onError(ex) should be called before upstream.cancel() that may close downstream with
// CancelledSubscriptionException.
onError(ex);
final Subscription upstream = this.upstream;
assert upstream != null;
upstream.cancel();

onError(ex);
}
}

Expand All @@ -179,8 +180,8 @@ private void publishDownstream(@Nullable U item, @Nullable Throwable cause) {

try {
if (cause != null) {
upstream.cancel();
onError(cause);
upstream.cancel();
} else {
requireNonNull(item, "function.apply()'s future completed with null");
downstream.onNext(item);
Expand All @@ -205,8 +206,8 @@ private void publishDownstream(@Nullable U item, @Nullable Throwable cause) {
if (item != null) {
StreamMessageUtil.closeOrAbort(item, ex);
}
upstream.cancel();
onError(ex);
upstream.cancel();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,21 @@

package com.linecorp.armeria.common.stream;

import static com.linecorp.armeria.internal.common.stream.InternalStreamMessageUtil.EMPTY_OPTIONS;
import static com.linecorp.armeria.internal.common.stream.InternalStreamMessageUtil.POOLED_OBJECTS;
import static com.linecorp.armeria.internal.common.stream.InternalStreamMessageUtil.containsNotifyCancellation;
import static com.linecorp.armeria.internal.common.stream.InternalStreamMessageUtil.containsWithPooledObjects;
import static com.linecorp.armeria.internal.common.stream.InternalStreamMessageUtil.toSubscriptionOptions;
import static java.util.Objects.requireNonNull;

import java.util.List;
import java.util.concurrent.CompletableFuture;

import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.ImmutableList;

import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.internal.common.stream.StreamMessageUtil;
import com.linecorp.armeria.unsafe.PooledObjects;

Expand Down Expand Up @@ -136,79 +130,6 @@ public final CompletableFuture<Void> whenComplete() {
return completionFuture;
}

@Override
public CompletableFuture<List<U>> collect(EventExecutor executor, SubscriptionOption... options) {
final SubscriptionOption[] filterOptions = filterSupportsPooledObjects ? POOLED_OBJECTS : EMPTY_OPTIONS;
return upstream.collect(executor, filterOptions).handle((result, cause) -> {
// CollectingSubscriberAndSubscription just captures cancel(), onComplete(), and onError() signals
// from the subclass of FilteredStreamMessage. So we need to follow regular Reactive Streams
// specifications.
final CollectingSubscriberAndSubscription<U> subscriberAndSubscription =
new CollectingSubscriberAndSubscription<>();
beforeSubscribe(subscriberAndSubscription, subscriberAndSubscription);
if (cause != null) {
beforeError(subscriberAndSubscription, cause);
completionFuture.completeExceptionally(cause);
return Exceptions.throwUnsafely(cause);
} else {
Throwable abortCause = null;
final ImmutableList.Builder<U> builder = ImmutableList.builderWithExpectedSize(result.size());
final boolean withPooledObjects = containsWithPooledObjects(options);
for (T t : result) {
if (abortCause != null) {
// This StreamMessage was aborted already. However, we need to release the remaining
// objects in result.
StreamMessageUtil.closeOrAbort(t, abortCause);
continue;
}

try {
U filtered = filter(t);

if (subscriberAndSubscription.completed || subscriberAndSubscription.cause != null ||
subscriberAndSubscription.cancelled) {
if (subscriberAndSubscription.cause != null) {
abortCause = cause;
} else {
abortCause = CancelledSubscriptionException.get();
}
StreamMessageUtil.closeOrAbort(filtered, abortCause);
} else {
requireNonNull(filtered, "filter() returned null");
if (!withPooledObjects) {
filtered = PooledObjects.copyAndClose(filtered);
}
builder.add(filtered);
}
} catch (Throwable ex) {
// Failed to filter the object.
StreamMessageUtil.closeOrAbort(t, abortCause);
abortCause = ex;
}
}

final List<U> elements = builder.build();
if (abortCause != null && !(abortCause instanceof CancelledSubscriptionException)) {
// The stream was aborted with an unsafe exception.
for (U element : elements) {
StreamMessageUtil.closeOrAbort(element, abortCause);
}
completionFuture.completeExceptionally(abortCause);
return Exceptions.throwUnsafely(abortCause);
}

try {
beforeComplete(subscriberAndSubscription);
completionFuture.complete(null);
} catch (Exception ex) {
completionFuture.completeExceptionally(ex);
throw ex;
}
return elements;
}
});
}

@Override
public final void subscribe(Subscriber<? super U> subscriber, EventExecutor executor) {
subscribe(subscriber, executor, false, false);
Expand Down Expand Up @@ -298,17 +219,21 @@ public void onNext(T o) {
try {
filtered = filter(o);
} catch (Throwable ex) {
StreamMessageUtil.closeOrAbort(o);
// onError(ex) should be called before upstream.cancel() to deliver the cause to downstream.
// upstream.cancel() and make downstream closed with CancelledSubscriptionException
// before sending the actual cause.
// upstream.cancel() may close downstream with CancelledSubscriptionException before sending
// the actual cause.
onError(ex);

assert upstream != null;
upstream.cancel();
return;
}

if (completed) {
// onError(Throwable) or onComplete() has been called in filter().
StreamMessageUtil.closeOrAbort(filtered);
return;
}
if (!subscribedWithPooledObjects) {
filtered = PooledObjects.copyAndClose(filtered);
}
Expand Down Expand Up @@ -351,42 +276,4 @@ public void onComplete() {
}
}
}

private static final class CollectingSubscriberAndSubscription<T> implements Subscriber<T>, Subscription {

private boolean completed;
private boolean cancelled;
@Nullable
private Throwable cause;

@Override
public void onSubscribe(Subscription s) {}

@Override
public void onNext(T o) {}

@Override
public void onError(Throwable t) {
if (completed) {
return;
}
cause = t;
}

@Override
public void onComplete() {
if (cause != null) {
return;
}
completed = true;
}

@Override
public void request(long n) {}

@Override
public void cancel() {
cancelled = true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ public void onNext(Object item) {
if (result != null && item != result) {
StreamMessageUtil.closeOrAbort(result, ex);
}
upstream.cancel();
onError(ex);
upstream.cancel();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.linecorp.armeria.common.stream.StreamMessageUtil.createStreamMessageFrom;
import static com.linecorp.armeria.internal.common.stream.InternalStreamMessageUtil.EMPTY_OPTIONS;
import static com.linecorp.armeria.internal.common.stream.InternalStreamMessageUtil.containsNotifyCancellation;
import static java.util.Objects.requireNonNull;

import java.io.File;
Expand All @@ -43,6 +44,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.ObjectArrays;

import com.linecorp.armeria.common.CommonPools;
import com.linecorp.armeria.common.HttpData;
Expand Down Expand Up @@ -752,6 +754,11 @@ default CompletableFuture<List<T>> collect(EventExecutor executor, SubscriptionO
requireNonNull(executor, "executor");
requireNonNull(options, "options");
final StreamMessageCollector<T> collector = new StreamMessageCollector<>(options);
if (!containsNotifyCancellation(options)) {
// Make the return CompletableFuture completed exceptionally if the stream is cancelled while
// collecting the elements.
options = ObjectArrays.concat(options, SubscriptionOption.NOTIFY_CANCELLATION);
}
subscribe(collector, executor, options);
return collector.collect();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ public void onNext(T item) {
byteBufsInputStream.add(result.byteBuf());
} catch (Throwable ex) {
StreamMessageUtil.closeOrAbort(item, ex);
onError(ex);
final Subscription upstream = this.upstream;
assert upstream != null;
upstream.cancel();
onError(ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,11 @@ private void close0(@Nullable Throwable cause) {
downstream.onComplete();
completionFuture.complete(null);
} else {
downstream.onError(cause);
final Subscription upstream = this.upstream;
if (upstream != null) {
upstream.cancel();
}
downstream.onError(cause);
completionFuture.completeExceptionally(cause);
}
release(cause);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ protected HttpObject filter(HttpObject obj) {
encodedBuf.readerIndex(encodedBuf.writerIndex());
return httpData;
} catch (IOException e) {
// An unreleased ByteBuf will be released by `beforeError()`
// An unreleased ByteBuf in `encodedStream` will be released by `beforeError()`
throw new IllegalStateException(
"Error encoding HttpData, this should not happen with byte arrays.",
e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ void emptyStreamMessage() {
.hasCause(cause);
}

@CsvSource({ "1, true", "1, false",
"2, true", "2, false",
"3, true", "3, false",
"4, true", "4, false",
"100, true", "100, false" })
@CsvSource({
"1, true", "1, false",
"2, true", "2, false",
"3, true", "3, false",
"4, true", "4, false",
"100, true", "100, false"
})
@ParameterizedTest
void closeOrAbortAndCollect(int size, boolean fixedStream) {
Map<HttpData, ByteBuf> data = newHttpData(size);
Expand Down Expand Up @@ -151,6 +153,8 @@ protected HttpData filter(HttpData obj) {
if (count < 2) {
return obj;
} else {
// The ownership of `obj` belongs to this method.
obj.close();
return Exceptions.throwUnsafely(cause);
}
}
Expand Down Expand Up @@ -193,19 +197,16 @@ protected HttpData filter(HttpData obj) {
}
};

final List<HttpData> collected = filtered.collect(SubscriptionOption.WITH_POOLED_OBJECTS).join();
assertThat(collected).hasSize(2);
assertThatThrownBy(() -> {
filtered.collect(SubscriptionOption.WITH_POOLED_OBJECTS).join();
}).isInstanceOf(CompletionException.class)
.hasCauseInstanceOf(CancelledSubscriptionException.class);

final List<ByteBuf> bufs = ImmutableList.copyOf(data.values());

assertThat(bufs.get(0).refCnt()).isOne();
assertThat(bufs.get(1).refCnt()).isOne();
assertThat(bufs.get(2).refCnt()).isZero();
assertThat(bufs.get(3).refCnt()).isZero();
assertThat(bufs.get(4).refCnt()).isZero();

bufs.get(0).release();
bufs.get(1).release();
for (ByteBuf buf : bufs) {
assertThat(buf.refCnt()).isZero();
}
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ void abortedRequestShouldAlsoBeCompleted() {

ctx.logBuilder().endRequest();
ctx.logBuilder().endResponse();
ctx.logBuilder().ensureComplete();

final RequestLog log = ctx.log().whenComplete().join();
assertThat(log.requestContentPreview()).isEmpty();
Expand Down
Loading

0 comments on commit c1d5475

Please sign in to comment.