diff --git a/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java b/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java index 0fdb16adaac..244e95921b1 100644 --- a/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java +++ b/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java @@ -251,9 +251,17 @@ private void handleRequest(ChannelHandlerContext ctx, FullHttpRequest req) throw final Service service = serviceCfg.service(); final ServiceCodec codec = service.codec(); final Promise promise = ctx.executor().newPromise(); - final DecodeResult decodeResult = codec.decodeRequest( - serviceCfg, ctx.channel(), protocol, - hostname, path, mappedPath, req.content(), req, promise); + + final DecodeResult decodeResult; + try { + decodeResult = codec.decodeRequest( + serviceCfg, ctx.channel(), protocol, + hostname, path, mappedPath, req.content(), req, promise); + } catch (Exception e) { + logger.warn("{} Unexpected exception from a decoder:", ctx.channel(), e); + respond(ctx, reqSeq, req, HttpResponseStatus.BAD_REQUEST, e); + return; + } switch (decodeResult.type()) { case SUCCESS: { diff --git a/src/main/java/com/linecorp/armeria/server/thrift/ThriftServiceCodec.java b/src/main/java/com/linecorp/armeria/server/thrift/ThriftServiceCodec.java index 9d726854114..df6c045ce47 100644 --- a/src/main/java/com/linecorp/armeria/server/thrift/ThriftServiceCodec.java +++ b/src/main/java/com/linecorp/armeria/server/thrift/ThriftServiceCodec.java @@ -19,6 +19,7 @@ import static java.util.Objects.requireNonNull; import java.lang.reflect.Constructor; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -44,6 +45,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.net.MediaType; + import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.ServiceInvocationContext; @@ -56,6 +59,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; @@ -249,8 +253,7 @@ public DecodeResult decodeRequest( try { serializationFormat = validateRequestAndDetermineSerializationFormat(originalRequest); } catch (InvalidHttpRequestException e) { - return new DefaultDecodeResult( - new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, e.httpResponseStatus), e.getCause()); + return new DefaultDecodeResult(errorResponse(e.httpResponseStatus), e.getCause()); } final TProtocol inProto = FORMAT_TO_THREAD_LOCAL_IN_PROTOCOL.get(serializationFormat).get(); @@ -259,7 +262,13 @@ public DecodeResult decodeRequest( inTransport.reset(in); try { - final TMessage header = inProto.readMessageBegin(); + final TMessage header; + try { + header = inProto.readMessageBegin(); + } catch (TException e) { + return new DefaultDecodeResult(errorResponse(HttpResponseStatus.BAD_REQUEST), e.getCause()); + } + final byte typeValue = header.type; final int seqId = header.seqid; final String methodName = header.name; @@ -324,6 +333,15 @@ public DecodeResult decodeRequest( } } + private static DefaultFullHttpResponse errorResponse(HttpResponseStatus status) { + final DefaultFullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, status, + Unpooled.copiedBuffer(status.toString(), StandardCharsets.UTF_8)); + + res.headers().set(HttpHeaderNames.CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString()); + return res; + } + @Override public boolean failureResponseFailsSession(ServiceInvocationContext ctx) { return false;