Skip to content

Commit

Permalink
Fix a bug where INTERNAL_ERROR is returned when gRPC message length…
Browse files Browse the repository at this point in the history
… is exceeded (line#5824)

Motivation:

If a request or response message exceeds the maximum allowed message
length, `RESOURCE_EXHAUSTED` status should be returned to the client.

In `UnaryServerCall`, an exception raised `ArmeriaMessageFramer` is not
handled by `GrpcExceptionHandlerFunction`. As a result, the exception is
directly passed to the HTTP level `ServerErrorHandler` and a proper gRPC
status is not delivered to the client.

While I was fixing the bug, I found a lot of code that calls
`generateMetadataFromThrowable()` and then `fromThrowable()`. As the
same pattern was used repeatedly, I also refactored
`GrpcExceptionHandlerFunctionUtil` to return `StatusAndMetata` instead.

Modifications:

- Apply `GrpcExceptionHandlerFunction` when an exception is raised in
`UnaryServerCall.doClose()`
- Refactor `GrpcExceptionHandlerFunctionUtil` and its calling site.
- Add `InternalGrpcExceptionHandler` to wrap
`GrpcExceptionHandlerFunction`.
- Remove the incorrect `@Nullable` annotation for
`GrpcExceptionHandlerFunction`.
- Remove `ServerStatusAndMetadata.setResponseContent()` and use
`RESPONSE_CONTENT` property for simpliciy.

Result:

- `GrpcService` now correctly returns `RESOURCE_EXHAUSTED` when a
response message exceeds the maximum allowed length.
- Closes line#5818
  • Loading branch information
ikhoon authored Jul 26, 2024
1 parent cbe744e commit f58ffb3
Show file tree
Hide file tree
Showing 18 changed files with 415 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata
if (status.getCode() != Code.UNKNOWN) {
return status;
}
final Status s = Status.fromThrowable(cause);
if (s.getCode() != Code.UNKNOWN) {
return s;
}

if (cause instanceof ClosedSessionException || cause instanceof ClosedChannelException) {
if (ctx instanceof ServiceRequestContext) {
// Upstream uses CANCELLED
Expand All @@ -63,7 +58,7 @@ public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata
// ClosedChannelException is used any time the Netty channel is closed. Proper error
// processing requires remembering the error that occurred before this one and using it
// instead.
return s;
return status;
}
if (cause instanceof ClosedStreamException || cause instanceof RequestTimeoutException) {
return Status.CANCELLED.withCause(cause);
Expand All @@ -89,6 +84,6 @@ public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata
if (cause instanceof ContentTooLargeException) {
return Status.RESOURCE_EXHAUSTED.withCause(cause);
}
return s;
return status;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ static GrpcExceptionHandlerFunction of() {
*/
default GrpcExceptionHandlerFunction orElse(GrpcExceptionHandlerFunction next) {
requireNonNull(next, "next");
if (this == next) {
return this;
}
return (ctx, status, cause, metadata) -> {
final Status newStatus = apply(ctx, status, cause, metadata);
if (newStatus != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.grpc.GrpcCallOptions;
import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction;
import com.linecorp.armeria.common.grpc.GrpcJsonMarshaller;
import com.linecorp.armeria.common.logging.RequestLogProperty;
import com.linecorp.armeria.common.util.SystemInfo;
import com.linecorp.armeria.common.util.Unwrappable;
import com.linecorp.armeria.internal.client.DefaultClientRequestContext;
import com.linecorp.armeria.internal.common.RequestTargetCache;
import com.linecorp.armeria.internal.common.grpc.InternalGrpcExceptionHandler;

import io.grpc.CallCredentials;
import io.grpc.CallOptions;
Expand Down Expand Up @@ -98,7 +98,7 @@ final class ArmeriaChannel extends Channel implements ClientBuilderParams, Unwra
private final Compressor compressor;
private final DecompressorRegistry decompressorRegistry;
private final CallCredentials credentials0;
private final GrpcExceptionHandlerFunction exceptionHandler;
private final InternalGrpcExceptionHandler exceptionHandler;
private final boolean useMethodMarshaller;

ArmeriaChannel(ClientBuilderParams params,
Expand All @@ -124,7 +124,7 @@ final class ArmeriaChannel extends Channel implements ClientBuilderParams, Unwra
compressor = options.get(GrpcClientOptions.COMPRESSOR);
decompressorRegistry = options.get(GrpcClientOptions.DECOMPRESSOR_REGISTRY);
credentials0 = options.get(GrpcClientOptions.CALL_CREDENTIALS);
exceptionHandler = options.get(GrpcClientOptions.EXCEPTION_HANDLER);
exceptionHandler = new InternalGrpcExceptionHandler(options.get(GrpcClientOptions.EXCEPTION_HANDLER));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

import static com.linecorp.armeria.internal.client.ClientUtil.initContextAndExecuteWithFallback;
import static com.linecorp.armeria.internal.client.grpc.protocol.InternalGrpcWebUtil.messageBuf;
import static com.linecorp.armeria.internal.common.grpc.GrpcExceptionHandlerFunctionUtil.fromThrowable;
import static com.linecorp.armeria.internal.common.grpc.GrpcExceptionHandlerFunctionUtil.generateMetadataFromThrowable;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

Expand Down Expand Up @@ -50,7 +48,6 @@
import com.linecorp.armeria.common.RequestHeadersBuilder;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction;
import com.linecorp.armeria.common.grpc.GrpcJsonMarshaller;
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageFramer;
Expand All @@ -71,6 +68,7 @@
import com.linecorp.armeria.internal.common.grpc.GrpcMessageMarshaller;
import com.linecorp.armeria.internal.common.grpc.GrpcStatus;
import com.linecorp.armeria.internal.common.grpc.HttpStreamDeframer;
import com.linecorp.armeria.internal.common.grpc.InternalGrpcExceptionHandler;
import com.linecorp.armeria.internal.common.grpc.MetadataUtil;
import com.linecorp.armeria.internal.common.grpc.StatusAndMetadata;
import com.linecorp.armeria.internal.common.grpc.TimeoutHeaderUtil;
Expand Down Expand Up @@ -127,7 +125,7 @@ final class ArmeriaClientCall<I, O> extends ClientCall<I, O>
private final int maxInboundMessageSizeBytes;
private final boolean grpcWebText;
private final Compressor compressor;
private final GrpcExceptionHandlerFunction exceptionHandler;
private final InternalGrpcExceptionHandler exceptionHandler;

private boolean endpointInitialized;
@Nullable
Expand Down Expand Up @@ -162,7 +160,7 @@ final class ArmeriaClientCall<I, O> extends ClientCall<I, O>
SerializationFormat serializationFormat,
@Nullable GrpcJsonMarshaller jsonMarshaller,
boolean unsafeWrapResponseBuffers,
GrpcExceptionHandlerFunction exceptionHandler,
InternalGrpcExceptionHandler exceptionHandler,
boolean useMethodMarshaller) {
this.ctx = ctx;
this.endpointGroup = endpointGroup;
Expand Down Expand Up @@ -251,8 +249,8 @@ public void start(Listener<O> responseListener, Metadata metadata) {

final BiFunction<ClientRequestContext, Throwable, HttpResponse> errorResponseFactory =
(unused, cause) -> {
final Metadata responseMetadata = generateMetadataFromThrowable(cause);
Status status = fromThrowable(ctx, exceptionHandler, cause, responseMetadata);
final StatusAndMetadata statusAndMetadata = exceptionHandler.handle(ctx, cause);
Status status = statusAndMetadata.status();
if (status.getDescription() == null) {
status = status.withDescription(cause.getMessage());
}
Expand Down Expand Up @@ -460,8 +458,8 @@ public void onNext(DeframedMessage message) {
}
});
} catch (Throwable t) {
final Metadata metadata = generateMetadataFromThrowable(t);
close(fromThrowable(ctx, exceptionHandler, t, metadata), metadata);
final StatusAndMetadata statusAndMetadata = exceptionHandler.handle(ctx, t);
close(statusAndMetadata.status(), statusAndMetadata.metadata());
}
}

Expand Down Expand Up @@ -517,8 +515,8 @@ private void prepareHeaders(Compressor compressor, Metadata metadata, long remai
}

private void closeWhenListenerThrows(Throwable t) {
final Metadata metadata = generateMetadataFromThrowable(t);
closeWhenEos(fromThrowable(ctx, exceptionHandler, t, metadata), metadata);
final StatusAndMetadata statusAndMetadata = exceptionHandler.handle(ctx, t);
closeWhenEos(statusAndMetadata.status(), statusAndMetadata.metadata());
}

private void closeWhenEos(Status status, Metadata metadata) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
package com.linecorp.armeria.internal.common.grpc;

import static com.google.common.base.Preconditions.checkState;
import static com.linecorp.armeria.internal.common.grpc.GrpcExceptionHandlerFunctionUtil.fromThrowable;
import static com.linecorp.armeria.internal.common.grpc.GrpcExceptionHandlerFunctionUtil.generateMetadataFromThrowable;
import static java.util.Objects.requireNonNull;

import com.linecorp.armeria.common.HttpHeaderNames;
Expand All @@ -28,7 +26,6 @@
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction;
import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageDeframer;
import com.linecorp.armeria.common.grpc.protocol.Decompressor;
import com.linecorp.armeria.common.grpc.protocol.DeframedMessage;
Expand All @@ -46,7 +43,7 @@ public final class HttpStreamDeframer extends ArmeriaMessageDeframer {
private final RequestContext ctx;
private final DecompressorRegistry decompressorRegistry;
private final TransportStatusListener transportStatusListener;
private final GrpcExceptionHandlerFunction exceptionHandler;
private final InternalGrpcExceptionHandler exceptionHandler;

@Nullable
private StreamMessage<DeframedMessage> deframedStreamMessage;
Expand All @@ -57,7 +54,7 @@ public HttpStreamDeframer(
DecompressorRegistry decompressorRegistry,
RequestContext ctx,
TransportStatusListener transportStatusListener,
GrpcExceptionHandlerFunction exceptionHandler,
InternalGrpcExceptionHandler exceptionHandler,
int maxMessageLength, boolean grpcWebText, boolean server) {
super(maxMessageLength, ctx.alloc(), grpcWebText);
this.ctx = requireNonNull(ctx, "ctx");
Expand Down Expand Up @@ -121,9 +118,9 @@ public void processHeaders(HttpHeaders headers, StreamDecoderOutput<DeframedMess
try {
decompressor(ForwardingDecompressor.forGrpc(decompressor));
} catch (Throwable t) {
final Metadata metadata = generateMetadataFromThrowable(t);
transportStatusListener.transportReportStatus(
fromThrowable(ctx, exceptionHandler, t, metadata), metadata);
final StatusAndMetadata statusAndMetadata = exceptionHandler.handle(ctx, t);
transportStatusListener.transportReportStatus(statusAndMetadata.status(),
statusAndMetadata.metadata());
return;
}
}
Expand All @@ -149,9 +146,8 @@ public void processTrailers(HttpHeaders headers, StreamDecoderOutput<DeframedMes

@Override
public void processOnError(Throwable cause) {
final Metadata metadata = generateMetadataFromThrowable(cause);
transportStatusListener.transportReportStatus(
fromThrowable(ctx, exceptionHandler, cause, metadata), metadata);
final StatusAndMetadata statusAndMetadata = exceptionHandler.handle(ctx, cause);
transportStatusListener.transportReportStatus(statusAndMetadata.status(), statusAndMetadata.metadata());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.internal.common.grpc;

import static java.util.Objects.requireNonNull;

import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction;
import com.linecorp.armeria.common.grpc.protocol.ArmeriaStatusException;
import com.linecorp.armeria.common.util.Exceptions;

import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;

public final class InternalGrpcExceptionHandler {

private final GrpcExceptionHandlerFunction delegate;

public InternalGrpcExceptionHandler(GrpcExceptionHandlerFunction delegate) {
this.delegate = delegate;
}

public StatusAndMetadata handle(RequestContext ctx, Throwable t) {
final Throwable peeled = peelAndUnwrap(t);
Metadata metadata = Status.trailersFromThrowable(peeled);
if (metadata == null) {
metadata = new Metadata();
}
Status status = Status.fromThrowable(peeled);
status = handle0(ctx, status, peeled, metadata);
return new StatusAndMetadata(status, metadata);
}

public Status handle(RequestContext ctx, Status status, Throwable cause, Metadata metadata) {
final Throwable peeled = peelAndUnwrap(cause);
return handle0(ctx, status, peeled, metadata);
}

private Status handle0(RequestContext ctx, Status status, Throwable cause, Metadata metadata) {
if (status.getCode() == Code.UNKNOWN) {
// If ArmeriaStatusException is thrown, it is converted to UNKNOWN and passed through close(Status).
// So try to restore the original status.
Status newStatus = null;
if (cause instanceof StatusRuntimeException) {
newStatus = ((StatusRuntimeException) cause).getStatus();
} else if (cause instanceof StatusException) {
newStatus = ((StatusException) cause).getStatus();
}
if (newStatus != null && newStatus.getCode() != Code.UNKNOWN) {
status = newStatus;
}
}
status = delegate.apply(ctx, status, cause, metadata);
assert status != null;
return status;
}

private static Throwable peelAndUnwrap(Throwable t) {
requireNonNull(t, "t");
t = Exceptions.peel(t);
Throwable cause = t;
while (cause != null) {
if (cause instanceof ArmeriaStatusException) {
return StatusExceptionConverter.toGrpc((ArmeriaStatusException) cause);
}
cause = cause.getCause();
}
return t;
}
}
Loading

0 comments on commit f58ffb3

Please sign in to comment.