diff --git a/.github/workflows/actions_build.yml b/.github/workflows/actions_build.yml index ed00fe1787b..f072a34d976 100644 --- a/.github/workflows/actions_build.yml +++ b/.github/workflows/actions_build.yml @@ -140,7 +140,9 @@ jobs: -PbuildJdkVersion=${{ env.BUILD_JDK_VERSION }} \ -PtestJavaVersion=${{ matrix.java }} \ ${{ matrix.min-java && format('-PminimumJavaVersion={0}', matrix.min-java) || '' }} \ - -Porg.gradle.java.installations.paths=${{ steps.setup-build-jdk.outputs.path }},${{ steps.setup-jdk.outputs.path }} + -Porg.gradle.java.installations.paths=${{ steps.setup-build-jdk.outputs.path }},${{ steps.setup-jdk.outputs.path }} \ + -PpreferShadedTests=${{ github.ref_name != 'main' }} + # Unshaded tests are skipped for PRs to avoid running the same tests twice. shell: bash env: COMMIT_SHA: ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/gradle-enterprise-postjob.yml b/.github/workflows/gradle-enterprise-postjob.yml index c54c414cef5..c8ff55cad15 100644 --- a/.github/workflows/gradle-enterprise-postjob.yml +++ b/.github/workflows/gradle-enterprise-postjob.yml @@ -33,7 +33,7 @@ jobs: - name: Download artifact id: download-artifact - uses: dawidd6/action-download-artifact@v3 + uses: dawidd6/action-download-artifact@v6 with: workflow_conclusion: "" run_id: ${{ env.RUN_ID }} diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java index 9204452cc7c..9fc9badb900 100644 --- a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java +++ b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java @@ -60,7 +60,8 @@ public class RoutersBenchmark { SERVICES = ImmutableList.of(newServiceConfig(route1), newServiceConfig(route2)); FALLBACK_SERVICE = newServiceConfig(Route.ofCatchAll()); HOST = new VirtualHost( - "localhost", "localhost", 0, null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED, + "localhost", "localhost", 0, null, + null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED, unused -> NOPLogger.NOP_LOGGER, FALLBACK_SERVICE.defaultServiceNaming(), FALLBACK_SERVICE.defaultLogName(), 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), 0, SuccessFunction.ofDefault(), diff --git a/core/build.gradle b/core/build.gradle index 18a23cbf366..c01d1a9c3d2 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -31,19 +31,21 @@ mrJarVersions.each { version-> options.release.set(targetJavaVersion) } - task "testJava${version}"(type: Test, group: 'Verification', description: "Runs unit tests for Java ${version} source set") { + tasks.register("testJava${version}", Test) { + group = 'Verification' + description = "Runs unit tests for Java ${version} source set" testClassesDirs = sourceSets."java${version}Test".output.classesDirs classpath = sourceSets."java${version}Test".runtimeClasspath project.ext.configureCommonTestSettings(it) enabled = project.ext.testJavaVersion >= targetJavaVersion + + check.dependsOn it } configurations."java${version}Implementation".extendsFrom configurations.implementation configurations."java${version}TestImplementation".extendsFrom configurations.testImplementation configurations."java${version}TestRuntimeClasspath".extendsFrom configurations.testRuntimeClasspath - - check.dependsOn "testJava${version}" } tasks.withType(Jar) { @@ -63,7 +65,7 @@ tasks.trimShadedJar.doLast { def trimmed = tasks.trimShadedJar.outJarFiles[0].toPath() ant.jar(destfile: trimmed.toString(), update: true, duplicate: 'fail') { - zipfileset(src: tasks.shadedJar.archivePath) { + zipfileset(src: tasks.shadedJar.archiveFile.get().asFile) { include(name: 'META-INF/versions/**') } @@ -225,6 +227,9 @@ if (tasks.findByName('trimShadedJar')) { keep "class com.linecorp.armeria.internal.shaded.bouncycastle.jcajce.provider.asymmetric.ec.** { *; }" keep "class com.linecorp.armeria.internal.shaded.bouncycastle.jcajce.provider.asymmetric.rsa.** { *; }" keep "class com.linecorp.armeria.internal.shaded.bouncycastle.jcajce.provider.asymmetric.x509.** { *; }" + // Keep the Guava classes accessed during testing. + keep "class com.linecorp.armeria.internal.shaded.guava.net.HttpHeaders { *; }" + keep "class com.linecorp.armeria.internal.shaded.guava.net.MediaType { *; }" dontnote } } diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java index b613a2f9216..a25a5b9fd6b 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java @@ -33,8 +33,10 @@ import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.ResponseCompleteException; +import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.logging.RequestLogBuilder; @@ -50,6 +52,7 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http2.Http2Error; import io.netty.handler.proxy.ProxyConnectException; @@ -61,6 +64,7 @@ enum State { NEEDS_TO_WRITE_FIRST_HEADER, NEEDS_DATA, NEEDS_DATA_OR_TRAILERS, + NEEDS_100_CONTINUE, DONE } @@ -143,6 +147,11 @@ public final void operationComplete(ChannelFuture future) throws Exception { responseWrapper.initTimeout(); } + if (state == State.NEEDS_100_CONTINUE) { + assert responseWrapper != null; + responseWrapper.initTimeout(); + } + onWriteSuccess(); return; } @@ -169,14 +178,14 @@ final boolean tryInitialize() { "Can't send requests. ID: " + id + ", session active: " + session.isAcquirable(responseDecoder.keepAliveHandler())); } - session.deactivate(); + session.markUnacquirable(); // No need to send RST because we didn't send any packet and this will be disconnected anyway. fail(UnprocessedRequestException.of(exception)); return false; } this.session = session; - responseWrapper = responseDecoder.addResponse(id, originalRes, ctx, ch.eventLoop()); + responseWrapper = responseDecoder.addResponse(this, id, originalRes, ctx, ch.eventLoop()); if (timeoutMillis > 0) { // The timer would be executed if the first message has not been sent out within the timeout. @@ -187,6 +196,18 @@ final boolean tryInitialize() { return true; } + RequestHeaders mergedRequestHeaders(RequestHeaders headers) { + final HttpHeaders internalHeaders; + final ClientRequestContextExtension ctxExtension = ctx.as(ClientRequestContextExtension.class); + if (ctxExtension == null) { + internalHeaders = HttpHeaders.of(); + } else { + internalHeaders = ctxExtension.internalRequestHeaders(); + } + return mergeRequestHeaders( + headers, ctx.defaultRequestHeaders(), ctx.additionalRequestHeaders(), internalHeaders); + } + /** * Writes the {@link RequestHeaders} to the {@link Channel}. * The {@link RequestHeaders} is merged with {@link ClientRequestContext#additionalRequestHeaders()} @@ -194,10 +215,12 @@ final boolean tryInitialize() { * Note that the written data is not flushed by this method. The caller should explicitly call * {@link Channel#flush()} when each write unit is done. */ - final void writeHeaders(RequestHeaders headers) { + final void writeHeaders(RequestHeaders headers, boolean needs100Continue) { final SessionProtocol protocol = session.protocol(); assert protocol != null; - if (headersOnly) { + if (needs100Continue) { + state = State.NEEDS_100_CONTINUE; + } else if (headersOnly) { state = State.DONE; } else if (allowTrailers) { state = State.NEEDS_DATA_OR_TRAILERS; @@ -205,16 +228,7 @@ final void writeHeaders(RequestHeaders headers) { state = State.NEEDS_DATA; } - final HttpHeaders internalHeaders; - final ClientRequestContextExtension ctxExtension = ctx.as(ClientRequestContextExtension.class); - if (ctxExtension == null) { - internalHeaders = HttpHeaders.of(); - } else { - internalHeaders = ctxExtension.internalRequestHeaders(); - } - final RequestHeaders merged = mergeRequestHeaders( - headers, ctx.defaultRequestHeaders(), ctx.additionalRequestHeaders(), internalHeaders); - logBuilder.requestHeaders(merged); + logBuilder.requestHeaders(headers); final String connectionOption = headers.get(HttpHeaderNames.CONNECTION); if (CLOSE_STRING.equalsIgnoreCase(connectionOption) || !keepAlive) { @@ -223,16 +237,44 @@ final void writeHeaders(RequestHeaders headers) { // connection by sending a GOAWAY frame that will be sent after receiving the corresponding // response from the remote peer. The "Connection: close" header is stripped when it is converted to // a Netty HTTP/2 header. - session.deactivate(); + session.markUnacquirable(); } final ChannelPromise promise = ch.newPromise(); // Attach a listener first to make the listener early handle a cause raised while writing headers // before any other callbacks like `onStreamClosed()` are invoked. promise.addListener(this); - encoder.writeHeaders(id, streamId(), merged, headersOnly, promise); + encoder.writeHeaders(id, streamId(), headers, headersOnly, promise); + } + + static boolean needs100Continue(RequestHeaders headers) { + return headers.contains(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE.toString()); + } + + void handle100Continue(ResponseHeaders responseHeaders) { + if (state != State.NEEDS_100_CONTINUE) { + return; + } + + if (responseHeaders.status() == HttpStatus.CONTINUE) { + state = State.NEEDS_DATA_OR_TRAILERS; + resume(); + // TODO(minwoox): reset the timeout + } else { + // We do not retry the request when HttpStatus.EXPECTATION_FAILED is received + // because: + // - Most servers support 100-continue. + // - It's much simpler to just fail the request and let the user retry. + state = State.DONE; + logBuilder.endRequest(); + discardRequestBody(); + } } + abstract void resume(); + + abstract void discardRequestBody(); + /** * Writes the {@link HttpData} to the {@link Channel}. * Note that the written data is not flushed by this method. The caller should explicitly call @@ -329,6 +371,12 @@ private void fail(Throwable cause) { } final void failAndReset(Throwable cause) { + if (cause instanceof WriteTimeoutException) { + final HttpSession session = HttpSession.get(ch); + // Mark the session as unhealthy so that subsequent requests do not use it. + session.markUnacquirable(); + } + if (cause instanceof ProxyConnectException || cause instanceof ResponseCompleteException) { // - ProxyConnectException is handled by HttpSessionHandler.exceptionCaught(). // - ResponseCompleteException means the response is successfully received. diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestSubscriber.java b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestSubscriber.java index 4ff251be998..a3e94085f53 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestSubscriber.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestSubscriber.java @@ -53,6 +53,7 @@ static AbstractHttpRequestSubscriber of(Channel channel, ClientHttpObjectEncoder } private final HttpRequest request; + private final boolean http1WebSocket; @Nullable private Subscription subscription; @@ -62,10 +63,11 @@ static AbstractHttpRequestSubscriber of(Channel channel, ClientHttpObjectEncoder HttpResponseDecoder responseDecoder, HttpRequest request, DecodedHttpResponse originalRes, ClientRequestContext ctx, long timeoutMillis, boolean allowTrailers, - boolean keepAlive) { + boolean keepAlive, boolean http1WebSocket) { super(ch, encoder, responseDecoder, originalRes, ctx, timeoutMillis, request.isEmpty(), allowTrailers, keepAlive); this.request = request; + this.http1WebSocket = http1WebSocket; } @Override @@ -77,6 +79,14 @@ public void onSubscribe(Subscription subscription) { return; } + final RequestHeaders headers = mergedRequestHeaders(mapHeaders(request.headers())); + final boolean needs100Continue = needs100Continue(headers); + if (needs100Continue && http1WebSocket) { + failRequest(new IllegalArgumentException( + "a WebSocket request is not allowed to have Expect: 100-continue header")); + return; + } + if (!tryInitialize()) { return; } @@ -84,8 +94,7 @@ public void onSubscribe(Subscription subscription) { // NB: This must be invoked at the end of this method because otherwise the callback methods in this // class can be called before the member fields (subscription, id, responseWrapper and // timeoutFuture) are initialized. - // It is because the successful write of the first headers will trigger subscription.request(1). - writeHeaders(mapHeaders(request.headers())); + writeHeaders(headers, needs100Continue(headers)); channel().flush(); } @@ -111,6 +120,13 @@ public void onComplete() { @Override void onWriteSuccess() { + if (state() == State.NEEDS_100_CONTINUE) { + return; + } + request(); + } + + private void request() { // Request more messages regardless whether the state is DONE. It makes the producer have // a chance to produce the last call such as 'onComplete' and 'onError' when there are // no more messages it can produce. @@ -126,4 +142,14 @@ void cancel() { assert subscription != null; subscription.cancel(); } + + @Override + final void resume() { + request(); + } + + @Override + void discardRequestBody() { + cancel(); + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpResponseDecoder.java b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpResponseDecoder.java index 3f85c7eb180..e1ba730eb38 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpResponseDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpResponseDecoder.java @@ -58,10 +58,11 @@ public InboundTrafficController inboundTrafficController() { } @Override - public HttpResponseWrapper addResponse( - int id, DecodedHttpResponse res, ClientRequestContext ctx, EventLoop eventLoop) { + public HttpResponseWrapper addResponse(@Nullable AbstractHttpRequestHandler requestHandler, + int id, DecodedHttpResponse res, + ClientRequestContext ctx, EventLoop eventLoop) { final HttpResponseWrapper newRes = - new HttpResponseWrapper(res, eventLoop, ctx, + new HttpResponseWrapper(requestHandler, res, eventLoop, ctx, ctx.responseTimeoutMillis(), ctx.maxResponseLength()); final HttpResponseWrapper oldRes = responses.put(id, newRes); keepAliveHandler().increaseNumRequests(); diff --git a/core/src/main/java/com/linecorp/armeria/client/AggregatedHttpRequestHandler.java b/core/src/main/java/com/linecorp/armeria/client/AggregatedHttpRequestHandler.java index bc90fc4561c..4c6f8eb6784 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AggregatedHttpRequestHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/AggregatedHttpRequestHandler.java @@ -22,6 +22,7 @@ import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.client.DecodedHttpResponse; @@ -31,6 +32,8 @@ final class AggregatedHttpRequestHandler extends AbstractHttpRequestHandler implements BiFunction { + @Nullable + private AggregatedHttpRequest request; private boolean cancelled; AggregatedHttpRequestHandler(Channel ch, ClientHttpObjectEncoder encoder, @@ -58,32 +61,51 @@ private void apply0(@Nullable AggregatedHttpRequest request, @Nullable Throwable } assert request != null; + final RequestHeaders merged = mergedRequestHeaders(request.headers()); + final boolean needs100Continue = needs100Continue(merged); + final HttpData content = request.content(); + if (needs100Continue && content.isEmpty()) { + content.close(); + failRequest(new IllegalArgumentException( + "an empty content is not allowed with Expect: 100-continue header")); + return; + } + if (!tryInitialize()) { - request.content().close(); + content.close(); return; } - writeHeaders(request.headers()); + writeHeaders(merged, needs100Continue); if (cancelled) { - request.content().close(); + content.close(); // If the headers size exceeds the limit, the headers write fails immediately. return; } - HttpData content = request.content(); + if (!needs100Continue) { + writeDataAndTrailers(request); + } else { + this.request = request; + } + channel().flush(); + } + + private void writeDataAndTrailers(AggregatedHttpRequest request) { + final HttpData content = request.content(); final boolean contentEmpty = content.isEmpty(); final HttpHeaders trailers = request.trailers(); final boolean trailersEmpty = trailers.isEmpty(); if (!contentEmpty) { if (trailersEmpty) { - content = content.withEndOfStream(); + writeData(content.withEndOfStream()); + } else { + writeData(content); } - writeData(content); } if (!trailersEmpty) { writeTrailers(trailers); } - channel().flush(); } @Override @@ -95,4 +117,17 @@ void onWriteSuccess() { void cancel() { cancelled = true; } + + @Override + void resume() { + assert request != null; + writeDataAndTrailers(request); + channel().flush(); + } + + @Override + void discardRequestBody() { + assert request != null; + request.content().close(); + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientFactoryBuilder.java b/core/src/main/java/com/linecorp/armeria/client/ClientFactoryBuilder.java index b32ac0413b7..6cc8ab3abdc 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientFactoryBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientFactoryBuilder.java @@ -67,6 +67,7 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.util.EventLoopGroups; +import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.internal.common.RequestContextUtil; import com.linecorp.armeria.internal.common.util.ChannelUtil; @@ -476,6 +477,17 @@ public ClientFactoryBuilder tlsAllowUnsafeCiphers(boolean tlsAllowUnsafeCiphers) return this; } + /** + * Sets the {@link TlsEngineType} that will be used for processing TLS connections. + * + * @param tlsEngineType the {@link TlsEngineType} to use + */ + @UnstableApi + public ClientFactoryBuilder tlsEngineType(TlsEngineType tlsEngineType) { + option(ClientFactoryOptions.TLS_ENGINE_TYPE, tlsEngineType); + return this; + } + /** * Sets the factory that creates a {@link AddressResolverGroup} which resolves remote addresses into * {@link InetSocketAddress}es. diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientFactoryOptions.java b/core/src/main/java/com/linecorp/armeria/client/ClientFactoryOptions.java index b5cb839a586..b0a77ad00a2 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientFactoryOptions.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientFactoryOptions.java @@ -37,6 +37,7 @@ import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.util.AbstractOptions; +import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.internal.common.util.ChannelUtil; import io.micrometer.core.instrument.MeterRegistry; @@ -98,6 +99,13 @@ public final class ClientFactoryOptions public static final ClientFactoryOption TLS_ALLOW_UNSAFE_CIPHERS = ClientFactoryOption.define("tlsAllowUnsafeCiphers", Flags.tlsAllowUnsafeCiphers()); + /** + * The {@link TlsEngineType} that will be used for processing TLS connections. + */ + @UnstableApi + public static final ClientFactoryOption TLS_ENGINE_TYPE = + ClientFactoryOption.define("tlsEngineType", Flags.tlsEngineType()); + /** * The factory that creates an {@link AddressResolverGroup} which resolves remote addresses into * {@link InetSocketAddress}es. @@ -620,6 +628,14 @@ public boolean tlsAllowUnsafeCiphers() { return get(TLS_ALLOW_UNSAFE_CIPHERS); } + /** + * Returns the {@link TlsEngineType} that will be used for processing TLS connections. + */ + @UnstableApi + public TlsEngineType tlsEngineType() { + return get(TLS_ENGINE_TYPE); + } + /** * The {@link Consumer} that customizes the Netty {@link ChannelPipeline}. * This customizer is run right before {@link ChannelPipeline#connect(SocketAddress)} diff --git a/core/src/main/java/com/linecorp/armeria/client/Http1ResponseDecoder.java b/core/src/main/java/com/linecorp/armeria/client/Http1ResponseDecoder.java index 3e0f571e4ae..5de46317b99 100644 --- a/core/src/main/java/com/linecorp/armeria/client/Http1ResponseDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/Http1ResponseDecoder.java @@ -183,7 +183,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } if (!HttpUtil.isKeepAlive(nettyRes)) { - session().deactivate(); + session().markUnacquirable(); } final HttpResponseWrapper res = getResponse(resId); @@ -196,6 +196,9 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception res.startResponse(); final ResponseHeaders responseHeaders = ArmeriaHttpUtil.toArmeria(nettyRes); + + res.handle100Continue(responseHeaders); + final boolean written; if (responseHeaders.status().codeClass() == HttpStatusClass.INFORMATIONAL) { state = State.NEED_INFORMATIONAL_DATA; diff --git a/core/src/main/java/com/linecorp/armeria/client/Http2ResponseDecoder.java b/core/src/main/java/com/linecorp/armeria/client/Http2ResponseDecoder.java index 147e976cfe3..130608b96c5 100644 --- a/core/src/main/java/com/linecorp/armeria/client/Http2ResponseDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/Http2ResponseDecoder.java @@ -159,14 +159,14 @@ public void onStreamRemoved(Http2Stream stream) {} @Override public void onGoAwaySent(int lastStreamId, long errorCode, ByteBuf debugData) { - session().deactivate(); + session().markUnacquirable(); goAwayHandler.onGoAwaySent(channel(), lastStreamId, errorCode, debugData); } @Override public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { // Should not reuse a connection that received a GOAWAY frame. - session().deactivate(); + session().markUnacquirable(); goAwayHandler.onGoAwayReceived(channel(), lastStreamId, errorCode, debugData); } @@ -206,6 +206,7 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers if (converted instanceof ResponseHeaders) { res.startResponse(); final ResponseHeaders responseHeaders = (ResponseHeaders) converted; + res.handle100Continue(responseHeaders); if (responseHeaders.status().codeClass() == HttpStatusClass.INFORMATIONAL) { written = res.tryWrite(converted); } else { diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java index 1c76651a661..ee65a69f054 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java @@ -53,6 +53,7 @@ import com.linecorp.armeria.common.util.AsyncCloseableSupport; import com.linecorp.armeria.common.util.ReleasableHolder; import com.linecorp.armeria.common.util.ShutdownHooks; +import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.common.util.TransportType; import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.internal.common.util.ChannelUtil; @@ -179,12 +180,13 @@ private static void setupTlsMetrics(List certificates, MeterReg ImmutableList.of(options.tlsCustomizer()); final boolean tlsAllowUnsafeCiphers = options.tlsAllowUnsafeCiphers(); final List keyCertChainCaptor = new ArrayList<>(); + final TlsEngineType tlsEngineType = options.tlsEngineType(); sslCtxHttp1Or2 = SslContextUtil - .createSslContext(SslContextBuilder::forClient, false, tlsAllowUnsafeCiphers, tlsCustomizers, - keyCertChainCaptor); + .createSslContext(SslContextBuilder::forClient, false, tlsEngineType, + tlsAllowUnsafeCiphers, tlsCustomizers, keyCertChainCaptor); sslCtxHttp1Only = SslContextUtil - .createSslContext(SslContextBuilder::forClient, true, tlsAllowUnsafeCiphers, tlsCustomizers, - keyCertChainCaptor); + .createSslContext(SslContextBuilder::forClient, true, tlsEngineType, + tlsAllowUnsafeCiphers, tlsCustomizers, keyCertChainCaptor); setupTlsMetrics(keyCertChainCaptor, options.meterRegistry()); http2InitialConnectionWindowSize = options.http2InitialConnectionWindowSize(); diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java index 6da8aea4a18..1119c68b35b 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java @@ -570,7 +570,7 @@ public void onComplete() {} System.nanoTime(), SystemInfo.currentTimeMicros()); // NB: No need to set the response timeout because we have session creation timeout. - responseDecoder.addResponse(0, res, reqCtx, ctx.channel().eventLoop()); + responseDecoder.addResponse(null, 0, res, reqCtx, ctx.channel().eventLoop()); ctx.fireChannelActive(); } @@ -805,7 +805,7 @@ private static final class ReadSuppressingAndChannelDeactivatingHandler extends @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - HttpSession.get(ctx.channel()).deactivate(); + HttpSession.get(ctx.channel()).markUnacquirable(); super.close(ctx, promise); } } diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpRequestSubscriber.java b/core/src/main/java/com/linecorp/armeria/client/HttpRequestSubscriber.java index e4ddfb1b284..db19598b464 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpRequestSubscriber.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpRequestSubscriber.java @@ -31,7 +31,7 @@ class HttpRequestSubscriber extends AbstractHttpRequestSubscriber { HttpRequestSubscriber(Channel ch, ClientHttpObjectEncoder encoder, HttpResponseDecoder responseDecoder, HttpRequest request, DecodedHttpResponse originalRes, ClientRequestContext ctx, long timeoutMillis) { - super(ch, encoder, responseDecoder, request, originalRes, ctx, timeoutMillis, true, true); + super(ch, encoder, responseDecoder, request, originalRes, ctx, timeoutMillis, true, true, false); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpResponseDecoder.java b/core/src/main/java/com/linecorp/armeria/client/HttpResponseDecoder.java index 283124750f0..bbce2e5adb2 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpResponseDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpResponseDecoder.java @@ -31,8 +31,8 @@ interface HttpResponseDecoder { InboundTrafficController inboundTrafficController(); - HttpResponseWrapper addResponse( - int id, DecodedHttpResponse res, ClientRequestContext ctx, EventLoop eventLoop); + HttpResponseWrapper addResponse(@Nullable AbstractHttpRequestHandler requestHandler, int id, + DecodedHttpResponse res, ClientRequestContext ctx, EventLoop eventLoop); @Nullable HttpResponseWrapper getResponse(int id); diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java b/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java index a2db00adfa7..d9adc43ac95 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java @@ -50,6 +50,8 @@ class HttpResponseWrapper implements StreamWriter { private static final Logger logger = LoggerFactory.getLogger(HttpResponseWrapper.class); + @Nullable + private final AbstractHttpRequestHandler requestHandler; private final DecodedHttpResponse delegate; private final EventLoop eventLoop; private final ClientRequestContext ctx; @@ -62,8 +64,10 @@ class HttpResponseWrapper implements StreamWriter { private boolean done; private boolean closed; - HttpResponseWrapper(DecodedHttpResponse delegate, EventLoop eventLoop, ClientRequestContext ctx, + HttpResponseWrapper(@Nullable AbstractHttpRequestHandler requestHandler, + DecodedHttpResponse delegate, EventLoop eventLoop, ClientRequestContext ctx, long responseTimeoutMillis, long maxContentLength) { + this.requestHandler = requestHandler; this.delegate = delegate; this.eventLoop = eventLoop; this.ctx = ctx; @@ -71,12 +75,14 @@ class HttpResponseWrapper implements StreamWriter { this.responseTimeoutMillis = responseTimeoutMillis; } - DecodedHttpResponse delegate() { - return delegate; + void handle100Continue(ResponseHeaders responseHeaders) { + if (requestHandler != null) { + requestHandler.handle100Continue(responseHeaders); + } } - EventLoop eventLoop() { - return eventLoop; + DecodedHttpResponse delegate() { + return delegate; } long maxContentLength() { diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpSessionHandler.java b/core/src/main/java/com/linecorp/armeria/client/HttpSessionHandler.java index 6fcb5c129c5..01375220af8 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpSessionHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpSessionHandler.java @@ -327,7 +327,7 @@ public boolean isAcquirable(KeepAliveHandler keepAliveHandler) { } @Override - public void deactivate() { + public void markUnacquirable() { isAcquirable = false; } diff --git a/core/src/main/java/com/linecorp/armeria/client/WebClient.java b/core/src/main/java/com/linecorp/armeria/client/WebClient.java index 8c4bcb0b900..c7a3ca8df9d 100644 --- a/core/src/main/java/com/linecorp/armeria/client/WebClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/WebClient.java @@ -16,6 +16,7 @@ package com.linecorp.armeria.client; +import static com.linecorp.armeria.client.DefaultWebClient.RESPONSE_STREAMING_REQUEST_OPTIONS; import static java.util.Objects.requireNonNull; import java.net.URI; @@ -249,7 +250,7 @@ default HttpResponse execute(HttpRequest req) { @CheckReturnValue default HttpResponse execute(AggregatedHttpRequest aggregatedReq) { requireNonNull(aggregatedReq, "aggregatedReq"); - return execute(aggregatedReq.toHttpRequest()); + return execute(aggregatedReq.toHttpRequest(), RESPONSE_STREAMING_REQUEST_OPTIONS); } /** @@ -257,7 +258,7 @@ default HttpResponse execute(AggregatedHttpRequest aggregatedReq) { */ @CheckReturnValue default HttpResponse execute(RequestHeaders headers) { - return execute(HttpRequest.of(headers)); + return execute(HttpRequest.of(headers), RESPONSE_STREAMING_REQUEST_OPTIONS); } /** @@ -265,7 +266,7 @@ default HttpResponse execute(RequestHeaders headers) { */ @CheckReturnValue default HttpResponse execute(RequestHeaders headers, HttpData content) { - return execute(HttpRequest.of(headers, content)); + return execute(HttpRequest.of(headers, content), RESPONSE_STREAMING_REQUEST_OPTIONS); } /** @@ -273,7 +274,7 @@ default HttpResponse execute(RequestHeaders headers, HttpData content) { */ @CheckReturnValue default HttpResponse execute(RequestHeaders headers, byte[] content) { - return execute(HttpRequest.of(headers, HttpData.wrap(content))); + return execute(HttpRequest.of(headers, HttpData.wrap(content)), RESPONSE_STREAMING_REQUEST_OPTIONS); } /** @@ -281,7 +282,7 @@ default HttpResponse execute(RequestHeaders headers, byte[] content) { */ @CheckReturnValue default HttpResponse execute(RequestHeaders headers, String content) { - return execute(HttpRequest.of(headers, HttpData.ofUtf8(content))); + return execute(HttpRequest.of(headers, HttpData.ofUtf8(content)), RESPONSE_STREAMING_REQUEST_OPTIONS); } /** @@ -289,7 +290,8 @@ default HttpResponse execute(RequestHeaders headers, String content) { */ @CheckReturnValue default HttpResponse execute(RequestHeaders headers, String content, Charset charset) { - return execute(HttpRequest.of(headers, HttpData.of(charset, content))); + return execute(HttpRequest.of(headers, HttpData.of(charset, content)), + RESPONSE_STREAMING_REQUEST_OPTIONS); } /** diff --git a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ClientChannelHandler.java b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ClientChannelHandler.java index baf4518234c..f2c30cc37e0 100644 --- a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ClientChannelHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ClientChannelHandler.java @@ -96,7 +96,8 @@ public InboundTrafficController inboundTrafficController() { } @Override - public HttpResponseWrapper addResponse(int id, DecodedHttpResponse decodedHttpResponse, + public HttpResponseWrapper addResponse(@Nullable AbstractHttpRequestHandler requestHandler, + int id, DecodedHttpResponse decodedHttpResponse, ClientRequestContext ctx, EventLoop eventLoop) { assert res == null; res = new WebSocketHttp1ResponseWrapper(decodedHttpResponse, eventLoop, ctx, @@ -180,7 +181,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } if (!HttpUtil.isKeepAlive(nettyRes)) { - session().deactivate(); + session().markUnacquirable(); } if (res == null && ArmeriaHttpUtil.isRequestTimeoutResponse(nettyRes)) { diff --git a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1RequestSubscriber.java b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1RequestSubscriber.java index 83b733a60ac..63f291b0244 100644 --- a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1RequestSubscriber.java +++ b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1RequestSubscriber.java @@ -29,7 +29,7 @@ final class WebSocketHttp1RequestSubscriber extends AbstractHttpRequestSubscribe HttpResponseDecoder responseDecoder, HttpRequest request, DecodedHttpResponse originalRes, ClientRequestContext ctx, long timeoutMillis) { - super(ch, encoder, responseDecoder, request, originalRes, ctx, timeoutMillis, false, false); + super(ch, encoder, responseDecoder, request, originalRes, ctx, timeoutMillis, false, false, true); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ResponseWrapper.java b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ResponseWrapper.java index 735e640ae8c..6d920f7116f 100644 --- a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ResponseWrapper.java +++ b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ResponseWrapper.java @@ -28,7 +28,7 @@ final class WebSocketHttp1ResponseWrapper extends HttpResponseWrapper { WebSocketHttp1ResponseWrapper(DecodedHttpResponse delegate, EventLoop eventLoop, ClientRequestContext ctx, long responseTimeoutMillis, long maxContentLength) { - super(delegate, eventLoop, ctx, responseTimeoutMillis, maxContentLength); + super(null, delegate, eventLoop, ctx, responseTimeoutMillis, maxContentLength); WebSocketClientUtil.setClosingResponseTask(ctx, cause -> { super.close(cause, false); }); diff --git a/core/src/main/java/com/linecorp/armeria/client/endpoint/AbstractEndpointSelector.java b/core/src/main/java/com/linecorp/armeria/client/endpoint/AbstractEndpointSelector.java index 9bccf9cb9d8..3761e42dc32 100644 --- a/core/src/main/java/com/linecorp/armeria/client/endpoint/AbstractEndpointSelector.java +++ b/core/src/main/java/com/linecorp/armeria/client/endpoint/AbstractEndpointSelector.java @@ -148,7 +148,6 @@ protected final void initialize() { private void refreshEndpoints(List endpoints) { // Allow subclasses to update the endpoints first. updateNewEndpoints(endpoints); - lock.lock(); try { pendingFutures.removeIf(ListeningFuture::tryComplete); diff --git a/core/src/main/java/com/linecorp/armeria/client/endpoint/WeightRampingUpStrategy.java b/core/src/main/java/com/linecorp/armeria/client/endpoint/WeightRampingUpStrategy.java index 4560ed4d0da..9721ce10a60 100644 --- a/core/src/main/java/com/linecorp/armeria/client/endpoint/WeightRampingUpStrategy.java +++ b/core/src/main/java/com/linecorp/armeria/client/endpoint/WeightRampingUpStrategy.java @@ -49,6 +49,7 @@ import com.linecorp.armeria.common.CommonPools; import com.linecorp.armeria.common.util.ListenableAsyncCloseable; import com.linecorp.armeria.common.util.Ticker; +import com.linecorp.armeria.internal.common.util.ReentrantShortLock; import io.netty.util.concurrent.EventExecutor; import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap; @@ -133,6 +134,7 @@ final class RampingUpEndpointWeightSelector extends AbstractEndpointSelector { @VisibleForTesting final Map rampingUpWindowsMap = new HashMap<>(); private Object2LongOpenHashMap endpointCreatedTimestamps = new Object2LongOpenHashMap<>(); + private final ReentrantShortLock lock = new ReentrantShortLock(true); RampingUpEndpointWeightSelector(EndpointGroup endpointGroup, EventExecutor executor) { super(endpointGroup); @@ -145,8 +147,13 @@ final class RampingUpEndpointWeightSelector extends AbstractEndpointSelector { @Override protected void updateNewEndpoints(List endpoints) { - // Use the executor so the order of endpoints change is guaranteed. - executor.execute(() -> updateEndpoints(endpoints)); + // Use a lock so the order of endpoints change is guaranteed. + lock.lock(); + try { + updateEndpoints(endpoints); + } finally { + lock.unlock(); + } } private long computeCreateTimestamp(Endpoint endpoint) { @@ -244,14 +251,19 @@ private long initialDelayNanos(long windowIndex) { } private void updateWeightAndStep(long window) { - final EndpointsRampingUpEntry entry = rampingUpWindowsMap.get(window); - assert entry != null; - final Set endpointAndSteps = entry.endpointAndSteps(); - updateWeightAndStep(endpointAndSteps); - if (endpointAndSteps.isEmpty()) { - rampingUpWindowsMap.remove(window).scheduledFuture.cancel(true); + lock.lock(); + try { + final EndpointsRampingUpEntry entry = rampingUpWindowsMap.get(window); + assert entry != null; + final Set endpointAndSteps = entry.endpointAndSteps(); + updateWeightAndStep(endpointAndSteps); + if (endpointAndSteps.isEmpty()) { + rampingUpWindowsMap.remove(window).scheduledFuture.cancel(true); + } + buildEndpointSelector(); + } finally { + lock.unlock(); } - buildEndpointSelector(); } private void updateWeightAndStep(Set endpointAndSteps) { @@ -267,7 +279,12 @@ private void updateWeightAndStep(Set endpointAndSteps) { } private void close() { - rampingUpWindowsMap.values().forEach(e -> e.scheduledFuture.cancel(true)); + lock.lock(); + try { + rampingUpWindowsMap.values().forEach(e -> e.scheduledFuture.cancel(true)); + } finally { + lock.unlock(); + } } } diff --git a/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java index 53e40f59bde..2d4acd9f64f 100644 --- a/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java +++ b/core/src/main/java/com/linecorp/armeria/common/DefaultFlagsProvider.java @@ -60,6 +60,9 @@ final class DefaultFlagsProvider implements FlagsProvider { static final long DEFAULT_CONNECT_TIMEOUT_MILLIS = 3200; // 3.2 seconds static final long DEFAULT_WRITE_TIMEOUT_MILLIS = 1000; // 1 second + // Use the fragmentation size as the default. https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1 + static final int DEFAULT_MAX_CLIENT_HELLO_LENGTH = 16384; // 16KiB + // Use slightly greater value than the default request timeout so that clients have a higher chance of // getting proper 503 Service Unavailable response when server-side timeout occurs. static final long DEFAULT_RESPONSE_TIMEOUT_MILLIS = 15 * 1000; // 15 seconds @@ -437,6 +440,11 @@ public Boolean tlsAllowUnsafeCiphers() { return false; } + @Override + public Integer defaultMaxClientHelloLength() { + return DEFAULT_MAX_CLIENT_HELLO_LENGTH; + } + @Override public Set transientServiceOptions() { return ImmutableSet.of(); diff --git a/core/src/main/java/com/linecorp/armeria/common/Flags.java b/core/src/main/java/com/linecorp/armeria/common/Flags.java index c15f3307d03..e0f327c82f7 100644 --- a/core/src/main/java/com/linecorp/armeria/common/Flags.java +++ b/core/src/main/java/com/linecorp/armeria/common/Flags.java @@ -400,6 +400,11 @@ private static boolean validateTransportType(TransportType transportType, String private static final boolean TLS_ALLOW_UNSAFE_CIPHERS = getValue(FlagsProvider::tlsAllowUnsafeCiphers, "tlsAllowUnsafeCiphers"); + // Maximum 16MiB https://datatracker.ietf.org/doc/html/rfc5246#section-7.4 + private static final int DEFAULT_MAX_CLIENT_HELLO_LENGTH = + getValue(FlagsProvider::defaultMaxClientHelloLength, "defaultMaxClientHelloLength", + value -> value >= 0 && value <= 16777216); // 16MiB + private static final Set TRANSIENT_SERVICE_OPTIONS = getValue(FlagsProvider::transientServiceOptions, "transientServiceOptions"); @@ -629,6 +634,7 @@ private static void detectTlsEngineAndDumpOpenSslInfo() { final SSLEngine engine = SslContextUtil.createSslContext( SslContextBuilder::forClient, /* forceHttp1 */ false, + tlsEngineType, /* tlsAllowUnsafeCiphers */ false, ImmutableList.of(), null).newEngine(ByteBufAllocator.DEFAULT); logger.info("All available SSL protocols: {}", @@ -1426,6 +1432,19 @@ public static boolean tlsAllowUnsafeCiphers() { return TLS_ALLOW_UNSAFE_CIPHERS; } + /** + * Returns the default maximum client hello length that a server allows. + * The length shouldn't exceed 16MiB as described in + * Handshake Protocol. + * + *

The default value of this flag is {@value DefaultFlagsProvider#DEFAULT_MAX_CLIENT_HELLO_LENGTH}. + * Specify the {@code -Dcom.linecorp.armeria.defaultMaxClientHelloLength=} JVM option to + * override the default value. + */ + public static int defaultMaxClientHelloLength() { + return DEFAULT_MAX_CLIENT_HELLO_LENGTH; + } + /** * Returns the {@link Set} of {@link TransientServiceOption}s that are enabled for a * {@link TransientService}. diff --git a/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java index 3c38e422bf3..72dfb4f8d6e 100644 --- a/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java +++ b/core/src/main/java/com/linecorp/armeria/common/FlagsProvider.java @@ -217,6 +217,7 @@ default Boolean useOpenSsl() { * the default.

*/ @Nullable + @UnstableApi default TlsEngineType tlsEngineType() { return null; } @@ -1012,6 +1013,20 @@ default Boolean tlsAllowUnsafeCiphers() { return null; } + /** + * Returns the default maximum client hello length that a server allows. + * The length shouldn't exceed 16MiB as described in + * Handshake Protocol. + * + *

The default value of this flag is {@value DefaultFlagsProvider#DEFAULT_MAX_CLIENT_HELLO_LENGTH}. + * Specify the {@code -Dcom.linecorp.armeria.defaultMaxClientHelloLength=} JVM option to + * override the default value. + */ + @Nullable + default Integer defaultMaxClientHelloLength() { + return null; + } + /** * Returns the {@link Set} of {@link TransientServiceOption}s that are enabled for a * {@link TransientService}. diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpHeaderNames.java b/core/src/main/java/com/linecorp/armeria/common/HttpHeaderNames.java index 861cb5252eb..42892439e15 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpHeaderNames.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpHeaderNames.java @@ -40,6 +40,7 @@ import com.google.common.math.IntMath; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; import io.netty.util.AsciiString; @@ -237,6 +238,7 @@ public final class HttpHeaderNames { * The HTTP {@code "Git-Protocol"} header field name, as described in * HTTP Transport. */ + @UnstableApi public static final AsciiString GIT_PROTOCOL = create("Git-Protocol"); /** * The HTTP {@code "Host"} header field name. @@ -743,6 +745,12 @@ public final class HttpHeaderNames { * header field name. */ public static final AsciiString PERMISSIONS_POLICY = create("Permissions-Policy"); + /** + * The HTTP {@code + * Permissions-Policy-Report-Only} header field name. + */ + public static final AsciiString PERMISSIONS_POLICY_REPORT_ONLY = create("Permissions-Policy-Report-Only"); /** * The HTTP {@code @@ -902,24 +910,33 @@ public final class HttpHeaderNames { * Observe-Browsing-Topics} header field name. */ public static final AsciiString OBSERVE_BROWSING_TOPICS = create("Observe-Browsing-Topics"); - /** - * The HTTP {@code CDN-Loop} header field name. - */ - public static final AsciiString CDN_LOOP = create("CDN-Loop"); - /** * The HTTP {@code * Sec-Ad-Auction-Fetch} header field name. */ public static final AsciiString SEC_AD_AUCTION_FETCH = create("Sec-Ad-Auction-Fetch"); - + /** + * The HTTP {@code + * Sec-GPC} header field name. + */ + public static final AsciiString SEC_GPC = create("Sec-GPC"); /** * The HTTP {@code * Ad-Auction-Signals} header field name. */ public static final AsciiString AD_AUCTION_SIGNALS = create("Ad-Auction-Signals"); + /** + * The HTTP {@code + * Ad-Auction-Allowed} header field name. + */ + public static final AsciiString AD_AUCTION_ALLOWED = create("Ad-Auction-Allowed"); + /** + * The HTTP {@code CDN-Loop} header field name. + */ + public static final AsciiString CDN_LOOP = create("CDN-Loop"); private static final Map map; private static final Map inverseMap; diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java index 8967b6cc0c0..1d75eb565fe 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java @@ -50,6 +50,7 @@ import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.internal.common.DefaultHttpRequest; import com.linecorp.armeria.internal.common.DefaultSplitHttpRequest; +import com.linecorp.armeria.internal.common.HeaderOverridingHttpRequest; import com.linecorp.armeria.internal.common.stream.SurroundingPublisher; import com.linecorp.armeria.unsafe.PooledObjects; @@ -478,8 +479,7 @@ default HttpRequest withHeaders(RequestHeaders newHeaders) { // Just check the reference only to avoid heavy comparison. return this; } - - return new HeaderOverridingHttpRequest(this, newHeaders); + return HeaderOverridingHttpRequest.of(this, newHeaders); } /** diff --git a/core/src/main/java/com/linecorp/armeria/common/MediaType.java b/core/src/main/java/com/linecorp/armeria/common/MediaType.java index 125121c8513..2ee4c7461cd 100644 --- a/core/src/main/java/com/linecorp/armeria/common/MediaType.java +++ b/core/src/main/java/com/linecorp/armeria/common/MediaType.java @@ -496,6 +496,7 @@ private static MediaType addKnownType(MediaType mediaType) { * This constant is used for advertising the capabilities of a Git server, * as described in Smart Clients. */ + @UnstableApi public static final MediaType GIT_UPLOAD_PACK_ADVERTISEMENT = createConstant(APPLICATION_TYPE, "x-git-upload-pack-advertisement"); @@ -504,6 +505,7 @@ private static MediaType addKnownType(MediaType mediaType) { * * Smart Service git-upload-pack. */ + @UnstableApi public static final MediaType GIT_UPLOAD_PACK_REQUEST = createConstant(APPLICATION_TYPE, "x-git-upload-pack-request"); @@ -512,6 +514,7 @@ private static MediaType addKnownType(MediaType mediaType) { * * Smart Service git-upload-pack. */ + @UnstableApi public static final MediaType GIT_UPLOAD_PACK_RESULT = createConstant(APPLICATION_TYPE, "x-git-upload-pack-result"); diff --git a/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java b/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java index dee041246a7..b6ffe6a27ba 100644 --- a/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java +++ b/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java @@ -16,6 +16,8 @@ package com.linecorp.armeria.common; +import com.linecorp.armeria.common.annotation.UnstableApi; + /** * String constants defined in {@link MediaType} class. */ @@ -348,14 +350,17 @@ public final class MediaTypeNames { /** * {@value #GIT_UPLOAD_PACK_ADVERTISEMENT}. */ + @UnstableApi public static final String GIT_UPLOAD_PACK_ADVERTISEMENT = "application/x-git-upload-pack-advertisement"; /** * {@value #GIT_UPLOAD_PACK_REQUEST}. */ + @UnstableApi public static final String GIT_UPLOAD_PACK_REQUEST = "application/x-git-upload-pack-request"; /** * {@value #GIT_UPLOAD_PACK_RESULT}. */ + @UnstableApi public static final String GIT_UPLOAD_PACK_RESULT = "application/x-git-upload-pack-result"; /** * {@value #GZIP}. diff --git a/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java b/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java index 6d2094caa1b..572b6711a1d 100644 --- a/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java +++ b/core/src/main/java/com/linecorp/armeria/common/SystemPropertyFlagsProvider.java @@ -442,6 +442,11 @@ public Boolean tlsAllowUnsafeCiphers() { return getBoolean("tlsAllowUnsafeCiphers"); } + @Override + public Integer defaultMaxClientHelloLength() { + return getInt("defaultMaxClientHelloLength"); + } + @Override public Set transientServiceOptions() { final String val = getNormalized("transientServiceOptions"); diff --git a/core/src/main/java/com/linecorp/armeria/common/encoding/AbstractStreamDecoder.java b/core/src/main/java/com/linecorp/armeria/common/encoding/AbstractStreamDecoder.java index 5b184b770e7..6286ba02204 100644 --- a/core/src/main/java/com/linecorp/armeria/common/encoding/AbstractStreamDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/common/encoding/AbstractStreamDecoder.java @@ -53,6 +53,8 @@ public HttpData decode(HttpData obj) { .maxContentLength(maxLength) .cause(ex) .build(); + } else { + throw ex; } } return fetchDecoderOutput(); diff --git a/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java b/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java index 13e8c6679f9..b91bc58a247 100644 --- a/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java +++ b/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java @@ -82,6 +82,8 @@ final class DefaultRequestLog implements RequestLog, RequestLogBuilder { private static final ResponseHeaders DUMMY_RESPONSE_HEADERS = ResponseHeaders.of(HttpStatus.UNKNOWN); private final RequestContext ctx; + private int currentAttempt; + private final CompleteRequestLog notCheckingAccessor = new CompleteRequestLog(); @Nullable @@ -545,7 +547,11 @@ public void addChild(RequestLogAccess child) { children = new ArrayList<>(); propagateRequestSideLog(child); } + children.add(child); + if (child instanceof DefaultRequestLog) { + ((DefaultRequestLog) child).currentAttempt = children.size(); + } } private void propagateRequestSideLog(RequestLogAccess child) { @@ -1025,6 +1031,11 @@ public void requestTrailers(HttpHeaders requestTrailers) { updateFlags(RequestLogProperty.REQUEST_TRAILERS); } + @Override + public int currentAttempt() { + return currentAttempt; + } + @Override public void endRequest() { endRequest0(null); @@ -1779,6 +1790,11 @@ public HttpHeaders requestTrailers() { return requestTrailers; } + @Override + public int currentAttempt() { + return currentAttempt; + } + @Override public long responseStartTimeMicros() { return responseStartTimeMicros; diff --git a/core/src/main/java/com/linecorp/armeria/common/logging/JsonLogFormatter.java b/core/src/main/java/com/linecorp/armeria/common/logging/JsonLogFormatter.java index e6a25507999..9b28df9ae23 100644 --- a/core/src/main/java/com/linecorp/armeria/common/logging/JsonLogFormatter.java +++ b/core/src/main/java/com/linecorp/armeria/common/logging/JsonLogFormatter.java @@ -131,6 +131,14 @@ public String formatRequest(RequestOnlyLog log) { objectNode.put("startTime", TextFormatter.epochMicros(log.requestStartTimeMicros()).toString()); + if (RequestLogProperty.SESSION.isAvailable(flags)) { + final ObjectNode connectionNode = + maybeCreateConnectionTimings(log.connectionTimings(), objectMapper); + if (connectionNode != null) { + objectNode.set("connection", connectionNode); + } + } + if (RequestLogProperty.REQUEST_LENGTH.isAvailable(flags)) { objectNode.put("length", TextFormatter.size(log.requestLength()).toString()); } @@ -170,6 +178,12 @@ public String formatRequest(RequestOnlyLog log) { if (sanitizedTrailers != null) { objectNode.set("trailers", sanitizedTrailers); } + + final int currentAttempt = log.currentAttempt(); + if (currentAttempt > 0) { + objectNode.put("currentAttempt", currentAttempt); + } + return objectMapper.writeValueAsString(objectNode); } catch (Exception e) { logger.warn("Unexpected exception while formatting a request log: {}", log, e); @@ -177,6 +191,59 @@ public String formatRequest(RequestOnlyLog log) { } } + @Nullable + private static ObjectNode maybeCreateConnectionTimings(@Nullable ClientConnectionTimings timings, + ObjectMapper objectMapper) { + if (timings == null) { + return null; + } + + final ObjectNode objectNode = objectMapper.createObjectNode(); + final ObjectNode connectionObjectNode = + startTimeAndDuration(objectMapper, + timings.connectionAcquisitionDurationNanos(), + timings.connectionAcquisitionStartTimeMicros()); + objectNode.set("total", connectionObjectNode); + + if (timings.dnsResolutionDurationNanos() >= 0) { + final ObjectNode dnsObjectNode = + startTimeAndDuration(objectMapper, + timings.dnsResolutionDurationNanos(), + timings.dnsResolutionStartTimeMicros()); + objectNode.set("dns", dnsObjectNode); + } + if (timings.pendingAcquisitionDurationNanos() >= 0) { + final ObjectNode pendingObjectNode = + startTimeAndDuration(objectMapper, + timings.pendingAcquisitionDurationNanos(), + timings.pendingAcquisitionStartTimeMicros()); + objectNode.set("pending", pendingObjectNode); + } + if (timings.socketConnectDurationNanos() >= 0) { + final ObjectNode socketObjectNode = + startTimeAndDuration(objectMapper, + timings.socketConnectDurationNanos(), + timings.socketConnectStartTimeMicros()); + objectNode.set("socket", socketObjectNode); + } + if (timings.tlsHandshakeDurationNanos() >= 0) { + final ObjectNode tlsObjectNode = + startTimeAndDuration(objectMapper, + timings.tlsHandshakeDurationNanos(), + timings.tlsHandshakeStartTimeMicros()); + objectNode.set("tls", tlsObjectNode); + } + return objectNode; + } + + static ObjectNode startTimeAndDuration(ObjectMapper objectMapper, long durationNanos, + long startTimeMicros) { + final ObjectNode objectNode = objectMapper.createObjectNode(); + objectNode.put("durationNanos", durationNanos); + objectNode.put("startTimeMicros", startTimeMicros); + return objectNode; + } + @Override public String formatResponse(RequestLog log) { requireNonNull(log, "log"); diff --git a/core/src/main/java/com/linecorp/armeria/common/logging/RequestOnlyLog.java b/core/src/main/java/com/linecorp/armeria/common/logging/RequestOnlyLog.java index 2dcbb37ccbb..56c22c888e7 100644 --- a/core/src/main/java/com/linecorp/armeria/common/logging/RequestOnlyLog.java +++ b/core/src/main/java/com/linecorp/armeria/common/logging/RequestOnlyLog.java @@ -271,6 +271,13 @@ default long requestDurationNanos() { */ HttpHeaders requestTrailers(); + /** + * Returns the current attempt number of the {@link Request}. + * It returns {@code 0} for the very first request. It returns {@code 1} for the first retry. + * It returns {@code 2} for the second retry, and so forth. + */ + int currentAttempt(); + /** * Returns the string representation of the {@link Request}, with no sanitization of headers or content. * This method is a shortcut for: diff --git a/core/src/main/java/com/linecorp/armeria/common/logging/TextLogFormatter.java b/core/src/main/java/com/linecorp/armeria/common/logging/TextLogFormatter.java index ce39d9a7f95..bf293de165b 100644 --- a/core/src/main/java/com/linecorp/armeria/common/logging/TextLogFormatter.java +++ b/core/src/main/java/com/linecorp/armeria/common/logging/TextLogFormatter.java @@ -140,6 +140,10 @@ public String formatRequest(RequestOnlyLog log) { buf.append("Request: {startTime="); TextFormatter.appendEpochMicros(buf, log.requestStartTimeMicros()); + if (RequestLogProperty.SESSION.isAvailable(flags)) { + maybeAppendConnectionTimings(log.connectionTimings(), buf); + } + if (RequestLogProperty.REQUEST_LENGTH.isAvailable(flags)) { buf.append(", length="); TextFormatter.appendSize(buf, log.requestLength()); @@ -182,12 +186,47 @@ public String formatRequest(RequestOnlyLog log) { if (sanitizedTrailers != null) { buf.append(", trailers=").append(sanitizedTrailers); } - buf.append('}'); + final int currentAttempt = log.currentAttempt(); + if (currentAttempt > 0) { + buf.append(", currentAttempt=").append(currentAttempt); + } + buf.append('}'); return buf.toString(); } } + private static void maybeAppendConnectionTimings(@Nullable ClientConnectionTimings timings, + StringBuilder buf) { + if (timings == null) { + return; + } + buf.append(", Connection: {total="); + TextFormatter.appendEpochAndElapsed(buf, timings.connectionAcquisitionStartTimeMicros(), + timings.connectionAcquisitionDurationNanos()); + if (timings.dnsResolutionDurationNanos() >= 0) { + buf.append(", dns="); + TextFormatter.appendEpochAndElapsed(buf, timings.dnsResolutionStartTimeMicros(), + timings.dnsResolutionDurationNanos()); + } + if (timings.pendingAcquisitionDurationNanos() >= 0) { + buf.append(", pending="); + TextFormatter.appendEpochAndElapsed(buf, timings.pendingAcquisitionStartTimeMicros(), + timings.pendingAcquisitionDurationNanos()); + } + if (timings.socketConnectDurationNanos() >= 0) { + buf.append(", socket="); + TextFormatter.appendEpochAndElapsed(buf, timings.socketConnectStartTimeMicros(), + timings.socketConnectDurationNanos()); + } + if (timings.tlsHandshakeDurationNanos() >= 0) { + buf.append(", tls="); + TextFormatter.appendEpochAndElapsed(buf, timings.tlsHandshakeStartTimeMicros(), + timings.tlsHandshakeDurationNanos()); + } + buf.append('}'); + } + @Override public String formatResponse(RequestLog log) { requireNonNull(log, "log"); diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/DefaultStreamMessageDuplicator.java b/core/src/main/java/com/linecorp/armeria/common/stream/DefaultStreamMessageDuplicator.java index 5fb1ed5f5b9..1e07659d434 100644 --- a/core/src/main/java/com/linecorp/armeria/common/stream/DefaultStreamMessageDuplicator.java +++ b/core/src/main/java/com/linecorp/armeria/common/stream/DefaultStreamMessageDuplicator.java @@ -285,12 +285,7 @@ void subscribe(DownstreamSubscription subscription) { private void doSubscribe(DownstreamSubscription subscription) { if (state == State.ABORTED) { - final EventExecutor executor = subscription.executor; - if (executor.inEventLoop()) { - failLateProcessorSubscriber(subscription); - } else { - executor.execute(() -> failLateProcessorSubscriber(subscription)); - } + subscription.failLateProcessorSubscriber(); return; } @@ -301,63 +296,11 @@ private void doSubscribe(DownstreamSubscription subscription) { } } - private static void failLateProcessorSubscriber(DownstreamSubscription subscription) { - final Subscriber lateSubscriber = subscription.subscriber(); - try { - lateSubscriber.onSubscribe(NoopSubscription.get()); - lateSubscriber.onError( - new IllegalStateException("duplicator is closed or no more downstream can be added.")); - } catch (Throwable t) { - throwIfFatal(t); - logger.warn("Subscriber should not throw an exception. subscriber: {}", lateSubscriber, t); - } - } - - void unsubscribe(DownstreamSubscription subscription, @Nullable Throwable cause) { + private void cleanupIfLastSubscription() { if (executor.inEventLoop()) { - doUnsubscribe(subscription, cause); - } else { - executor.execute(() -> doUnsubscribe(subscription, cause)); - } - } - - private void doUnsubscribe(DownstreamSubscription subscription, @Nullable Throwable cause) { - if (!downstreamSubscriptions.remove(subscription)) { - return; - } - - final Subscriber subscriber = subscription.subscriber(); - subscription.clearSubscriber(); - - final CompletableFuture completionFuture = subscription.whenComplete(); - if (cause == null) { - try { - subscriber.onComplete(); - completionFuture.complete(null); - } catch (Throwable t) { - completionFuture.completeExceptionally(t); - throwIfFatal(t); - logger.warn("Subscriber.onComplete() should not raise an exception. subscriber: {}", - subscriber, t); - } finally { - doCleanupIfLastSubscription(); - } - return; - } - - try { - if (subscription.notifyCancellation || !(cause instanceof CancelledSubscriptionException)) { - subscriber.onError(cause); - } - completionFuture.completeExceptionally(cause); - } catch (Throwable t) { - final Exception composite = new CompositeException(t, cause); - completionFuture.completeExceptionally(composite); - throwIfFatal(t); - logger.warn("Subscriber.onError() should not raise an exception. subscriber: {}", - subscriber, composite); - } finally { doCleanupIfLastSubscription(); + } else { + executor.execute(this::doCleanupIfLastSubscription); } } @@ -613,7 +556,7 @@ static final class DownstreamSubscription implements Subscription { private final StreamMessage streamMessage; private Subscriber subscriber; private final StreamMessageProcessor processor; - private final EventExecutor executor; + private final EventExecutor downstreamExecutor; private final boolean withPooledObjects; private final boolean notifyCancellation; @@ -640,7 +583,7 @@ static final class DownstreamSubscription implements Subscription { this.streamMessage = streamMessage; this.subscriber = subscriber; this.processor = processor; - this.executor = executor; + downstreamExecutor = executor; this.withPooledObjects = withPooledObjects; this.notifyCancellation = notifyCancellation; } @@ -661,6 +604,25 @@ void clearSubscriber() { } } + void failLateProcessorSubscriber() { + if (downstreamExecutor.inEventLoop()) { + failLateProcessorSubscriber0(); + } else { + downstreamExecutor.execute(this::failLateProcessorSubscriber0); + } + } + + private void failLateProcessorSubscriber0() { + try { + subscriber.onSubscribe(NoopSubscription.get()); + subscriber.onError( + new IllegalStateException("duplicator is closed or no more downstream can be added.")); + } catch (Throwable t) { + throwIfFatal(t); + logger.warn("Subscriber should not throw an exception. subscriber: {}", subscriber, t); + } + } + // Called from processor.processorExecutor void invokeOnSubscribe() { if (invokedOnSubscribe) { @@ -668,10 +630,10 @@ void invokeOnSubscribe() { } invokedOnSubscribe = true; - if (executor.inEventLoop()) { + if (downstreamExecutor.inEventLoop()) { invokeOnSubscribe0(); } else { - executor.execute(this::invokeOnSubscribe0); + downstreamExecutor.execute(this::invokeOnSubscribe0); } } @@ -680,7 +642,7 @@ void invokeOnSubscribe0() { try { subscriber.onSubscribe(this); } catch (Throwable t) { - processor.unsubscribe(this, t); + unsubscribe(t); throwIfFatal(t); logger.warn("Subscriber.onSubscribe() should not raise an exception. subscriber: {}", subscriber, t); @@ -692,7 +654,7 @@ public void request(long n) { if (n <= 0) { final Throwable cause = new IllegalArgumentException( "n: " + n + " (expected: > 0, see Reactive Streams specification rule 3.9)"); - processor.unsubscribe(this, cause); + unsubscribe(cause); return; } @@ -726,10 +688,10 @@ private void accumulateDemand(long n) { } void signal() { - if (executor.inEventLoop()) { + if (downstreamExecutor.inEventLoop()) { doSignal(); } else { - executor.execute(this::doSignal); + downstreamExecutor.execute(this::doSignal); } } @@ -757,7 +719,7 @@ private boolean doSignalSingle(SignalQueue signals) { if (cancelledOrAborted != null) { // Stream ended due to cancellation or abortion. - processor.unsubscribe(this, cancelledOrAborted); + unsubscribe(cancelledOrAborted); return false; } @@ -771,7 +733,7 @@ private boolean doSignalSingle(SignalQueue signals) { if (signal instanceof CloseEvent) { // The stream has reached at its end. offset++; - processor.unsubscribe(this, ((CloseEvent) signal).cause); + unsubscribe(((CloseEvent) signal).cause); return false; } @@ -812,7 +774,7 @@ private boolean doSignalSingle(SignalQueue signals) { // If an exception such as IllegalReferenceCountException is raised while operating // on the ByteBuf, catch it and notify the subscriber with it. So the // subscriber does not hang forever. - processor.unsubscribe(this, thrown); + unsubscribe(thrown); return false; } @@ -832,7 +794,7 @@ private boolean doSignalSingle(SignalQueue signals) { try { subscriber.onNext(obj); } catch (Throwable t) { - processor.unsubscribe(this, t); + unsubscribe(t); throwIfFatal(t); logger.warn("Subscriber.onNext({}) should not raise an exception. subscriber: {}", obj, subscriber, t); @@ -844,6 +806,54 @@ private boolean doSignalSingle(SignalQueue signals) { } } + void unsubscribe(@Nullable Throwable cause) { + if (downstreamExecutor.inEventLoop()) { + doUnsubscribe(cause); + } else { + downstreamExecutor.execute(() -> doUnsubscribe(cause)); + } + } + + private void doUnsubscribe(@Nullable Throwable cause) { + if (!processor.downstreamSubscriptions.remove(this)) { + return; + } + + final Subscriber subscriber = this.subscriber; + clearSubscriber(); + + final CompletableFuture completionFuture = whenComplete(); + if (cause == null) { + try { + subscriber.onComplete(); + completionFuture.complete(null); + } catch (Throwable t) { + completionFuture.completeExceptionally(t); + throwIfFatal(t); + logger.warn("Subscriber.onComplete() should not raise an exception. subscriber: {}", + subscriber, t); + } finally { + processor.cleanupIfLastSubscription(); + } + return; + } + + try { + if (notifyCancellation || !(cause instanceof CancelledSubscriptionException)) { + subscriber.onError(cause); + } + completionFuture.completeExceptionally(cause); + } catch (Throwable t) { + final Exception composite = new CompositeException(t, cause); + completionFuture.completeExceptionally(composite); + throwIfFatal(t); + logger.warn("Subscriber.onError() should not raise an exception. subscriber: {}", + subscriber, composite); + } finally { + processor.cleanupIfLastSubscription(); + } + } + @Override public void cancel() { abort(subscriber instanceof AbortingSubscriber ? ((AbortingSubscriber) subscriber).cause() diff --git a/core/src/main/java/com/linecorp/armeria/common/util/TextFormatter.java b/core/src/main/java/com/linecorp/armeria/common/util/TextFormatter.java index 1b25ca8e56b..6227e497c60 100644 --- a/core/src/main/java/com/linecorp/armeria/common/util/TextFormatter.java +++ b/core/src/main/java/com/linecorp/armeria/common/util/TextFormatter.java @@ -22,10 +22,13 @@ import java.time.ZoneId; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; +import java.time.temporal.ChronoUnit; import java.util.Locale; import java.util.concurrent.TimeUnit; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.logging.ClientConnectionTimings; /** * A utility class to format things as a {@link String} with ease. @@ -130,11 +133,32 @@ public static void appendElapsedAndSize( appendSize(buf, size); } + /** + * Formats the given epoch time in microseconds and duration in nanos to the format + * "epochMicros[elapsedNanos]" and appends it to the specified {@link StringBuilder}. + * This may be useful to record high-resolution timings such as {@link ClientConnectionTimings}. + */ + @UnstableApi + public static void appendEpochAndElapsed(StringBuilder buf, long epochMicros, long elapsedNanos) { + buf.append(dateTimeMicrosecondFormatter.format(getInstantFromMicros(epochMicros))).append('['); + appendElapsed(buf, elapsedNanos); + buf.append(']'); + } + + private static Instant getInstantFromMicros(long microsSinceEpoch) { + return Instant.EPOCH.plus(microsSinceEpoch, ChronoUnit.MICROS); + } + private static final DateTimeFormatter dateTimeFormatter = new DateTimeFormatterBuilder().appendPattern("yyyy-MM-dd'T'HH:mm:ss.SSSX") .toFormatter(Locale.ENGLISH) .withZone(ZoneId.of("GMT")); + private static final DateTimeFormatter dateTimeMicrosecondFormatter = + new DateTimeFormatterBuilder().appendPattern("yyyy-MM-dd'T'HH:mm:ss.SSSSSSX") + .toFormatter(Locale.ENGLISH) + .withZone(ZoneId.of("GMT")); + /** * Formats the given epoch time in milliseconds to typical human-readable format * "yyyy-MM-dd'T'HH:mm:ss.SSSX". diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java index 16bb7f4c01f..f6b0478c4bb 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java @@ -70,6 +70,7 @@ import com.linecorp.armeria.common.util.TimeoutMode; import com.linecorp.armeria.common.util.UnmodifiableFuture; import com.linecorp.armeria.internal.common.CancellationScheduler; +import com.linecorp.armeria.internal.common.HeaderOverridingHttpRequest; import com.linecorp.armeria.internal.common.NonWrappingRequestContext; import com.linecorp.armeria.internal.common.RequestContextExtension; import com.linecorp.armeria.internal.common.SchemeAndAuthority; @@ -285,6 +286,11 @@ private static ExchangeType guessExchangeType(RequestOptions requestOptions, @Nu if (req instanceof FixedStreamMessage) { return ExchangeType.RESPONSE_STREAMING; } + if (req instanceof HeaderOverridingHttpRequest) { + if (((HeaderOverridingHttpRequest) req).delegate() instanceof FixedStreamMessage) { + return ExchangeType.RESPONSE_STREAMING; + } + } return ExchangeType.BIDI_STREAMING; } @@ -1013,7 +1019,7 @@ public CompletableFuture initiateConnectionShutdown() { }); // To deactivate the channel when initiateShutdown is called after the RequestHeaders is sent. // The next request will trigger shutdown. - HttpSession.get(ch).deactivate(); + HttpSession.get(ch).markUnacquirable(); } }); return completableFuture; diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/HttpSession.java b/core/src/main/java/com/linecorp/armeria/internal/client/HttpSession.java index 675400205ad..2b69160e015 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/HttpSession.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/HttpSession.java @@ -17,6 +17,7 @@ package com.linecorp.armeria.internal.client; import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.WriteTimeoutException; import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.SerializationFormat; @@ -90,7 +91,7 @@ public boolean isAcquirable(KeepAliveHandler keepAliveHandler) { } @Override - public void deactivate() {} + public void markUnacquirable() {} @Override public int incrementAndGetNumRequestsSent() { @@ -137,9 +138,10 @@ static HttpSession get(Channel ch) { *

  • A connection is closed.
  • *
  • "Connection: close" header is sent or received.
  • *
  • A GOAWAY frame is sent or received.
  • + *
  • A {@link WriteTimeoutException} is raised
  • * */ - void deactivate(); + void markUnacquirable(); /** * Returns {@code true} if a new request can be sent with this {@link HttpSession}. diff --git a/core/src/main/java/com/linecorp/armeria/common/HeaderOverridingHttpRequest.java b/core/src/main/java/com/linecorp/armeria/internal/common/HeaderOverridingHttpRequest.java similarity index 80% rename from core/src/main/java/com/linecorp/armeria/common/HeaderOverridingHttpRequest.java rename to core/src/main/java/com/linecorp/armeria/internal/common/HeaderOverridingHttpRequest.java index eaf9c244231..3d286869199 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HeaderOverridingHttpRequest.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/HeaderOverridingHttpRequest.java @@ -13,7 +13,7 @@ * License for the specific language governing permissions and limitations * under the License. */ -package com.linecorp.armeria.common; +package com.linecorp.armeria.internal.common; import static java.util.Objects.requireNonNull; @@ -25,6 +25,13 @@ import com.google.common.base.MoreObjects; +import com.linecorp.armeria.common.AggregatedHttpRequest; +import com.linecorp.armeria.common.AggregationOptions; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.stream.SubscriptionOption; @@ -33,16 +40,30 @@ /** * An {@link HttpRequest} that overrides the {@link RequestHeaders}. */ -final class HeaderOverridingHttpRequest implements HttpRequest { +public final class HeaderOverridingHttpRequest implements HttpRequest { private final HttpRequest delegate; private final RequestHeaders headers; + public static HeaderOverridingHttpRequest of(HttpRequest delegate, RequestHeaders headers) { + requireNonNull(delegate, "delegate"); + requireNonNull(headers, "headers"); + if (delegate instanceof HeaderOverridingHttpRequest) { + return new HeaderOverridingHttpRequest( + ((HeaderOverridingHttpRequest) delegate).delegate(), headers); + } + return new HeaderOverridingHttpRequest(delegate, headers); + } + HeaderOverridingHttpRequest(HttpRequest delegate, RequestHeaders headers) { this.delegate = delegate; this.headers = headers; } + public HttpRequest delegate() { + return delegate; + } + @Override public HttpRequest withHeaders(RequestHeaders newHeaders) { requireNonNull(newHeaders, "newHeaders"); diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/util/ReentrantShortLock.java b/core/src/main/java/com/linecorp/armeria/internal/common/util/ReentrantShortLock.java index 626ea48245a..fee115a5b07 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/util/ReentrantShortLock.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/util/ReentrantShortLock.java @@ -28,6 +28,12 @@ public class ReentrantShortLock extends ReentrantLock { private static final long serialVersionUID = 8999619612996643502L; + public ReentrantShortLock() {} + + public ReentrantShortLock(boolean fair) { + super(fair); + } + @Override public void lock() { super.lock(); diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/util/SslContextUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/util/SslContextUtil.java index f78ca18c29d..2a67d63f5e4 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/util/SslContextUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/util/SslContextUtil.java @@ -37,8 +37,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.linecorp.armeria.common.Flags; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.util.TlsEngineType; import io.netty.buffer.PooledByteBufAllocator; import io.netty.handler.codec.http2.Http2SecurityUtil; @@ -97,13 +97,13 @@ public final class SslContextUtil { */ public static SslContext createSslContext( Supplier builderSupplier, boolean forceHttp1, - boolean tlsAllowUnsafeCiphers, + TlsEngineType tlsEngineType, boolean tlsAllowUnsafeCiphers, Iterable> userCustomizers, @Nullable List keyCertChainCaptor) { return MinifiedBouncyCastleProvider.call(() -> { final SslContextBuilder builder = builderSupplier.get(); - final SslProvider provider = Flags.tlsEngineType().sslProvider(); + final SslProvider provider = tlsEngineType.sslProvider(); builder.sslProvider(provider); final Set supportedProtocols = supportedProtocols(builder); diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/annotation/DefaultAnnotatedService.java b/core/src/main/java/com/linecorp/armeria/internal/server/annotation/DefaultAnnotatedService.java index 14150bd2ba0..24143652354 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/annotation/DefaultAnnotatedService.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/annotation/DefaultAnnotatedService.java @@ -61,6 +61,9 @@ import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.Route; import com.linecorp.armeria.server.RoutingContext; +import com.linecorp.armeria.server.ServiceOption; +import com.linecorp.armeria.server.ServiceOptions; +import com.linecorp.armeria.server.ServiceOptionsBuilder; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.server.SimpleDecoratingHttpService; import com.linecorp.armeria.server.annotation.AnnotatedService; @@ -111,6 +114,8 @@ final class DefaultAnnotatedService implements AnnotatedService { @Nullable private final String name; + private final ServiceOptions options; + DefaultAnnotatedService(Object object, Method method, int overloadId, List resolvers, List exceptionHandlers, @@ -176,6 +181,16 @@ final class DefaultAnnotatedService implements AnnotatedService { this.method.setAccessible(true); // following must be called only after method.setAccessible(true) methodHandle = asMethodHandle(method, object); + + ServiceOption serviceOption = AnnotationUtil.findFirst(method, ServiceOption.class); + if (serviceOption == null) { + serviceOption = AnnotationUtil.findFirst(object.getClass(), ServiceOption.class); + } + if (serviceOption != null) { + options = buildServiceOptions(serviceOption); + } else { + options = ServiceOptions.of(); + } } private static Type getActualReturnType(Method method) { @@ -221,6 +236,20 @@ private static void warnIfHttpResponseArgumentExists(Type returnType, } } + private static ServiceOptions buildServiceOptions(ServiceOption serviceOption) { + final ServiceOptionsBuilder builder = ServiceOptions.builder(); + if (serviceOption.requestTimeoutMillis() >= 0) { + builder.requestTimeoutMillis(serviceOption.requestTimeoutMillis()); + } + if (serviceOption.maxRequestLength() >= 0) { + builder.maxRequestLength(serviceOption.maxRequestLength()); + } + if (serviceOption.requestAutoAbortDelayMillis() >= 0) { + builder.requestAutoAbortDelayMillis(serviceOption.requestAutoAbortDelayMillis()); + } + return builder.build(); + } + @Override public String name() { return name; @@ -486,6 +515,11 @@ public ExchangeType exchangeType(RoutingContext routingContext) { } } + @Override + public ServiceOptions options() { + return options; + } + /** * An {@link ExceptionHandlerFunction} which wraps a list of {@link ExceptionHandlerFunction}s. */ diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/websocket/DefaultWebSocketService.java b/core/src/main/java/com/linecorp/armeria/internal/server/websocket/DefaultWebSocketService.java index 151f8c0f8e1..b3628103d30 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/websocket/DefaultWebSocketService.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/websocket/DefaultWebSocketService.java @@ -45,11 +45,13 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.stream.ClosedStreamException; import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.common.util.TimeoutMode; import com.linecorp.armeria.common.websocket.WebSocket; import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder; import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.ServiceConfig; +import com.linecorp.armeria.server.ServiceOptions; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.server.websocket.WebSocketProtocolHandler; import com.linecorp.armeria.server.websocket.WebSocketService; @@ -97,12 +99,13 @@ public final class DefaultWebSocketService implements WebSocketService, WebSocke @Nullable private final Predicate originPredicate; private final boolean aggregateContinuation; + private final ServiceOptions serviceOptions; public DefaultWebSocketService(WebSocketServiceHandler handler, @Nullable HttpService fallbackService, int maxFramePayloadLength, boolean allowMaskMismatch, Set subprotocols, boolean allowAnyOrigin, @Nullable Predicate originPredicate, - boolean aggregateContinuation) { + boolean aggregateContinuation, ServiceOptions serviceOptions) { this.handler = handler; this.fallbackService = fallbackService; this.maxFramePayloadLength = maxFramePayloadLength; @@ -111,6 +114,7 @@ public DefaultWebSocketService(WebSocketServiceHandler handler, @Nullable HttpSe this.allowAnyOrigin = allowAnyOrigin; this.originPredicate = originPredicate; this.aggregateContinuation = aggregateContinuation; + this.serviceOptions = serviceOptions; } @Override @@ -205,6 +209,27 @@ private WebSocketUpgradeResult upgradeHttp1(ServiceRequestContext ctx, HttpReque private HttpResponse failOrFallback(ServiceRequestContext ctx, HttpRequest req, Supplier invalidResponse) throws Exception { if (fallbackService != null) { + // Try to apply ServiceOptions from fallbackService first. If not set, use the settings of the + // virtual host. + final ServiceOptions options = fallbackService.options(); + long requestTimeoutMillis = options.requestTimeoutMillis(); + if (requestTimeoutMillis < 0) { + requestTimeoutMillis = ctx.config().virtualHost().requestTimeoutMillis(); + } + ctx.setRequestTimeoutMillis(TimeoutMode.SET_FROM_START, requestTimeoutMillis); + + long maxRequestLength = options.maxRequestLength(); + if (maxRequestLength < 0) { + maxRequestLength = ctx.config().virtualHost().maxRequestLength(); + } + ctx.setMaxRequestLength(maxRequestLength); + + long requestAutoAbortDelayMillis = options.requestAutoAbortDelayMillis(); + if (requestAutoAbortDelayMillis < 0) { + requestAutoAbortDelayMillis = ctx.config().virtualHost().requestAutoAbortDelayMillis(); + } + ctx.setRequestAutoAbortDelayMillis(requestAutoAbortDelayMillis); + return fallbackService.serve(ctx, req); } else { return invalidResponse.get(); @@ -397,4 +422,9 @@ public HttpResponse encode(ServiceRequestContext ctx, WebSocket out) { public WebSocketProtocolHandler protocolHandler() { return this; } + + @Override + public ServiceOptions options() { + return serviceOptions; + } } diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractContextPathAnnotatedServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/AbstractContextPathAnnotatedServiceConfigSetters.java index 55be27575ea..2fbcb36865d 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractContextPathAnnotatedServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractContextPathAnnotatedServiceConfigSetters.java @@ -23,8 +23,7 @@ abstract class AbstractContextPathAnnotatedServiceConfigSetters , T extends AbstractContextPathServicesBuilder> - extends AbstractAnnotatedServiceConfigSetters< - AbstractContextPathAnnotatedServiceConfigSetters> { + extends AbstractAnnotatedServiceConfigSetters { private final T builder; private final Set contextPaths; @@ -42,7 +41,7 @@ abstract class AbstractContextPathAnnotatedServiceConfigSetters * If path prefix is not set then this service is registered to handle requests matching * {@code /} */ - T build(Object service) { + public T build(Object service) { requireNonNull(service, "service"); service(service); contextPaths(contextPaths); diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractContextPathServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/AbstractContextPathServiceBindingBuilder.java index 99a022bffbb..5d457bede56 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractContextPathServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractContextPathServiceBindingBuilder.java @@ -21,7 +21,7 @@ abstract class AbstractContextPathServiceBindingBuilder , T extends AbstractContextPathServicesBuilder> - extends AbstractServiceBindingBuilder> { + extends AbstractServiceBindingBuilder { private final T contextPathServicesBuilder; diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java index dbb3f38eafc..34edc943299 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java @@ -44,8 +44,7 @@ * @see VirtualHostServiceBindingBuilder */ abstract class AbstractServiceBindingBuilder> - extends AbstractBindingBuilder - implements ServiceConfigSetters> { + extends AbstractBindingBuilder implements ServiceConfigSetters { private final DefaultServiceConfigSetters defaultServiceConfigSetters = new DefaultServiceConfigSetters(); @@ -261,14 +260,4 @@ final void build0(HttpService service) { } } } - - final void build0(HttpService service, Route mappedRoute) { - final List routes = buildRouteList(ImmutableSet.of()); - assert routes.size() == 1; // Only one route is set via addRoute(). - final HttpService decoratedService = defaultServiceConfigSetters.decorator().apply(service); - final ServiceConfigBuilder serviceConfigBuilder = - defaultServiceConfigSetters.toServiceConfigBuilder(routes.get(0), "/", decoratedService); - serviceConfigBuilder.addMappedRoute(mappedRoute); - serviceConfigBuilder(serviceConfigBuilder); - } } diff --git a/core/src/main/java/com/linecorp/armeria/server/ConnectionLimitingHandler.java b/core/src/main/java/com/linecorp/armeria/server/ConnectionLimitingHandler.java index 4167a7a7e3e..e750c42972e 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ConnectionLimitingHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/ConnectionLimitingHandler.java @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.LongAdder; import org.slf4j.Logger; @@ -47,29 +46,30 @@ final class ConnectionLimitingHandler extends ChannelInboundHandlerAdapter { private final Set childChannels = Collections.newSetFromMap(new ConcurrentHashMap<>()); private final Set unmodifiableChildChannels = Collections.unmodifiableSet(childChannels); private final int maxNumConnections; - private final AtomicInteger numConnections = new AtomicInteger(); + private final ServerMetrics serverMetrics; private final AtomicBoolean loggingScheduled = new AtomicBoolean(); private final LongAdder numDroppedConnections = new LongAdder(); - ConnectionLimitingHandler(int maxNumConnections) { + ConnectionLimitingHandler(int maxNumConnections, ServerMetrics serverMetrics) { this.maxNumConnections = validateMaxNumConnections(maxNumConnections); + this.serverMetrics = serverMetrics; } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { final Channel child = (Channel) msg; - final int conn = numConnections.incrementAndGet(); + final int conn = serverMetrics.increaseActiveConnectionsAndGet(); if (conn > 0 && conn <= maxNumConnections) { childChannels.add(child); child.closeFuture().addListener(future -> { childChannels.remove(child); - numConnections.decrementAndGet(); + serverMetrics.decreaseActiveConnections(); }); super.channelRead(ctx, msg); } else { - numConnections.decrementAndGet(); + serverMetrics.decreaseActiveConnections(); // Set linger option to 0 so that the server doesn't get too many TIME_WAIT states. child.config().setOption(ChannelOption.SO_LINGER, 0); @@ -104,7 +104,7 @@ public int maxNumConnections() { * Returns the number of open connections. */ public int numConnections() { - return numConnections.get(); + return serverMetrics.activeConnections(); } /** diff --git a/core/src/main/java/com/linecorp/armeria/server/CorsServerErrorHandler.java b/core/src/main/java/com/linecorp/armeria/server/CorsServerErrorHandler.java index 94c3a8baad8..d8ea993a6f1 100644 --- a/core/src/main/java/com/linecorp/armeria/server/CorsServerErrorHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/CorsServerErrorHandler.java @@ -43,12 +43,13 @@ final class CorsServerErrorHandler implements ServerErrorHandler { this.serverErrorHandler = serverErrorHandler; } + @Nullable @Override - public @Nullable AggregatedHttpResponse renderStatus(@Nullable ServiceRequestContext ctx, - ServiceConfig serviceConfig, - @Nullable RequestHeaders headers, - HttpStatus status, @Nullable String description, - @Nullable Throwable cause) { + public AggregatedHttpResponse renderStatus(@Nullable ServiceRequestContext ctx, + ServiceConfig serviceConfig, + @Nullable RequestHeaders headers, + HttpStatus status, @Nullable String description, + @Nullable Throwable cause) { if (ctx == null) { return serverErrorHandler.renderStatus(null, serviceConfig, headers, status, description, cause); @@ -73,8 +74,9 @@ final class CorsServerErrorHandler implements ServerErrorHandler { return AggregatedHttpResponse.of(updatedResponseHeaders, res.content()); } + @Nullable @Override - public @Nullable HttpResponse onServiceException(ServiceRequestContext ctx, Throwable cause) { + public HttpResponse onServiceException(ServiceRequestContext ctx, Throwable cause) { if (cause instanceof HttpResponseException) { final HttpResponse oldRes = serverErrorHandler.onServiceException(ctx, cause); if (oldRes == null) { @@ -84,14 +86,21 @@ final class CorsServerErrorHandler implements ServerErrorHandler { if (corsService == null) { return oldRes; } - return oldRes - .recover(HttpResponseException.class, - ex -> ex.httpResponse() - .mapHeaders(oldHeaders -> addCorsHeaders(ctx, - corsService.config(), - oldHeaders))); + return oldRes.recover(HttpResponseException.class, ex -> { + return ex.httpResponse() + .mapHeaders(oldHeaders -> addCorsHeaders(ctx, corsService.config(), oldHeaders)); + }); } else { return serverErrorHandler.onServiceException(ctx, cause); } } + + @Nullable + @Override + public AggregatedHttpResponse onProtocolViolation(ServiceConfig config, + @Nullable RequestHeaders headers, + HttpStatus status, @Nullable String description, + @Nullable Throwable cause) { + return serverErrorHandler.onProtocolViolation(config, headers, status, description, cause); + } } diff --git a/core/src/main/java/com/linecorp/armeria/server/DefaultServerConfig.java b/core/src/main/java/com/linecorp/armeria/server/DefaultServerConfig.java index ed2296bdc1e..02cdad33d55 100644 --- a/core/src/main/java/com/linecorp/armeria/server/DefaultServerConfig.java +++ b/core/src/main/java/com/linecorp/armeria/server/DefaultServerConfig.java @@ -119,6 +119,7 @@ final class DefaultServerConfig implements ServerConfig { @Nullable private final Mapping sslContexts; + private final ServerMetrics serverMetrics = new ServerMetrics(); @Nullable private String strVal; @@ -681,6 +682,11 @@ public long unloggedExceptionsReportIntervalMillis() { return unloggedExceptionsReportIntervalMillis; } + @Override + public ServerMetrics serverMetrics() { + return serverMetrics; + } + List shutdownSupports() { return shutdownSupports; } @@ -702,7 +708,8 @@ public String toString() { clientAddressSources(), clientAddressTrustedProxyFilter(), clientAddressFilter(), clientAddressMapper(), isServerHeaderEnabled(), isDateHeaderEnabled(), - dependencyInjector(), absoluteUriTransformer(), unloggedExceptionsReportIntervalMillis()); + dependencyInjector(), absoluteUriTransformer(), unloggedExceptionsReportIntervalMillis(), + serverMetrics()); } return strVal; @@ -727,8 +734,8 @@ static String toString( boolean serverHeaderEnabled, boolean dateHeaderEnabled, @Nullable DependencyInjector dependencyInjector, Function absoluteUriTransformer, - long unloggedExceptionsReportIntervalMillis) { - + long unloggedExceptionsReportIntervalMillis, + ServerMetrics serverMetrics) { final StringBuilder buf = new StringBuilder(); if (type != null) { buf.append(type.getSimpleName()); @@ -828,6 +835,8 @@ static String toString( buf.append(absoluteUriTransformer); buf.append(", unloggedExceptionsReportIntervalMillis: "); buf.append(unloggedExceptionsReportIntervalMillis); + buf.append(", serverMetrics: "); + buf.append(serverMetrics); buf.append(')'); return buf.toString(); diff --git a/core/src/main/java/com/linecorp/armeria/server/DefaultServerErrorHandler.java b/core/src/main/java/com/linecorp/armeria/server/DefaultServerErrorHandler.java index 740a72e3148..8e8d5c4dd74 100644 --- a/core/src/main/java/com/linecorp/armeria/server/DefaultServerErrorHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/DefaultServerErrorHandler.java @@ -19,9 +19,6 @@ import javax.annotation.Nonnull; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.ContentTooLargeException; import com.linecorp.armeria.common.HttpData; @@ -50,8 +47,6 @@ enum DefaultServerErrorHandler implements ServerErrorHandler { INSTANCE; - private static final Logger logger = LoggerFactory.getLogger(DefaultServerErrorHandler.class); - /** * Converts the specified {@link Throwable} to an {@link HttpResponse}. */ diff --git a/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java b/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java index edfea9224ae..cc35c9d52f3 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java @@ -261,7 +261,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception eventLoop, id, 1, headers, false, inboundTrafficController, serviceConfig.maxRequestLength(), routingCtx, ExchangeType.BIDI_STREAMING, - System.nanoTime(), SystemInfo.currentTimeMicros(), true, false); + System.nanoTime(), SystemInfo.currentTimeMicros(), true, false + ); assert encoder instanceof ServerHttp1ObjectEncoder; ((ServerHttp1ObjectEncoder) encoder).webSocketUpgrading(); final ChannelPipeline pipeline = ctx.pipeline(); @@ -270,6 +271,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception if (pipeline.get(HttpServerUpgradeHandler.class) != null) { pipeline.remove(HttpServerUpgradeHandler.class); } + cfg.serverMetrics().increasePendingHttp1Requests(); ctx.fireChannelRead(webSocketRequest); return; } @@ -278,7 +280,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception final boolean endOfStream = contentEmpty && !transferEncodingChunked; this.req = req = DecodedHttpRequest.of(endOfStream, eventLoop, id, 1, headers, keepAlive, inboundTrafficController, routingCtx); - + cfg.serverMetrics().increasePendingHttp1Requests(); ctx.fireChannelRead(req); } else { fail(id, null, HttpStatus.BAD_REQUEST, "Invalid decoder state", null); diff --git a/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java b/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java index 785cf0a087a..a68207501ad 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java @@ -207,6 +207,7 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers req = DecodedHttpRequest.of(endOfStream, eventLoop, id, streamId, headers, true, inboundTrafficController, routingCtx); requests.put(streamId, req); + cfg.serverMetrics().increasePendingHttp2Requests(); ctx.fireChannelRead(req); } else { if (!(req instanceof DecodedHttpRequestWriter)) { diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java index 777a5433843..30029c87c7c 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java @@ -329,6 +329,8 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th // Ignore the request received after the last request, // because we are going to close the connection after sending the last response. if (handledLastRequest) { + req.abort(); + decreasePendingRequests(); return; } @@ -357,6 +359,7 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th // Handle 'OPTIONS * HTTP/1.1'. if (routingStatus == RoutingStatus.OPTIONS) { handleOptions(ctx, reqCtx); + decreasePendingRequests(); return; } @@ -388,23 +391,25 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th req.init(reqCtx); final CompletableFuture whenAggregated = req.whenAggregated(); if (whenAggregated != null) { - res = HttpResponse.of(req.whenAggregated().thenApply(ignored -> { + res = HttpResponse.of(whenAggregated.thenApply(ignored -> { if (serviceEventLoop.inEventLoop()) { - return serve0(req, service, reqCtx); + return serve0(req, service, reqCtx, req.isHttp1WebSocket()); } - return serveInServiceEventLoop(req, service, reqCtx, serviceEventLoop); + return serveInServiceEventLoop(req, service, reqCtx, serviceEventLoop, req.isHttp1WebSocket()); })); } else { if (serviceEventLoop.inEventLoop()) { - res = serve0(req, service, reqCtx); + res = serve0(req, service, reqCtx, req.isHttp1WebSocket()); } else { - res = serveInServiceEventLoop(req, service, reqCtx, serviceEventLoop); + res = serveInServiceEventLoop(req, service, reqCtx, serviceEventLoop, req.isHttp1WebSocket()); } } res = res.recover(cause -> { reqCtx.logBuilder().responseCause(cause); // Recover the failed response with the error handler. - return serviceCfg.errorHandler().onServiceException(reqCtx, cause); + try (SafeCloseable ignored = reqCtx.push()) { + return serviceCfg.errorHandler().onServiceException(reqCtx, cause); + } }); // Keep track of the number of unfinished requests and @@ -456,12 +461,33 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th } } - private static HttpResponse serve0(HttpRequest req, HttpService service, - DefaultServiceRequestContext reqCtx) { + private void decreasePendingRequests() { + if (protocol.isExplicitHttp1()) { + config.serverMetrics().decreasePendingHttp1Requests(); + } else { + assert protocol.isExplicitHttp2(); + config.serverMetrics().decreasePendingHttp2Requests(); + } + } + + private void increaseActiveRequests(boolean isHttp1WebSocket) { + if (isHttp1WebSocket) { + config.serverMetrics().increaseActiveHttp1WebSocketRequests(); + } else if (protocol.isExplicitHttp1()) { + config.serverMetrics().increaseActiveHttp1Requests(); + } else { + assert protocol.isExplicitHttp2(); + config.serverMetrics().increaseActiveHttp2Requests(); + } + } + + private HttpResponse serve0(HttpRequest req, HttpService service, DefaultServiceRequestContext reqCtx, + boolean isHttp1WebSocket) { try (SafeCloseable ignored = reqCtx.push()) { - HttpResponse serviceResponse; try { - serviceResponse = service.serve(reqCtx, req); + decreasePendingRequests(); + increaseActiveRequests(isHttp1WebSocket); + return service.serve(reqCtx, req); } catch (Throwable cause) { // No need to consume further since the response is ready. if (cause instanceof HttpResponseException || cause instanceof HttpStatusException) { @@ -469,18 +495,18 @@ private static HttpResponse serve0(HttpRequest req, HttpService service, } else { req.abort(cause); } - serviceResponse = HttpResponse.ofFailure(cause); + return HttpResponse.ofFailure(cause); } - - return serviceResponse; } } - private static HttpResponse serveInServiceEventLoop(DecodedHttpRequest req, - HttpService service, - DefaultServiceRequestContext reqCtx, - EventLoop serviceEventLoop) { - return HttpResponse.of(() -> serve0(req.subscribeOn(serviceEventLoop), service, reqCtx), + private HttpResponse serveInServiceEventLoop(DecodedHttpRequest req, + HttpService service, + DefaultServiceRequestContext reqCtx, + EventLoop serviceEventLoop, + boolean isHttp1WebSocket) { + return HttpResponse.of(() -> serve0(req.subscribeOn(serviceEventLoop), service, + reqCtx, isHttp1WebSocket), serviceEventLoop) .subscribeOn(serviceEventLoop); } @@ -767,6 +793,14 @@ private void handleRequestOrResponseComplete() { requestOrResponseComplete = true; return; } + if (req.isHttp1WebSocket()) { + config.serverMetrics().decreaseActiveHttp1WebSocketRequests(); + } else if (protocol.isExplicitHttp1()) { + config.serverMetrics().decreaseActiveHttp1Requests(); + } else if (protocol.isExplicitHttp2()) { + config.serverMetrics().decreaseActiveHttp2Requests(); + } + // NB: logBuilder.endResponse() is called by HttpResponseSubscriber. if (!isTransientService) { gracefulShutdownSupport.dec(); diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpServerPipelineConfigurator.java b/core/src/main/java/com/linecorp/armeria/server/HttpServerPipelineConfigurator.java index 85c9c84551c..2c23dc5f357 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpServerPipelineConfigurator.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpServerPipelineConfigurator.java @@ -45,6 +45,7 @@ import com.google.common.collect.ImmutableList; +import com.linecorp.armeria.common.Flags; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.metric.MoreMeters; @@ -109,7 +110,6 @@ final class HttpServerPipelineConfigurator extends ChannelInitializer { private static final Logger logger = LoggerFactory.getLogger(HttpServerPipelineConfigurator.class); private static final int SSL_RECORD_HEADER_LENGTH = 5; - private static final int MAX_CLIENT_HELLO_LENGTH = 4096; // 4KiB should be more than enough. static final AsciiString SCHEME_HTTP = AsciiString.cached("http"); static final AsciiString SCHEME_HTTPS = AsciiString.cached("https"); @@ -232,7 +232,7 @@ private Timer newKeepAliveTimer(SessionProtocol protocol) { private void configureHttps(ChannelPipeline p, @Nullable ProxiedAddresses proxiedAddresses) { final Mapping sslContexts = requireNonNull(config.sslContextMapping(), "config.sslContextMapping() returned null"); - p.addLast(new SniHandler(sslContexts, MAX_CLIENT_HELLO_LENGTH, config.idleTimeoutMillis())); + p.addLast(new SniHandler(sslContexts, Flags.defaultMaxClientHelloLength(), config.idleTimeoutMillis())); p.addLast(TrafficLoggingHandler.SERVER); p.addLast(new Http2OrHttpHandler(proxiedAddresses)); } diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpService.java b/core/src/main/java/com/linecorp/armeria/server/HttpService.java index 7a783c6f37a..c5714cf85f7 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpService.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpService.java @@ -71,4 +71,12 @@ default HttpService decorate(DecoratingHttpServiceFunction function) { default ExchangeType exchangeType(RoutingContext routingContext) { return ExchangeType.BIDI_STREAMING; } + + /** + * Returns the {@link ServiceOptions} of this {@link HttpService}. + */ + @UnstableApi + default ServiceOptions options() { + return ServiceOptions.of(); + } } diff --git a/core/src/main/java/com/linecorp/armeria/server/Server.java b/core/src/main/java/com/linecorp/armeria/server/Server.java index 2123ee9b14c..ce7636bc6ac 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Server.java +++ b/core/src/main/java/com/linecorp/armeria/server/Server.java @@ -73,6 +73,7 @@ import com.linecorp.armeria.common.util.ListenableAsyncCloseable; import com.linecorp.armeria.common.util.ShutdownHooks; import com.linecorp.armeria.common.util.StartStopSupport; +import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.common.util.TransportType; import com.linecorp.armeria.common.util.Version; import com.linecorp.armeria.internal.common.RequestTargetCache; @@ -130,7 +131,8 @@ public static ServerBuilder builder() { serverConfig.setServer(this); config = new UpdatableServerConfig(requireNonNull(serverConfig, "serverConfig")); startStop = new ServerStartStopSupport(config.startStopExecutor()); - connectionLimitingHandler = new ConnectionLimitingHandler(config.maxNumConnections()); + connectionLimitingHandler = new ConnectionLimitingHandler(config.maxNumConnections(), + config.serverMetrics()); // Server-wide metrics. RequestTargetCache.registerServerMetrics(config.meterRegistry()); @@ -138,7 +140,9 @@ public static ServerBuilder builder() { for (VirtualHost virtualHost : config().virtualHosts()) { if (virtualHost.sslContext() != null) { - setupTlsMetrics(virtualHost.sslContext(), virtualHost.hostnamePattern()); + assert virtualHost.tlsEngineType() != null; + setupTlsMetrics(virtualHost.sslContext(), virtualHost.tlsEngineType(), + virtualHost.hostnamePattern()); } } @@ -417,10 +421,10 @@ void setupVersionMetrics() { /** * Sets up gauge metric for each server certificate. */ - private void setupTlsMetrics(SslContext sslContext, String hostnamePattern) { + private void setupTlsMetrics(SslContext sslContext, TlsEngineType tlsEngineType, String hostnamePattern) { final MeterRegistry meterRegistry = config().meterRegistry(); - final SSLSession sslSession = validateSslContext(sslContext); + final SSLSession sslSession = validateSslContext(sslContext, tlsEngineType); final MeterIdPrefix meterIdPrefix = new MeterIdPrefix("armeria.server", "hostname.pattern", hostnamePattern); for (Certificate certificate : sslSession.getLocalCertificates()) { @@ -574,14 +578,13 @@ private ChannelFuture doStart(ServerPort port) { } private void setupServerMetrics() { - final MeterRegistry meterRegistry = config().meterRegistry(); + final MeterRegistry meterRegistry = config.meterRegistry(); final GracefulShutdownSupport gracefulShutdownSupport = this.gracefulShutdownSupport; assert gracefulShutdownSupport != null; meterRegistry.gauge("armeria.server.pending.responses", gracefulShutdownSupport, GracefulShutdownSupport::pendingResponses); - meterRegistry.gauge("armeria.server.connections", connectionLimitingHandler, - ConnectionLimitingHandler::numConnections); + config.serverMetrics().bindTo(meterRegistry); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java index 7bd5a932592..5674d209f67 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java @@ -249,6 +249,7 @@ public final class ServerBuilder implements TlsSetters, ServiceConfigsBuilder LoggerFactory.getLogger(defaultAccessLoggerName(host.hostnamePattern()))); virtualHostTemplate.tlsSelfSigned(false); virtualHostTemplate.tlsAllowUnsafeCiphers(false); + virtualHostTemplate.tlsEngineType(Flags.tlsEngineType()); virtualHostTemplate.annotatedServiceExtensions(ImmutableList.of(), ImmutableList.of(), ImmutableList.of()); virtualHostTemplate.blockingTaskExecutor(CommonPools.blockingTaskExecutor(), false); @@ -1193,6 +1194,17 @@ public ServerBuilder tlsAllowUnsafeCiphers(boolean tlsAllowUnsafeCiphers) { return this; } + /** + * Sets {@link TlsEngineType} that will be used for processing TLS connections. + * + * @param tlsEngineType the {@link TlsEngineType} to use + */ + @UnstableApi + public ServerBuilder tlsEngineType(TlsEngineType tlsEngineType) { + virtualHostTemplate.tlsEngineType(tlsEngineType); + return this; + } + /** * Returns a {@link ContextPathServicesBuilder} which binds {@link HttpService}s under the * specified context paths. diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java b/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java index 3115cdff6f0..9e3de9a8e26 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java @@ -347,4 +347,9 @@ default boolean shutdownBlockingTaskExecutorOnStop() { * Returns the interval between reporting unlogged exceptions in milliseconds. */ long unloggedExceptionsReportIntervalMillis(); + + /** + * Returns the {@link ServerMetrics} that collects metrics related server. + */ + ServerMetrics serverMetrics(); } diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerMetrics.java b/core/src/main/java/com/linecorp/armeria/server/ServerMetrics.java new file mode 100644 index 00000000000..9a57bf65888 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/ServerMetrics.java @@ -0,0 +1,190 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.LongAdder; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.common.annotation.UnstableApi; + +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.binder.MeterBinder; + +/** + * A class that holds metrics related server. + */ +@UnstableApi +public final class ServerMetrics implements MeterBinder { + + private final LongAdder pendingHttp1Requests = new LongAdder(); + private final LongAdder pendingHttp2Requests = new LongAdder(); + private final LongAdder activeHttp1WebSocketRequests = new LongAdder(); + private final LongAdder activeHttp1Requests = new LongAdder(); + private final LongAdder activeHttp2Requests = new LongAdder(); + + /** + * AtomicInteger is used to read the number of active connections frequently. + */ + private final AtomicInteger activeConnections = new AtomicInteger(); + + ServerMetrics() {} + + /** + * Returns the number of all pending requests. + */ + public long pendingRequests() { + return pendingHttp1Requests() + pendingHttp2Requests(); + } + + /** + * Returns the number of pending http1 requests. + */ + public long pendingHttp1Requests() { + return pendingHttp1Requests.longValue(); + } + + /** + * Returns the number of pending http2 requests. + */ + public long pendingHttp2Requests() { + return pendingHttp2Requests.longValue(); + } + + /** + * Returns the number of all active requests. + */ + public long activeRequests() { + return activeHttp1WebSocketRequests() + + activeHttp1Requests() + + activeHttp2Requests(); + } + + /** + * Returns the number of active http1 web socket requests. + */ + public long activeHttp1WebSocketRequests() { + return activeHttp1WebSocketRequests.longValue(); + } + + /** + * Returns the number of active http1 requests. + */ + public long activeHttp1Requests() { + return activeHttp1Requests.longValue(); + } + + /** + * Returns the number of active http2 requests. + */ + public long activeHttp2Requests() { + return activeHttp2Requests.longValue(); + } + + /** + * Returns the number of open connections. + */ + public int activeConnections() { + return activeConnections.get(); + } + + void increasePendingHttp1Requests() { + pendingHttp1Requests.increment(); + } + + void decreasePendingHttp1Requests() { + pendingHttp1Requests.decrement(); + } + + void increasePendingHttp2Requests() { + pendingHttp2Requests.increment(); + } + + void decreasePendingHttp2Requests() { + pendingHttp2Requests.decrement(); + } + + void increaseActiveHttp1Requests() { + activeHttp1Requests.increment(); + } + + void decreaseActiveHttp1Requests() { + activeHttp1Requests.decrement(); + } + + void increaseActiveHttp1WebSocketRequests() { + activeHttp1WebSocketRequests.increment(); + } + + void decreaseActiveHttp1WebSocketRequests() { + activeHttp1WebSocketRequests.decrement(); + } + + void increaseActiveHttp2Requests() { + activeHttp2Requests.increment(); + } + + void decreaseActiveHttp2Requests() { + activeHttp2Requests.decrement(); + } + + int increaseActiveConnectionsAndGet() { + return activeConnections.incrementAndGet(); + } + + void decreaseActiveConnections() { + activeConnections.decrementAndGet(); + } + + @Override + public void bindTo(MeterRegistry meterRegistry) { + meterRegistry.gauge("armeria.server.connections", activeConnections); + // pending requests + final String allRequestsMeterName = "armeria.server.all.requests"; + meterRegistry.gauge(allRequestsMeterName, + ImmutableList.of(Tag.of("protocol", "http1"), Tag.of("state", "pending")), + pendingHttp1Requests); + meterRegistry.gauge(allRequestsMeterName, + ImmutableList.of(Tag.of("protocol", "http2"), Tag.of("state", "pending")), + pendingHttp2Requests); + // Active requests + meterRegistry.gauge(allRequestsMeterName, + ImmutableList.of(Tag.of("protocol", "http1"), Tag.of("state", "active")), + activeHttp1Requests); + meterRegistry.gauge(allRequestsMeterName, + ImmutableList.of(Tag.of("protocol", "http2"), Tag.of("state", "active")), + activeHttp2Requests); + meterRegistry.gauge(allRequestsMeterName, + ImmutableList.of(Tag.of("protocol", "http1.websocket"), Tag.of("state", "active")), + activeHttp1WebSocketRequests); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("pendingHttp1Requests", pendingHttp1Requests) + .add("activeHttp1WebSocketRequests", activeHttp1WebSocketRequests) + .add("activeHttp1Requests", activeHttp1Requests) + .add("pendingHttp2Requests", pendingHttp2Requests) + .add("activeHttp2Requests", activeHttp2Requests) + .add("activeConnections", activeConnections) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerSslContextUtil.java b/core/src/main/java/com/linecorp/armeria/server/ServerSslContextUtil.java index a7570258bb0..0d3c080b5dc 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerSslContextUtil.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerSslContextUtil.java @@ -28,6 +28,7 @@ import com.google.common.collect.ImmutableList; +import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.internal.common.util.SslContextUtil; import io.netty.buffer.ByteBufAllocator; @@ -46,7 +47,7 @@ final class ServerSslContextUtil { * key store password is not given to key store when {@link SslContext} was created using * {@link KeyManagerFactory}, the validation will fail and an {@link IllegalStateException} will be raised. */ - static SSLSession validateSslContext(SslContext sslContext) { + static SSLSession validateSslContext(SslContext sslContext, TlsEngineType tlsEngineType) { if (!sslContext.isServer()) { throw new IllegalArgumentException("sslContext: " + sslContext + " (expected: server context)"); } @@ -64,7 +65,7 @@ static SSLSession validateSslContext(SslContext sslContext) { final SslContext sslContextClient = buildSslContext(() -> SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE), - true, ImmutableList.of()); + tlsEngineType, true, ImmutableList.of()); clientEngine = sslContextClient.newEngine(ByteBufAllocator.DEFAULT); clientEngine.setUseClientMode(true); clientEngine.setEnabledProtocols(clientEngine.getSupportedProtocols()); @@ -96,11 +97,12 @@ static SSLSession validateSslContext(SslContext sslContext) { static SslContext buildSslContext( Supplier sslContextBuilderSupplier, + TlsEngineType tlsEngineType, boolean tlsAllowUnsafeCiphers, Iterable> tlsCustomizers) { return SslContextUtil - .createSslContext(sslContextBuilderSupplier, - /* forceHttp1 */ false, tlsAllowUnsafeCiphers, tlsCustomizers, null); + .createSslContext(sslContextBuilderSupplier,/* forceHttp1 */ false, tlsEngineType, + tlsAllowUnsafeCiphers, tlsCustomizers, null); } private static void unwrap(SSLEngine engine, ByteBuffer packetBuf) throws SSLException { diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java index 2c4911fc85a..993bb1c623f 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java @@ -35,7 +35,6 @@ import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableList; -import com.linecorp.armeria.common.Flags; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpHeadersBuilder; import com.linecorp.armeria.common.RequestId; @@ -43,8 +42,6 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.common.util.EventLoopGroups; -import com.linecorp.armeria.internal.common.websocket.WebSocketUtil; -import com.linecorp.armeria.internal.server.websocket.DefaultWebSocketService; import com.linecorp.armeria.server.logging.AccessLogWriter; import io.netty.channel.EventLoopGroup; @@ -93,6 +90,17 @@ final class ServiceConfigBuilder implements ServiceConfigSetters mergedContextHook = mergeHooks(contextHook, this.contextHook); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceOption.java b/core/src/main/java/com/linecorp/armeria/server/ServiceOption.java new file mode 100644 index 00000000000..b479fb24d3f --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceOption.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; + +/** + * An annotation used to configure {@link ServiceOptions} of an {@link HttpService}. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ ElementType.METHOD, ElementType.TYPE }) +public @interface ServiceOption { + + /** + * Server-side timeout of a request in milliseconds. + */ + long requestTimeoutMillis() default -1; + + /** + * Server-side maximum length of a request. + */ + long maxRequestLength() default -1; + + /** + * The amount of time to wait before aborting an {@link HttpRequest} when its corresponding + * {@link HttpResponse} is complete. + */ + long requestAutoAbortDelayMillis() default -1; +} diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceOptions.java b/core/src/main/java/com/linecorp/armeria/server/ServiceOptions.java new file mode 100644 index 00000000000..75588db259e --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceOptions.java @@ -0,0 +1,111 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import java.util.Objects; + +import com.google.common.base.MoreObjects; + +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * Options for configuring an {@link HttpService}. + * You can override the default options by implementing {@link HttpService#options()}. + */ +@UnstableApi +public final class ServiceOptions { + private static final ServiceOptions DEFAULT_OPTIONS = builder().build(); + + /** + * Returns the default {@link ServiceOptions}. + */ + public static ServiceOptions of() { + return DEFAULT_OPTIONS; + } + + /** + * Returns a new {@link ServiceOptionsBuilder}. + */ + public static ServiceOptionsBuilder builder() { + return new ServiceOptionsBuilder(); + } + + private final long requestTimeoutMillis; + private final long maxRequestLength; + private final long requestAutoAbortDelayMillis; + + ServiceOptions(long requestTimeoutMillis, long maxRequestLength, long requestAutoAbortDelayMillis) { + this.requestTimeoutMillis = requestTimeoutMillis; + this.maxRequestLength = maxRequestLength; + this.requestAutoAbortDelayMillis = requestAutoAbortDelayMillis; + } + + /** + * Returns the server-side timeout of a request in milliseconds. {@code -1} if not set. + */ + public long requestTimeoutMillis() { + return requestTimeoutMillis; + } + + /** + * Returns the server-side maximum length of a request. {@code -1} if not set. + */ + public long maxRequestLength() { + return maxRequestLength; + } + + /** + * Returns the amount of time to wait before aborting an {@link HttpRequest} when its corresponding + * {@link HttpResponse} is complete. {@code -1} if not set. + */ + public long requestAutoAbortDelayMillis() { + return requestAutoAbortDelayMillis; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final ServiceOptions that = (ServiceOptions) o; + + return requestTimeoutMillis == that.requestTimeoutMillis && + maxRequestLength == that.maxRequestLength && + requestAutoAbortDelayMillis == that.requestAutoAbortDelayMillis; + } + + @Override + public int hashCode() { + return Objects.hash(requestTimeoutMillis, maxRequestLength, requestAutoAbortDelayMillis); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("requestTimeoutMillis", requestTimeoutMillis) + .add("maxRequestLength", maxRequestLength) + .add("requestAutoAbortDelayMillis", requestAutoAbortDelayMillis) + .toString(); + } +} + diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceOptionsBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceOptionsBuilder.java new file mode 100644 index 00000000000..463b0412832 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceOptionsBuilder.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * Creates a new {@link ServiceOptions} with the specified parameters. + */ +@UnstableApi +public final class ServiceOptionsBuilder { + private long requestTimeoutMillis = -1; + private long maxRequestLength = -1; + private long requestAutoAbortDelayMillis = -1; + + ServiceOptionsBuilder() {} + + /** + * Returns the server-side timeout of a request in milliseconds. + */ + public ServiceOptionsBuilder requestTimeoutMillis(long requestTimeoutMillis) { + checkArgument(requestTimeoutMillis >= 0, "requestTimeoutMillis: %s (expected: >= 0)", + requestTimeoutMillis); + this.requestTimeoutMillis = requestTimeoutMillis; + return this; + } + + /** + * Returns the server-side maximum length of a request. + */ + public ServiceOptionsBuilder maxRequestLength(long maxRequestLength) { + checkArgument(maxRequestLength >= 0, "maxRequestLength: %s (expected: >= 0)", maxRequestLength); + this.maxRequestLength = maxRequestLength; + return this; + } + + /** + * Sets the amount of time to wait before aborting an {@link HttpRequest} when its corresponding + * {@link HttpResponse} is complete. + */ + public ServiceOptionsBuilder requestAutoAbortDelayMillis(long requestAutoAbortDelayMillis) { + checkArgument(requestAutoAbortDelayMillis >= 0, "requestAutoAbortDelayMillis: %s (expected: >= 0)", + requestAutoAbortDelayMillis); + this.requestAutoAbortDelayMillis = requestAutoAbortDelayMillis; + return this; + } + + /** + * Returns a newly created {@link ServiceOptions} based on the properties of this builder. + */ + public ServiceOptions build() { + return new ServiceOptions(requestTimeoutMillis, maxRequestLength, requestAutoAbortDelayMillis); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/server/SimpleDecoratingHttpService.java b/core/src/main/java/com/linecorp/armeria/server/SimpleDecoratingHttpService.java index ff60d0c81aa..3327c478871 100644 --- a/core/src/main/java/com/linecorp/armeria/server/SimpleDecoratingHttpService.java +++ b/core/src/main/java/com/linecorp/armeria/server/SimpleDecoratingHttpService.java @@ -39,4 +39,9 @@ protected SimpleDecoratingHttpService(HttpService delegate) { public ExchangeType exchangeType(RoutingContext routingContext) { return ((HttpService) unwrap()).exchangeType(routingContext); } + + @Override + public ServiceOptions options() { + return ((HttpService) unwrap()).options(); + } } diff --git a/core/src/main/java/com/linecorp/armeria/server/UpdatableServerConfig.java b/core/src/main/java/com/linecorp/armeria/server/UpdatableServerConfig.java index 54f1ea7a068..1b82f01a805 100644 --- a/core/src/main/java/com/linecorp/armeria/server/UpdatableServerConfig.java +++ b/core/src/main/java/com/linecorp/armeria/server/UpdatableServerConfig.java @@ -320,6 +320,11 @@ public long unloggedExceptionsReportIntervalMillis() { return delegate.unloggedExceptionsReportIntervalMillis(); } + @Override + public ServerMetrics serverMetrics() { + return delegate.serverMetrics(); + } + @Override public String toString() { return delegate.toString(); diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java index 5f4eba56736..72559bd1520 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java @@ -45,6 +45,7 @@ import com.linecorp.armeria.common.logging.RequestLogBuilder; import com.linecorp.armeria.common.metric.MeterIdPrefix; import com.linecorp.armeria.common.util.BlockingTaskExecutor; +import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.server.logging.AccessLogWriter; import io.micrometer.core.instrument.MeterRegistry; @@ -82,6 +83,8 @@ public final class VirtualHost { private final int port; @Nullable private final SslContext sslContext; + @Nullable + private final TlsEngineType tlsEngineType; private final Router router; private final List serviceConfigs; private final ServiceConfig fallbackServiceConfig; @@ -105,6 +108,7 @@ public final class VirtualHost { VirtualHost(String defaultHostname, String hostnamePattern, int port, @Nullable SslContext sslContext, + @Nullable TlsEngineType tlsEngineType, Iterable serviceConfigs, ServiceConfig fallbackServiceConfig, RejectedRouteHandler rejectionHandler, @@ -133,6 +137,7 @@ public final class VirtualHost { } this.port = port; this.sslContext = sslContext; + this.tlsEngineType = tlsEngineType; this.defaultServiceNaming = defaultServiceNaming; this.defaultLogName = defaultLogName; this.requestTimeoutMillis = requestTimeoutMillis; @@ -167,9 +172,9 @@ public final class VirtualHost { VirtualHost withNewSslContext(SslContext sslContext) { return new VirtualHost(originalDefaultHostname, originalHostnamePattern, port, sslContext, - serviceConfigs, fallbackServiceConfig, RejectedRouteHandler.DISABLED, - host -> accessLogger, defaultServiceNaming, defaultLogName, requestTimeoutMillis, - maxRequestLength, verboseResponses, + tlsEngineType, serviceConfigs, fallbackServiceConfig, + RejectedRouteHandler.DISABLED, host -> accessLogger, defaultServiceNaming, + defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, successFunction, multipartUploadsLocation, multipartRemovalStrategy, serviceWorkerGroup, @@ -321,6 +326,15 @@ public SslContext sslContext() { return sslContext; } + /** + * Returns the {@link TlsEngineType} of this virtual host. + */ + @Nullable + @UnstableApi + public TlsEngineType tlsEngineType() { + return tlsEngineType; + } + /** * Returns the information about the {@link HttpService}s bound to this virtual host. */ @@ -575,12 +589,12 @@ VirtualHost decorate(@Nullable Function accessLogger, defaultServiceNaming, defaultLogName, requestTimeoutMillis, - maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, - requestAutoAbortDelayMillis, successFunction, multipartUploadsLocation, - multipartRemovalStrategy, serviceWorkerGroup, shutdownSupports, - requestIdGenerator); + tlsEngineType, serviceConfigs, fallbackServiceConfig, + RejectedRouteHandler.DISABLED, host -> accessLogger, defaultServiceNaming, + defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, + accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, + successFunction, multipartUploadsLocation, multipartRemovalStrategy, + serviceWorkerGroup, shutdownSupports, requestIdGenerator); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java index 221a1355e03..ea946e27a19 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java @@ -84,6 +84,7 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.common.util.SystemInfo; +import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.internal.common.util.SelfSignedCertificate; import com.linecorp.armeria.internal.server.RouteDecoratingService; import com.linecorp.armeria.internal.server.RouteUtil; @@ -136,6 +137,8 @@ public final class VirtualHostBuilder implements TlsSetters, ServiceConfigsBuild private final List> tlsCustomizers = new ArrayList<>(); @Nullable private Boolean tlsAllowUnsafeCiphers; + @Nullable + private TlsEngineType tlsEngineType; private final LinkedList routeDecoratingServices = new LinkedList<>(); @Nullable private Function accessLoggerMapper; @@ -427,6 +430,16 @@ public VirtualHostBuilder tlsAllowUnsafeCiphers(boolean tlsAllowUnsafeCiphers) { return this; } + /** + * The {@link TlsEngineType} that will be used for processing TLS connections. + */ + @UnstableApi + public VirtualHostBuilder tlsEngineType(TlsEngineType tlsEngineType) { + requireNonNull(tlsEngineType, "tlsEngineType"); + this.tlsEngineType = tlsEngineType; + return this; + } + /** * Returns a {@link VirtualHostContextPathServicesBuilder} which binds {@link HttpService}s under the * specified context paths. @@ -1436,9 +1449,12 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje builder.addAll(shutdownSupports); builder.addAll(template.shutdownSupports); + final TlsEngineType tlsEngineType = this.tlsEngineType != null ? + this.tlsEngineType : template.tlsEngineType; + assert tlsEngineType != null; final VirtualHost virtualHost = - new VirtualHost(defaultHostname, hostnamePattern, port, sslContext(template), - serviceConfigs, fallbackServiceConfig, rejectedRouteHandler, + new VirtualHost(defaultHostname, hostnamePattern, port, sslContext(template, tlsEngineType), + tlsEngineType, serviceConfigs, fallbackServiceConfig, rejectedRouteHandler, accessLoggerMapper, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, successFunction, multipartUploadsLocation, @@ -1470,7 +1486,7 @@ static HttpHeaders mergeDefaultHeaders(HttpHeadersBuilder lowPriorityHeaders, } @Nullable - private SslContext sslContext(VirtualHostBuilder template) { + private SslContext sslContext(VirtualHostBuilder template, TlsEngineType tlsEngineType) { if (portBased) { return null; } @@ -1486,13 +1502,13 @@ private SslContext sslContext(VirtualHostBuilder template) { // Build a new SslContext or use a user-specified one for backward compatibility. if (sslContextBuilderSupplier != null) { - sslContext = buildSslContext(sslContextBuilderSupplier, tlsAllowUnsafeCiphers, tlsCustomizers); + sslContext = buildSslContext(sslContextBuilderSupplier, tlsEngineType, tlsAllowUnsafeCiphers, + tlsCustomizers); sslContextFromThis = true; releaseSslContextOnFailure = true; } else if (template.sslContextBuilderSupplier != null) { - sslContext = buildSslContext(template.sslContextBuilderSupplier, - tlsAllowUnsafeCiphers, - template.tlsCustomizers); + sslContext = buildSslContext(template.sslContextBuilderSupplier, tlsEngineType, + tlsAllowUnsafeCiphers, template.tlsCustomizers); releaseSslContextOnFailure = true; } @@ -1519,6 +1535,7 @@ private SslContext sslContext(VirtualHostBuilder template) { sslContext = buildSslContext(() -> SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()), + tlsEngineType, tlsAllowUnsafeCiphers, tlsCustomizers); releaseSslContextOnFailure = true; @@ -1531,7 +1548,7 @@ private SslContext sslContext(VirtualHostBuilder template) { // Validate the built `SslContext`. if (sslContext != null) { - validateSslContext(sslContext); + validateSslContext(sslContext, tlsEngineType); checkState(sslContext.isServer(), "sslContextBuilder built a client SSL context."); } releaseSslContextOnFailure = false; diff --git a/core/src/main/java/com/linecorp/armeria/server/WrappingTransientHttpService.java b/core/src/main/java/com/linecorp/armeria/server/WrappingTransientHttpService.java index 77d0327f9ba..9ce420361d2 100644 --- a/core/src/main/java/com/linecorp/armeria/server/WrappingTransientHttpService.java +++ b/core/src/main/java/com/linecorp/armeria/server/WrappingTransientHttpService.java @@ -22,7 +22,7 @@ import com.linecorp.armeria.common.HttpResponse; /** - * Decorates a {@link HttpService} to be treated as {@link TransientService} without inheritance. + * Decorates an {@link HttpService} to be treated as {@link TransientService} without inheritance. */ final class WrappingTransientHttpService extends SimpleDecoratingHttpService implements TransientHttpService { diff --git a/core/src/main/java/com/linecorp/armeria/server/docs/JsonSchemaGenerator.java b/core/src/main/java/com/linecorp/armeria/server/docs/JsonSchemaGenerator.java index 81a746382dc..96447af1506 100644 --- a/core/src/main/java/com/linecorp/armeria/server/docs/JsonSchemaGenerator.java +++ b/core/src/main/java/com/linecorp/armeria/server/docs/JsonSchemaGenerator.java @@ -136,7 +136,7 @@ private JsonSchemaGenerator(ServiceSpecification serviceSpecification, Boolean u ImmutableMap.builderWithExpectedSize(serviceSpecification.structs().size()); for (StructInfo struct : serviceSpecification.structs()) { typeSignatureToStructMappingBuilder.put(struct.name(), struct); - if (struct.alias() != null) { + if (struct.alias() != null && !struct.alias().equals(struct.name())) { // TypeSignature.signature() could be StructInfo.alias() if the type is a protobuf Message. typeSignatureToStructMappingBuilder.put(struct.alias(), struct); } diff --git a/core/src/main/java/com/linecorp/armeria/server/file/FileService.java b/core/src/main/java/com/linecorp/armeria/server/file/FileService.java index 09cf83e3ead..3d89dc95299 100644 --- a/core/src/main/java/com/linecorp/armeria/server/file/FileService.java +++ b/core/src/main/java/com/linecorp/armeria/server/file/FileService.java @@ -23,6 +23,7 @@ import java.nio.file.Path; import java.util.EnumSet; import java.util.Iterator; +import java.util.List; import java.util.Objects; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -257,35 +258,18 @@ private HttpFile findFile(ServiceRequestContext ctx, HttpRequest req) { }); }); } else { - // Redirect to the slash appended path if: - // 1) /index.html exists or - // 2) it has a directory listing. - final String indexPath = decodedMappedPath + "/index.html"; - return findFile(ctx, indexPath, encodings, decompress).thenCompose(indexFile -> { - if (indexFile != null) { - return UnmodifiableFuture.completedFuture(true); - } + final List fallbackExtensions = config.fallbackFileExtensions(); + if (fallbackExtensions.isEmpty()) { + return findFileWithIndexPath(ctx, decodedMappedPath, encodings, decompress); + } - if (!config.autoIndex()) { - return UnmodifiableFuture.completedFuture(false); - } - - return config.vfs().canList(ctx.blockingTaskExecutor(), decodedMappedPath); - }).thenApply(canList -> { - if (canList) { - try (TemporaryThreadLocals ttl = TemporaryThreadLocals.acquire()) { - final StringBuilder locationBuilder = ttl.stringBuilder() - .append(ctx.path()) - .append('/'); - if (ctx.query() != null) { - locationBuilder.append('?') - .append(ctx.query()); - } - return HttpFile.ofRedirect(locationBuilder.toString()); - } - } else { - return HttpFile.nonExistent(); + // Try appending file extensions if it was a file access and file extensions are configured. + return findFileWithExtensions(ctx, fallbackExtensions.iterator(), decodedMappedPath, + encodings, decompress).thenCompose(fileWithExtension -> { + if (fileWithExtension != null) { + return UnmodifiableFuture.completedFuture(fileWithExtension); } + return findFileWithIndexPath(ctx, decodedMappedPath, encodings, decompress); }); } })); @@ -385,6 +369,58 @@ private HttpFile findFile(ServiceRequestContext ctx, HttpRequest req) { }); } + private CompletableFuture<@Nullable HttpFile> findFileWithIndexPath( + ServiceRequestContext ctx, String decodedMappedPath, + Set encodings, boolean decompress) { + // Redirect to the slash appended path if: + // 1) /index.html exists or + // 2) it has a directory listing. + final String indexPath = decodedMappedPath + "/index.html"; + return findFile(ctx, indexPath, encodings, decompress).thenCompose(indexFile -> { + if (indexFile != null) { + return UnmodifiableFuture.completedFuture(true); + } + + if (!config.autoIndex()) { + return UnmodifiableFuture.completedFuture(false); + } + + return config.vfs().canList(ctx.blockingTaskExecutor(), decodedMappedPath); + }).thenApply(canList -> { + if (canList) { + try (TemporaryThreadLocals ttl = TemporaryThreadLocals.acquire()) { + final StringBuilder locationBuilder = ttl.stringBuilder() + .append(ctx.path()) + .append('/'); + if (ctx.query() != null) { + locationBuilder.append('?') + .append(ctx.query()); + } + return HttpFile.ofRedirect(locationBuilder.toString()); + } + } else { + return HttpFile.nonExistent(); + } + }); + } + + private CompletableFuture<@Nullable HttpFile> findFileWithExtensions( + ServiceRequestContext ctx, @Nullable Iterator extensionIterator, String path, + Set supportedEncodings, boolean decompress) { + if (extensionIterator == null || !extensionIterator.hasNext()) { + return UnmodifiableFuture.completedFuture(null); + } + + final String extension = extensionIterator.next(); + return findFile(ctx, path + '.' + extension, supportedEncodings, decompress).thenCompose(file -> { + if (file != null) { + return UnmodifiableFuture.completedFuture(file); + } + + return findFileWithExtensions(ctx, extensionIterator, path, supportedEncodings, decompress); + }); + } + private CompletableFuture<@Nullable HttpFile> findFileAndDecompress( ServiceRequestContext ctx, String path, Set supportedEncodings) { // Look up a non-compressed file first to avoid extra decompression diff --git a/core/src/main/java/com/linecorp/armeria/server/file/FileServiceBuilder.java b/core/src/main/java/com/linecorp/armeria/server/file/FileServiceBuilder.java index 9b303191edf..1366e6ab26d 100644 --- a/core/src/main/java/com/linecorp/armeria/server/file/FileServiceBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/file/FileServiceBuilder.java @@ -16,6 +16,7 @@ package com.linecorp.armeria.server.file; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.linecorp.armeria.server.file.FileServiceConfig.validateEntryCacheSpec; import static com.linecorp.armeria.server.file.FileServiceConfig.validateMaxCacheEntrySizeBytes; @@ -23,9 +24,11 @@ import static java.util.Objects.requireNonNull; import java.time.Clock; +import java.util.List; import java.util.Map.Entry; import com.github.benmanes.caffeine.cache.CaffeineSpec; +import com.google.common.collect.ImmutableList; import com.linecorp.armeria.common.CacheControl; import com.linecorp.armeria.common.Flags; @@ -35,6 +38,7 @@ import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; /** * Builds a new {@link FileService} and its {@link FileServiceConfig}. Use the factory methods in @@ -60,6 +64,9 @@ public final class FileServiceBuilder { HttpHeadersBuilder headers; MediaTypeResolver mediaTypeResolver = MediaTypeResolver.ofDefault(); + @Nullable + private ImmutableList.Builder fallbackFileExtensions; + FileServiceBuilder(HttpVfs vfs) { this.vfs = requireNonNull(vfs, "vfs"); } @@ -153,6 +160,46 @@ public FileServiceBuilder autoIndex(boolean autoIndex) { return this; } + /** + * Adds the file extensions to be considered when resolving file names. + * This method allows specifying alternative file names by appending the provided extensions + * to the requested file name if the initially requested resource is not found. + * + *

    For instance, if {@code "/index"} is requested and {@code "html"} is an added extension, + * {@link FileService} will attempt to serve {@code "/index.html"} if {@code "/index"} is not found. + */ + @UnstableApi + public FileServiceBuilder fallbackFileExtensions(String... extensions) { + requireNonNull(extensions, "extensions"); + return fallbackFileExtensions(ImmutableList.copyOf(extensions)); + } + + /** + * Adds the file extensions to be considered when resolving file names. + * This method allows specifying alternative file names by appending the provided extensions + * to the requested file name if the initially requested resource is not found. + * + *

    For instance, if {@code "/index"} is requested and {@code "html"} is an added extension, + * {@link FileService} will attempt to serve {@code "/index.html"} if {@code "/index"} is not found. + */ + @UnstableApi + public FileServiceBuilder fallbackFileExtensions(Iterable extensions) { + requireNonNull(extensions, "extensions"); + for (String extension : extensions) { + checkArgument(!extension.isEmpty(), "extension is empty"); + checkArgument(extension.charAt(0) != '.', "extension: %s (expected: without a dot)", extension); + } + if (fallbackFileExtensions == null) { + fallbackFileExtensions = ImmutableList.builder(); + } + fallbackFileExtensions.addAll(extensions); + return this; + } + + private List fallbackFileExtensions() { + return fallbackFileExtensions != null ? fallbackFileExtensions.build() : ImmutableList.of(); + } + /** * Returns the immutable additional {@link HttpHeaders} which will be set when building an * {@link HttpResponse}. @@ -248,12 +295,13 @@ public FileService build() { return new FileService(new FileServiceConfig( vfs, clock, entryCacheSpec, maxCacheEntrySizeBytes, serveCompressedFiles, autoDecompress, autoIndex, buildHeaders(), - mediaTypeResolver.orElse(MediaTypeResolver.ofDefault()))); + mediaTypeResolver.orElse(MediaTypeResolver.ofDefault()), fallbackFileExtensions())); } @Override public String toString() { return FileServiceConfig.toString(this, vfs, clock, entryCacheSpec, maxCacheEntrySizeBytes, - serveCompressedFiles, autoIndex, headers, mediaTypeResolver); + serveCompressedFiles, autoIndex, headers, mediaTypeResolver, + fallbackFileExtensions()); } } diff --git a/core/src/main/java/com/linecorp/armeria/server/file/FileServiceConfig.java b/core/src/main/java/com/linecorp/armeria/server/file/FileServiceConfig.java index 401438092e4..8766acd95a7 100644 --- a/core/src/main/java/com/linecorp/armeria/server/file/FileServiceConfig.java +++ b/core/src/main/java/com/linecorp/armeria/server/file/FileServiceConfig.java @@ -19,6 +19,7 @@ import static java.util.Objects.requireNonNull; import java.time.Clock; +import java.util.List; import java.util.Map.Entry; import com.github.benmanes.caffeine.cache.CaffeineSpec; @@ -28,6 +29,7 @@ import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; import io.netty.util.AsciiString; @@ -46,10 +48,12 @@ public final class FileServiceConfig { private final boolean autoIndex; private final HttpHeaders headers; private final MediaTypeResolver mediaTypeResolver; + private final List fallbackFileExtensions; FileServiceConfig(HttpVfs vfs, Clock clock, @Nullable String entryCacheSpec, int maxCacheEntrySizeBytes, boolean serveCompressedFiles, boolean autoDecompress, boolean autoIndex, - HttpHeaders headers, MediaTypeResolver mediaTypeResolver) { + HttpHeaders headers, MediaTypeResolver mediaTypeResolver, + List fallbackFileExtensions) { this.vfs = requireNonNull(vfs, "vfs"); this.clock = requireNonNull(clock, "clock"); this.entryCacheSpec = validateEntryCacheSpec(entryCacheSpec); @@ -59,6 +63,7 @@ public final class FileServiceConfig { this.autoIndex = autoIndex; this.headers = requireNonNull(headers, "headers"); this.mediaTypeResolver = requireNonNull(mediaTypeResolver, "mediaTypeResolver"); + this.fallbackFileExtensions = requireNonNull(fallbackFileExtensions, "fallbackFileExtensions"); } @Nullable @@ -152,17 +157,26 @@ public MediaTypeResolver mediaTypeResolver() { return mediaTypeResolver; } + /** + * Returns the file extensions that are appended to the file name when the file is not found. + */ + @UnstableApi + public List fallbackFileExtensions() { + return fallbackFileExtensions; + } + @Override public String toString() { return toString(this, vfs(), clock(), entryCacheSpec(), maxCacheEntrySizeBytes(), - serveCompressedFiles(), autoIndex(), headers(), mediaTypeResolver()); + serveCompressedFiles(), autoIndex(), headers(), mediaTypeResolver(), + fallbackFileExtensions()); } static String toString(Object holder, HttpVfs vfs, Clock clock, @Nullable String entryCacheSpec, int maxCacheEntrySizeBytes, boolean serveCompressedFiles, boolean autoIndex, @Nullable Iterable> headers, - MediaTypeResolver mediaTypeResolver) { + MediaTypeResolver mediaTypeResolver, @Nullable List fallbackFileExtensions) { return MoreObjects.toStringHelper(holder).omitNullValues() .add("vfs", vfs) @@ -173,6 +187,7 @@ static String toString(Object holder, HttpVfs vfs, Clock clock, .add("autoIndex", autoIndex) .add("headers", headers) .add("mediaTypeResolver", mediaTypeResolver) + .add("fallbackFileExtensions", fallbackFileExtensions) .toString(); } } diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java index ea6b36c628e..8c0186adf72 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java @@ -22,6 +22,7 @@ import com.linecorp.armeria.common.websocket.WebSocket; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.ServiceConfig; +import com.linecorp.armeria.server.ServiceOptions; import com.linecorp.armeria.server.ServiceRequestContext; /** @@ -68,4 +69,9 @@ default HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Ex * Returns the {@link WebSocketProtocolHandler} of this service. */ WebSocketProtocolHandler protocolHandler(); + + @Override + default ServiceOptions options() { + return WebSocketServiceBuilder.DEFAULT_OPTIONS; + } } diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java index 7a09381f130..a0054b20424 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java @@ -39,6 +39,7 @@ import com.linecorp.armeria.internal.server.websocket.DefaultWebSocketService; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.ServiceConfig; +import com.linecorp.armeria.server.ServiceOptions; /** * Builds a {@link WebSocketService}. @@ -61,6 +62,13 @@ public final class WebSocketServiceBuilder { static final int DEFAULT_MAX_FRAME_PAYLOAD_LENGTH = 65535; // 64 * 1024 -1 + static final ServiceOptions DEFAULT_OPTIONS = ServiceOptions + .builder() + .requestTimeoutMillis(WebSocketUtil.DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS) + .maxRequestLength(WebSocketUtil.DEFAULT_MAX_REQUEST_RESPONSE_LENGTH) + .requestAutoAbortDelayMillis(WebSocketUtil.DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS) + .build(); + private final WebSocketServiceHandler handler; private int maxFramePayloadLength = DEFAULT_MAX_FRAME_PAYLOAD_LENGTH; @@ -73,6 +81,7 @@ public final class WebSocketServiceBuilder { private boolean aggregateContinuation; @Nullable private HttpService fallbackService; + private ServiceOptions serviceOptions = DEFAULT_OPTIONS; WebSocketServiceBuilder(WebSocketServiceHandler handler) { this.handler = requireNonNull(handler, "handler"); @@ -202,12 +211,24 @@ private static Set validateOrigins(Iterable allowedOrigins) { return copied; } + /** + * Sets the {@link ServiceOptions} for the {@link WebSocketService}. + * If not set, {@link WebSocketService#options()} is used. + */ + public WebSocketServiceBuilder serviceOptions(ServiceOptions serviceOptions) { + requireNonNull(serviceOptions, "serviceOptions"); + this.serviceOptions = serviceOptions; + return this; + } + /** * Sets the fallback {@link HttpService} to use when the request is not a valid WebSocket upgrade request. * This is useful when you want to serve both WebSocket and HTTP requests at the same path. */ public WebSocketServiceBuilder fallbackService(HttpService fallbackService) { this.fallbackService = requireNonNull(fallbackService, "fallbackService"); + checkArgument(!(fallbackService instanceof WebSocketService), + "fallbackService must not be a WebSocketService."); return this; } @@ -225,7 +246,7 @@ public WebSocketService build() { originPredicate = this.originPredicate; } return new DefaultWebSocketService(handler, fallbackService, maxFramePayloadLength, allowMaskMismatch, - subprotocols, allowAnyOrigin, - originPredicate, aggregateContinuation); + subprotocols, allowAnyOrigin, originPredicate, aggregateContinuation, + serviceOptions); } } diff --git a/core/src/main/resources/com/linecorp/armeria/public_suffixes.txt b/core/src/main/resources/com/linecorp/armeria/public_suffixes.txt index 22b9d2f6226..a72cc4b5638 100644 --- a/core/src/main/resources/com/linecorp/armeria/public_suffixes.txt +++ b/core/src/main/resources/com/linecorp/armeria/public_suffixes.txt @@ -11,13 +11,16 @@ *.advisor.ws *.af-south-1.airflow.amazonaws.com *.alces.network -*.amplifyapp.com *.ap-east-1.airflow.amazonaws.com *.ap-northeast-1.airflow.amazonaws.com *.ap-northeast-2.airflow.amazonaws.com +*.ap-northeast-3.airflow.amazonaws.com *.ap-south-1.airflow.amazonaws.com +*.ap-south-2.airflow.amazonaws.com *.ap-southeast-1.airflow.amazonaws.com *.ap-southeast-2.airflow.amazonaws.com +*.ap-southeast-3.airflow.amazonaws.com +*.ap-southeast-4.airflow.amazonaws.com *.awdev.ca *.awsapprunner.com *.azurecontainer.io @@ -30,6 +33,7 @@ *.bzz.dapps.earth *.c.ts.net *.ca-central-1.airflow.amazonaws.com +*.ca-west-1.airflow.amazonaws.com *.ck *.cloud.metacentrum.cz *.cloudera.site @@ -41,9 +45,9 @@ *.compute.amazonaws.com *.compute.amazonaws.com.cn *.compute.estate -*.cprapid.com *.cryptonomic.net *.customer-oci.com +*.d.crm.dev *.dapps.earth *.database.run *.dev-builder.code.com @@ -57,14 +61,17 @@ *.elb.amazonaws.com.cn *.er *.eu-central-1.airflow.amazonaws.com +*.eu-central-2.airflow.amazonaws.com *.eu-north-1.airflow.amazonaws.com *.eu-south-1.airflow.amazonaws.com +*.eu-south-2.airflow.amazonaws.com *.eu-west-1.airflow.amazonaws.com *.eu-west-2.airflow.amazonaws.com *.eu-west-3.airflow.amazonaws.com *.ewp.live *.ex.futurecms.at *.ex.ortsinfo.at +*.experiments.sagemaker.aws *.firenet.ch *.fk *.frusky.de @@ -74,6 +81,7 @@ *.hosting.myjino.ru *.hosting.ovh.net *.id.pub +*.il-central-1.airflow.amazonaws.com *.in.futurecms.at *.jm *.kawasaki.jp @@ -88,6 +96,7 @@ *.lclstage.dev *.linodeobjects.com *.magentosite.cloud +*.me-central-1.airflow.amazonaws.com *.me-south-1.airflow.amazonaws.com *.migration.run *.mm @@ -151,9 +160,16 @@ *.usercontent.goog *.vps.myjino.ru *.vultrobjects.com +*.w.crm.dev +*.wa.crm.dev *.wadl.top +*.wb.crm.dev +*.wc.crm.dev +*.wd.crm.dev +*.we.crm.dev *.webhare.dev *.webpaas.ovh.net +*.wf.crm.dev *.xmit.co *.yokohama.jp 0.bg @@ -478,9 +494,9 @@ ami.ibaraki.jp amica amli.no amot.no +amplifyapp.com amscompute.com amsterdam -amusement.aero an.it analytics analytics-gateway.ap-northeast-1.amazonaws.com @@ -607,6 +623,7 @@ arts.ro arts.ve arvo.network as +as.sh.cn as.us asago.hyogo.jp asahi.chiba.jp @@ -696,6 +713,7 @@ auth-fips.us-gov-west-1.amazoncognito.com auth-fips.us-west-1.amazoncognito.com auth-fips.us-west-2.amazoncognito.com auth.af-south-1.amazoncognito.com +auth.ap-east-1.amazoncognito.com auth.ap-northeast-1.amazoncognito.com auth.ap-northeast-2.amazoncognito.com auth.ap-northeast-3.amazoncognito.com @@ -706,6 +724,7 @@ auth.ap-southeast-2.amazoncognito.com auth.ap-southeast-3.amazoncognito.com auth.ap-southeast-4.amazoncognito.com auth.ca-central-1.amazoncognito.com +auth.ca-west-1.amazoncognito.com auth.eu-central-1.amazoncognito.com auth.eu-central-2.amazoncognito.com auth.eu-north-1.amazoncognito.com @@ -823,7 +842,9 @@ barsy.ca barsy.club barsy.co.uk barsy.de +barsy.dev barsy.eu +barsy.gr barsy.in barsy.info barsy.io @@ -836,13 +857,17 @@ barsy.org barsy.pro barsy.pub barsy.ro +barsy.rs barsy.shop barsy.site +barsy.store barsy.support barsy.uk barsycenter.com barsyonline.co.uk barsyonline.com +barsyonline.menu +barsyonline.shop barueri.br barum.no bas.it @@ -1361,6 +1386,7 @@ cf cf-ipfs.com cfa cfd +cfolks.pl cg ch ch.eu.org @@ -1516,7 +1542,6 @@ cn.eu.org cn.in cn.it cn.ua -cn.vu cng.br cnpy.gdn cnt.br @@ -1808,11 +1833,13 @@ courses coz.br cpa cpa.pro +cprapid.com cpserver.com cq.cn cr cr.it cr.ua +craft.me crafting.xyz cranky.jp crap.jp @@ -1981,6 +2008,7 @@ development.run devices.resinstaging.io df.gov.br df.leg.br +dfirma.pl dgca.aero dh.bytemark.co.uk dhl @@ -2014,6 +2042,7 @@ diy dj dk dk.eu.org +dkonto.pl dlugoleka.pl dm dn.ua @@ -2088,6 +2117,7 @@ duckdns.org dunlop dupont durban +durumis.com dvag dvr dvrcam.info @@ -2783,6 +2813,7 @@ freemyip.com freesite.host freetls.fastly.net frei.no +freight.aero frenchkiss.jp fresenius friuli-v-giulia.it @@ -2930,7 +2961,6 @@ fylkesbibl.no fyresdal.no g.bg g.se -g.vbrplsbx.io g12.br ga ga.us @@ -3423,6 +3453,12 @@ hasuda.saitama.jp hasura-app.io hasura.app hasvik.no +hateblo.jp +hatenablog.com +hatenablog.jp +hatenadiary.com +hatenadiary.jp +hatenadiary.org hatinh.vn hatogaya.saitama.jp hatoyama.saitama.jp @@ -3447,6 +3483,7 @@ health.vn healthcare heavy.jp heguri.nara.jp +heiyu.space hekinan.aichi.jp helioho.st heliohost.us @@ -3798,6 +3835,7 @@ ind.br ind.gt ind.in ind.kw +ind.mom ind.tn independent-commission.uk independent-inquest.uk @@ -3858,7 +3896,6 @@ ink ino.kochi.jp instance.datadetect.com instances.spawn.cc -instantcloud.cn institute insurance insurance.aero @@ -4953,6 +4990,7 @@ macapa.br maceio.br macerata.it machida.tokyo.jp +madethis.site madrid maebashi.gunma.jp magazine.aero @@ -4994,6 +5032,7 @@ maringa.br marker.no market marketing +marketplace.aero markets marnardal.no marriott @@ -5078,6 +5117,7 @@ media media.aero media.hu media.pl +media.strapiapp.com mediatech.by mediatech.dev medicina.bo @@ -5499,6 +5539,7 @@ mypsx.net myqnapcloud.cn myqnapcloud.com myradweb.net +myrdbx.io mysecuritycamera.com mysecuritycamera.net mysecuritycamera.org @@ -5648,7 +5689,6 @@ nasushiobara.tochigi.jp nat.tn natal.br natori.miyagi.jp -natura natural.bo naturbruksgymn.se naustdal.no @@ -5824,6 +5864,7 @@ net.za net.zm netbank netflix +netfy.app netgamers.jp netlify.app network @@ -6106,6 +6147,7 @@ obihiro.hokkaido.jp obira.hokkaido.jp objects.lpg.cloudscale.ch objects.rma.cloudscale.ch +obl.ong obninsk.su observablehq.cloud observer @@ -6935,6 +6977,8 @@ reit reklam.hu rel.ht rel.pl +relay.evervault.app +relay.evervault.dev reliance remotewd.com ren @@ -7015,7 +7059,6 @@ rn.leg.br ro ro.eu.org ro.gov.br -ro.im ro.it ro.leg.br roan.no @@ -7650,7 +7693,6 @@ shakotan.hokkaido.jp shangrila shari.hokkaido.jp sharp -shaw sheezy.games shell shia @@ -7791,6 +7833,7 @@ siracusa.it sirdal.no sisko.replit.dev site +site.rb-hosting.io site.tb-hosting.com site.transip.me siteleaf.net @@ -7987,6 +8030,7 @@ storipress.app storj.farm strand.no stranda.no +strapiapp.com streak-link.com streaklinks.com streakusercontent.com @@ -8220,6 +8264,7 @@ tattoo tawaramoto.nara.jp tax taxi +taxi.aero taxi.br tayninh.vn tc @@ -9094,11 +9139,13 @@ works.aero world worse-than.tv wow +wp2.host wpdevcloud.com wpenginepowered.com wphostedmail.com wpmucdn.com wpmudev.host +wpsquared.site writesthisblog.com wroc.pl wroclaw.pl @@ -9700,6 +9747,7 @@ yoshinogari.saga.jp yoshioka.gunma.jp yotsukaido.chiba.jp you +you2.pl youtube yt yuasa.wakayama.jp diff --git a/core/src/test/java/com/linecorp/armeria/client/Http2ClientSettingsTest.java b/core/src/test/java/com/linecorp/armeria/client/Http2ClientSettingsTest.java index 64a3ffe5577..87dbed6f26b 100644 --- a/core/src/test/java/com/linecorp/armeria/client/Http2ClientSettingsTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/Http2ClientSettingsTest.java @@ -259,7 +259,7 @@ private static void readHeadersFrame(InputStream in) throws IOException { readBytes(in, payloadLength); } - private static int payloadLength(byte[] buf) { + static int payloadLength(byte[] buf) { return (buf[0] & 0xff) << 16 | (buf[1] & 0xff) << 8 | (buf[2] & 0xff); } diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpClientExpect100HeaderTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpClientExpect100HeaderTest.java new file mode 100644 index 00000000000..c2c7115f2fd --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/HttpClientExpect100HeaderTest.java @@ -0,0 +1,580 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static com.linecorp.armeria.client.Http2ClientSettingsTest.payloadLength; +import static com.linecorp.armeria.client.Http2ClientSettingsTest.readBytes; +import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +import java.io.BufferedOutputStream; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; + +import org.apache.hc.core5.http2.hpack.HPackEncoder; +import org.apache.hc.core5.util.ByteArrayBuffer; +import org.jetbrains.annotations.Nullable; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketSession; +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.HttpHeaders; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpRequestWriter; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.stream.CancelledSubscriptionException; +import com.linecorp.armeria.server.Server; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http2.DefaultHttp2HeadersDecoder; +import io.netty.handler.codec.http2.Http2CodecUtil; +import io.netty.handler.codec.http2.Http2Flags; +import io.netty.handler.codec.http2.Http2FrameTypes; +import io.netty.handler.codec.http2.Http2Headers; + +/** + * This test is to check the behavior of the HttpClient when the 'Expect: 100-continue' header is set. + */ +final class HttpClientExpect100HeaderTest { + + /////////////////// + // Empty content // + /////////////////// + @Test + void sendRequestWithEmptyContent() throws Exception { + try (ServerSocket ss = new ServerSocket(0)) { + final int port = ss.getLocalPort(); + final WebClient client = WebClient.of("h1c://127.0.0.1:" + port); + final CompletableFuture future = + client.prepare() + .get("/") + .header(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE) + .execute() + .aggregate(); + + assertThatThrownBy(future::join) + .hasCauseInstanceOf(UnprocessedRequestException.class) + .cause() + .hasCauseInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Expect: 100-continue header"); + } + } + + ///////////////////////////////////ㄷ + // Response Status: 100 Continue // + /////////////////////////////////// + @Test + void continueToSendRequestOnHttp1() throws Exception { + try (ServerSocket ss = new ServerSocket(0)) { + ss.setSoTimeout(10000); + + final int port = ss.getLocalPort(); + final WebClient client = WebClient.of("h1c://127.0.0.1:" + port); + final CompletableFuture future = + client.prepare() + .post("/") + .content("foo\n") + .header(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE) + .execute() + .aggregate(); + + try (Socket s = ss.accept()) { + final InputStream inputStream = s.getInputStream(); + final BufferedReader in = new BufferedReader( + new InputStreamReader(inputStream, StandardCharsets.US_ASCII)); + final OutputStream out = s.getOutputStream(); + + assertThat(in.readLine()).isEqualTo("POST / HTTP/1.1"); + assertThat(in.readLine()).startsWith("host: 127.0.0.1:"); + assertThat(in.readLine()).isEqualTo("content-type: text/plain; charset=utf-8"); + assertThat(in.readLine()).isEqualTo("expect: 100-continue"); + assertThat(in.readLine()).isEqualTo("content-length: 4"); + assertThat(in.readLine()).startsWith("user-agent: armeria/"); + assertThat(in.readLine()).isEmpty(); + + // Check that the data is not sent until sending 100-continue response. + Thread.sleep(1000); + assertThat(inputStream.available()).isZero(); + + out.write("HTTP/1.1 100 Continue\r\n\r\n".getBytes(StandardCharsets.US_ASCII)); + + assertThat(in.readLine()).isEqualTo("foo"); + + out.write(("HTTP/1.1 201 Created\r\n" + + "Connection: close\r\n" + + "Content-Length: 0\r\n" + + "\r\n").getBytes(StandardCharsets.US_ASCII)); + + assertThat(in.readLine()).isNull(); + + final AggregatedHttpResponse res = future.join(); + assertThat(res.status()).isEqualTo(HttpStatus.CREATED); + } + } + } + + @Test + void continueToSendRequestOnHttp2() throws Exception { + try (ServerSocket ss = new ServerSocket(0); + ClientFactory clientFactory = + ClientFactory.builder() + .useHttp2Preface(true) + .http2InitialConnectionWindowSize(Http2CodecUtil.DEFAULT_WINDOW_SIZE) + .http2InitialStreamWindowSize(Http2CodecUtil.DEFAULT_WINDOW_SIZE) + .build()) { + ss.setSoTimeout(10000); + + final int port = ss.getLocalPort(); + final WebClient client = WebClient.builder("http://127.0.0.1:" + port) + .factory(clientFactory) + .build(); + final CompletableFuture future = + client.prepare() + .post("/") + .content("foo") + .header(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE) + .execute() + .aggregate(); + + try (Socket s = ss.accept()) { + final InputStream in = s.getInputStream(); + final BufferedOutputStream bos = new BufferedOutputStream(s.getOutputStream()); + + // Read the connection preface and discard it. + readBytes(in, connectionPrefaceBuf().capacity()); + + // Read a SETTINGS frame and validate it. + readSettingsFrame(in); + sendEmptySettingsAndAckFrame(bos); + + readBytes(in, 9); // Read a SETTINGS_ACK frame and discard it. + + // Read a HEADERS frame and validate it. + readHeadersFrame(new DefaultHttp2HeadersDecoder(), in, true); + // Check that the data is not sent until sending 100-continue response. + Thread.sleep(1000); + assertThat(in.available()).isZero(); + // Send a CONTINUE response. + sendFrameHeaders(bos, HttpStatus.CONTINUE, false, 3); + + // Read a DATA frame. + readDataFrame(in); + // Send a response. + sendFrameHeaders(bos, HttpStatus.CREATED, true, 3); + + final AggregatedHttpResponse res = future.join(); + assertThat(res.status()).isEqualTo(HttpStatus.CREATED); + } + } + } + + ///////////////////////////////////////////// + // Response Status: 417 Expectation Failed // + ///////////////////////////////////////////// + @Test + void expectationFailedHttp1() throws Exception { + try (ServerSocket ss = new ServerSocket(0)) { + ss.setSoTimeout(10000); + + final int port = ss.getLocalPort(); + final WebClient client = WebClient.of("h1c://127.0.0.1:" + port); + final CompletableFuture future = + client.prepare() + .post("/") + .content("foo\n") + .header(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE) + .execute() + .aggregate(); + + try (Socket s = ss.accept()) { + final InputStream inputStream = s.getInputStream(); + final BufferedReader in = new BufferedReader( + new InputStreamReader(inputStream, StandardCharsets.US_ASCII)); + final OutputStream out = s.getOutputStream(); + + assertThat(in.readLine()).isEqualTo("POST / HTTP/1.1"); + assertThat(in.readLine()).startsWith("host: 127.0.0.1:"); + assertThat(in.readLine()).isEqualTo("content-type: text/plain; charset=utf-8"); + assertThat(in.readLine()).isEqualTo("expect: 100-continue"); + assertThat(in.readLine()).isEqualTo("content-length: 4"); + assertThat(in.readLine()).startsWith("user-agent: armeria/"); + assertThat(in.readLine()).isEmpty(); + + // Check that the data is not sent until sending 100-continue response. + Thread.sleep(1000); + assertThat(inputStream.available()).isZero(); + + out.write(("HTTP/1.1 417 Expectation Failed\r\n" + + "Content-Length: 0\r\n" + + "\r\n").getBytes(StandardCharsets.US_ASCII)); + + final AggregatedHttpResponse res = future.join(); + assertThat(res.status()).isEqualTo(HttpStatus.EXPECTATION_FAILED); + } + } + } + + @Test + void expectationFailedHttp2() throws Exception { + try (ServerSocket ss = new ServerSocket(0); + ClientFactory clientFactory = + ClientFactory.builder() + .useHttp2Preface(true) + .http2InitialConnectionWindowSize(Http2CodecUtil.DEFAULT_WINDOW_SIZE) + .http2InitialStreamWindowSize(Http2CodecUtil.DEFAULT_WINDOW_SIZE) + .build()) { + ss.setSoTimeout(10000); + + final int port = ss.getLocalPort(); + final WebClient client = WebClient.builder("http://127.0.0.1:" + port) + .factory(clientFactory) + .build(); + final CompletableFuture future = + client.prepare() + .post("/") + .content("foo") + .header(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE) + .execute() + .aggregate(); + + try (Socket s = ss.accept()) { + final InputStream in = s.getInputStream(); + final BufferedOutputStream bos = new BufferedOutputStream(s.getOutputStream()); + + // Read the connection preface and discard it. + readBytes(in, connectionPrefaceBuf().capacity()); + + // Read a SETTINGS frame and validate it. + readSettingsFrame(in); + sendEmptySettingsAndAckFrame(bos); + + readBytes(in, 9); // Read a SETTINGS_ACK frame and discard it. + + final DefaultHttp2HeadersDecoder headersDecoder = new DefaultHttp2HeadersDecoder(); + // Read a HEADERS frame and validate it. + readHeadersFrame(headersDecoder, in, true); + // Check that the data is not sent until sending 100-continue response. + Thread.sleep(1000); + assertThat(in.available()).isZero(); + // Send a 417 Expectation Failed response. + sendFrameHeaders(bos, HttpStatus.EXPECTATION_FAILED, true, 3); + + final AggregatedHttpResponse res = future.join(); + assertThat(res.status()).isEqualTo(HttpStatus.EXPECTATION_FAILED); + } + } + } + + ///////////////////////////// + // Response Status: Others // + ///////////////////////////// + @Test + void receiveResponseWithStandardFlowOnHttp1() throws Exception { + try (ServerSocket ss = new ServerSocket(0)) { + ss.setSoTimeout(10000); + + final int port = ss.getLocalPort(); + final WebClient client = WebClient.of("h1c://127.0.0.1:" + port); + final CompletableFuture future = + client.prepare() + .post("/") + .content("foo\n") + .header(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE) + .execute() + .aggregate(); + + try (Socket s = ss.accept()) { + final InputStream inputStream = s.getInputStream(); + final BufferedReader in = new BufferedReader( + new InputStreamReader(inputStream, StandardCharsets.US_ASCII)); + final OutputStream out = s.getOutputStream(); + + assertThat(in.readLine()).isEqualTo("POST / HTTP/1.1"); + assertThat(in.readLine()).startsWith("host: 127.0.0.1:"); + assertThat(in.readLine()).isEqualTo("content-type: text/plain; charset=utf-8"); + assertThat(in.readLine()).isEqualTo("expect: 100-continue"); + assertThat(in.readLine()).isEqualTo("content-length: 4"); + assertThat(in.readLine()).startsWith("user-agent: armeria/"); + assertThat(in.readLine()).isEmpty(); + + // Check that the data is not sent until sending 100-continue response. + Thread.sleep(1000); + assertThat(inputStream.available()).isZero(); + + out.write(("HTTP/1.1 201 Created\r\n" + + "Connection: close\r\n" + + "Content-Length: 0\r\n" + + "\r\n").getBytes(StandardCharsets.US_ASCII)); + + assertThat(in.readLine()).isNull(); + + // Receive the response with the standard flow. + final AggregatedHttpResponse res = future.join(); + assertThat(res.status()).isEqualTo(HttpStatus.CREATED); + } + } + } + + @Test + void receiveResponseWithStandardFlowOnHttp2() throws Exception { + try (ServerSocket ss = new ServerSocket(0); + ClientFactory clientFactory = + ClientFactory.builder() + .useHttp2Preface(true) + .http2InitialConnectionWindowSize(Http2CodecUtil.DEFAULT_WINDOW_SIZE) + .http2InitialStreamWindowSize(Http2CodecUtil.DEFAULT_WINDOW_SIZE) + .build()) { + ss.setSoTimeout(10000); + + final int port = ss.getLocalPort(); + final WebClient client = WebClient.builder("http://127.0.0.1:" + port) + .factory(clientFactory) + .build(); + final CompletableFuture future = + client.prepare() + .post("/") + .content("foo") + .header(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE) + .execute() + .aggregate(); + + try (Socket s = ss.accept()) { + final InputStream in = s.getInputStream(); + final BufferedOutputStream bos = new BufferedOutputStream(s.getOutputStream()); + + // Read the connection preface and discard it. + readBytes(in, connectionPrefaceBuf().capacity()); + + // Read a SETTINGS frame and validate it. + readSettingsFrame(in); + sendEmptySettingsAndAckFrame(bos); + + readBytes(in, 9); // Read a SETTINGS_ACK frame and discard it. + + // Read a HEADERS frame and validate it. + readHeadersFrame(new DefaultHttp2HeadersDecoder(), in, true); + // Check that the data is not sent until sending 100-continue response. + Thread.sleep(1000); + assertThat(in.available()).isZero(); + // Send a CREATED response. + sendFrameHeaders(bos, HttpStatus.CREATED, true, 3); + + final AggregatedHttpResponse res = future.join(); + assertThat(res.status()).isEqualTo(HttpStatus.CREATED); + } + } + } + + @Test + void timeoutFor100Continue() throws Exception { + try (ServerSocket ss = new ServerSocket(0)) { + final int port = ss.getLocalPort(); + final WebClient client = WebClient.builder("h1c://127.0.0.1:" + port) + .responseTimeoutMillis(500) + .build(); + final CompletableFuture future = + client.prepare() + .post("/") + .content("foo\n") + .header(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE) + .execute() + .aggregate(); + + try (Socket s = ss.accept()) { + final InputStream inputStream = s.getInputStream(); + final BufferedReader in = new BufferedReader( + new InputStreamReader(inputStream, StandardCharsets.US_ASCII)); + assertThat(in.readLine()).isEqualTo("POST / HTTP/1.1"); + assertThat(in.readLine()).startsWith("host: 127.0.0.1:"); + assertThat(in.readLine()).isEqualTo("content-type: text/plain; charset=utf-8"); + assertThat(in.readLine()).isEqualTo("expect: 100-continue"); + assertThat(in.readLine()).isEqualTo("content-length: 4"); + assertThat(in.readLine()).startsWith("user-agent: armeria/"); + assertThat(in.readLine()).isEmpty(); + + // Do not send response so that the client will time out. + + await().atMost(Duration.ofSeconds(2)).untilAsserted(() -> { + assertThatThrownBy(future::join).hasCauseInstanceOf(ResponseTimeoutException.class); + }); + } + } + } + + @CsvSource({ "true", "false" }) + @ParameterizedTest + void streamingRequest(boolean send100Continue) throws Exception { + try (ServerSocket ss = new ServerSocket(0)) { + final int port = ss.getLocalPort(); + final WebClient client = WebClient.of("h1c://127.0.0.1:" + port); + final RequestHeaders headers = + RequestHeaders.builder(HttpMethod.POST, "/") + .contentType(MediaType.PLAIN_TEXT_UTF_8) + .add(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE.toString()) + .build(); + final HttpRequestWriter req = HttpRequest.streaming(headers); + + final CompletableFuture future = client.execute(req).aggregate(); + + req.write(HttpData.ofUtf8("foo")); + req.close(); + + try (Socket s = ss.accept()) { + final InputStream inputStream = s.getInputStream(); + final BufferedReader in = new BufferedReader( + new InputStreamReader(inputStream, StandardCharsets.US_ASCII)); + final OutputStream out = s.getOutputStream(); + + assertThat(in.readLine()).isEqualTo("POST / HTTP/1.1"); + assertThat(in.readLine()).startsWith("host: 127.0.0.1:"); + assertThat(in.readLine()).isEqualTo("content-type: text/plain; charset=utf-8"); + assertThat(in.readLine()).isEqualTo("expect: 100-continue"); + assertThat(in.readLine()).startsWith("user-agent: armeria/"); + assertThat(in.readLine()).isEqualTo("transfer-encoding: chunked"); + assertThat(in.readLine()).isEmpty(); + + if (send100Continue) { + out.write("HTTP/1.1 100 Continue\r\n\r\n".getBytes(StandardCharsets.US_ASCII)); + + assertThat(in.readLine()).isEqualTo("3"); + assertThat(in.readLine()).isEqualTo("foo"); + } + out.write(("HTTP/1.1 201 Created\r\n" + + "Connection: close\r\n" + + "Content-Length: 0\r\n" + + "\r\n").getBytes(StandardCharsets.US_ASCII)); + final AggregatedHttpResponse res = future.join(); + assertThat(res.status()).isEqualTo(HttpStatus.CREATED); + if (!send100Continue) { + // request body wasn't sent so cancelled. + assertThatThrownBy(() -> req.whenComplete().join()).hasCauseInstanceOf( + CancelledSubscriptionException.class); + } + } + } + } + + @Test + void webSocketFails() { + final Server server = Server.builder().service("/", (ctx, req) -> HttpResponse.of(200)).build(); + server.start().join(); + final WebSocketClient webSocketClient = WebSocketClient.of( + "h1c://127.0.0.1:" + server.activePort().localAddress().getPort()); + final CompletableFuture future = webSocketClient.connect( + "/", HttpHeaders.of(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE.toString())); + assertThatThrownBy(future::join) + .hasCauseInstanceOf(UnprocessedRequestException.class) + .cause() + .hasCauseInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("a WebSocket request is not allowed to have Expect: 100-continue header"); + server.stop().join(); + } + + private static void readSettingsFrame(InputStream in) throws Exception { + final byte[] expected = { + 0x00, 0x00, 0x0c, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, // SETTINGS_ENABLE_PUSH = 0 (disabled) + 0x00, 0x06, 0x00, 0x00, 0x20, 0x00 // MAX_HEADER_LIST_SIZE = 8192 + }; + assertThat(readBytes(in, expected.length)).containsExactly(expected); + } + + private static void sendEmptySettingsAndAckFrame(BufferedOutputStream bos) throws IOException { + // Send an empty SETTINGS frame. + bos.write(new byte[] { 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00 }); + // Send a SETTINGS_ACK frame. + bos.write(new byte[] { 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, 0x00, 0x00 }); + bos.flush(); + } + + private static void readHeadersFrame(DefaultHttp2HeadersDecoder headersDecoder, InputStream in, + boolean hasExpect100ContinueHeader) throws Exception { + final byte[] frameHeader = readBytes(in, 9); + final int payloadLength = payloadLength(frameHeader); + final byte[] headersPayload = readBytes(in, payloadLength); + + final ByteBuf payloadBuf = Unpooled.wrappedBuffer(headersPayload); + final Http2Headers headers = headersDecoder.decodeHeaders(0, payloadBuf); + + assertThat(get(headers, HttpHeaderNames.METHOD)).isEqualTo("POST"); + assertThat(get(headers, HttpHeaderNames.PATH)).isEqualTo("/"); + assertThat(get(headers, HttpHeaderNames.SCHEME)).isEqualTo("http"); + assertThat(get(headers, HttpHeaderNames.AUTHORITY)).startsWith("127.0.0.1"); + assertThat(get(headers, HttpHeaderNames.USER_AGENT)).startsWith("armeria/"); + assertThat(get(headers, HttpHeaderNames.CONTENT_TYPE)).isEqualTo("text/plain; charset=utf-8"); + if (hasExpect100ContinueHeader) { + assertThat(get(headers, HttpHeaderNames.EXPECT)).isEqualTo(HttpHeaderValues.CONTINUE.toString()); + } + } + + @Nullable + private static String get(Http2Headers headers, CharSequence name) { + final CharSequence value = headers.get(name); + return value != null ? value.toString() : null; + } + + private static void sendFrameHeaders(BufferedOutputStream bos, + HttpStatus status, + boolean endOfStream, int streamId) throws Exception { + final HPackEncoder encoder = new HPackEncoder(StandardCharsets.UTF_8); + final ByteArrayBuffer buffer = new ByteArrayBuffer(1024); + encoder.encodeHeader(buffer, ":status", status.codeAsText(), false); + final byte[] headersPayload = buffer.toByteArray(); + + final ByteBuf buf = Unpooled.buffer(FRAME_HEADER_LENGTH + headersPayload.length); + buf.writeMedium(headersPayload.length); + buf.writeByte(Http2FrameTypes.HEADERS); + buf.writeByte(new Http2Flags().endOfHeaders(true).endOfStream(endOfStream).value()); + buf.writeInt(streamId); + buf.writeBytes(headersPayload); + + bos.write(buf.array()); + bos.flush(); + } + + private static void readDataFrame(InputStream in) throws Exception { + final byte[] frameHeader = readBytes(in, 9); + final int payloadLength = payloadLength(frameHeader); + final byte[] payloadBuf = readBytes(in, payloadLength); + + assertThat(payloadBuf).containsExactly("foo".getBytes(StandardCharsets.UTF_8)); + } + + private HttpClientExpect100HeaderTest() {} +} diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpResponseWrapperTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpResponseWrapperTest.java index 092094e43fe..9b5a9e10047 100644 --- a/core/src/test/java/com/linecorp/armeria/client/HttpResponseWrapperTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/HttpResponseWrapperTest.java @@ -160,7 +160,7 @@ private static HttpResponseWrapper httpResponseWrapper(DecodedHttpResponse res) final TestHttpResponseDecoder decoder = new TestHttpResponseDecoder(channel, controller); res.init(controller); - return decoder.addResponse(1, res, cctx, cctx.eventLoop()); + return decoder.addResponse(null, 1, res, cctx, cctx.eventLoop()); } private static class TestHttpResponseDecoder extends AbstractHttpResponseDecoder { diff --git a/core/src/test/java/com/linecorp/armeria/client/WebClientExchangeTypeTest.java b/core/src/test/java/com/linecorp/armeria/client/WebClientExchangeTypeTest.java index ff73a1d8e93..6f161d017ca 100644 --- a/core/src/test/java/com/linecorp/armeria/client/WebClientExchangeTypeTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/WebClientExchangeTypeTest.java @@ -67,6 +67,18 @@ void fixedMessage() { }).isEqualTo(ExchangeType.RESPONSE_STREAMING); } + @Test + void headerOverridingFixedMessage() { + assertExchangeType(() -> { + client.execute(HttpRequest.of(HttpMethod.POST, "/", + MediaType.PLAIN_TEXT, "foo") + .withHeaders(RequestHeaders.builder(HttpMethod.POST, "/") + .add("foo", "bar") + .build())) + .aggregate(); + }).isEqualTo(ExchangeType.RESPONSE_STREAMING); + } + @Test void fixedMessageWithCustomRequestOptions() { assertExchangeType(() -> { diff --git a/core/src/test/java/com/linecorp/armeria/client/WriteTimeoutTest.java b/core/src/test/java/com/linecorp/armeria/client/WriteTimeoutTest.java index ace7c823ff8..3890ed8f262 100644 --- a/core/src/test/java/com/linecorp/armeria/client/WriteTimeoutTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/WriteTimeoutTest.java @@ -16,8 +16,11 @@ package com.linecorp.armeria.client; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import java.util.concurrent.CompletionException; + import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -29,6 +32,8 @@ import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestHeadersBuilder; import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.logging.RequestLog; +import com.linecorp.armeria.internal.client.HttpSession; import com.linecorp.armeria.server.Route; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.testing.junit5.server.ServerExtension; @@ -63,12 +68,20 @@ void testWriteTimeout() { headersBuilder.add("header1", Strings.repeat("a", 2048)); // set a header over 1KB // using h1c since http2 compresses headers - assertThatThrownBy(() -> WebClient.builder(SessionProtocol.H1C, server.httpEndpoint()) - .factory(clientFactory) - .writeTimeoutMillis(1000) - .build() - .blocking() - .execute(headersBuilder.build(), "content")) - .isInstanceOf(WriteTimeoutException.class); + try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { + final HttpResponse res = WebClient.builder(SessionProtocol.H1C, server.httpEndpoint()) + .factory(clientFactory) + .writeTimeoutMillis(1000) + .build() + .execute(headersBuilder.build(), "content"); + final ClientRequestContext ctx = captor.get(); + assertThatThrownBy(() -> res.aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(WriteTimeoutException.class); + + final RequestLog log = ctx.log().whenComplete().join(); + // Make sure that the session is deactivated after the write timeout. + assertThat(HttpSession.get(log.channel()).isAcquirable()).isFalse(); + } } } diff --git a/core/src/test/java/com/linecorp/armeria/client/endpoint/AbstractEndpointSelectorTest.java b/core/src/test/java/com/linecorp/armeria/client/endpoint/AbstractEndpointSelectorTest.java index 12c59716627..2c68acd34b9 100644 --- a/core/src/test/java/com/linecorp/armeria/client/endpoint/AbstractEndpointSelectorTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/endpoint/AbstractEndpointSelectorTest.java @@ -116,6 +116,17 @@ void testSelectionTimeoutException() { .hasRootCauseInstanceOf(EndpointSelectionTimeoutException.class); } + @Test + void testRampingUpInitialSelection() { + final DynamicEndpointGroup endpointGroup = + new DynamicEndpointGroup(EndpointSelectionStrategy.rampingUp()); + final Endpoint endpoint = Endpoint.of("foo.com"); + endpointGroup.setEndpoints(ImmutableList.of(endpoint)); + final ClientRequestContext ctx = ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + final Endpoint selected = endpointGroup.select(ctx, ctx.eventLoop()).join(); + assertThat(selected).isEqualTo(endpoint); + } + private static AbstractEndpointSelector newSelector(EndpointGroup endpointGroup) { final AbstractEndpointSelector selector = new AbstractEndpointSelector(endpointGroup) { diff --git a/core/src/test/java/com/linecorp/armeria/client/logging/LoggingClientTest.java b/core/src/test/java/com/linecorp/armeria/client/logging/LoggingClientTest.java index 2db15b1f270..73b115f53de 100644 --- a/core/src/test/java/com/linecorp/armeria/client/logging/LoggingClientTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/logging/LoggingClientTest.java @@ -378,7 +378,7 @@ void shouldLogFailedResponseWhenFailureSamplingRateIsAlways() throws Exception { // verify request log verify(logger).warn(argThat((String actLog) -> actLog.contains("Request:") && - actLog.endsWith("headers=[:method=GET, :path=/]}"))); + actLog.endsWith("headers=[:method=GET, :path=/]}"))); // verify response log verify(logger).warn(argThat((String actLog) -> actLog.contains("Response:") && diff --git a/core/src/test/java/com/linecorp/armeria/client/retry/RetryingClientWithEmptyEndpointGroupTest.java b/core/src/test/java/com/linecorp/armeria/client/retry/RetryingClientWithEmptyEndpointGroupTest.java index 00b0a7118bf..d70acd122c0 100644 --- a/core/src/test/java/com/linecorp/armeria/client/retry/RetryingClientWithEmptyEndpointGroupTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/retry/RetryingClientWithEmptyEndpointGroupTest.java @@ -86,6 +86,10 @@ void shouldRetryEvenIfEndpointGroupIsEmpty() { assertEmptyEndpointGroupException(log); assertThat(log.children()).hasSize(numAttempts); + for (int i = 0; i < log.children().size(); i++) { + assertThat(log.children().get(i).partial().currentAttempt()).isEqualTo(i + 1); + } + log.children().stream() .map(RequestLogAccess::ensureComplete) .forEach(RetryingClientWithEmptyEndpointGroupTest::assertEmptyEndpointGroupException); diff --git a/core/src/test/java/com/linecorp/armeria/common/AggregationOptionsTest.java b/core/src/test/java/com/linecorp/armeria/common/AggregationOptionsTest.java index ff64fa1ff0f..f7c38dd993e 100644 --- a/core/src/test/java/com/linecorp/armeria/common/AggregationOptionsTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/AggregationOptionsTest.java @@ -37,6 +37,7 @@ import com.google.common.collect.ImmutableList; import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.internal.common.HeaderOverridingHttpRequest; import io.netty.buffer.ByteBufAllocator; import reactor.core.publisher.Flux; diff --git a/core/src/test/java/com/linecorp/armeria/client/encoding/DefaultHttpDecodedResponseTest.java b/core/src/test/java/com/linecorp/armeria/common/encoding/DefaultHttpDecodedResponseTest.java similarity index 96% rename from core/src/test/java/com/linecorp/armeria/client/encoding/DefaultHttpDecodedResponseTest.java rename to core/src/test/java/com/linecorp/armeria/common/encoding/DefaultHttpDecodedResponseTest.java index 661ed31aba1..fe7a4fd4a02 100644 --- a/core/src/test/java/com/linecorp/armeria/client/encoding/DefaultHttpDecodedResponseTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/encoding/DefaultHttpDecodedResponseTest.java @@ -14,7 +14,7 @@ * under the License. */ -package com.linecorp.armeria.client.encoding; +package com.linecorp.armeria.common.encoding; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -46,6 +46,7 @@ import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.RequestOptions; +import com.linecorp.armeria.client.encoding.StreamDecoderFactory; import com.linecorp.armeria.common.AggregationOptions; import com.linecorp.armeria.common.ContentTooLargeException; import com.linecorp.armeria.common.HttpData; @@ -57,7 +58,6 @@ import com.linecorp.armeria.common.HttpResponseWriter; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.ResponseHeaders; -import com.linecorp.armeria.common.encoding.StreamDecoder; import com.linecorp.armeria.common.stream.AbortedStreamException; import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.common.util.CompositeException; @@ -66,6 +66,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.compression.DecompressionException; +import io.netty.handler.codec.compression.SnappyFrameDecoder; import reactor.test.StepVerifier; class DefaultHttpDecodedResponseTest { @@ -311,6 +312,15 @@ public void onComplete() {} .hasCauseInstanceOf(CompositeException.class); } + @Test + void shouldExposeReasonWhenEncounterUnexpectedDecodeException() { + final HttpData httpData = HttpData.of(StandardCharsets.UTF_8, "Hello"); + final StreamDecoder decoder = new AbstractStreamDecoder(new SnappyFrameDecoder(), + ByteBufAllocator.DEFAULT, 100); + assertThatThrownBy(() -> decoder.decode(httpData)) + .isInstanceOf(DecompressionException.class); + } + private static HttpResponse newFailingDecodedResponse() { final HttpResponse delegate = HttpResponse.of(RESPONSE_HEADERS, HttpData.ofUtf8("Hello")); final ClientRequestContext ctx = diff --git a/core/src/test/java/com/linecorp/armeria/common/logging/JsonLogFormatterTest.java b/core/src/test/java/com/linecorp/armeria/common/logging/JsonLogFormatterTest.java index ffb70b48b3a..41a31695f22 100644 --- a/core/src/test/java/com/linecorp/armeria/common/logging/JsonLogFormatterTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/logging/JsonLogFormatterTest.java @@ -16,23 +16,33 @@ package com.linecorp.armeria.common.logging; +import static net.javacrumbs.jsonunit.fluent.JsonFluentAssert.assertThatJson; import static org.assertj.core.api.Assertions.assertThat; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.NullSource; import com.fasterxml.jackson.databind.JsonNode; +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.ClientRequestContextBuilder; +import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestId; import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.server.ServiceRequestContext; class JsonLogFormatterTest { @@ -62,6 +72,23 @@ void formatResponse() { "\"duration\":\".+\",\"totalDuration\":\".+\",\"headers\":\\{\".+\"}}$"); } + @Test + void derivedLog() { + final LogFormatter logFormatter = LogFormatter.ofJson(); + final HttpRequest request = HttpRequest.of(HttpMethod.GET, "/format"); + final ClientRequestContext ctx = ClientRequestContext.of(request); + final ClientRequestContext derivedCtx = + ctx.newDerivedContext(RequestId.of(1), request, null, Endpoint.of("127.0.0.1")); + final DefaultRequestLog log = (DefaultRequestLog) derivedCtx.log(); + ctx.logBuilder().addChild(log); + log.endRequest(); + final String requestLog = logFormatter.formatRequest(log); + assertThat(requestLog) + .matches("^\\{\"type\":\"request\",\"startTime\":\".+\",\"length\":\".+\"," + + "\"duration\":\".+\",\"scheme\":\".+\",\"name\":\".+\",\"headers\":\\{\".+\"}" + + ",\"currentAttempt\":1}$"); + } + @Test void maskSensitiveHeadersByDefault() { final LogFormatter logFormatter = LogFormatter.builderForJson() @@ -212,4 +239,83 @@ void removeSensitiveHeaders() { assertThat(matcher3.find()).isTrue(); assertThat(matcher3.group(1)).isEqualTo("no-cache"); } + + static Stream connectionTimingsAreLoggedIfExistParams() { + return Stream.of( + Arguments.of(ClientConnectionTimings.builder() + .build()), + Arguments.of(ClientConnectionTimings.builder() + .dnsResolutionEnd() + .build()), + Arguments.of(ClientConnectionTimings.builder() + .tlsHandshakeStart() + .tlsHandshakeEnd() + .build()), + Arguments.of(ClientConnectionTimings.builder() + .dnsResolutionEnd() + .pendingAcquisitionStart() + .pendingAcquisitionEnd() + .socketConnectStart() + .socketConnectEnd() + .tlsHandshakeStart() + .tlsHandshakeEnd() + .build()) + ); + } + + @ParameterizedTest + @NullSource + @MethodSource("connectionTimingsAreLoggedIfExistParams") + void connectionTimingsAreLoggedIfExist(@Nullable ClientConnectionTimings timings) { + final LogFormatter logFormatter = JsonLogFormatter.DEFAULT_INSTANCE; + final HttpRequest req = HttpRequest.of(RequestHeaders.of(HttpMethod.GET, "/")); + final ClientRequestContextBuilder builder = ClientRequestContext.builder(req); + if (timings != null) { + builder.connectionTimings(timings); + } + final ClientRequestContext ctx = builder.build(); + final RequestLogBuilder logBuilder = ctx.logBuilder(); + logBuilder.endRequest(); + final String formatted = logFormatter.formatRequest(logBuilder.partial()); + if (timings == null) { + assertThatJson(formatted).node("connection").isAbsent(); + return; + } + + assertThatJson(formatted) + .node("connection.total.durationNanos") + .isEqualTo(timings.connectionAcquisitionDurationNanos()); + + if (timings.dnsResolutionDurationNanos() >= 0) { + assertThatJson(formatted) + .node("connection.dns.durationNanos") + .isEqualTo(timings.dnsResolutionDurationNanos()); + } else { + assertThatJson(formatted).node("connection.dns.durationNanos").isAbsent(); + } + + if (timings.pendingAcquisitionDurationNanos() >= 0) { + assertThatJson(formatted) + .node("connection.pending.durationNanos") + .isEqualTo(timings.pendingAcquisitionDurationNanos()); + } else { + assertThatJson(formatted).node("connection.pending.durationNanos").isAbsent(); + } + + if (timings.socketConnectDurationNanos() >= 0) { + assertThatJson(formatted) + .node("connection.socket.durationNanos") + .isEqualTo(timings.socketConnectDurationNanos()); + } else { + assertThatJson(formatted).node("connection.socket.durationNanos").isAbsent(); + } + + if (timings.tlsHandshakeDurationNanos() >= 0) { + assertThatJson(formatted) + .node("connection.tls.durationNanos") + .isEqualTo(timings.tlsHandshakeDurationNanos()); + } else { + assertThatJson(formatted).node("connection.tls.durationNanos").isAbsent(); + } + } } diff --git a/core/src/test/java/com/linecorp/armeria/common/logging/TextLogFormatterTest.java b/core/src/test/java/com/linecorp/armeria/common/logging/TextLogFormatterTest.java index 3eeeafff8d5..cd01a6062cd 100644 --- a/core/src/test/java/com/linecorp/armeria/common/logging/TextLogFormatterTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/logging/TextLogFormatterTest.java @@ -20,17 +20,27 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.NullSource; +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.ClientRequestContextBuilder; +import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestId; import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.util.TextFormatter; import com.linecorp.armeria.server.ServiceRequestContext; class TextLogFormatterTest { @@ -69,6 +79,23 @@ void formatResponse(boolean containContext) { } } + @Test + void derivedLog() { + final LogFormatter logFormatter = LogFormatter.builderForText().build(); + final HttpRequest request = HttpRequest.of(HttpMethod.GET, "/format"); + final ClientRequestContext ctx = ClientRequestContext.of(request); + final ClientRequestContext derivedCtx = + ctx.newDerivedContext(RequestId.of(1), request, null, Endpoint.of("127.0.0.1")); + final DefaultRequestLog log = (DefaultRequestLog) derivedCtx.log(); + ctx.logBuilder().addChild(log); + log.endRequest(); + final String requestLog = logFormatter.formatRequest(log); + final String regex = + ".*Request: .*\\{startTime=.+, length=.+, duration=.+, scheme=.+, name=.+, headers=.+" + + "currentAttempt=1}$"; + assertThat(requestLog).matches(regex); + } + @Test void maskSensitiveHeadersByDefault() { final LogFormatter logFormatter = LogFormatter.builderForText() @@ -131,7 +158,6 @@ void maskRequestHeaders() { final DefaultRequestLog log = (DefaultRequestLog) ctx.log(); log.endRequest(); final String requestLog = logFormatter.formatRequest(log); - System.out.println(requestLog); final Matcher matcher1 = Pattern.compile("cookie=(.*?)[,\\]]").matcher(requestLog); assertThat(matcher1.find()).isTrue(); assertThat(matcher1.group(1)).isEqualTo( @@ -235,4 +261,106 @@ void removeSensitiveHeaders() { assertThat(matcher3.find()).isTrue(); assertThat(matcher3.group(1)).isEqualTo("no-cache"); } + + static Stream connectionTimingsAreLoggedIfExistParams() { + return Stream.of( + Arguments.of(ClientConnectionTimings.builder() + .build()), + Arguments.of(ClientConnectionTimings.builder() + .dnsResolutionEnd() + .build()), + Arguments.of(ClientConnectionTimings.builder() + .tlsHandshakeStart() + .tlsHandshakeEnd() + .build()), + Arguments.of(ClientConnectionTimings.builder() + .dnsResolutionEnd() + .pendingAcquisitionStart() + .pendingAcquisitionEnd() + .socketConnectStart() + .socketConnectEnd() + .tlsHandshakeStart() + .tlsHandshakeEnd() + .build()) + ); + } + + @ParameterizedTest + @NullSource + @MethodSource("connectionTimingsAreLoggedIfExistParams") + void connectionTimingsAreLoggedIfExist(@Nullable ClientConnectionTimings timings) { + final LogFormatter logFormatter = TextLogFormatter.DEFAULT_INSTANCE; + final HttpRequest req = HttpRequest.of(RequestHeaders.of(HttpMethod.GET, "/")); + final ClientRequestContextBuilder builder = ClientRequestContext.builder(req); + if (timings != null) { + builder.connectionTimings(timings); + } + final ClientRequestContext ctx = builder.build(); + final RequestLogBuilder logBuilder = ctx.logBuilder(); + logBuilder.endRequest(); + final String formatted = logFormatter.formatRequest(logBuilder.partial()); + + final Matcher connStartMatcher = Pattern.compile("Connection: \\{total=([^\\s,}]+)").matcher(formatted); + if (timings == null) { + assertThat(connStartMatcher.find()).isFalse(); + return; + } + assertThat(connStartMatcher.find()).isTrue(); + assertThat(connStartMatcher.group(1)).isEqualTo( + epochAndElapsed(timings.connectionAcquisitionStartTimeMicros(), + timings.connectionAcquisitionDurationNanos())); + + final Matcher dnsMatcher = Pattern.compile("dns=([^\\s,}]+)").matcher(formatted); + if (timings.dnsResolutionDurationNanos() >= 0) { + assertThat(dnsMatcher.find()).isTrue(); + assertThat(dnsMatcher.group(1)).isEqualTo( + epochAndElapsed(timings.dnsResolutionStartTimeMicros(), + timings.dnsResolutionDurationNanos())); + } else { + assertThat(dnsMatcher.find()).isFalse(); + } + + final Matcher pendingMatcher = Pattern.compile("pending=([^\\s,}]+)") + .matcher(formatted); + if (timings.pendingAcquisitionDurationNanos() >= 0) { + assertThat(pendingMatcher.find()).isTrue(); + assertThat(pendingMatcher.group(1)).isEqualTo( + epochAndElapsed(timings.pendingAcquisitionStartTimeMicros(), + timings.pendingAcquisitionDurationNanos())); + } else { + assertThat(pendingMatcher.find()).isFalse(); + } + + final Matcher socketMatcher = Pattern.compile("socket=([^\\s,}]+)").matcher(formatted); + if (timings.pendingAcquisitionDurationNanos() >= 0) { + assertThat(socketMatcher.find()).isTrue(); + assertThat(socketMatcher.group(1)).isEqualTo( + epochAndElapsed(timings.socketConnectStartTimeMicros(), + timings.socketConnectDurationNanos())); + } else { + assertThat(socketMatcher.find()).isFalse(); + } + + final Matcher tlsMatcher = Pattern.compile("tls=([^\\s,}]+)").matcher(formatted); + if (timings.tlsHandshakeDurationNanos() >= 0) { + assertThat(tlsMatcher.find()).isTrue(); + assertThat(tlsMatcher.group(1)).isEqualTo( + epochAndElapsed(timings.tlsHandshakeStartTimeMicros(), + timings.tlsHandshakeDurationNanos())); + } else { + assertThat(tlsMatcher.find()).isFalse(); + } + } + + @Test + void epochAndElapsedTest() { + assertThat(epochAndElapsed(1717987526233123L, 456L)) + .isEqualTo("2024-06-10T02:45:26.233123Z[456ns]"); + } + + private static String epochAndElapsed(long epochMicros, long durationNanos) { + final StringBuilder sb = new StringBuilder(); + TextFormatter.appendEpochAndElapsed(sb, epochMicros, durationNanos); + return sb.toString(); + } } diff --git a/core/src/test/java/com/linecorp/armeria/common/stream/StreamMessageDuplicatorChildSubscriberTest.java b/core/src/test/java/com/linecorp/armeria/common/stream/StreamMessageDuplicatorChildSubscriberTest.java new file mode 100644 index 00000000000..07571d481db --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/common/stream/StreamMessageDuplicatorChildSubscriberTest.java @@ -0,0 +1,104 @@ +/* + * Copyright 2021 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.common.stream; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.CountDownLatch; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.testing.junit5.common.EventLoopExtension; + +import io.netty.channel.EventLoop; + +class StreamMessageDuplicatorChildSubscriberTest { + + @RegisterExtension + static final EventLoopExtension eventLoop1 = new EventLoopExtension(); + + @RegisterExtension + static final EventLoopExtension eventLoop2 = new EventLoopExtension(); + + @RegisterExtension + static final EventLoopExtension eventLoop3 = new EventLoopExtension(); + + @CsvSource({ "true", "false" }) + @ParameterizedTest + void childSubscriberMethodsMustBeCalledByExecutors(boolean close) throws InterruptedException { + final StreamWriter publisher = StreamMessage.streaming(); + publisher.write("foo"); + if (close) { + publisher.close(); + } else { + publisher.abort(); + } + + final StreamMessageDuplicator duplicator = + publisher.toDuplicator(eventLoop1.get()); + + final StreamMessage first = duplicator.duplicate(); + final StreamMessage second = duplicator.duplicate(); + + duplicator.close(); + + final CountDownLatch latch = new CountDownLatch(2); + final EventLoop executor2 = eventLoop2.get(); + first.subscribe(new ChildSubscriber(executor2, latch), executor2); + + final EventLoop executor3 = eventLoop3.get(); + second.subscribe(new ChildSubscriber(executor3, latch), executor3); + latch.await(); + } + + private static final class ChildSubscriber implements Subscriber { + + private final EventLoop eventLoop; + private final CountDownLatch latch; + + ChildSubscriber(EventLoop eventLoop, CountDownLatch latch) { + this.eventLoop = eventLoop; + this.latch = latch; + } + + @Override + public void onSubscribe(Subscription s) { + assertThat(eventLoop.inEventLoop()).isTrue(); + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(String data) { + assertThat(eventLoop.inEventLoop()).isTrue(); + } + + @Override + public void onError(Throwable t) { + assertThat(eventLoop.inEventLoop()).isTrue(); + latch.countDown(); + } + + @Override + public void onComplete() { + assertThat(eventLoop.inEventLoop()).isTrue(); + latch.countDown(); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/common/stream/StreamMessageDuplicatorTest.java b/core/src/test/java/com/linecorp/armeria/common/stream/StreamMessageDuplicatorTest.java index 9f935e4e590..9e4850b20f8 100644 --- a/core/src/test/java/com/linecorp/armeria/common/stream/StreamMessageDuplicatorTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/stream/StreamMessageDuplicatorTest.java @@ -42,8 +42,6 @@ import org.mockito.ArgumentCaptor; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import com.google.common.base.Charsets; @@ -61,8 +59,6 @@ class StreamMessageDuplicatorTest { - private static final Logger logger = LoggerFactory.getLogger(StreamMessageDuplicatorTest.class); - private static final List byteBufs = new ArrayList<>(); @AfterEach diff --git a/core/src/test/java/com/linecorp/armeria/internal/client/dns/DefaultDnsResolverTest.java b/core/src/test/java/com/linecorp/armeria/internal/client/dns/DefaultDnsResolverTest.java index 4c7190894dd..4328afba130 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/client/dns/DefaultDnsResolverTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/client/dns/DefaultDnsResolverTest.java @@ -155,8 +155,6 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception NoopDnsCache.INSTANCE, eventLoop, ImmutableList.of(), 1, queryTimeoutMillis, HostsFileEntriesResolver.DEFAULT); - final DnsQuestionContext ctx = new DnsQuestionContext(eventLoop, queryTimeoutMillis); - final Stopwatch stopwatch = Stopwatch.createStarted(); final List questions; if (resolvedAddressType == ResolvedAddressTypes.IPV4_PREFERRED) { @@ -169,9 +167,9 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception new DefaultDnsQuestion("foo.com.", DnsRecordType.A)); } - // resolver.resolveAll() should be executed by the event loop set to DnsNameResolver. + // resolver.resolve() should be executed by the event loop set to DnsNameResolver. final CompletableFuture> result = eventLoop.submit(() -> { - return resolver.resolveAll(ctx, questions, ""); + return resolver.resolve(questions, ""); }).get(); final List records = result.join(); diff --git a/core/src/test/java/com/linecorp/armeria/server/ConnectionLimitingHandlerTest.java b/core/src/test/java/com/linecorp/armeria/server/ConnectionLimitingHandlerTest.java index 174fd06028f..c814a955770 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ConnectionLimitingHandlerTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ConnectionLimitingHandlerTest.java @@ -26,7 +26,9 @@ class ConnectionLimitingHandlerTest { @Test void testExceedMaxNumConnections() { - final ConnectionLimitingHandler handler = new ConnectionLimitingHandler(1); + final ServerMetrics serverMetrics = new ServerMetrics(); + final ConnectionLimitingHandler handler = + new ConnectionLimitingHandler(1, serverMetrics); final EmbeddedChannel ch1 = new EmbeddedChannel(handler); ch1.writeInbound(ch1); @@ -44,13 +46,15 @@ void testExceedMaxNumConnections() { @Test void testMaxNumConnectionsRange() { - final ConnectionLimitingHandler handler = new ConnectionLimitingHandler(Integer.MAX_VALUE); + final ServerMetrics serverMetrics = new ServerMetrics(); + final ConnectionLimitingHandler handler = new ConnectionLimitingHandler(Integer.MAX_VALUE, + serverMetrics); assertThat(handler.maxNumConnections()).isEqualTo(Integer.MAX_VALUE); - assertThatThrownBy(() -> new ConnectionLimitingHandler(0)) + assertThatThrownBy(() -> new ConnectionLimitingHandler(0, serverMetrics)) .isInstanceOf(IllegalArgumentException.class); - assertThatThrownBy(() -> new ConnectionLimitingHandler(-1)) + assertThatThrownBy(() -> new ConnectionLimitingHandler(-1, serverMetrics)) .isInstanceOf(IllegalArgumentException.class); } } diff --git a/core/src/test/java/com/linecorp/armeria/server/EmptyContentDecodedHttpRequestTest.java b/core/src/test/java/com/linecorp/armeria/server/EmptyContentDecodedHttpRequestTest.java index 3f0b739ede5..837a6ced875 100644 --- a/core/src/test/java/com/linecorp/armeria/server/EmptyContentDecodedHttpRequestTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/EmptyContentDecodedHttpRequestTest.java @@ -17,6 +17,7 @@ package com.linecorp.armeria.server; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -35,9 +36,10 @@ class EmptyContentDecodedHttpRequestTest { @Test void emptyContent() { + final RoutingContext routingContext = mock(RoutingContext.class); final RequestHeaders headers = RequestHeaders.of(HttpMethod.GET, "/"); final EmptyContentDecodedHttpRequest req = - new EmptyContentDecodedHttpRequest(eventLoop.get(), 1, 3, headers, true, null, + new EmptyContentDecodedHttpRequest(eventLoop.get(), 1, 3, headers, true, routingContext, ExchangeType.BIDI_STREAMING, 0, 0); StepVerifier.create(req) diff --git a/core/src/test/java/com/linecorp/armeria/server/ProtocolViolationHandlingTest.java b/core/src/test/java/com/linecorp/armeria/server/ProtocolViolationHandlingTest.java new file mode 100644 index 00000000000..58fc88eb313 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/ProtocolViolationHandlingTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.client.BlockingWebClient; +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import joptsimple.internal.Strings; + +class ProtocolViolationHandlingTest { + + @RegisterExtension + static ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.http1MaxInitialLineLength(100); + sb.service("/", (ctx, req) -> HttpResponse.of(HttpStatus.OK)); + sb.errorHandler(new ServerErrorHandler() { + @Override + public @Nullable HttpResponse onServiceException(ServiceRequestContext ctx, Throwable cause) { + return null; + } + + @Override + public @Nullable AggregatedHttpResponse onProtocolViolation(ServiceConfig config, + @Nullable RequestHeaders headers, + HttpStatus status, + @Nullable String description, + @Nullable Throwable cause) { + return AggregatedHttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT, + "Custom response"); + } + }); + } + }; + + @Test + void shouldHandleInvalidHttp1Request() { + final BlockingWebClient client = BlockingWebClient.of(server.uri(SessionProtocol.H1C)); + final AggregatedHttpResponse res = client.get("/?" + Strings.repeat('a', 100)); + assertThat(res.status()).isEqualTo(HttpStatus.BAD_REQUEST); + assertThat(res.contentUtf8()).isEqualTo("Custom response"); + } + + @Test + void shouldHandleInvalidHttp2Request() { + final BlockingWebClient client = BlockingWebClient.of(server.uri(SessionProtocol.H2C)); + final AggregatedHttpResponse res = client.get("*"); + assertThat(res.status()).isEqualTo(HttpStatus.BAD_REQUEST); + assertThat(res.contentUtf8()).isEqualTo("Custom response"); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/RoutingContextTest.java b/core/src/test/java/com/linecorp/armeria/server/RoutingContextTest.java index a1fe3e7d0d3..d0fa916e080 100644 --- a/core/src/test/java/com/linecorp/armeria/server/RoutingContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/RoutingContextTest.java @@ -18,6 +18,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.util.List; @@ -141,6 +142,7 @@ static RoutingContext create(VirtualHost virtualHost, String path, @Nullable Str static VirtualHost virtualHost() { final HttpService service = mock(HttpService.class); + when(service.options()).thenReturn(ServiceOptions.of()); final Server server = Server.builder() .virtualHost("example.com") .serviceUnder("/", service) diff --git a/core/src/test/java/com/linecorp/armeria/server/ServerBuilderTest.java b/core/src/test/java/com/linecorp/armeria/server/ServerBuilderTest.java index fba4af0b045..11912c93a7c 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServerBuilderTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServerBuilderTest.java @@ -54,6 +54,7 @@ import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.prometheus.PrometheusMeterRegistries; import com.linecorp.armeria.common.util.DomainSocketAddress; +import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.common.util.TransportType; import com.linecorp.armeria.internal.common.util.MinifiedBouncyCastleProvider; import com.linecorp.armeria.internal.testing.MockAddressResolverGroup; @@ -555,6 +556,37 @@ void tlsPkcsPrivateKeysWithCustomizer(String privateKeyFileName) { .build(); } + @Test + void tlsEngineType() { + final Server sb1 = Server.builder() + .service("/example", (ctx, req) -> HttpResponse.of(HttpStatus.OK)) + .build(); + assertThat(sb1.config().defaultVirtualHost().tlsEngineType()).isEqualTo(TlsEngineType.OPENSSL); + + final Server sb2 = Server.builder() + .tlsSelfSigned() + .service("/example", (ctx, req) -> HttpResponse.of(HttpStatus.OK)) + .tlsEngineType(TlsEngineType.OPENSSL) + .virtualHost("*.example1.com") + .service("/example", (ctx, req) -> HttpResponse.of(HttpStatus.OK)) + .tlsSelfSigned() + .tlsEngineType(TlsEngineType.JDK) + .and() + .virtualHost("*.example2.com") + .service("/example", (ctx, req) -> HttpResponse.of(HttpStatus.OK)) + .tlsSelfSigned() + .and() + .virtualHost("*.example3.com") + .service("/example", (ctx, req) -> HttpResponse.of(HttpStatus.OK)) + .and() + .build(); + assertThat(sb2.config().defaultVirtualHost().tlsEngineType()).isEqualTo(TlsEngineType.OPENSSL); + assertThat(sb2.config().findVirtualHost("*.example1.com", 8080).tlsEngineType()) + .isEqualTo(TlsEngineType.JDK); + assertThat(sb2.config().findVirtualHost("*.example2.com", 8080).tlsEngineType()) + .isEqualTo(TlsEngineType.OPENSSL); + } + @Test void monitorBlockingTaskExecutorAndSchedulersTogetherWithPrometheus() { final PrometheusMeterRegistry registry = PrometheusMeterRegistries.newRegistry(); diff --git a/core/src/test/java/com/linecorp/armeria/server/ServerErrorHandlerTest.java b/core/src/test/java/com/linecorp/armeria/server/ServerErrorHandlerTest.java index 9addaac8473..fa493b69e3f 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServerErrorHandlerTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServerErrorHandlerTest.java @@ -42,11 +42,17 @@ class ServerErrorHandlerTest { protected void configure(ServerBuilder sb) throws Exception { sb.route() .get("/foo") - .errorHandler((ctx, cause) -> null) + .errorHandler((ctx, cause) -> { + assertThat(ServiceRequestContext.current()).isSameAs(ctx); + return null; + }) .build((ctx, req) -> { throw new RuntimeException(); }); - sb.errorHandler((ctx, cause) -> HttpResponse.of(HttpStatus.BAD_REQUEST)); + sb.errorHandler((ctx, cause) -> { + assertThat(ServiceRequestContext.current()).isSameAs(ctx); + return HttpResponse.of(HttpStatus.BAD_REQUEST); + }); } }; diff --git a/core/src/test/java/com/linecorp/armeria/server/ServerMetricsTest.java b/core/src/test/java/com/linecorp/armeria/server/ServerMetricsTest.java new file mode 100644 index 00000000000..43106ff8cf3 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/ServerMetricsTest.java @@ -0,0 +1,305 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.server; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.assertj.core.api.Condition; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import com.linecorp.armeria.client.BlockingWebClient; +import com.linecorp.armeria.client.ClientFactory; +import com.linecorp.armeria.client.ClientOptions; +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.ExchangeType; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpRequestWriter; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.metric.MeterIdPrefixFunction; +import com.linecorp.armeria.common.metric.MoreMeters; +import com.linecorp.armeria.common.prometheus.PrometheusMeterRegistries; +import com.linecorp.armeria.server.metric.MetricCollectingService; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.micrometer.prometheusmetrics.PrometheusMeterRegistry; + +class ServerMetricsTest { + + @RegisterExtension + final ServerExtension server = new ServerExtension() { + @Override + protected boolean runForEachTest() { + return true; + } + + @Override + protected void configure(ServerBuilder sb) throws Exception { + final PrometheusMeterRegistry prometheusMeterRegistry = PrometheusMeterRegistries.newRegistry(); + sb.meterRegistry(prometheusMeterRegistry); + // Use 'armeria.server' to make sure that the metric names are not conflicted with `ServerMetrics`. + sb.decorator(MetricCollectingService.newDecorator( + MeterIdPrefixFunction.ofDefault("armeria.server"))); + + sb.requestTimeoutMillis(0) + .requestAutoAbortDelayMillis(0) + .service("/ok/http", new HttpService() { + @Override + public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception { + final ServerMetrics serverMetrics = server.server().config().serverMetrics(); + assertThat(serverMetrics.pendingRequests()).isZero(); + if (ctx.sessionProtocol().isMultiplex()) { + assertThat(serverMetrics.activeHttp2Requests()).isOne(); + } else { + assertThat(serverMetrics.activeHttp1Requests()).isOne(); + } + assertThat(serverMetrics.activeRequests()).isOne(); + return HttpResponse.of("Hello, world!"); + } + + @Override + public ExchangeType exchangeType(RoutingContext routingContext) { + return ExchangeType.UNARY; + } + }) + .service("/server-error/http1", new HttpService() { + @Override + public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception { + final ServerMetrics serverMetrics = server.server().config().serverMetrics(); + assertThat(serverMetrics.pendingRequests()).isZero(); + assertThat(serverMetrics.activeHttp1Requests()).isOne(); + assertThat(serverMetrics.activeRequests()).isOne(); + throw new IllegalArgumentException("Oops!"); + } + + @Override + public ExchangeType exchangeType(RoutingContext routingContext) { + return ExchangeType.UNARY; + } + }).service("/server-error/http2", new HttpService() { + @Override + public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception { + final ServerMetrics serverMetrics = server.server().config().serverMetrics(); + assertThat(serverMetrics.pendingRequests()).isZero(); + assertThat(serverMetrics.activeHttp2Requests()).isOne(); + assertThat(serverMetrics.activeRequests()).isOne(); + throw new IllegalArgumentException("Oops!"); + } + + @Override + public ExchangeType exchangeType(RoutingContext routingContext) { + return ExchangeType.UNARY; + } + }).service("/request-timeout/http", new HttpService() { + @Override + public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception { + final ServerMetrics serverMetrics = server.server().config().serverMetrics(); + assertThat(serverMetrics.pendingRequests()).isZero(); + if (ctx.sessionProtocol().isMultiplex()) { + assertThat(serverMetrics.activeHttp2Requests()).isOne(); + } else { + assertThat(serverMetrics.activeHttp1Requests()).isOne(); + } + assertThat(serverMetrics.activeRequests()).isOne(); + ctx.timeoutNow(); + return HttpResponse.delayed(HttpResponse.of(200), Duration.ofSeconds(1)); + } + + @Override + public ExchangeType exchangeType(RoutingContext routingContext) { + return ExchangeType.UNARY; + } + }); + } + }; + + @Test + void pendingRequests() { + final ServerMetrics serverMetrics = new ServerMetrics(); + + serverMetrics.increasePendingHttp1Requests(); + assertThat(serverMetrics.pendingRequests()).isEqualTo(1); + + serverMetrics.increasePendingHttp2Requests(); + assertThat(serverMetrics.pendingRequests()).isEqualTo(2); + + serverMetrics.decreasePendingHttp1Requests(); + assertThat(serverMetrics.pendingRequests()).isEqualTo(1); + + serverMetrics.decreasePendingHttp2Requests(); + assertThat(serverMetrics.pendingRequests()).isZero(); + } + + @Test + void activeRequests() { + final ServerMetrics serverMetrics = new ServerMetrics(); + + serverMetrics.increaseActiveHttp1Requests(); + assertThat(serverMetrics.activeRequests()).isEqualTo(1); + + serverMetrics.increaseActiveHttp1WebSocketRequests(); + assertThat(serverMetrics.activeRequests()).isEqualTo(2); + + serverMetrics.increaseActiveHttp2Requests(); + assertThat(serverMetrics.activeRequests()).isEqualTo(3); + + serverMetrics.decreaseActiveHttp1WebSocketRequests(); + assertThat(serverMetrics.activeRequests()).isEqualTo(2); + + serverMetrics.decreaseActiveHttp1Requests(); + assertThat(serverMetrics.activeRequests()).isEqualTo(1); + + serverMetrics.decreaseActiveHttp2Requests(); + assertThat(serverMetrics.activeRequests()).isZero(); + } + + @CsvSource({ "H1C, 1, 0", "H2C, 0, 1" }) + @ParameterizedTest + void checkWhenOk(SessionProtocol sessionProtocol, long expectedPendingHttp1Request, + long expectedPendingHttp2Request) throws InterruptedException { + // maxConnectionAgeMillis() method is for testing whether activeConnections is decreased. + try (ClientFactory clientFactory = ClientFactory.builder() + .maxConnectionAgeMillis(1000) + .build()) { + final WebClient webClient = WebClient.builder(server.uri(sessionProtocol)) + .factory(clientFactory) + .build(); + + final HttpRequestWriter request = HttpRequest.streaming(HttpMethod.POST, "/ok/http"); + final CompletableFuture response = webClient.execute(request) + .aggregate(); + + final ServerMetrics serverMetrics = server.server() + .config() + .serverMetrics(); + await().until(() -> serverMetrics.pendingRequests() == 1); + assertThat(serverMetrics.pendingHttp1Requests()).isEqualTo(expectedPendingHttp1Request); + assertThat(serverMetrics.pendingHttp2Requests()).isEqualTo(expectedPendingHttp2Request); + assertThat(serverMetrics.activeConnections()).isOne(); + request.close(); + + final AggregatedHttpResponse result = response.join(); + + assertThat(result.status()).isSameAs(HttpStatus.OK); + assertThat(serverMetrics.pendingRequests()).isZero(); + await().untilAsserted(() -> assertThat(serverMetrics.activeRequests()).isZero()); + await().until(() -> serverMetrics.activeConnections() == 0); + } + } + + @CsvSource({ "H1C, /server-error/http1, 1, 0", "H2C, /server-error/http2, 0, 1" }) + @ParameterizedTest + void checkWhenServerError(SessionProtocol sessionProtocol, String path, long expectedPendingHttp1Request, + long expectedPendingHttp2Request) throws InterruptedException { + try (ClientFactory clientFactory = ClientFactory.builder() + .maxConnectionAgeMillis(1000) + .build()) { + final WebClient webClient = WebClient.builder(server.uri(sessionProtocol)) + .factory(clientFactory) + .build(); + + final HttpRequestWriter request = HttpRequest.streaming(HttpMethod.POST, path); + final CompletableFuture response = webClient.execute(request) + .aggregate(); + + final ServerMetrics serverMetrics = server.server() + .config() + .serverMetrics(); + await().until(() -> serverMetrics.pendingRequests() == 1); + assertThat(serverMetrics.pendingHttp1Requests()).isEqualTo(expectedPendingHttp1Request); + assertThat(serverMetrics.pendingHttp2Requests()).isEqualTo(expectedPendingHttp2Request); + assertThat(serverMetrics.activeConnections()).isOne(); + request.close(); + + final AggregatedHttpResponse result = response.join(); + + assertThat(result.status()).isSameAs(HttpStatus.INTERNAL_SERVER_ERROR); + assertThat(serverMetrics.pendingRequests()).isZero(); + assertThat(serverMetrics.activeRequests()).isZero(); + await().until(() -> serverMetrics.activeConnections() == 0); + } + } + + @CsvSource({ "H1C, 1, 0", "H2C, 0, 1" }) + @ParameterizedTest + void checkWhenRequestTimeout(SessionProtocol sessionProtocol, long expectedPendingHttp1Request, + long expectedPendingHttp2Request) throws InterruptedException { + try (ClientFactory clientFactory = ClientFactory.builder() + .maxConnectionAgeMillis(1000) + .build()) { + final WebClient webClient = WebClient.builder(server.uri(sessionProtocol)) + .option(ClientOptions.RESPONSE_TIMEOUT_MILLIS.newValue(0L)) + .factory(clientFactory) + .build(); + + final HttpRequestWriter request = HttpRequest.streaming(HttpMethod.POST, "/request-timeout/http"); + final CompletableFuture response = webClient.execute(request) + .aggregate(); + + final ServerMetrics serverMetrics = server.server() + .config() + .serverMetrics(); + await().until(() -> serverMetrics.pendingRequests() == 1); + assertThat(serverMetrics.pendingHttp1Requests()).isEqualTo(expectedPendingHttp1Request); + assertThat(serverMetrics.pendingHttp2Requests()).isEqualTo(expectedPendingHttp2Request); + assertThat(serverMetrics.activeConnections()).isOne(); + request.close(); + + final AggregatedHttpResponse result = response.join(); + + assertThat(result.status()).isSameAs(HttpStatus.SERVICE_UNAVAILABLE); + assertThat(serverMetrics.pendingRequests()).isZero(); + await().untilAsserted(() -> assertThat(serverMetrics.activeRequests()).isZero()); + await().until(() -> serverMetrics.activeConnections() == 0); + } + } + + @CsvSource({ "H1C", "H2C" }) + @ParameterizedTest + void meterNames(SessionProtocol protocol) { + final BlockingWebClient client = BlockingWebClient.of(server.uri(protocol)); + assertThat(client.get("/ok/http").status()).isEqualTo(HttpStatus.OK); + + await().untilAsserted(() -> { + final Map meters = MoreMeters.measureAll(server.server().meterRegistry()); + // armeria.server.active.requests#value is measured by MetricCollectingService + assertThat(meters).hasKeySatisfying(new Condition("armeria.server.active.requests#value") { + @Override + public boolean matches(String key) { + return key.startsWith("armeria.server.active.requests#value{hostname.pattern="); + } + }); + + final String protocolName = protocol == SessionProtocol.H1C ? "http1" : "http2"; + // armeria.server.active.requests.all#value is measured by ServerMetrics + assertThat(meters).containsKey("armeria.server.all.requests#value{protocol=" + protocolName + + ",state=active}"); + assertThat(meters).containsKey("armeria.server.all.requests#value{protocol=" + protocolName + + ",state=pending}"); + }); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/ServiceOptionsTest.java b/core/src/test/java/com/linecorp/armeria/server/ServiceOptionsTest.java new file mode 100644 index 00000000000..5c64de1a1cf --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/ServiceOptionsTest.java @@ -0,0 +1,136 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import org.junit.jupiter.api.Test; + +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; + +/** + * The priority of configurations from highest to lowest: + * 1. ServiceBinderBuilder + * 2. ServiceOptions (if exists) + * 3. VirtualHostBuilder + * 4. ServerBuilder + */ +class ServiceOptionsTest { + + private static final ServiceOptions SERVICE_OPTIONS = + ServiceOptions.builder().requestTimeoutMillis(5000) + .maxRequestLength(1024) + .requestAutoAbortDelayMillis(1000) + .build(); + + @Test + void serviceOptionsShouldNotOverrideServiceBindingBuilder() { + final HttpService httpService = new HttpService() { + @Override + public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) { + return HttpResponse.of("OK"); + } + + @Override + public ServiceOptions options() { + return SERVICE_OPTIONS; + } + }; + + try (Server server = Server.builder().route().path("/test") + .requestTimeoutMillis(20001) + .maxRequestLength(20002) + .requestAutoAbortDelayMillis(20003) + .build(httpService) + .build()) { + final ServiceConfig sc = server.serviceConfigs().get(0); + + assertThat(sc.requestTimeoutMillis()).isEqualTo(20001); + assertThat(sc.maxRequestLength()).isEqualTo(20002); + assertThat(sc.requestAutoAbortDelayMillis()).isEqualTo(20003); + } + } + + @Test + void serviceOptionsShouldOverrideVirtualHostTemplate() { + final HttpService httpService = new HttpService() { + @Override + public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) { + return HttpResponse.of("OK"); + } + + @Override + public ServiceOptions options() { + return SERVICE_OPTIONS; + } + }; + try (Server server = Server.builder() + .virtualHost("example.com") + .requestTimeoutMillis(20001) + .maxRequestLength(20002) + .requestAutoAbortDelayMillis(20003) + .service("/test", httpService) + .and() + .build()) { + final ServiceConfig sc = server.serviceConfigs() + .stream() + .filter(s -> s.route().paths().contains("/test")) + .findFirst().get(); + + assertThat(sc.requestTimeoutMillis()).isEqualTo(SERVICE_OPTIONS.requestTimeoutMillis()); + assertThat(sc.maxRequestLength()).isEqualTo(SERVICE_OPTIONS.maxRequestLength()); + assertThat(sc.requestAutoAbortDelayMillis()).isEqualTo( + SERVICE_OPTIONS.requestAutoAbortDelayMillis()); + } + } + + @Test + void serviceOptionsShouldOverrideServerBuilder() { + final HttpService httpService1 = new HttpService() { + @Override + public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) { + return HttpResponse.of("OK"); + } + + @Override + public ServiceOptions options() { + return SERVICE_OPTIONS; + } + }; + + final long defaultRequestTimeoutMillis = 30001; + final long defaultMaxRequestLength = 30002; + final long defaultRequestAutoAbortDelayMillis = 30003; + try (Server server = Server.builder() + .service("/test", httpService1) + .requestTimeoutMillis(defaultRequestTimeoutMillis) + .maxRequestLength(defaultMaxRequestLength) + .requestAutoAbortDelayMillis(defaultRequestAutoAbortDelayMillis) + .build()) { + + final ServiceConfig sc1 = server.serviceConfigs() + .stream() + .filter(s -> s.route().paths().contains("/test")) + .findFirst().get(); + assertThat(sc1.requestTimeoutMillis()).isEqualTo(SERVICE_OPTIONS.requestTimeoutMillis()); + assertThat(sc1.maxRequestLength()).isEqualTo(SERVICE_OPTIONS.maxRequestLength()); + assertThat(sc1.requestAutoAbortDelayMillis()).isEqualTo( + SERVICE_OPTIONS.requestAutoAbortDelayMillis()); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/VirtualHostAndServiceConfigConsistencyTest.java b/core/src/test/java/com/linecorp/armeria/server/VirtualHostAndServiceConfigConsistencyTest.java index 2b0d0175c6c..8b7df684640 100644 --- a/core/src/test/java/com/linecorp/armeria/server/VirtualHostAndServiceConfigConsistencyTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/VirtualHostAndServiceConfigConsistencyTest.java @@ -42,6 +42,7 @@ void testApiConsistencyBetweenVirtualHostAndServiceConfig() { final Set ignorableVirtualHostMethods = ImmutableSet.of( "defaultHostname", "sslContext", + "tlsEngineType", "accessLogger", "port", "hostnamePattern", diff --git a/core/src/test/java/com/linecorp/armeria/server/annotation/ServiceOptionTest.java b/core/src/test/java/com/linecorp/armeria/server/annotation/ServiceOptionTest.java new file mode 100644 index 00000000000..632015f5f51 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/annotation/ServiceOptionTest.java @@ -0,0 +1,144 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.annotation; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import org.junit.jupiter.api.Test; + +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServiceConfig; +import com.linecorp.armeria.server.ServiceOption; + +public class ServiceOptionTest { + @Test + void serviceOptionAnnotationShouldBeAppliedWhenConfiguredAtMethodLevel() { + final class TestAnnotatedService { + @ServiceOption( + requestTimeoutMillis = 11111, + maxRequestLength = 1111 + ) + @Get("/test1") + public HttpResponse test1() { + return HttpResponse.of("OK"); + } + + @Get("/test2") + public HttpResponse test2() { + return HttpResponse.of("OK"); + } + } + + final long DEFAULT_REQUEST_TIMEOUT_MILLIS = 30001; + final long DEFAULT_MAX_REQUEST_LENGTH = 30002; + final long DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS = 30003; + + try (Server server = Server.builder().annotatedService(new TestAnnotatedService()) + .requestTimeoutMillis(DEFAULT_REQUEST_TIMEOUT_MILLIS) + .maxRequestLength(DEFAULT_MAX_REQUEST_LENGTH) + .requestAutoAbortDelayMillis(DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS) + .build()) { + final ServiceConfig sc1 = server.serviceConfigs() + .stream() + .filter(s -> s.route().paths().contains("/test1")) + .findFirst().orElse(null); + + assertThat(sc1).isNotNull(); + assertThat(sc1.requestTimeoutMillis()).isEqualTo(11111); + assertThat(sc1.maxRequestLength()).isEqualTo(1111); + assertThat(sc1.requestAutoAbortDelayMillis()).isEqualTo(DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS); + + // default values should be applied + final ServiceConfig sc2 = server.serviceConfigs() + .stream() + .filter(s -> s.route().paths().contains("/test2")) + .findFirst().get(); + assertThat(sc2.requestTimeoutMillis()).isEqualTo(DEFAULT_REQUEST_TIMEOUT_MILLIS); + assertThat(sc2.maxRequestLength()).isEqualTo(DEFAULT_MAX_REQUEST_LENGTH); + assertThat(sc2.requestAutoAbortDelayMillis()).isEqualTo(DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS); + } + } + + @Test + void serviceOptionAnnotationShouldBeAppliedWhenConfiguredAtClassLevel() { + @ServiceOption( + requestTimeoutMillis = 11111, + maxRequestLength = 1111 + ) + final class TestAnnotatedService { + @Get("/test") + public HttpResponse test() { + return HttpResponse.of("OK"); + } + } + + final long DEFAULT_REQUEST_TIMEOUT_MILLIS = 30001; + final long DEFAULT_MAX_REQUEST_LENGTH = 30002; + final long DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS = 30003; + + try (Server server = Server.builder().annotatedService(new TestAnnotatedService()) + .requestTimeoutMillis(DEFAULT_REQUEST_TIMEOUT_MILLIS) + .maxRequestLength(DEFAULT_MAX_REQUEST_LENGTH) + .requestAutoAbortDelayMillis(DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS) + .build()) { + final ServiceConfig sc = server.serviceConfigs() + .stream() + .filter(s -> s.route().paths().contains("/test")) + .findFirst().orElse(null); + + assertThat(sc).isNotNull(); + assertThat(sc.requestTimeoutMillis()).isEqualTo(11111); + assertThat(sc.maxRequestLength()).isEqualTo(1111); + assertThat(sc.requestAutoAbortDelayMillis()).isEqualTo(DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS); + } + } + + @Test + void serviceOptionAnnotationAtMethodLevelShouldOverrideServiceOptionAtClassLevel() { + @ServiceOption( + requestTimeoutMillis = 11111, + maxRequestLength = 1111, + requestAutoAbortDelayMillis = 111 + ) + final class TestAnnotatedService { + @ServiceOption( + requestTimeoutMillis = 22222, + maxRequestLength = 2222, + requestAutoAbortDelayMillis = 222 + ) + @Get("/test") + public HttpResponse test() { + return HttpResponse.of("OK"); + } + } + + try (Server server = Server.builder() + .annotatedService(new TestAnnotatedService()) + .build()) { + final ServiceConfig sc = server.serviceConfigs() + .stream() + .filter(s -> s.route().paths().contains("/test")) + .findFirst().orElse(null); + + assertThat(sc).isNotNull(); + assertThat(sc.requestTimeoutMillis()).isEqualTo(22222); + assertThat(sc.maxRequestLength()).isEqualTo(2222); + assertThat(sc.requestAutoAbortDelayMillis()).isEqualTo(222); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/contextpath/test/ServiceBuilderSelfTypeTest.java b/core/src/test/java/com/linecorp/armeria/server/contextpath/test/ServiceBuilderSelfTypeTest.java new file mode 100644 index 00000000000..988790efad9 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/contextpath/test/ServiceBuilderSelfTypeTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.contextpath.test; + +import org.junit.jupiter.api.Test; + +import com.linecorp.armeria.server.ContextPathAnnotatedServiceConfigSetters; +import com.linecorp.armeria.server.ContextPathServiceBindingBuilder; +import com.linecorp.armeria.server.ContextPathServicesBuilder; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceBindingBuilder; +import com.linecorp.armeria.server.VirtualHostBuilder; +import com.linecorp.armeria.server.VirtualHostContextPathAnnotatedServiceConfigSetters; +import com.linecorp.armeria.server.VirtualHostContextPathServiceBindingBuilder; +import com.linecorp.armeria.server.VirtualHostContextPathServicesBuilder; + +class ServiceBuilderSelfTypeTest { + + // A non-existent package is used to check if the API is exposed publicly. + + @Test + void contextPathAnnotatedServiceConfigSetters() { + final ContextPathAnnotatedServiceConfigSetters setters = + Server.builder() + .contextPath("/foo") + .annotatedService() + .addHeader("X-foo", "bar"); + final ContextPathServicesBuilder contextPathServicesBuilder = setters.build(new Object()); + final ServerBuilder serverBuilder = contextPathServicesBuilder.and(); + } + + @Test + void virtualHostContextPathAnnotatedServiceConfigSetters() { + final VirtualHostContextPathAnnotatedServiceConfigSetters setters = + Server.builder() + .virtualHost("foo.com") + .contextPath("/foo") + .annotatedService() + .addHeader("X-foo", "bar"); + final VirtualHostContextPathServicesBuilder contextPathServicesBuilder = setters.build(new Object()); + final VirtualHostBuilder serverBuilder = contextPathServicesBuilder.and(); + } + + @Test + void serviceBindingBuilder() { + final ServiceBindingBuilder serviceBindingBuilder = + Server.builder() + .route() + .path("/"); + final ServiceBindingBuilder serviceBindingBuilder1 = serviceBindingBuilder.decorator( + (delegate, ctx, req) -> null); + serviceBindingBuilder1.build((ctx, req) -> null); + } + + @Test + void contextPathServiceBindingBuilder() { + final ContextPathServiceBindingBuilder builder = + Server.builder() + .contextPath("/foo") + .route() + .path("/") + .addHeader("X-foo", "bar"); + builder.build((ctx, req) -> null); + } + + @Test + void virtualHostContextPathServiceBindingBuilder() { + final VirtualHostContextPathServiceBindingBuilder builder = + Server.builder() + .virtualHost("foo.com") + .contextPath("/foo") + .route() + .path("/") + .addHeader("X-foo", "bar"); + builder.build((ctx, req) -> null); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/file/FileServiceTest.java b/core/src/test/java/com/linecorp/armeria/server/file/FileServiceTest.java index b686b4a75dd..890599a4687 100644 --- a/core/src/test/java/com/linecorp/armeria/server/file/FileServiceTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/file/FileServiceTest.java @@ -41,6 +41,7 @@ import org.apache.hc.core5.http.io.entity.EntityUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.io.TempDir; @@ -54,8 +55,10 @@ import com.google.common.io.ByteStreams; import com.google.common.io.Resources; +import com.linecorp.armeria.client.BlockingWebClient; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.OsType; @@ -173,6 +176,23 @@ protected void configure(ServerBuilder sb) { .maxCacheEntries(0) .build())); + sb.serviceUnder( + "/no-extension", + FileService.builder(classLoader, baseResourceDir + "foo") + .build()); + sb.serviceUnder( + "/extension", + FileService.builder(classLoader, baseResourceDir + "foo") + .fallbackFileExtensions("txt") + .build()); + sb.serviceUnder( + "/extension/decompress", + FileService.builder(classLoader, baseResourceDir + "foo") + .fallbackFileExtensions("txt") + .serveCompressedFiles(true) + .autoDecompress(true) + .build()); + sb.decorator(LoggingService.newDecorator()); } }; @@ -625,6 +645,33 @@ void testFileSystemGetUtf8(String baseUri) throws Exception { } } + @Test + void useFileExtensionsToFindFile() { + final BlockingWebClient client = server.blockingWebClient(); + AggregatedHttpResponse response = client.get("/extension/foo.txt"); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("foo"); + + // Without .txt extension + response = client.get("/extension/foo"); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("foo"); + // Make sure that the existing operation is not affected by the fileExtensions option. + response = client.get("/extension/"); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("\n"); + + response = client.get("/no-extension/foo.txt"); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("foo"); + response = client.get("/no-extension/foo"); + assertThat(response.status()).isEqualTo(HttpStatus.NOT_FOUND); + + response = client.get("/extension/decompress/foo"); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("foo"); + } + private static void writeFile(Path path, String content) throws Exception { // Retry to work around the `AccessDeniedException` in Windows. for (int i = 9; i >= 0; i--) { diff --git a/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java b/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java index 0628138c968..a7dce0ca3a2 100644 --- a/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java @@ -30,7 +30,6 @@ import com.linecorp.armeria.client.websocket.WebSocketClient; import com.linecorp.armeria.client.websocket.WebSocketSession; import com.linecorp.armeria.common.AggregatedHttpResponse; -import com.linecorp.armeria.common.Flags; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.stream.StreamMessage; @@ -39,7 +38,6 @@ import com.linecorp.armeria.common.websocket.WebSocketFrameType; import com.linecorp.armeria.common.websocket.WebSocketWriter; import com.linecorp.armeria.server.ServerBuilder; -import com.linecorp.armeria.server.ServiceConfig; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.testing.junit5.server.ServerExtension; @@ -83,14 +81,6 @@ void shouldReturnFallbackResponse() { assertThat(response.contentUtf8()).isEqualTo("fallback"); } - @Test - void shouldNotSetDefaultSettings() { - final ServiceConfig serviceConfig = server.server().serviceConfigs().get(0); - assertThat(serviceConfig.service().as(DelegatingWebSocketService.class)).isNotNull(); - // The default settings for `WebSocketService` should be applied only to `DefaultWebSocketService`. - assertThat(serviceConfig.requestTimeoutMillis()).isEqualTo(Flags.defaultRequestTimeoutMillis()); - } - private static class EchoWebSocketHandler implements WebSocketServiceHandler { @Override diff --git a/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceConfigTest.java b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceConfigTest.java index 31862ce50c4..820e1ab1e1c 100644 --- a/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceConfigTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceConfigTest.java @@ -16,11 +16,13 @@ package com.linecorp.armeria.server.websocket; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.DEFAULT_MAX_REQUEST_RESPONSE_LENGTH; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS; import static org.assertj.core.api.Assertions.assertThat; import org.junit.jupiter.api.Test; -import com.linecorp.armeria.internal.common.websocket.WebSocketUtil; import com.linecorp.armeria.server.Server; import com.linecorp.armeria.server.ServiceConfig; import com.linecorp.armeria.server.websocket.WebSocketServiceTest.AbstractWebSocketHandler; @@ -32,25 +34,34 @@ void webSocketServiceDefaultConfigValues() { final WebSocketService webSocketService = WebSocketService.of(new AbstractWebSocketHandler()); final Server server = Server.builder().service("/", webSocketService).build(); assertThat(server.config().serviceConfigs()).hasSize(1); - ServiceConfig serviceConfig = server.config().serviceConfigs().get(0); + final ServiceConfig serviceConfig = server.config().serviceConfigs().get(0); assertThat(serviceConfig.requestTimeoutMillis()).isEqualTo( - WebSocketUtil.DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS); + DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS); assertThat(serviceConfig.maxRequestLength()).isEqualTo( - WebSocketUtil.DEFAULT_MAX_REQUEST_RESPONSE_LENGTH); + DEFAULT_MAX_REQUEST_RESPONSE_LENGTH); assertThat(serviceConfig.requestAutoAbortDelayMillis()).isEqualTo( - WebSocketUtil.DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS); + DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS); + } - server.reconfigure(sb -> sb.requestAutoAbortDelayMillis(1000) - .route() - .get("/") - .requestTimeoutMillis(2000) - .build(webSocketService)); + @Test + void webSocketServiceOptionsPriority() { + final WebSocketService webSocketService = WebSocketService.of(new AbstractWebSocketHandler()); + try (Server server = Server.builder() + .requestAutoAbortDelayMillis(1500) + .service("/", webSocketService) + .build()) { + final ServiceConfig sc = server.config().serviceConfigs().get(0); + assertThat(sc.requestAutoAbortDelayMillis()).isEqualTo(DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS); + } - assertThat(server.config().serviceConfigs()).hasSize(1); - serviceConfig = server.config().serviceConfigs().get(0); - assertThat(serviceConfig.requestTimeoutMillis()).isEqualTo(2000); - assertThat(serviceConfig.maxRequestLength()).isEqualTo( - WebSocketUtil.DEFAULT_MAX_REQUEST_RESPONSE_LENGTH); - assertThat(serviceConfig.requestAutoAbortDelayMillis()).isEqualTo(1000); + try (Server server = Server.builder() + .route() + .path("/") + .requestAutoAbortDelayMillis(1500) + .build(webSocketService) + .build()) { + final ServiceConfig sc = server.config().serviceConfigs().get(0); + assertThat(sc.requestAutoAbortDelayMillis()).isEqualTo(1500); + } } } diff --git a/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceOptionsTest.java b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceOptionsTest.java new file mode 100644 index 00000000000..72a977c8889 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceOptionsTest.java @@ -0,0 +1,102 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.websocket; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.client.BlockingWebClient; +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketSession; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.HttpService; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceOptions; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +class WebSocketServiceOptionsTest { + private static final ServiceOptions webSocketServiceOptions = + ServiceOptions.builder() + .requestTimeoutMillis(100001) + .maxRequestLength(10002) + .requestAutoAbortDelayMillis(10003) + .build(); + + private static final ServiceOptions serviceOptions = + ServiceOptions.builder() + .requestTimeoutMillis(50001) + .build(); + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final WebSocketService webSocketService = + WebSocketService + .builder((ctx, in) -> in) // echo back + .serviceOptions(webSocketServiceOptions) + .fallbackService(new HttpService() { + @Override + public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) + throws Exception { + return HttpResponse.of("fallback"); + } + + @Override + public ServiceOptions options() { + return serviceOptions; + } + }) + .build(); + sb.service("/ws-or-rest", webSocketService); + } + }; + + @Test + void overrideServiceOptions() throws InterruptedException { + final WebSocketClient webSocketClient = WebSocketClient.of(server.httpUri()); + final WebSocketSession session = webSocketClient.connect("/ws-or-rest").join(); + final WebSocketWriter out = session.outbound(); + out.write("hello"); + out.write("world"); + out.close(); + assertThat(session.inbound().collect().join().stream().map(WebSocketFrame::text)) + .contains("hello", "world"); + + final ServiceRequestContext wsCtx = server.requestContextCaptor().take(); + assertThat(wsCtx.requestTimeoutMillis()).isEqualTo(webSocketServiceOptions.requestTimeoutMillis()); + assertThat(wsCtx.maxRequestLength()).isEqualTo(webSocketServiceOptions.maxRequestLength()); + assertThat(wsCtx.requestAutoAbortDelayMillis()).isEqualTo( + webSocketServiceOptions.requestAutoAbortDelayMillis()); + + final BlockingWebClient restClient = server.blockingWebClient(); + assertThat(restClient.get("/ws-or-rest").contentUtf8()).isEqualTo("fallback"); + final ServiceRequestContext restCtx = server.requestContextCaptor().take(); + assertThat(restCtx.requestTimeoutMillis()).isEqualTo(serviceOptions.requestTimeoutMillis()); + // Respect the virtual host's configurations if no value is set in the ServiceOptions of the + // fallback service. + assertThat(restCtx.maxRequestLength()).isEqualTo(restCtx.config().virtualHost().maxRequestLength()); + assertThat(restCtx.requestAutoAbortDelayMillis()) + .isEqualTo(restCtx.config().virtualHost().requestAutoAbortDelayMillis()); + } +} diff --git a/docs-client/package-lock.json b/docs-client/package-lock.json index 62aa6d4d101..8a3e209c2ec 100644 --- a/docs-client/package-lock.json +++ b/docs-client/package-lock.json @@ -3644,12 +3644,12 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -6159,9 +6159,9 @@ } }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "dependencies": { "to-regex-range": "^5.0.1" @@ -12072,9 +12072,9 @@ } }, "node_modules/webpack-bundle-analyzer/node_modules/ws": { - "version": "7.5.9", - "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz", - "integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==", + "version": "7.5.10", + "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.10.tgz", + "integrity": "sha512-+dbF1tHwZpXcbOJdVOkzLDxZP1ailvSxM6ZweXTegylPny803bFhA+vqBYw4s31NSAk4S2Qz+AKXK9a4wkdjcQ==", "dev": true, "engines": { "node": ">=8.3.0" @@ -12447,16 +12447,16 @@ "dev": true }, "node_modules/ws": { - "version": "8.5.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.5.0.tgz", - "integrity": "sha512-BWX0SWVgLPzYwF8lTzEy1egjhS4S4OEAHfsO8o65WOVsrnSRGaSiUaa9e0ggGlkMTtBlmOpEXiie9RUcBO86qg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "engines": { "node": ">=10.0.0" }, "peerDependencies": { "bufferutil": "^4.0.1", - "utf-8-validate": "^5.0.2" + "utf-8-validate": ">=5.0.2" }, "peerDependenciesMeta": { "bufferutil": { diff --git a/examples/dropwizard/build.gradle b/examples/dropwizard/build.gradle index 0057ce8412a..87023189c52 100644 --- a/examples/dropwizard/build.gradle +++ b/examples/dropwizard/build.gradle @@ -15,6 +15,6 @@ dependencies { task runDropwizardExample(type: JavaExec) { classpath = sourceSets.main.runtimeClasspath - main = application.mainClass.get() + mainClass = application.mainClass.get() args = ['server', 'server.yaml'] } diff --git a/examples/grpc-envoy/build.gradle b/examples/grpc-envoy/build.gradle new file mode 100644 index 00000000000..1203c0986af --- /dev/null +++ b/examples/grpc-envoy/build.gradle @@ -0,0 +1,19 @@ +plugins { + id 'application' +} + +dependencies { + implementation project(':core') + implementation project(':grpc') + implementation libs.testcontainers.junit.jupiter + compileOnly libs.javax.annotation + runtimeOnly libs.slf4j.simple + + testImplementation project(':junit5') + testImplementation libs.assertj + testImplementation libs.junit5.jupiter.api +} + +application { + mainClass.set('example.armeria.grpc.envoy.Main') +} diff --git a/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/EnvoyContainer.java b/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/EnvoyContainer.java new file mode 100644 index 00000000000..d62d55767fb --- /dev/null +++ b/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/EnvoyContainer.java @@ -0,0 +1,62 @@ +package example.armeria.grpc.envoy; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.BindMode; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.output.Slf4jLogConsumer; + +import com.github.dockerjava.api.command.InspectContainerResponse; + +import com.linecorp.armeria.common.annotation.Nullable; + +// https://github.com/envoyproxy/java-control-plane/blob/eaca1a4380e53b4b6339db4e9ffe0ada5e0b7f8f/server/src/test/java/io/envoyproxy/controlplane/server/EnvoyContainer.java +class EnvoyContainer extends GenericContainer { + + private static final Logger LOGGER = LoggerFactory.getLogger(EnvoyContainer.class); + + private static final String CONFIG_DEST = "/etc/envoy/envoy.yaml"; + private static final String LAUNCH_ENVOY_SCRIPT = "envoy/launch_envoy.sh"; + private static final String LAUNCH_ENVOY_SCRIPT_DEST = "/usr/local/bin/launch_envoy.sh"; + + static final int ADMIN_PORT = 9901; + + private final String config; + @Nullable + private final String sedCommand; + + /** + * A {@link GenericContainer} implementation for envoy containers. + * + * @param sedCommand optional sed command which may be used to postprocess the provided {@param config}. + * This parameter will be fed into the command {@code sed -e }. + * An example command may be {@code "s/foo/bar/g;s/abc/def/g"}. + */ + EnvoyContainer(String config, @Nullable String sedCommand) { + super("envoyproxy/envoy:v1.30.1"); + this.config = config; + this.sedCommand = sedCommand; + } + + @Override + protected void configure() { + super.configure(); + + withClasspathResourceMapping(LAUNCH_ENVOY_SCRIPT, LAUNCH_ENVOY_SCRIPT_DEST, BindMode.READ_ONLY); + withClasspathResourceMapping(config, CONFIG_DEST, BindMode.READ_ONLY); + + if (sedCommand != null) { + withCommand("/bin/bash", "/usr/local/bin/launch_envoy.sh", + sedCommand, CONFIG_DEST, "-l", "debug"); + } + + addExposedPort(ADMIN_PORT); + } + + @Override + protected void containerIsStarting(InspectContainerResponse containerInfo) { + followOutput(new Slf4jLogConsumer(LOGGER).withPrefix("ENVOY")); + + super.containerIsStarting(containerInfo); + } +} diff --git a/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/HelloService.java b/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/HelloService.java new file mode 100644 index 00000000000..9622c5e6d90 --- /dev/null +++ b/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/HelloService.java @@ -0,0 +1,29 @@ +package example.armeria.grpc.envoy; + +import example.armeria.grpc.envoy.Hello.HelloReply; +import example.armeria.grpc.envoy.Hello.HelloRequest; +import example.armeria.grpc.envoy.HelloServiceGrpc.HelloServiceImplBase; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; + +public class HelloService extends HelloServiceImplBase { + + @Override + public void hello(HelloRequest request, StreamObserver responseObserver) { + if (request.getName().isEmpty()) { + responseObserver.onError( + Status.FAILED_PRECONDITION.withDescription("Name cannot be empty").asRuntimeException()); + } else { + responseObserver.onNext(buildReply(toMessage(request.getName()))); + responseObserver.onCompleted(); + } + } + + static String toMessage(String name) { + return "Hello, " + name + '!'; + } + + private static HelloReply buildReply(Object message) { + return HelloReply.newBuilder().setMessage(String.valueOf(message)).build(); + } +} diff --git a/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/Main.java b/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/Main.java new file mode 100644 index 00000000000..c8d1f142475 --- /dev/null +++ b/examples/grpc-envoy/src/main/java/example/armeria/grpc/envoy/Main.java @@ -0,0 +1,52 @@ +package example.armeria.grpc.envoy; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.DockerClientFactory; + +import com.linecorp.armeria.common.util.ShutdownHooks; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.grpc.GrpcService; + +public final class Main { + + private static final Logger logger = LoggerFactory.getLogger(Main.class); + + private static final int serverPort = 8080; + // the port envoy binds to within the container + private static final int envoyPort = 10000; + + public static void main(String[] args) { + if (!DockerClientFactory.instance().isDockerAvailable()) { + throw new IllegalStateException("Docker is not available"); + } + + final Server backendServer = startBackendServer(serverPort); + backendServer.closeOnJvmShutdown(); + backendServer.start().join(); + logger.info("Serving backend at http://127.0.0.1:{}/", backendServer.activePort()); + + final EnvoyContainer envoyProxy = configureEnvoy(serverPort, envoyPort); + ShutdownHooks.addClosingTask(envoyProxy::stop); + envoyProxy.start(); + final Integer mappedEnvoyPort = envoyProxy.getMappedPort(envoyPort); + logger.info("Serving envoy at http://127.0.0.1:{}/", mappedEnvoyPort); + } + + private static Server startBackendServer(int serverPort) { + return Server.builder() + .http(serverPort) + .service(GrpcService.builder() + .addService(new HelloService()) + .build()) + .build(); + } + + static EnvoyContainer configureEnvoy(int serverPort, int envoyPort) { + final String sedPattern = String.format("s/SERVER_PORT/%s/g;s/ENVOY_PORT/%s/g", serverPort, envoyPort); + return new EnvoyContainer("envoy/envoy.yaml", sedPattern) + .withExposedPorts(envoyPort); + } + + private Main() {} +} diff --git a/examples/grpc-envoy/src/main/proto/hello.proto b/examples/grpc-envoy/src/main/proto/hello.proto new file mode 100644 index 00000000000..f5c4ac2c9a7 --- /dev/null +++ b/examples/grpc-envoy/src/main/proto/hello.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package example.grpc.hello; +option java_package = "example.armeria.grpc.envoy"; +option java_multiple_files = false; + +import "google/api/annotations.proto"; + +service HelloService { + rpc Hello (HelloRequest) returns (HelloReply); +} + +message HelloRequest { + string name = 1; +} + +message HelloReply { + string message = 1; +} diff --git a/examples/grpc-envoy/src/main/resources/.gitkeep b/examples/grpc-envoy/src/main/resources/.gitkeep new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/grpc-envoy/src/main/resources/envoy/envoy.yaml b/examples/grpc-envoy/src/main/resources/envoy/envoy.yaml new file mode 100644 index 00000000000..983a7c0f775 --- /dev/null +++ b/examples/grpc-envoy/src/main/resources/envoy/envoy.yaml @@ -0,0 +1,52 @@ +admin: + address: + socket_address: { address: 0.0.0.0, port_value: 9901 } +static_resources: + listeners: + - name: listener_0 + address: + socket_address: + address: 0.0.0.0 + port_value: ENVOY_PORT + filter_chains: + - filters: + - name: envoy.filters.network.http_connection_manager + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + stat_prefix: ingress_http + http_protocol_options: + enable_trailers: true + codec_type: AUTO + route_config: + name: local_route + virtual_hosts: + - name: local_service + domains: ["*"] + routes: + - match: + prefix: "/" + route: + cluster: grpc_service + http_filters: + - name: envoy.filters.http.router + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + clusters: + - name: grpc_service + type: STRICT_DNS + lb_policy: ROUND_ROBIN + typed_extension_protocol_options: + envoy.extensions.upstreams.http.v3.HttpProtocolOptions: + "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions + explicit_http_config: + http_protocol_options: + enable_trailers: true + load_assignment: + cluster_name: grpc_service + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: host.testcontainers.internal + port_value: SERVER_PORT diff --git a/examples/grpc-envoy/src/main/resources/envoy/launch_envoy.sh b/examples/grpc-envoy/src/main/resources/envoy/launch_envoy.sh new file mode 100755 index 00000000000..34581990ee6 --- /dev/null +++ b/examples/grpc-envoy/src/main/resources/envoy/launch_envoy.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -Eeuo pipefail + +SED_COMMAND=$1 + +CONFIG=$(cat $2) +CONFIG_DIR=$(mktemp -d) +CONFIG_FILE="$CONFIG_DIR/envoy.yaml" + +echo "${CONFIG}" | sed -e "${SED_COMMAND}" > "${CONFIG_FILE}" + + +shift 2 +/usr/local/bin/envoy --drain-time-s 1 -c "${CONFIG_FILE}" "$@" + +rm -rf "${CONFIG_DIR}" diff --git a/examples/grpc-envoy/src/test/java/example/armeria/grpc/envoy/GrpcEnvoyProxyTest.java b/examples/grpc-envoy/src/test/java/example/armeria/grpc/envoy/GrpcEnvoyProxyTest.java new file mode 100644 index 00000000000..da800025caf --- /dev/null +++ b/examples/grpc-envoy/src/test/java/example/armeria/grpc/envoy/GrpcEnvoyProxyTest.java @@ -0,0 +1,54 @@ +package example.armeria.grpc.envoy; + +import static example.armeria.grpc.envoy.Main.configureEnvoy; +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.testcontainers.junit.jupiter.Testcontainers; + +import com.linecorp.armeria.client.grpc.GrpcClients; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.grpc.GrpcService; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import example.armeria.grpc.envoy.Hello.HelloReply; +import example.armeria.grpc.envoy.Hello.HelloRequest; + +@Testcontainers(disabledWithoutDocker = true) +class GrpcEnvoyProxyTest { + + // the port envoy binds to within the container + private static final int ENVOY_PORT = 10000; + + @RegisterExtension + static ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + sb.service(GrpcService.builder() + .addService(new HelloService()) + .build()); + } + }; + + @ParameterizedTest + @EnumSource(value = SessionProtocol.class, names = {"H1C", "H2C"}) + void reverseProxy(SessionProtocol sessionProtocol) { + org.testcontainers.Testcontainers.exposeHostPorts(server.httpPort()); + try (EnvoyContainer envoy = configureEnvoy(server.httpPort(), ENVOY_PORT)) { + envoy.start(); + final String uri = sessionProtocol.uriText() + "://" + envoy.getHost() + + ':' + envoy.getMappedPort(ENVOY_PORT); + final HelloServiceGrpc.HelloServiceBlockingStub helloService = + GrpcClients.builder(uri) + .build(HelloServiceGrpc.HelloServiceBlockingStub.class); + final HelloReply reply = + helloService.hello(HelloRequest.newBuilder() + .setName("Armeria") + .build()); + assertThat(reply.getMessage()).isEqualTo("Hello, Armeria!"); + } + } +} diff --git a/examples/tutorials/grpc/src/main/java/example/armeria/server/blog/grpc/GrpcExceptionHandler.java b/examples/tutorials/grpc/src/main/java/example/armeria/server/blog/grpc/GrpcExceptionHandler.java index c24b301194e..91feda5fd1e 100644 --- a/examples/tutorials/grpc/src/main/java/example/armeria/server/blog/grpc/GrpcExceptionHandler.java +++ b/examples/tutorials/grpc/src/main/java/example/armeria/server/blog/grpc/GrpcExceptionHandler.java @@ -11,7 +11,7 @@ class GrpcExceptionHandler implements GrpcExceptionHandlerFunction { @Nullable @Override - public Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata) { if (cause instanceof IllegalArgumentException) { return Status.INVALID_ARGUMENT.withCause(cause); } diff --git a/gradle.properties b/gradle.properties index 8720cb98061..bcc699d8a57 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ group=com.linecorp.armeria -version=1.28.2-SNAPSHOT +version=1.30.0-SNAPSHOT projectName=Armeria projectUrl=https://armeria.dev/ projectDescription=Asynchronous HTTP/2 RPC/REST client/server library built on top of Java 8, Netty, Thrift and gRPC diff --git a/gradle/scripts/.gitrepo b/gradle/scripts/.gitrepo index ea36b8a9f2e..bd6cfd448cc 100644 --- a/gradle/scripts/.gitrepo +++ b/gradle/scripts/.gitrepo @@ -6,7 +6,7 @@ [subrepo] remote = https://github.com/line/gradle-scripts branch = main - commit = d31f74478150e01781adafb43ed51aec6c830126 - parent = 33f1eeba8058f126922efffa0dd0ec87a06fd8be + commit = a3211a7ec874b42fc7dc5a84b3960a705d5fc34c + parent = c66b9211afdd86ce388b5b77b99fc7ffad6c0888 method = merge - cmdver = 0.4.5 + cmdver = 0.4.6 diff --git a/gradle/scripts/README.md b/gradle/scripts/README.md index 71b9903cad3..29bc8c81f0b 100644 --- a/gradle/scripts/README.md +++ b/gradle/scripts/README.md @@ -34,7 +34,6 @@ sensible defaults. By applying them, you can: - [Shading a multi-module project with `relocate` flag](#shading-a-multi-module-project-with-relocate-flag) - [Setting a Java target version with the `java(\\d+)` flag](#setting-a-java-target-version-with-the-javad-flag) - [Setting a Kotlin target version with the `kotlin(\\d+\\.\\d+)` flag](#setting-a-koltin-target-version-with-the-kotlindd-flag) -- [Automatic module names](#automatic-module-names) - [Tagging conveniently with `release` task](#tagging-conveniently-with-release-task) @@ -99,7 +98,6 @@ sensible defaults. By applying them, you can: ``` group=com.doe.john.myexample version=0.0.1-SNAPSHOT - versionPattern=^[0-9]+\\.[0-9]+\\.[0-9]+$ projectName=My Example projectUrl=https://www.example.com/ projectDescription=My example project @@ -120,7 +118,6 @@ sensible defaults. By applying them, you can: googleAnalyticsId=UA-XXXXXXXX javaSourceCompatibility=1.8 javaTargetCompatibility=1.8 - automaticModuleNames=false ``` 5. That's all. You now have two Java subprojects with sensible defaults. @@ -576,6 +573,9 @@ relocations [ { from: "com.google.common", to: "com.doe.john.myproject.shaded.gu { from: "com.google.thirdparty.publicsuffix", to: "com.doe.john.myproject.shaded.publicsuffix" } ] ``` +Unshaded tests are disabled by default when a shading task is configured. If you want to run unshaded tests, +you can specify `-PpreferShadedTests=false` option. + ### Trimming a shaded JAR with `trim` flag If you shade many dependencies, your JAR will grow huge, even if you only use @@ -675,31 +675,6 @@ However, if you want to compile a Kotlin module with a different language versio For example, `kotlin1.6` flag makes your Kotlin module compatible with language version 1.6 and API version 1.6. -## Automatic module names - -By specifying the `automaticModuleNames=true` property in `settings.gradle`, every `java` project's JAR -file will contain the `Automatic-Module-Name` property in its `MANIFEST.MF`, auto-generated from the group ID -and artifact ID. For example: - -- groupId: `com.example`, artifactId: `foo-bar` - - module name: `com.example.foo.bar` -- groupId: `com.example.foo`, artifactId: `foo-bar` - - module name: `com.example.foo.bar` - -If enabled, each project with `java` flag will have the `automaticModuleName` property. - -You can override the automatic module name of a certain project via the `automaticModuleNameOverrides` -extension property: - - ```groovy - ext { - // Change the automatic module name of project ':bar' to 'com.example.fubar'. - automaticModuleNameOverrides = [ - ':bar': 'com.example.fubar' - ] - } - ``` - ## Tagging conveniently with `release` task The task called `release` is added at the top level project. It will update the diff --git a/gradle/scripts/lib/bom.gradle b/gradle/scripts/lib/bom.gradle index c955df5f5d0..c971c9f08d1 100644 --- a/gradle/scripts/lib/bom.gradle +++ b/gradle/scripts/lib/bom.gradle @@ -26,8 +26,8 @@ configure(projectsWithFlags('bom')) { "Please check bomGroups property", project.name) } subs = bomGroups.get(project.path) - if (!(subs.value instanceof List)) { - throw new IllegalStateException("bomGroups' value must be a List: ${subs.value}") + if (!(subs instanceof List)) { + throw new IllegalStateException("bomGroups' value must be a List: ${subs}") } } diff --git a/gradle/scripts/lib/common-info.gradle b/gradle/scripts/lib/common-info.gradle index 0f29c25407e..13666f21cf7 100644 --- a/gradle/scripts/lib/common-info.gradle +++ b/gradle/scripts/lib/common-info.gradle @@ -42,9 +42,17 @@ allprojects { ext { artifactId = { // Use the overridden one if available. - def overriddenArtifactId = findOverridden('artifactIdOverrides', project) - if (overriddenArtifactId != null) { - return overriddenArtifactId + if (rootProject.ext.has('artifactIdOverrides')) { + def overrides = rootProject.ext.artifactIdOverrides + if (!(overrides instanceof Map)) { + throw new IllegalStateException("artifactIdOverrides must be a Map: ${overrides}") + } + + for (Map.Entry e : overrides.entrySet()) { + if (rootProject.project(e.key) == project) { + return e.value + } + } } // Generate from the project names otherwise. @@ -62,49 +70,3 @@ allprojects { }.call() } } - -// Check whether to enable automatic module names. -def isAutomaticModuleNameEnabled = 'true' == rootProject.findProperty('automaticModuleNames') - -allprojects { - ext { - automaticModuleName = { - if (!isAutomaticModuleNameEnabled) { - return null - } - - // Use the overridden one if available. - def overriddenAutomaticModuleName = findOverridden('automaticModuleNameOverrides', project) - if (overriddenAutomaticModuleName != null) { - return overriddenAutomaticModuleName - } - - // Generate from the groupId and artifactId otherwise. - def groupIdComponents = String.valueOf(rootProject.group).split("\\.").toList() - def artifactIdComponents = - String.valueOf(project.ext.artifactId).replace('-', '.').split("\\.").toList() - if (groupIdComponents.last() == artifactIdComponents.first()) { - return String.join('.', groupIdComponents + artifactIdComponents.drop(1)) - } else { - return String.join('.', groupIdComponents + artifactIdComponents) - } - }.call() - } -} - -def findOverridden(String overridesPropertyName, Project project) { - if (rootProject.ext.has(overridesPropertyName)) { - def overrides = rootProject.ext.get(overridesPropertyName) - if (!(overrides instanceof Map)) { - throw new IllegalStateException("rootProject.ext.${overridesPropertyName} must be a Map: ${overrides}") - } - - for (Map.Entry e : overrides.entrySet()) { - if (rootProject.project(e.key) == project) { - return String.valueOf(e.value) - } - } - } - - return null -} diff --git a/gradle/scripts/lib/common-publish.gradle b/gradle/scripts/lib/common-publish.gradle index 0a89858b617..07d1f6b3c27 100644 --- a/gradle/scripts/lib/common-publish.gradle +++ b/gradle/scripts/lib/common-publish.gradle @@ -61,6 +61,11 @@ if (publishToStaging) { password = project.findProperty(publishPasswordProperty) } } + + transitionCheckOptions { + maxRetries.set(100) + delayBetween.set(Duration.ofSeconds(20)) + } } } diff --git a/gradle/scripts/lib/java-javadoc.gradle b/gradle/scripts/lib/java-javadoc.gradle index 7c30d24b783..92fbeb45589 100644 --- a/gradle/scripts/lib/java-javadoc.gradle +++ b/gradle/scripts/lib/java-javadoc.gradle @@ -496,7 +496,11 @@ class DownloadJavadocPackageListTask extends DefaultTask { conn.disconnect() } catch (e) { tmpListFile.delete() - logger.log(LogLevel.WARN, "Download failed: ${e}", e) + if (project.gradle.startParameter.showStacktrace == ShowStacktrace.ALWAYS_FULL) { + logger.log(LogLevel.WARN, "Download failed: ${e}", e) + } else { + logger.log(LogLevel.WARN, "Download failed: ${e}") + } } return success diff --git a/gradle/scripts/lib/java-publish.gradle b/gradle/scripts/lib/java-publish.gradle index 8bcf3c25236..76031617731 100644 --- a/gradle/scripts/lib/java-publish.gradle +++ b/gradle/scripts/lib/java-publish.gradle @@ -16,7 +16,7 @@ configure(projectsWithFlags('publish', 'java')) { jarOverrideFile = tasks.trimShadedJar.outJarFiles.find() as File jarOverrideTask = tasks.trimShadedJar } else if (tasks.findByName('shadedJar')) { - jarOverrideFile = tasks.shadedJar.archivePath + jarOverrideFile = tasks.shadedJar.archiveFile.get().asFile jarOverrideTask = tasks.shadedJar } if (jarOverrideFile != null) { diff --git a/gradle/scripts/lib/java-shade.gradle b/gradle/scripts/lib/java-shade.gradle index 1a2f41d5236..061b1304470 100644 --- a/gradle/scripts/lib/java-shade.gradle +++ b/gradle/scripts/lib/java-shade.gradle @@ -27,23 +27,12 @@ configure(relocatedProjects) { configureShadowTask(project, delegate, true) archiveBaseName.set("${project.archivesBaseName}-shaded") - // Exclude the legacy file listing. - exclude '/META-INF/INDEX.LIST' // Exclude the class signature files. exclude '/META-INF/*.SF' exclude '/META-INF/*.DSA' exclude '/META-INF/*.RSA' // Exclude the files generated by Maven exclude '/META-INF/maven/**' - // Exclude the module metadata that'll become invalid after relocation. - exclude '**/module-info.class' - - // Set the 'Automatic-Module-Name' property in MANIFEST.MF. - if (project.ext.automaticModuleName != null) { - manifest { - attributes('Automatic-Module-Name': project.ext.automaticModuleName) - } - } } tasks.assemble.dependsOn tasks.shadedJar @@ -57,7 +46,7 @@ configure(relocatedProjects) { description: 'Extracts the shaded main JAR.', dependsOn: tasks.shadedJar) { - from(zipTree(tasks.shadedJar.archivePath)) + from(zipTree(tasks.shadedJar.archiveFile.get().asFile)) from(sourceSets.main.output.classesDirs) { // Add the JAR resources excluded in the 'shadedJar' task. include '**/*.jar' @@ -81,7 +70,7 @@ configure(relocatedProjects) { description: 'Extracts the shaded test JAR.', dependsOn: tasks.shadedTestJar) { - from(zipTree(tasks.shadedTestJar.archivePath)) + from(zipTree(tasks.shadedTestJar.archiveFile.get().asFile)) from(sourceSets.test.output.classesDirs) { // Add the JAR resources excluded in the 'shadedTestJar' task. include '**/*.jar' @@ -107,7 +96,7 @@ configure(relocatedProjects) { dependsOn it.tasks.shadedTestJar.path } - def shadedFile = tasks.shadedJar.archivePath + def shadedFile = tasks.shadedJar.archiveFile.get().asFile def shadedAndTrimmedFile = file(shadedFile.path.replaceFirst('-untrimmed-', '-shaded-')) injars shadedFile @@ -118,11 +107,11 @@ configure(relocatedProjects) { // Include all other shaded JARs so that ProGuard does not trim the classes and methods // that are used actually. - injars tasks.shadedTestJar.archivePath + injars tasks.shadedTestJar.archiveFile.get().asFile relocatedProjects.each { if (it != project && !it.hasFlags('no_aggregation')) { - injars it.tasks.shadedJar.archivePath - injars it.tasks.shadedTestJar.archivePath + injars it.tasks.shadedJar.archiveFile.get().asFile + injars it.tasks.shadedTestJar.archiveFile.get().asFile } } @@ -208,7 +197,6 @@ configure(relocatedProjects) { group = 'Verification' description = 'Runs the unit tests with the shaded classes.' - project.ext.configureFlakyTests(it) project.ext.configureCommonTestSettings(it) dependsOn tasks.copyShadedTestClasses @@ -266,8 +254,13 @@ configure(relocatedProjects) { gradle.taskGraph.whenReady { // Skip unshaded tests if shaded tests will run. + // To enable, set the property 'preferShadedTests' to 'false'. + boolean runUnshadedTests = false + if (rootProject.hasProperty('preferShadedTests') && "false" == rootProject.property('preferShadedTests')) { + runUnshadedTests = true + } if (gradle.taskGraph.hasTask(tasks.shadedTest)) { - tasks.test.onlyIf { false } + tasks.test.onlyIf { runUnshadedTests } } } } @@ -336,7 +329,7 @@ private Configuration configureShadedTestImplementConfiguration( if (recursedProject.tasks.findByName('trimShadedJar')) { project.dependencies.add(shadedJarTestImplementation.name, files(recursedProject.tasks.trimShadedJar.outJarFiles)) } else if (recursedProject.tasks.findByName('shadedJar')) { - project.dependencies.add(shadedJarTestImplementation.name, files(recursedProject.tasks.shadedJar.archivePath)) + project.dependencies.add(shadedJarTestImplementation.name, files(recursedProject.tasks.shadedJar.archiveFile.get().asFile)) } def shadedDependencyNames = project.ext.relocations.collect { it['name'] } diff --git a/gradle/scripts/lib/java.gradle b/gradle/scripts/lib/java.gradle index 13ddbe58e2c..b9ebc8cae0d 100644 --- a/gradle/scripts/lib/java.gradle +++ b/gradle/scripts/lib/java.gradle @@ -1,6 +1,5 @@ import java.util.regex.Pattern -// Determine which version of JDK should be used for builds. def buildJdkVersion = Integer.parseInt(JavaVersion.current().getMajorVersion()) if (rootProject.hasProperty('buildJdkVersion')) { def jdkVersion = Integer.parseInt(String.valueOf(rootProject.findProperty('buildJdkVersion'))) @@ -50,7 +49,9 @@ configure(projectsWithFlags('java')) { apply plugin: 'idea' apply plugin: 'jvm-test-suite' - archivesBaseName = project.ext.artifactId + base { + archivesName = project.ext.artifactId + } // Delete the generated source directory on clean. ext { @@ -140,12 +141,6 @@ configure(projectsWithFlags('java')) { registerFeature('optional') { usingSourceSet(sourceSets.main) } - - // Do not let Gradle infer the module path if automatic module name is enabled, - // because it means the JAR will rely on JDK's automatic module metadata generation. - if (project.ext.automaticModuleName != null) { - modularity.inferModulePath = false - } } // Set the sensible compiler options. @@ -161,15 +156,6 @@ configure(projectsWithFlags('java')) { options.compilerArgs += '-parameters' } - // Set the 'Automatic-Module-Name' property in 'MANIFEST.MF' if `automaticModuleName` is not null. - if (project.ext.automaticModuleName != null) { - tasks.named('jar') { - manifest { - attributes('Automatic-Module-Name': project.ext.automaticModuleName) - } - } - } - project.ext.configureFlakyTests = { Test testTask -> def flakyTests = rootProject.findProperty('flakyTests') if (flakyTests == 'true') { @@ -201,6 +187,9 @@ configure(projectsWithFlags('java')) { "targetJavaVersion(${project.ext.targetJavaVersion})") testTask.enabled = false } + if (testTask.enabled) { + project.ext.configureFlakyTests(testTask) + } } testing.suites { @@ -210,7 +199,6 @@ configure(projectsWithFlags('java')) { targets.configureEach { testTask.configure { - project.ext.configureFlakyTests(it) project.ext.configureCommonTestSettings(it) } } diff --git a/gradle/scripts/lib/prerequisite.gradle b/gradle/scripts/lib/prerequisite.gradle index 713deb42f7b..7a94d9b4f8e 100644 --- a/gradle/scripts/lib/prerequisite.gradle +++ b/gradle/scripts/lib/prerequisite.gradle @@ -9,15 +9,12 @@ plugins { ''') } -['group', 'version', 'projectName', 'projectUrl', 'inceptionYear', 'licenseName', 'licenseUrl', 'scmUrl', 'scmConnection', +['projectName', 'projectUrl', 'inceptionYear', 'licenseName', 'licenseUrl', 'scmUrl', 'scmConnection', 'scmDeveloperConnection', 'publishUrlForRelease', 'publishUrlForSnapshot', 'publishUsernameProperty', 'publishPasswordProperty'].each { if (rootProject.findProperty(it) == null) { throw new IllegalStateException('''Add project info properties to gradle.properties: -group=com.doe.john.myexample -version=0.0.1-SNAPSHOT -versionPattern=^[0-9]+\\\\.[0-9]+\\\\.[0-9]+$ projectName=My Example projectUrl=https://www.example.com/ projectDescription=My example project @@ -34,12 +31,7 @@ publishUrlForRelease=https://oss.sonatype.org/service/local/staging/deploy/maven publishUrlForSnapshot=https://oss.sonatype.org/content/repositories/snapshots/ publishUsernameProperty=ossrhUsername publishPasswordProperty=ossrhPassword -publishSignatureRequired=true versionPattern=^[0-9]+\\\\.[0-9]+\\\\.[0-9]+$ -googleAnalyticsId=UA-XXXXXXXX -javaSourceCompatibility=1.8 -javaTargetCompatibility=1.8 -automaticModuleNames=false ''') } } diff --git a/gradle/scripts/settings-flags.gradle b/gradle/scripts/settings-flags.gradle index fc0b8972f9c..fbe5147ecbd 100644 --- a/gradle/scripts/settings-flags.gradle +++ b/gradle/scripts/settings-flags.gradle @@ -187,5 +187,9 @@ static void afterProjectsWithFlags(Project project, Iterable flags, Closure dataLoaderRegistryFunction; + ? extends DataLoaderRegistry> dataLoaderRegistryFunction; private final boolean useBlockingTaskExecutor; @@ -111,25 +112,35 @@ public CompletableFuture executeGraphql(ServiceRequestContext c private HttpResponse execute( ServiceRequestContext ctx, ExecutionInput input, MediaType produceType) { - final CompletableFuture future = executeGraphql(ctx, input); - return HttpResponse.of( - future.handle((executionResult, cause) -> { - if (executionResult.getData() instanceof Publisher) { - logger.warn("executionResult.getData() returns a {} that is not supported yet.", - executionResult.getData().toString()); - - return HttpResponse.ofJson(HttpStatus.NOT_IMPLEMENTED, - produceType, - toSpecification( - "Use GraphQL over WebSocket for subscription")); - } - - if (executionResult.getErrors().isEmpty() && cause == null) { - return HttpResponse.ofJson(produceType, executionResult.toSpecification()); - } - - return errorHandler.handle(ctx, input, executionResult, cause); - })); + try { + final CompletableFuture future = executeGraphql(ctx, input); + return HttpResponse.of( + future.handle((executionResult, cause) -> { + if (cause != null) { + cause = Exceptions.peel(cause); + return errorHandler.handle(ctx, input, null, cause); + } + + if (executionResult.getData() instanceof Publisher) { + logger.warn("Use GraphQL over WebSocket for subscription. " + + "executionResult.getData(): {}", executionResult.getData().toString()); + + return HttpResponse.ofJson(HttpStatus.NOT_IMPLEMENTED, + produceType, + toSpecification( + "Use GraphQL over WebSocket for subscription")); + } + + if (executionResult.getErrors().isEmpty()) { + return HttpResponse.ofJson(produceType, executionResult.toSpecification()); + } + + return errorHandler.handle(ctx, input, executionResult, null); + })); + } catch (Throwable cause) { + cause = Exceptions.peel(cause); + return errorHandler.handle(ctx, input, null, cause); + } } static Map toSpecification(String message) { diff --git a/graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlErrorHandler.java b/graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlErrorHandler.java index 02d7afa4c1b..2273db905a8 100644 --- a/graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlErrorHandler.java +++ b/graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlErrorHandler.java @@ -46,7 +46,8 @@ static GraphqlErrorHandler of() { */ @Nullable HttpResponse handle( - ServiceRequestContext ctx, ExecutionInput input, ExecutionResult result, @Nullable Throwable cause); + ServiceRequestContext ctx, ExecutionInput input, @Nullable ExecutionResult result, + @Nullable Throwable cause); /** * Returns a composed {@link GraphqlErrorHandler} that applies this first and the specified diff --git a/graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlWebSocketService.java b/graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlWebSocketService.java index ffe15121764..4f7dda740a5 100644 --- a/graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlWebSocketService.java +++ b/graphql/src/main/java/com/linecorp/armeria/server/graphql/GraphqlWebSocketService.java @@ -24,6 +24,7 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.websocket.WebSocket; import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.ServiceOptions; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.server.websocket.WebSocketProtocolHandler; import com.linecorp.armeria.server.websocket.WebSocketService; @@ -74,4 +75,9 @@ public WebSocket handle(ServiceRequestContext ctx, WebSocket in) { in.subscribe(new GraphqlWebSocketSubscriber(protocol, outgoing)); return outgoing; } + + @Override + public ServiceOptions options() { + return delegate.options(); + } } diff --git a/graphql/src/test/java/com/linecorp/armeria/server/graphql/GraphqlErrorHandlerTest.java b/graphql/src/test/java/com/linecorp/armeria/server/graphql/GraphqlErrorHandlerTest.java index 5938af365b5..ac7e7a6732f 100644 --- a/graphql/src/test/java/com/linecorp/armeria/server/graphql/GraphqlErrorHandlerTest.java +++ b/graphql/src/test/java/com/linecorp/armeria/server/graphql/GraphqlErrorHandlerTest.java @@ -20,83 +20,163 @@ import java.io.File; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.internal.testing.AnticipatedException; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.testing.junit5.server.ServerExtension; +import graphql.GraphQL; import graphql.GraphQLError; import graphql.GraphqlErrorException; +import graphql.execution.instrumentation.Instrumentation; +import graphql.execution.instrumentation.InstrumentationState; +import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters; import graphql.schema.DataFetcher; +import graphql.schema.GraphQLSchema; +import graphql.schema.idl.RuntimeWiring; +import graphql.schema.idl.SchemaGenerator; +import graphql.schema.idl.SchemaParser; +import graphql.schema.idl.TypeDefinitionRegistry; class GraphqlErrorHandlerTest { + private static final AtomicBoolean shouldFailRequests = new AtomicBoolean(); + + private static GraphQL newGraphQL() throws Exception { + final File graphqlSchemaFile = + new File(GraphqlErrorHandlerTest.class.getResource("/testing/graphql/test.graphqls").toURI()); + final SchemaParser schemaParser = new SchemaParser(); + final SchemaGenerator schemaGenerator = new SchemaGenerator(); + final TypeDefinitionRegistry typeRegistry = new TypeDefinitionRegistry(); + typeRegistry.merge(schemaParser.parse(graphqlSchemaFile)); + final RuntimeWiring.Builder runtimeWiringBuilder = RuntimeWiring.newRuntimeWiring(); + final DataFetcher foo = dataFetcher("foo"); + runtimeWiringBuilder.type("Query", + typeWiring -> typeWiring.dataFetcher("foo", foo)); + final DataFetcher error = dataFetcher("error"); + runtimeWiringBuilder.type("Query", + typeWiring -> typeWiring.dataFetcher("error", error)); + + final GraphQLSchema graphQLSchema = schemaGenerator.makeExecutableSchema(typeRegistry, + runtimeWiringBuilder.build()); + final Instrumentation instrumentation = new Instrumentation() { + @Override + public InstrumentationState createState( + InstrumentationCreateStateParameters parameters) { + if (shouldFailRequests.get()) { + throw new AnticipatedException("external exception"); + } else { + return Instrumentation.super.createState(parameters); + } + } + }; + + return new GraphQL.Builder(graphQLSchema) + .instrumentation(instrumentation) + .build(); + } + + private static final GraphqlErrorHandler errorHandler + = (ctx, input, result, cause) -> { + if (result == null) { + assertThat(cause).isNotNull(); + return HttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR, MediaType.PLAIN_TEXT, + cause.getMessage()); + } + final List errors = result.getErrors(); + if (errors.stream().map(GraphQLError::getMessage).anyMatch(m -> m.endsWith("foo"))) { + return HttpResponse.of(HttpStatus.BAD_REQUEST); + } + return null; + }; + + private static DataFetcher dataFetcher(String value) { + return environment -> { + final ServiceRequestContext ctx = GraphqlServiceContexts.get(environment); + // Make sure that a ServiceRequestContext is available + assertThat(ServiceRequestContext.current()).isSameAs(ctx); + throw GraphqlErrorException.newErrorException().message(value).build(); + }; + } + @RegisterExtension static ServerExtension server = new ServerExtension() { @Override protected void configure(ServerBuilder sb) throws Exception { - final File graphqlSchemaFile = - new File(getClass().getResource("/testing/graphql/test.graphqls").toURI()); - - final GraphqlErrorHandler errorHandler - = (ctx, input, result, cause) -> { - final List errors = result.getErrors(); - if (errors.stream().map(GraphQLError::getMessage).anyMatch(m -> m.endsWith("foo"))) { - return HttpResponse.of(HttpStatus.BAD_REQUEST); - } - return null; - }; final GraphqlService service = GraphqlService.builder() - .schemaFile(graphqlSchemaFile) - .runtimeWiring(c -> { - final DataFetcher foo = dataFetcher("foo"); - c.type("Query", - typeWiring -> typeWiring.dataFetcher("foo", foo)); - final DataFetcher error = dataFetcher("error"); - c.type("Query", - typeWiring -> typeWiring.dataFetcher("error", error)); - }) + .graphql(newGraphQL()) .errorHandler(errorHandler) .build(); sb.service("/graphql", service); } }; - private static DataFetcher dataFetcher(String value) { - return environment -> { - final ServiceRequestContext ctx = GraphqlServiceContexts.get(environment); - assertThat(ctx.eventLoop().inEventLoop()).isTrue(); - // Make sure that a ServiceRequestContext is available - assertThat(ServiceRequestContext.current()).isSameAs(ctx); - throw GraphqlErrorException.newErrorException().message(value).build(); - }; + @RegisterExtension + static ServerExtension blockingServer = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + + final GraphqlService service = + GraphqlService.builder() + .graphql(newGraphQL()) + .useBlockingTaskExecutor(true) + .errorHandler(errorHandler) + .build(); + sb.service("/graphql", service); + } + }; + + @BeforeEach + void setUp() { + shouldFailRequests.set(false); } - @Test - void handledError() { + @ValueSource(booleans = { true, false }) + @ParameterizedTest + void handledError(boolean blocking) { final HttpRequest request = HttpRequest.builder().post("/graphql") .content(MediaType.GRAPHQL, "{foo}") .build(); + final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server; final AggregatedHttpResponse response = server.blockingWebClient().execute(request); assertThat(response.status()).isEqualTo(HttpStatus.BAD_REQUEST); } - @Test - void unhandledError() { + @ValueSource(booleans = { true, false }) + @ParameterizedTest + void unhandledGraphqlError(boolean blocking) { final HttpRequest request = HttpRequest.builder().post("/graphql") .content(MediaType.GRAPHQL, "{error}") .build(); + final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server; final AggregatedHttpResponse response = server.blockingWebClient().execute(request); assertThat(response.status()).isEqualTo(HttpStatus.OK); } + + @ValueSource(booleans = { true, false }) + @ParameterizedTest + void unhandledException(boolean blocking) { + shouldFailRequests.set(true); + final HttpRequest request = HttpRequest.builder().post("/graphql") + .content(MediaType.GRAPHQL, "{error}") + .build(); + final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server; + final AggregatedHttpResponse response = server.blockingWebClient().execute(request); + assertThat(response.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR); + assertThat(response.contentUtf8()).isEqualTo("external exception"); + } } diff --git a/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt b/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt index baf008c14f9..6573dc1f63b 100644 --- a/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt +++ b/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt @@ -228,7 +228,12 @@ internal class CoroutineServerInterceptorTest { object : ServerExtension() { override fun configure(sb: ServerBuilder) { val exceptionHandler = - GrpcExceptionHandlerFunction { _: RequestContext, throwable: Throwable, _: Metadata -> + GrpcExceptionHandlerFunction { + _: RequestContext, + _: Status, + throwable: Throwable, + _: Metadata, + -> if (throwable is AnticipatedException && throwable.message == "Invalid access") { return@GrpcExceptionHandlerFunction Status.UNAUTHENTICATED } diff --git a/grpc-protocol/src/main/java/com/linecorp/armeria/client/grpc/protocol/UnaryGrpcClient.java b/grpc-protocol/src/main/java/com/linecorp/armeria/client/grpc/protocol/UnaryGrpcClient.java index 74eae7c80bd..b8ecd121efd 100644 --- a/grpc-protocol/src/main/java/com/linecorp/armeria/client/grpc/protocol/UnaryGrpcClient.java +++ b/grpc-protocol/src/main/java/com/linecorp/armeria/client/grpc/protocol/UnaryGrpcClient.java @@ -31,10 +31,12 @@ import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.Clients; import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.RequestOptions; import com.linecorp.armeria.client.SimpleDecoratingHttpClient; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.AggregationOptions; +import com.linecorp.armeria.common.ExchangeType; import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; @@ -72,12 +74,17 @@ */ @UnstableApi public final class UnaryGrpcClient { + + private static final Logger logger = LoggerFactory.getLogger(UnaryGrpcClient.class); + private static final Set SUPPORTED_SERIALIZATION_FORMATS = UnaryGrpcSerializationFormats.values(); + private static final RequestOptions REQUEST_OPTIONS = + RequestOptions.builder().exchangeType(ExchangeType.UNARY).build(); + private final SerializationFormat serializationFormat; private final WebClient webClient; - private static final Logger logger = LoggerFactory.getLogger(UnaryGrpcClient.class); /** * Constructs a {@link UnaryGrpcClient} for the given {@link WebClient}. @@ -131,7 +138,7 @@ public CompletableFuture execute(String uri, byte[] payload) { RequestHeaders.builder(HttpMethod.POST, uri).contentType(serializationFormat.mediaType()) .add(HttpHeaderNames.TE, HttpHeaderValues.TRAILERS.toString()).build(), HttpData.wrap(payload)); - return webClient.execute(request).aggregate( + return webClient.execute(request, REQUEST_OPTIONS).aggregate( AggregationOptions.builder() .usePooledObjects(PooledByteBufAllocator.DEFAULT) .build()) diff --git a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java index 039761f0511..7b9c59e5baa 100644 --- a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java +++ b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java @@ -75,7 +75,6 @@ import com.linecorp.armeria.common.grpc.GrpcSerializationFormats; import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageDeframer; import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageFramer; -import com.linecorp.armeria.internal.common.grpc.UnwrappingGrpcExceptionHandleFunction; import com.linecorp.armeria.unsafe.grpc.GrpcUnsafeBufferUtil; import io.grpc.CallCredentials; @@ -419,8 +418,7 @@ public T build(Class clientType) { option(INTERCEPTORS.newValue(clientInterceptors)); } if (exceptionHandler != null) { - option(EXCEPTION_HANDLER.newValue(new UnwrappingGrpcExceptionHandleFunction(exceptionHandler.orElse( - GrpcExceptionHandlerFunction.of())))); + option(EXCEPTION_HANDLER.newValue(exceptionHandler.orElse(GrpcExceptionHandlerFunction.of()))); } final Object client; diff --git a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientOptions.java b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientOptions.java index 09df88b849c..278b19844a5 100644 --- a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientOptions.java +++ b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientOptions.java @@ -37,7 +37,6 @@ import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageFramer; import com.linecorp.armeria.internal.client.grpc.NullCallCredentials; import com.linecorp.armeria.internal.client.grpc.NullGrpcClientStubFactory; -import com.linecorp.armeria.internal.common.grpc.UnwrappingGrpcExceptionHandleFunction; import com.linecorp.armeria.unsafe.grpc.GrpcUnsafeBufferUtil; import io.grpc.CallCredentials; @@ -174,8 +173,7 @@ public final class GrpcClientOptions { * to a gRPC {@link Status}. */ public static final ClientOption EXCEPTION_HANDLER = - ClientOption.define("EXCEPTION_HANDLER", new UnwrappingGrpcExceptionHandleFunction( - GrpcExceptionHandlerFunction.of())); + ClientOption.define("EXCEPTION_HANDLER", GrpcExceptionHandlerFunction.of()); /** * Sets whether to respect the marshaller specified in gRPC {@link MethodDescriptor}. diff --git a/grpc/src/main/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunction.java b/grpc/src/main/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunction.java index ce1c35ea51d..4a0203b19e4 100644 --- a/grpc/src/main/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunction.java +++ b/grpc/src/main/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunction.java @@ -45,7 +45,10 @@ enum DefaultGrpcExceptionHandlerFunction implements GrpcExceptionHandlerFunction * well and the protocol package. */ @Override - public Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata) { + if (status.getCode() != Code.UNKNOWN) { + return status; + } final Status s = Status.fromThrowable(cause); if (s.getCode() != Code.UNKNOWN) { return s; diff --git a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GoogleGrpcExceptionHandlerFunction.java b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GoogleGrpcExceptionHandlerFunction.java index 0cbc2bcd7fe..9d32de1003e 100644 --- a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GoogleGrpcExceptionHandlerFunction.java +++ b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GoogleGrpcExceptionHandlerFunction.java @@ -40,7 +40,7 @@ public interface GoogleGrpcExceptionHandlerFunction extends GrpcExceptionHandler @Nullable @Override - default Status apply(RequestContext ctx, Throwable throwable, Metadata metadata) { + default Status apply(RequestContext ctx, Status status, Throwable throwable, Metadata metadata) { return handleException(ctx, throwable, metadata, this::applyStatusProto); } @@ -48,7 +48,7 @@ default Status apply(RequestContext ctx, Throwable throwable, Metadata metadata) * Maps the specified {@link Throwable} to a {@link com.google.rpc.Status}, * and mutates the specified {@link Metadata}. * The `grpc-status-details-bin` key is ignored since it will be overwritten - * by {@link GoogleGrpcExceptionHandlerFunction#apply(RequestContext, Throwable, Metadata)}. + * by {@link GrpcExceptionHandlerFunction#apply(RequestContext, Status, Throwable, Metadata)}. * If {@code null} is returned, the built-in mapping rule is used by default. */ com.google.rpc.@Nullable Status applyStatusProto(RequestContext ctx, Throwable throwable, diff --git a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunction.java b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunction.java index c7853f05844..6e9f150e7d0 100644 --- a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunction.java +++ b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunction.java @@ -48,12 +48,14 @@ static GrpcExceptionHandlerFunction of() { } /** - * Maps the specified {@link Throwable} to a gRPC {@link Status}, - * and mutates the specified {@link Metadata}. - * If {@code null} is returned, the built-in mapping rule is used by default. + * Maps the specified {@link Throwable} to a gRPC {@link Status} and mutates the specified {@link Metadata}. + * If {@code null} is returned, {@link #of()} will be used to return {@link Status} as the default. + * + *

    The specified {@link Status} parameter was created via {@link Status#fromThrowable(Throwable)}. + * You can return the {@link Status} or any other {@link Status} as needed. */ @Nullable - Status apply(RequestContext ctx, Throwable cause, Metadata metadata); + Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata); /** * Returns a {@link GrpcExceptionHandlerFunction} that returns the result of this function @@ -63,12 +65,12 @@ static GrpcExceptionHandlerFunction of() { */ default GrpcExceptionHandlerFunction orElse(GrpcExceptionHandlerFunction next) { requireNonNull(next, "next"); - return (ctx, cause, metadata) -> { - final Status status = apply(ctx, cause, metadata); - if (status != null) { - return status; + return (ctx, status, cause, metadata) -> { + final Status newStatus = apply(ctx, status, cause, metadata); + if (newStatus != null) { + return newStatus; } - return next.apply(ctx, cause, metadata); + return next.apply(ctx, status, cause, metadata); }; } } diff --git a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilder.java b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilder.java index bca36bc8066..551f857d0b1 100644 --- a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilder.java +++ b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilder.java @@ -54,7 +54,7 @@ public final class GrpcExceptionHandlerFunctionBuilder { */ public GrpcExceptionHandlerFunctionBuilder on(Class exceptionType, Status status) { requireNonNull(status, "status"); - return on(exceptionType, (ctx, cause, metadata) -> status); + return on(exceptionType, (ctx, unused, cause, metadata) -> status); } /** @@ -66,7 +66,7 @@ public GrpcExceptionHandlerFunctionBuilder on( requireNonNull(exceptionType, "exceptionType"); requireNonNull(exceptionHandler, "exceptionHandler"); //noinspection unchecked - return on(exceptionType, (ctx, cause, metadata) -> exceptionHandler.apply((T) cause, metadata)); + return on(exceptionType, (ctx, status, cause, metadata) -> exceptionHandler.apply((T) cause, metadata)); } /** @@ -107,11 +107,11 @@ public GrpcExceptionHandlerFunction build() { final List, GrpcExceptionHandlerFunction>> mappings = ImmutableList.copyOf(exceptionMappings); - return (ctx, cause, metadata) -> { + return (ctx, status, cause, metadata) -> { for (Map.Entry, GrpcExceptionHandlerFunction> mapping : mappings) { if (mapping.getKey().isInstance(cause)) { - final Status status = mapping.getValue().apply(ctx, cause, metadata); - return status == null ? null : status.withCause(cause); + final Status newStatus = mapping.getValue().apply(ctx, status, cause, metadata); + return newStatus == null ? null : newStatus.withCause(cause); } } return null; diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java index 87aa120d312..a395d37fce7 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java @@ -17,6 +17,8 @@ import static com.linecorp.armeria.internal.client.ClientUtil.initContextAndExecuteWithFallback; import static com.linecorp.armeria.internal.client.grpc.protocol.InternalGrpcWebUtil.messageBuf; +import static com.linecorp.armeria.internal.common.grpc.GrpcExceptionHandlerFunctionUtil.fromThrowable; +import static com.linecorp.armeria.internal.common.grpc.GrpcExceptionHandlerFunctionUtil.generateMetadataFromThrowable; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -248,9 +250,14 @@ public void start(Listener responseListener, Metadata metadata) { prepareHeaders(compressor, metadata, remainingNanos); final BiFunction errorResponseFactory = - (unused, cause) -> HttpResponse.ofFailure(exceptionHandler.apply(ctx, cause, metadata) - .withDescription(cause.getMessage()) - .asRuntimeException()); + (unused, cause) -> { + final Metadata responseMetadata = generateMetadataFromThrowable(cause); + Status status = fromThrowable(ctx, exceptionHandler, cause, responseMetadata); + if (status.getDescription() == null) { + status = status.withDescription(cause.getMessage()); + } + return HttpResponse.ofFailure(status.asRuntimeException()); + }; final HttpResponse res = initContextAndExecuteWithFallback( httpClient, ctx, endpointGroup, HttpResponse::of, errorResponseFactory); @@ -453,8 +460,8 @@ public void onNext(DeframedMessage message) { } }); } catch (Throwable t) { - final Metadata metadata = new Metadata(); - close(exceptionHandler.apply(ctx, t, metadata), metadata); + final Metadata metadata = generateMetadataFromThrowable(t); + close(fromThrowable(ctx, exceptionHandler, t, metadata), metadata); } } @@ -510,8 +517,8 @@ private void prepareHeaders(Compressor compressor, Metadata metadata, long remai } private void closeWhenListenerThrows(Throwable t) { - final Metadata metadata = new Metadata(); - closeWhenEos(exceptionHandler.apply(ctx, t, metadata), metadata); + final Metadata metadata = generateMetadataFromThrowable(t); + closeWhenEos(fromThrowable(ctx, exceptionHandler, t, metadata), metadata); } private void closeWhenEos(Status status, Metadata metadata) { diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/UnwrappingGrpcExceptionHandleFunction.java b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/GrpcExceptionHandlerFunctionUtil.java similarity index 51% rename from grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/UnwrappingGrpcExceptionHandleFunction.java rename to grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/GrpcExceptionHandlerFunctionUtil.java index f8d8996efe0..5a87733d532 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/UnwrappingGrpcExceptionHandleFunction.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/GrpcExceptionHandlerFunctionUtil.java @@ -13,13 +13,11 @@ * License for the specific language governing permissions and limitations * under the License. */ - package com.linecorp.armeria.internal.common.grpc; import static java.util.Objects.requireNonNull; import com.linecorp.armeria.common.RequestContext; -import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; import com.linecorp.armeria.common.grpc.protocol.ArmeriaStatusException; import com.linecorp.armeria.common.util.Exceptions; @@ -27,22 +25,36 @@ import io.grpc.Metadata; import io.grpc.Status; -public final class UnwrappingGrpcExceptionHandleFunction implements GrpcExceptionHandlerFunction { - private final GrpcExceptionHandlerFunction delegate; +public final class GrpcExceptionHandlerFunctionUtil { - public UnwrappingGrpcExceptionHandleFunction(GrpcExceptionHandlerFunction handlerFunction) { - delegate = handlerFunction; + public static Metadata generateMetadataFromThrowable(Throwable exception) { + final Metadata metadata = Status.trailersFromThrowable(peelAndUnwrap(exception)); + return metadata != null ? metadata : new Metadata(); } - @Override - public @Nullable Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { - final Throwable t = peelAndUnwrap(cause); - return delegate.apply(ctx, t, metadata); + public static Status fromThrowable(RequestContext ctx, GrpcExceptionHandlerFunction exceptionHandler, + Throwable t, Metadata metadata) { + final Status status = Status.fromThrowable(peelAndUnwrap(t)); + final Throwable cause = status.getCause(); + if (cause == null) { + return status; + } + return applyExceptionHandler(ctx, exceptionHandler, status, cause, metadata); + } + + public static Status applyExceptionHandler(RequestContext ctx, + GrpcExceptionHandlerFunction exceptionHandler, + Status status, Throwable cause, Metadata metadata) { + final Throwable peeled = peelAndUnwrap(cause); + status = exceptionHandler.apply(ctx, status, peeled, metadata); + assert status != null; + return status; } private static Throwable peelAndUnwrap(Throwable t) { requireNonNull(t, "t"); - Throwable cause = Exceptions.peel(t); + t = Exceptions.peel(t); + Throwable cause = t; while (cause != null) { if (cause instanceof ArmeriaStatusException) { return StatusExceptionConverter.toGrpc((ArmeriaStatusException) cause); @@ -51,4 +63,6 @@ private static Throwable peelAndUnwrap(Throwable t) { } return t; } + + private GrpcExceptionHandlerFunctionUtil() {} } diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframer.java b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframer.java index f8c28fb6c5a..4fb84f701c6 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframer.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframer.java @@ -17,6 +17,8 @@ package com.linecorp.armeria.internal.common.grpc; import static com.google.common.base.Preconditions.checkState; +import static com.linecorp.armeria.internal.common.grpc.GrpcExceptionHandlerFunctionUtil.fromThrowable; +import static com.linecorp.armeria.internal.common.grpc.GrpcExceptionHandlerFunctionUtil.generateMetadataFromThrowable; import static java.util.Objects.requireNonNull; import com.linecorp.armeria.common.HttpHeaderNames; @@ -119,9 +121,9 @@ public void processHeaders(HttpHeaders headers, StreamDecoderOutput void startCall( } } - private void startCall(ServerMethodDefinition methodDef, ServiceRequestContext ctx, - HttpRequest req, MethodDescriptor methodDescriptor, - AbstractServerCall call) { + private static void startCall(ServerMethodDefinition methodDef, ServiceRequestContext ctx, + HttpRequest req, MethodDescriptor methodDescriptor, + AbstractServerCall call) { final Listener listener; final Metadata headers = MetadataUtil.copyFromHeaders(req.headers()); try { @@ -320,9 +326,10 @@ private void startCall(ServerMethodDefinition methodDef, ServiceReq call.setListener(listener); call.startDeframing(); ctx.whenRequestCancelling().handle((cancellationCause, unused) -> { - final Status status = call.exceptionHandler().apply(ctx, cancellationCause, headers); - assert status != null; - call.close(new ServerStatusAndMetadata(status, new Metadata(), true, true)); + final Metadata metadata = generateMetadataFromThrowable(cancellationCause); + call.close(new ServerStatusAndMetadata( + fromThrowable(ctx, call.exceptionHandler(), cancellationCause, metadata), + metadata, true, true)); return null; }); } diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java index 431980bc37d..b7d285b0a9d 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java @@ -51,7 +51,6 @@ import com.linecorp.armeria.common.grpc.GrpcStatusFunction; import com.linecorp.armeria.common.grpc.protocol.AbstractMessageDeframer; import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageFramer; -import com.linecorp.armeria.internal.common.grpc.UnwrappingGrpcExceptionHandleFunction; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.HttpServiceWithRoutes; import com.linecorp.armeria.server.Server; @@ -878,7 +877,8 @@ public GrpcServiceBuilder exceptionHandler(GrpcExceptionHandlerFunction exceptio @Deprecated public GrpcServiceBuilder exceptionMapping(GrpcStatusFunction statusFunction) { requireNonNull(statusFunction, "statusFunction"); - return exceptionHandler(statusFunction::apply); + return exceptionHandler( + (ctx, status, throwable, metadata) -> statusFunction.apply(ctx, throwable, metadata)); } /** @@ -943,7 +943,9 @@ public GrpcServiceBuilder addExceptionMapping(Class excepti checkState(exceptionHandler == null, "addExceptionMapping() and exceptionMapping() are mutually exclusive."); - exceptionMappingsBuilder().on(exceptionType, statusFunction::apply); + exceptionMappingsBuilder().on(exceptionType, + (ctx, status, throwable, metadata) -> + statusFunction.apply(ctx, throwable, metadata)); return this; } @@ -997,7 +999,7 @@ public GrpcService build() { registryBuilder.addService(grpcHealthCheckService.bindService(), null, ImmutableList.of()); } - GrpcExceptionHandlerFunction grpcExceptionHandler; + final GrpcExceptionHandlerFunction grpcExceptionHandler; if (exceptionMappingsBuilder != null) { grpcExceptionHandler = exceptionMappingsBuilder.build().orElse(GrpcExceptionHandlerFunction.of()); } else if (exceptionHandler != null) { @@ -1005,7 +1007,6 @@ public GrpcService build() { } else { grpcExceptionHandler = GrpcExceptionHandlerFunction.of(); } - grpcExceptionHandler = new UnwrappingGrpcExceptionHandleFunction(grpcExceptionHandler); registryBuilder.setDefaultExceptionHandler(grpcExceptionHandler); if (interceptors != null) { diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HandlerRegistry.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HandlerRegistry.java index 47b8d63e9c0..6044ef4c969 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HandlerRegistry.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HandlerRegistry.java @@ -71,7 +71,6 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; import com.linecorp.armeria.internal.common.ReflectiveDependencyInjector; -import com.linecorp.armeria.internal.common.grpc.UnwrappingGrpcExceptionHandleFunction; import com.linecorp.armeria.internal.server.annotation.AnnotationUtil; import com.linecorp.armeria.internal.server.annotation.DecoratorAnnotationUtil; import com.linecorp.armeria.internal.server.annotation.DecoratorAnnotationUtil.DecoratorAndOrder; @@ -282,8 +281,7 @@ private static void putGrpcExceptionHandlerIfPresent( grpcExceptionHandler.ifPresent(exceptionHandler -> { GrpcExceptionHandlerFunction grpcExceptionHandler0 = exceptionHandler; if (defaultExceptionHandler != null) { - grpcExceptionHandler0 = new UnwrappingGrpcExceptionHandleFunction( - exceptionHandler.orElse(defaultExceptionHandler)); + grpcExceptionHandler0 = exceptionHandler.orElse(defaultExceptionHandler); } grpcExceptionHandlersBuilder.put(methodDefinition, grpcExceptionHandler0); }); diff --git a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientBuilderTest.java b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientBuilderTest.java index 15a38c3bf5d..56f144f9377 100644 --- a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientBuilderTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientBuilderTest.java @@ -290,7 +290,7 @@ public O parse(InputStream inputStream) { @Test void useDefaultGrpcExceptionHandlerFunctionAsFallback() { - final GrpcExceptionHandlerFunction noopExceptionHandler = (ctx, cause, metadata) -> null; + final GrpcExceptionHandlerFunction noopExceptionHandler = (ctx, status, cause, metadata) -> null; final GrpcExceptionHandlerFunction exceptionHandler = GrpcExceptionHandlerFunction.builder() .on(ContentTooLargeException.class, noopExceptionHandler) diff --git a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientExceptionHandlerTest.java b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientExceptionHandlerTest.java index eb7b564f908..6858dce322a 100644 --- a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientExceptionHandlerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientExceptionHandlerTest.java @@ -85,15 +85,15 @@ void chaining() { final RuntimeException exception = new RuntimeException(); final TestServiceBlockingStub stub = GrpcClients.builder(server.httpUri()) - .exceptionHandler(((ctx, cause, metadata) -> { + .exceptionHandler(((ctx, status, cause, metadata) -> { stringDeque.add("1"); return null; })) - .exceptionHandler(((ctx, cause, metadata) -> { + .exceptionHandler(((ctx, status, cause, metadata) -> { stringDeque.add("2"); return null; })) - .exceptionHandler(((ctx, cause, metadata) -> { + .exceptionHandler(((ctx, status, cause, metadata) -> { if (cause == exception) { stringDeque.add("3"); return Status.DATA_LOSS; diff --git a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java index fec8017a51b..1c0580317b2 100644 --- a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java @@ -747,7 +747,7 @@ void cancelAfterBegin() throws Exception { responseObserver.awaitCompletion(); assertThat(responseObserver.getValues()).isEmpty(); assertThat(GrpcExceptionHandlerFunction.of() - .apply(null, responseObserver.getError(), null) + .apply(null, Status.UNKNOWN, responseObserver.getError(), null) .getCode()).isEqualTo(Code.CANCELLED); final RequestLog log = requestLogQueue.take(); @@ -783,7 +783,7 @@ void cancelAfterFirstResponse() throws Exception { responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS); assertThat(responseObserver.getValues()).hasSize(1); assertThat(GrpcExceptionHandlerFunction.of() - .apply(null, responseObserver.getError(), null) + .apply(null, Status.UNKNOWN, responseObserver.getError(), null) .getCode()).isEqualTo(Code.CANCELLED); checkRequestLog((rpcReq, rpcRes, grpcStatus) -> { @@ -1418,7 +1418,7 @@ void deadlineExceededServerStreaming() throws Exception { assertThat(recorder.getError()).isNotNull(); assertThat(GrpcExceptionHandlerFunction.of() - .apply(null, recorder.getError(), null) + .apply(null, Status.UNKNOWN, recorder.getError(), null) .getCode()) .isEqualTo(Status.DEADLINE_EXCEEDED.getCode()); @@ -1618,10 +1618,10 @@ void statusCodeAndMessage() throws Exception { verify(responseObserver, timeout(operationTimeoutMillis())).onError(captor.capture()); assertThat(GrpcExceptionHandlerFunction.of() - .apply(null, captor.getValue(), null) + .apply(null, Status.UNKNOWN, captor.getValue(), null) .getCode()).isEqualTo(Status.UNKNOWN.getCode()); assertThat(GrpcExceptionHandlerFunction.of() - .apply(null, captor.getValue(), null) + .apply(null, Status.UNKNOWN, captor.getValue(), null) .getDescription()).isEqualTo(errorMessage); verifyNoMoreInteractions(responseObserver); diff --git a/grpc/src/test/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunctionTest.java b/grpc/src/test/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunctionTest.java index 17a20ceb32f..0287cd53de8 100644 --- a/grpc/src/test/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunctionTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/common/grpc/DefaultGrpcExceptionHandlerFunctionTest.java @@ -33,7 +33,8 @@ class DefaultGrpcExceptionHandlerFunctionTest { void failFastExceptionToUnavailableCode() { assertThat(GrpcExceptionHandlerFunction .of() - .apply(null, new FailFastException(CircuitBreaker.ofDefaultName()), null) + .apply(null, Status.UNKNOWN, new FailFastException(CircuitBreaker.ofDefaultName()), + null) .getCode()).isEqualTo(Status.Code.UNAVAILABLE); } @@ -41,7 +42,8 @@ void failFastExceptionToUnavailableCode() { void invalidProtocolBufferExceptionToInvalidArgumentCode() { assertThat(GrpcExceptionHandlerFunction .of() - .apply(null, new InvalidProtocolBufferException("Failed to parse message"), null) + .apply(null, Status.UNKNOWN, + new InvalidProtocolBufferException("Failed to parse message"), null) .getCode()).isEqualTo(Status.Code.INVALID_ARGUMENT); } } diff --git a/grpc/src/test/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilderTest.java b/grpc/src/test/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilderTest.java index 5fc512e8858..5d9e3b1de27 100644 --- a/grpc/src/test/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilderTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/common/grpc/GrpcExceptionHandlerFunctionBuilderTest.java @@ -47,7 +47,7 @@ void duplicatedExceptionHandlers() { builder.on(A1Exception.class, Status.RESOURCE_EXHAUSTED); assertThatThrownBy(() -> { - builder.on(A1Exception.class, (ctx, throwable, metadata) -> Status.UNIMPLEMENTED); + builder.on(A1Exception.class, (ctx, status, throwable, metadata) -> Status.UNIMPLEMENTED); }).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("is already added with"); @@ -60,26 +60,26 @@ void duplicatedExceptionHandlers() { @Test void sortExceptionHandler() { final GrpcExceptionHandlerFunctionBuilder builder = GrpcExceptionHandlerFunction.builder(); - builder.on(A1Exception.class, (ctx, throwable, metadata) -> Status.RESOURCE_EXHAUSTED); - builder.on(A2Exception.class, (ctx, throwable, metadata) -> Status.UNIMPLEMENTED); + builder.on(A1Exception.class, (ctx, status, throwable, metadata) -> Status.RESOURCE_EXHAUSTED); + builder.on(A2Exception.class, (ctx, status, throwable, metadata) -> Status.UNIMPLEMENTED); assertThat(builder.exceptionMappings.stream().map(it -> (Class) it.getKey())) .containsExactly(A2Exception.class, A1Exception.class); - builder.on(B1Exception.class, (ctx, throwable, metadata) -> Status.UNAUTHENTICATED); + builder.on(B1Exception.class, (ctx, status, throwable, metadata) -> Status.UNAUTHENTICATED); assertThat(builder.exceptionMappings.stream().map(it -> (Class) it.getKey())) .containsExactly(A2Exception.class, A1Exception.class, B1Exception.class); - builder.on(A3Exception.class, (ctx, throwable, metadata) -> Status.UNAUTHENTICATED); + builder.on(A3Exception.class, (ctx, status, throwable, metadata) -> Status.UNAUTHENTICATED); assertThat(builder.exceptionMappings.stream().map(it -> (Class) it.getKey())) .containsExactly(A3Exception.class, A2Exception.class, A1Exception.class, B1Exception.class); - builder.on(B2Exception.class, (ctx, throwable, metadata) -> Status.NOT_FOUND); + builder.on(B2Exception.class, (ctx, status, throwable, metadata) -> Status.NOT_FOUND); assertThat(builder.exceptionMappings.stream().map(it -> (Class) it.getKey())) .containsExactly(A3Exception.class, A2Exception.class, @@ -89,19 +89,19 @@ void sortExceptionHandler() { final GrpcExceptionHandlerFunction exceptionHandler = builder.build().orElse( GrpcExceptionHandlerFunction.of()); - Status status = exceptionHandler.apply(ctx, new A3Exception(), new Metadata()); + Status status = exceptionHandler.apply(ctx, Status.UNKNOWN, new A3Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.UNAUTHENTICATED); - status = exceptionHandler.apply(ctx, new A2Exception(), new Metadata()); + status = exceptionHandler.apply(ctx, Status.UNKNOWN, new A2Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.UNIMPLEMENTED); - status = exceptionHandler.apply(ctx, new A1Exception(), new Metadata()); + status = exceptionHandler.apply(ctx, Status.UNKNOWN, new A1Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.RESOURCE_EXHAUSTED); - status = exceptionHandler.apply(ctx, new B2Exception(), new Metadata()); + status = exceptionHandler.apply(ctx, Status.UNKNOWN, new B2Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.NOT_FOUND); - status = exceptionHandler.apply(ctx, new B1Exception(), new Metadata()); + status = exceptionHandler.apply(ctx, Status.UNKNOWN, new B1Exception(), new Metadata()); assertThat(status.getCode()).isEqualTo(Code.UNAUTHENTICATED); } @@ -110,13 +110,13 @@ void mapStatus() { final GrpcExceptionHandlerFunction exceptionHandler = GrpcExceptionHandlerFunction .builder() - .on(A2Exception.class, (ctx, throwable, metadata) -> Status.PERMISSION_DENIED) - .on(A1Exception.class, (ctx1, cause, metadata) -> Status.DEADLINE_EXCEEDED) + .on(A2Exception.class, (ctx, status, throwable, metadata) -> Status.PERMISSION_DENIED) + .on(A1Exception.class, (ctx1, status, cause, metadata) -> Status.DEADLINE_EXCEEDED) .build(); for (Throwable ex : ImmutableList.of(new A2Exception(), new A3Exception())) { final Metadata metadata = new Metadata(); - final Status newStatus = exceptionHandler.apply(ctx, ex, metadata); + final Status newStatus = exceptionHandler.apply(ctx, Status.UNKNOWN, ex, metadata); assertThat(newStatus.getCode()).isEqualTo(Code.PERMISSION_DENIED); assertThat(newStatus.getCause()).isEqualTo(ex); assertThat(metadata.keys()).isEmpty(); @@ -124,7 +124,7 @@ void mapStatus() { final A1Exception cause = new A1Exception(); final Metadata metadata = new Metadata(); - final Status newStatus = exceptionHandler.apply(ctx, cause, metadata); + final Status newStatus = exceptionHandler.apply(ctx, Status.UNKNOWN, cause, metadata); assertThat(newStatus.getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); assertThat(newStatus.getCause()).isEqualTo(cause); @@ -136,7 +136,7 @@ void mapStatusAndMetadata() { final GrpcExceptionHandlerFunction exceptionHandler = GrpcExceptionHandlerFunction .builder() - .on(B1Exception.class, (ctx, throwable, metadata) -> { + .on(B1Exception.class, (ctx, status, throwable, metadata) -> { metadata.put(TEST_KEY, throwable.getClass().getSimpleName()); return Status.ABORTED; }) @@ -144,14 +144,14 @@ void mapStatusAndMetadata() { final B1Exception cause = new B1Exception(); final Metadata metadata1 = new Metadata(); - final Status newStatus1 = exceptionHandler.apply(ctx, cause, metadata1); + final Status newStatus1 = exceptionHandler.apply(ctx, Status.UNKNOWN, cause, metadata1); assertThat(newStatus1.getCode()).isEqualTo(Code.ABORTED); assertThat(metadata1.get(TEST_KEY)).isEqualTo("B1Exception"); assertThat(metadata1.keys()).containsOnly(TEST_KEY.name()); final Metadata metadata2 = new Metadata(); metadata2.put(TEST_KEY2, "test"); - final Status newStatus2 = exceptionHandler.apply(ctx, cause, metadata2); + final Status newStatus2 = exceptionHandler.apply(ctx, Status.UNKNOWN, cause, metadata2); assertThat(newStatus2.getCode()).isEqualTo(Code.ABORTED); assertThat(metadata2.get(TEST_KEY)).isEqualTo("B1Exception"); diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframerTest.java b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframerTest.java index d3ad2170874..5d68db290d4 100644 --- a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/HttpStreamDeframerTest.java @@ -58,8 +58,7 @@ void setUp() { final ServiceRequestContext ctx = ServiceRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); final TransportStatusListener statusListener = (status, metadata) -> statusRef.set(status); deframer = new HttpStreamDeframer(DecompressorRegistry.getDefaultInstance(), ctx, statusListener, - new UnwrappingGrpcExceptionHandleFunction( - GrpcExceptionHandlerFunction.of()), Integer.MAX_VALUE, + GrpcExceptionHandlerFunction.of(), Integer.MAX_VALUE, false, true); } diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/TestServiceImpl.java b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/TestServiceImpl.java index 889ad6f781f..f4cd411c314 100644 --- a/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/TestServiceImpl.java +++ b/grpc/src/test/java/com/linecorp/armeria/internal/common/grpc/TestServiceImpl.java @@ -367,7 +367,8 @@ private synchronized void dispatchChunk() { } catch (Throwable e) { failure = e; if (GrpcExceptionHandlerFunction.of() - .apply(ServiceRequestContext.current(), e, new Metadata()) + .apply(ServiceRequestContext.current(), Status.UNKNOWN, + e, new Metadata()) .getCode() == Status.CANCELLED.getCode()) { // Stream was cancelled by client, responseStream.onError() might be called already or // will be called soon by inbounding StreamObserver. diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/AsyncServerInterceptorTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/AsyncServerInterceptorTest.java index 67c499e9627..8e779eaf380 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/AsyncServerInterceptorTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/AsyncServerInterceptorTest.java @@ -55,7 +55,7 @@ class AsyncServerInterceptorTest { static ServerExtension server = new ServerExtension() { @Override protected void configure(ServerBuilder sb) { - final GrpcExceptionHandlerFunction exceptionHandler = (ctx, throwable, metadata) -> { + final GrpcExceptionHandlerFunction exceptionHandler = (ctx, status, throwable, metadata) -> { exceptionCounter.getAndIncrement(); if (throwable instanceof AnticipatedException && "Invalid access".equals(throwable.getMessage())) { diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerAnnotationOnlyTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerAnnotationOnlyTest.java index 558a34405af..fe6394647ec 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerAnnotationOnlyTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerAnnotationOnlyTest.java @@ -132,8 +132,9 @@ void exceptionHandler() { private static class FirstGrpcExceptionHandler implements GrpcExceptionHandlerFunction { + @Nullable @Override - public @Nullable Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata) { exceptionHandler.add("first"); if (Objects.equals(cause.getMessage(), "first")) { return Status.UNAUTHENTICATED; @@ -144,8 +145,9 @@ private static class FirstGrpcExceptionHandler implements GrpcExceptionHandlerFu private static class SecondGrpcExceptionHandler implements GrpcExceptionHandlerFunction { + @Nullable @Override - public @Nullable Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata) { exceptionHandler.add("second"); if (Objects.equals(cause.getMessage(), "second")) { return Status.INVALID_ARGUMENT; diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerFunctionUtilTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerFunctionUtilTest.java new file mode 100644 index 00000000000..5d4e594799e --- /dev/null +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerFunctionUtilTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.grpc; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.CompletionException; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import com.linecorp.armeria.client.grpc.GrpcClients; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.StreamObserver; +import testing.grpc.Messages.SimpleRequest; +import testing.grpc.Messages.SimpleRequest.NestedRequest; +import testing.grpc.Messages.SimpleResponse; +import testing.grpc.TestServiceGrpc.TestServiceBlockingStub; +import testing.grpc.TestServiceGrpc.TestServiceImplBase; + +class GrpcExceptionHandlerFunctionUtilTest { + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + sb.service(GrpcService.builder() + .addService(new TestServiceImpl()) + .exceptionHandler((ctx, status, throwable, metadata) -> { + assertThat(throwable).isInstanceOf(RuntimeException.class); + return Status.INTERNAL; + }) + .build()); + } + }; + + @CsvSource({ "onError", "throw", "onErrorStatus", "throwStatus" }) + @ParameterizedTest + void classAndMethodHaveMultipleExceptionHandlers(String exceptionType) { + final TestServiceBlockingStub client = + GrpcClients.newClient(server.httpUri(), TestServiceBlockingStub.class); + + final SimpleRequest globalRequest = + SimpleRequest.newBuilder() + .setNestedRequest(NestedRequest.newBuilder().setNestedPayload(exceptionType) + .build()) + .build(); + assertThatThrownBy(() -> client.unaryCall(globalRequest)) + .isInstanceOfSatisfying(StatusRuntimeException.class, + e -> assertThat(e.getStatus()).isEqualTo(Status.INTERNAL)); + } + + private static class TestServiceImpl extends TestServiceImplBase { + + @Override + public void unaryCall(SimpleRequest request, StreamObserver responseObserver) { + final CompletionException exception = new CompletionException(new RuntimeException()); + switch (request.getNestedRequest().getNestedPayload()) { + case "onError": + responseObserver.onError(exception); + break; + case "throw": + throw exception; + case "onErrorStatus": + responseObserver.onError(Status.INTERNAL.withCause(exception).asRuntimeException()); + break; + case "throwStatus": + throw Status.INTERNAL.withCause(exception).asRuntimeException(); + default: + throw new IllegalArgumentException("unknown payload"); + } + } + } +} diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerTest.java index 71cd9ef3428..18f85dbf5ed 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcExceptionHandlerTest.java @@ -74,7 +74,7 @@ protected void configure(ServerBuilder sb) throws Exception { .addService("/foo", new FooTestServiceImpl()) .addService("/bar", new BarTestServiceImpl(), TestServiceGrpc.getUnaryCallMethod()) - .exceptionHandler((ctx, throwable, metadata) -> { + .exceptionHandler((ctx, status, throwable, metadata) -> { exceptionHandler.add("global"); return Status.INTERNAL; }) @@ -88,7 +88,7 @@ protected void configure(ServerBuilder sb) throws Exception { protected void configure(ServerBuilder sb) throws Exception { sb.requestTimeoutMillis(5000) .service(GrpcService.builder() - .addService(new TestServiceIOException()) + .addService(new ErrorTestServiceImpl()) .build()); } }; @@ -459,12 +459,18 @@ void defaultGrpcExceptionHandlerConvertIOExceptionToUnavailable() { .isInstanceOfSatisfying(StatusRuntimeException.class, e -> { assertThat(e.getStatus().getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); }); + assertThatThrownBy(() -> client.unaryCall2(globalRequest)) + .isInstanceOfSatisfying(StatusRuntimeException.class, e -> { + assertThat(e.getStatus().getCode()).isEqualTo(Status.INVALID_ARGUMENT.getCode()); + assertThat(e.getStatus().getCause().getMessage()).contains("IllegalArgumentException"); + }); } private static class FirstGrpcExceptionHandler implements GrpcExceptionHandlerFunction { + @Nullable @Override - public @Nullable Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata) { exceptionHandler.add("first"); if (Objects.equals(cause.getMessage(), "first")) { return Status.UNAUTHENTICATED; @@ -475,8 +481,9 @@ private static class FirstGrpcExceptionHandler implements GrpcExceptionHandlerFu private static class SecondGrpcExceptionHandler implements GrpcExceptionHandlerFunction { + @Nullable @Override - public @Nullable Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata) { exceptionHandler.add("second"); if (Objects.equals(cause.getMessage(), "second")) { return Status.INVALID_ARGUMENT; @@ -487,8 +494,9 @@ private static class SecondGrpcExceptionHandler implements GrpcExceptionHandlerF private static class ThirdGrpcExceptionHandler implements GrpcExceptionHandlerFunction { + @Nullable @Override - public @Nullable Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata) { exceptionHandler.add("third"); if (Objects.equals(cause.getMessage(), "third")) { return Status.NOT_FOUND; @@ -499,8 +507,9 @@ private static class ThirdGrpcExceptionHandler implements GrpcExceptionHandlerFu private static class ForthGrpcExceptionHandler implements GrpcExceptionHandlerFunction { + @Nullable @Override - public @Nullable Status apply(RequestContext ctx, Throwable cause, Metadata metadata) { + public Status apply(RequestContext ctx, Status status, Throwable cause, Metadata metadata) { exceptionHandler.add("forth"); if (Objects.equals(cause.getMessage(), "forth")) { return Status.UNAVAILABLE; @@ -594,11 +603,16 @@ public void unaryCall(SimpleRequest request, StreamObserver resp } } - // TestServiceIOException has DefaultGRPCExceptionHandlerFunction as fallback exception handler - private static class TestServiceIOException extends TestServiceImpl { + private static class ErrorTestServiceImpl extends TestServiceImpl { @Override public void unaryCall(SimpleRequest request, StreamObserver responseObserver) { responseObserver.onError(new IOException()); } + + @Override + public void unaryCall2(SimpleRequest request, StreamObserver responseObserver) { + responseObserver.onError(new StatusRuntimeException(Status.INVALID_ARGUMENT.withCause( + new IllegalArgumentException()))); + } } } diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilderTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilderTest.java index 18e10472f0e..fa28f2494ef 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilderTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilderTest.java @@ -132,13 +132,13 @@ void mixExceptionMappingAndGrpcExceptionHandlerFunctions() { assertThatThrownBy(() -> GrpcService.builder() .addExceptionMapping(A1Exception.class, Status.RESOURCE_EXHAUSTED) .exceptionHandler( - (ctx, cause, metadata) -> Status.PERMISSION_DENIED)) + (ctx, status, cause, metadata) -> Status.PERMISSION_DENIED)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining("addExceptionMapping() and exceptionHandler() are mutually exclusive."); assertThatThrownBy(() -> GrpcService.builder() .exceptionHandler( - (ctx, cause, metadata) -> Status.PERMISSION_DENIED) + (ctx, status, cause, metadata) -> Status.PERMISSION_DENIED) .addExceptionMapping(A1Exception.class, Status.RESOURCE_EXHAUSTED)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining("addExceptionMapping() and exceptionHandler() are mutually exclusive."); diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcStatusMappingTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcStatusMappingTest.java index d10ebdb9f71..486183a4f6a 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcStatusMappingTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcStatusMappingTest.java @@ -102,7 +102,7 @@ protected void configure(ServerBuilder sb) throws Exception { sb.service( GrpcService.builder() .addService(new TestServiceImpl()) - .exceptionHandler((ctx, cause, metadata) -> { + .exceptionHandler((ctx, status, cause, metadata) -> { final String attr = ctx.attr(METHOD_ATTR); if (attr != null) { metadata.put(METHOD_KEY, attr); diff --git a/it/grpc/kotlin/src/test/kotlin/com/linecorp/armeria/grpc/kotlin/TestServiceTest.kt b/it/grpc/kotlin/src/test/kotlin/com/linecorp/armeria/grpc/kotlin/TestServiceTest.kt index c93c2f553fd..942b9e6b115 100644 --- a/it/grpc/kotlin/src/test/kotlin/com/linecorp/armeria/grpc/kotlin/TestServiceTest.kt +++ b/it/grpc/kotlin/src/test/kotlin/com/linecorp/armeria/grpc/kotlin/TestServiceTest.kt @@ -206,7 +206,7 @@ class TestServiceTest { .service( GrpcService.builder() .addService(TestServiceImpl()) - .exceptionHandler { _, throwable, _ -> + .exceptionHandler { _, _, throwable, _ -> when (throwable) { is AuthError -> { Status.UNAUTHENTICATED diff --git a/it/grpc/reactor/src/test/java/com/linecorp/armeria/grpc/reactor/TestServiceTest.java b/it/grpc/reactor/src/test/java/com/linecorp/armeria/grpc/reactor/TestServiceTest.java index e4b4df146ce..c6b5af5d00c 100644 --- a/it/grpc/reactor/src/test/java/com/linecorp/armeria/grpc/reactor/TestServiceTest.java +++ b/it/grpc/reactor/src/test/java/com/linecorp/armeria/grpc/reactor/TestServiceTest.java @@ -60,7 +60,7 @@ private static Server newServer(int httpPort) { final HttpServiceWithRoutes grpcService = GrpcService.builder() .addService(new TestServiceImpl()) - .exceptionHandler((ctx, throwable, metadata) -> { + .exceptionHandler((ctx, status, throwable, metadata) -> { if (throwable instanceof TestServiceImpl.AuthException) { return Status.UNAUTHENTICATED.withDescription(throwable.getMessage()) .withCause(throwable); diff --git a/it/grpc/scala/src/test/scala/com/linecorp/armeria/grpc/scala/TestServiceTest.scala b/it/grpc/scala/src/test/scala/com/linecorp/armeria/grpc/scala/TestServiceTest.scala index 5f6e6481d3b..dc28b216d6d 100644 --- a/it/grpc/scala/src/test/scala/com/linecorp/armeria/grpc/scala/TestServiceTest.scala +++ b/it/grpc/scala/src/test/scala/com/linecorp/armeria/grpc/scala/TestServiceTest.scala @@ -88,7 +88,7 @@ object TestServiceTest { .builder() .addService(TestServiceGrpc.bindService(new TestServiceImpl, ExecutionContext.global)) .exceptionHandler { - case (_, e: AuthError, _) => + case (_, _, e: AuthError, _) => Status.UNAUTHENTICATED.withDescription(e.getMessage).withCause(e) case _ => null } diff --git a/junit4/src/main/java/com/linecorp/armeria/testing/junit4/server/ServerRule.java b/junit4/src/main/java/com/linecorp/armeria/testing/junit4/server/ServerRule.java index 778186f9962..3eb826ba27c 100644 --- a/junit4/src/main/java/com/linecorp/armeria/testing/junit4/server/ServerRule.java +++ b/junit4/src/main/java/com/linecorp/armeria/testing/junit4/server/ServerRule.java @@ -31,6 +31,8 @@ import com.linecorp.armeria.client.RestClient; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.WebClientBuilder; +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.UnstableApi; @@ -68,6 +70,12 @@ public void configure(ServerBuilder sb) throws Exception { public void configureWebClient(WebClientBuilder wcb) throws Exception { ServerRule.this.configureWebClient(wcb); } + + @Override + public void configureWebSocketClient(WebSocketClientBuilder wscb) + throws Exception { + ServerRule.this.configureWebSocketClient(wscb); + } }; } @@ -109,6 +117,12 @@ public Server start() { */ protected void configureWebClient(WebClientBuilder webClientBuilder) throws Exception {} + /** + * Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}. + * You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}. + */ + protected void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder) throws Exception {} + /** * Stops the {@link Server} asynchronously. * @@ -344,4 +358,13 @@ public RestClient restClient(Consumer webClientCustomizer) { requireNonNull(webClientCustomizer, "webClientCustomizer"); return delegate.restClient(webClientCustomizer); } + + /** + * Returns the {@link WebSocketClient} configured + * by {@link #configureWebSocketClient(WebSocketClientBuilder)}. + */ + @UnstableApi + public WebSocketClient webSocketClient() { + return delegate.webSocketClient(); + } } diff --git a/junit5/src/main/java/com/linecorp/armeria/internal/testing/ServerRuleDelegate.java b/junit5/src/main/java/com/linecorp/armeria/internal/testing/ServerRuleDelegate.java index 0ea684d795e..3b6e661d309 100644 --- a/junit5/src/main/java/com/linecorp/armeria/internal/testing/ServerRuleDelegate.java +++ b/junit5/src/main/java/com/linecorp/armeria/internal/testing/ServerRuleDelegate.java @@ -32,6 +32,8 @@ import com.linecorp.armeria.client.RestClient; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.WebClientBuilder; +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; @@ -49,6 +51,7 @@ public abstract class ServerRuleDelegate { private final boolean autoStart; private final AtomicReference webClient = new AtomicReference<>(); + private final AtomicReference webSocketClient = new AtomicReference<>(); /** * Creates a new instance. @@ -114,6 +117,14 @@ public Server start() { */ public abstract void configureWebClient(WebClientBuilder webClientBuilder) throws Exception; + /** + * Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}. + * You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}. + */ + @UnstableApi + public abstract void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder) + throws Exception; + /** * Stops the {@link Server} asynchronously. * @@ -404,6 +415,25 @@ public RestClient restClient(Consumer webClientCustomizer) { return webClient(webClientCustomizer).asRestClient(); } + /** + * Returns the {@link WebSocketClient} configured + * by {@link #configureWebSocketClient(WebSocketClientBuilder)}. + */ + @UnstableApi + public WebSocketClient webSocketClient() { + final WebSocketClient webSocketClient = this.webSocketClient.get(); + if (webSocketClient != null) { + return webSocketClient; + } + + final WebSocketClient newWebSocketClient = webSocketClientBuilder().build(); + if (this.webSocketClient.compareAndSet(null, newWebSocketClient)) { + return newWebSocketClient; + } else { + return this.webSocketClient.get(); + } + } + private void ensureStarted() { // This will ensure that the server has started. server(); @@ -422,4 +452,21 @@ private WebClientBuilder webClientBuilder() { } return webClientBuilder; } + + private WebSocketClientBuilder webSocketClientBuilder() { + final boolean hasHttps = hasHttps(); + final String hostAndPort = hasHttps ? "wss://" + httpsUri().getAuthority() + : "ws://" + httpUri().getAuthority(); + final WebSocketClientBuilder webSocketClientBuilder = WebSocketClient.builder(hostAndPort); + if (hasHttps) { + webSocketClientBuilder.factory(ClientFactory.insecure()); + } + + try { + configureWebSocketClient(webSocketClientBuilder); + } catch (Exception e) { + throw new IllegalStateException("failed to configure a WebSocketClient", e); + } + return webSocketClientBuilder; + } } diff --git a/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/ServerExtension.java b/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/ServerExtension.java index 8ba44a31911..1c44cdf221a 100644 --- a/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/ServerExtension.java +++ b/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/ServerExtension.java @@ -31,6 +31,8 @@ import com.linecorp.armeria.client.RestClient; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.WebClientBuilder; +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.UnstableApi; @@ -75,6 +77,12 @@ public void configure(ServerBuilder sb) throws Exception { public void configureWebClient(WebClientBuilder wcb) throws Exception { ServerExtension.this.configureWebClient(wcb); } + + @Override + public void configureWebSocketClient(WebSocketClientBuilder wscb) + throws Exception { + ServerExtension.this.configureWebSocketClient(wscb); + } }; } @@ -126,6 +134,12 @@ public Server start() { */ protected void configureWebClient(WebClientBuilder webClientBuilder) throws Exception {} + /** + * Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}. + * You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}. + */ + protected void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder) throws Exception {} + /** * Stops the {@link Server} asynchronously. * @@ -370,6 +384,15 @@ public RestClient restClient(Consumer webClientCustomizer) { return delegate.restClient(webClientCustomizer); } + /** + * Returns the {@link WebSocketClient} configured + * by {@link #configureWebSocketClient(WebSocketClientBuilder)}. + */ + @UnstableApi + public WebSocketClient webSocketClient() { + return delegate.webSocketClient(); + } + /** * Determines whether the {@link ServiceRequestContext} should be captured or not. * This method returns {@code true} by default. Override it to capture the contexts diff --git a/junit5/src/test/java/com/linecorp/armeria/testing/junit5/server/ServerExtensionWithWebSocketClientTest.java b/junit5/src/test/java/com/linecorp/armeria/testing/junit5/server/ServerExtensionWithWebSocketClientTest.java new file mode 100644 index 00000000000..2ee442877d6 --- /dev/null +++ b/junit5/src/test/java/com/linecorp/armeria/testing/junit5/server/ServerExtensionWithWebSocketClientTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.testing.junit5.server; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketSession; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.websocket.WebSocketService; +import com.linecorp.armeria.server.websocket.WebSocketServiceHandler; + +class ServerExtensionWithWebSocketClientTest { + + @RegisterExtension + static ServerExtension wsServer = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.service("/chat", WebSocketService.builder(new WebSocketEchoHandler()) + .build()); + } + }; + + @RegisterExtension + static ServerExtension wssServer = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.tlsSelfSigned(); + sb.service("/chat", WebSocketService.builder(new WebSocketEchoHandler()) + .build()); + } + }; + + @CsvSource({ "true", "false" }) + @ParameterizedTest + void webSocketClient(boolean useTls) { + final WebSocketClient webSocketClient = useTls ? wssServer.webSocketClient() + : wsServer.webSocketClient(); + final WebSocketSession wsSession = webSocketClient.connect("/chat").join(); + assertThat(wsSession).isNotNull(); + final WebSocketWriter outbound = wsSession.outbound(); + outbound.write("hello"); + final String message = useTls ? "wss" : "ws"; + outbound.write(message); + outbound.close(); + final List responses = wsSession.inbound().collect().join().stream().map(WebSocketFrame::text) + .collect(toImmutableList()); + assertThat(responses).contains("hello", message); + } + + static final class WebSocketEchoHandler implements WebSocketServiceHandler { + + @Override + public WebSocket handle(ServiceRequestContext ctx, WebSocket in) { + final WebSocketWriter writer = WebSocket.streaming(); + in.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + writer.write(webSocketFrame); + } + + @Override + public void onError(Throwable t) { + writer.close(t); + } + + @Override + public void onComplete() { + writer.close(); + } + }); + return writer; + } + } +} diff --git a/native-image-config/build.gradle.kts b/native-image-config/build.gradle.kts index 4e477ba5609..5a9c31188b8 100644 --- a/native-image-config/build.gradle.kts +++ b/native-image-config/build.gradle.kts @@ -12,9 +12,9 @@ buildscript { dependencies { // TODO(trustin): Use platform(libs.boms.jackson) once Gradle supports platform() for build dependencies. // https://github.com/gradle/gradle/issues/21788 - val jacksonDatabind = libs.jackson.databind.get() + val jacksonDatabind: ModuleVersionSelector = libs.jackson.databind.get() classpath( - group = jacksonDatabind.group!!, + group = jacksonDatabind.group, name = jacksonDatabind.name, version = libs.boms.jackson.get().version ) @@ -29,12 +29,10 @@ plugins { // If `-Pscratch` is specified, do not source the previously generated config at core/src/main/resources/META-INF/native-image // otherwise, the previously generated config will be merged into the newly generated config. val shouldGenerateFromScratch = project.findProperty("scratch").let { - if (it == null) { - false - } else if (it == "") { - true - } else { - throw IllegalArgumentException("-Pscratch option must be specified without any value.") + when (it) { + null -> false + "" -> true + else -> throw IllegalArgumentException("-Pscratch option must be specified without any value.") } } @@ -50,6 +48,7 @@ val nativeImageConfigToolPath = "${graalHome.resolve("lib/svm/bin/native-image-c val thisProject = project val callerFilterFile = projectDir.resolve("src/trace-filters/caller-filter.json") +val buildDir = layout.buildDirectory.asFile.get() val processNativeImageTracesOutputDir = buildDir.resolve("step-1-process-native-image-traces") val simplifyNativeImageConfigOutputDir = buildDir.resolve("step-2-simplify-native-image-config") val nativeImageConfigOutputDir = buildDir.resolve("step-3-final-native-image-config") diff --git a/oauth2/src/main/java/com/linecorp/armeria/internal/common/auth/oauth2/OAuth2Endpoint.java b/oauth2/src/main/java/com/linecorp/armeria/internal/common/auth/oauth2/OAuth2Endpoint.java index 4873c6ac063..73b96317d67 100644 --- a/oauth2/src/main/java/com/linecorp/armeria/internal/common/auth/oauth2/OAuth2Endpoint.java +++ b/oauth2/src/main/java/com/linecorp/armeria/internal/common/auth/oauth2/OAuth2Endpoint.java @@ -18,7 +18,9 @@ import java.util.concurrent.CompletableFuture; +import com.linecorp.armeria.client.RequestOptions; import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.common.ExchangeType; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.QueryParams; import com.linecorp.armeria.common.auth.oauth2.OAuth2Request; @@ -29,6 +31,11 @@ */ public final class OAuth2Endpoint { + private static final RequestOptions UNARY_REQUEST_OPTIONS = + RequestOptions.builder() + .exchangeType(ExchangeType.UNARY) + .build(); + private final WebClient endpoint; private final String endpointPath; private final OAuth2ResponseHandler responseHandler; @@ -43,7 +50,7 @@ public OAuth2Endpoint(WebClient endpoint, String endpointPath, public CompletableFuture execute(OAuth2Request oAuth2Request) { final HttpRequest request = oAuth2Request.asHttpRequest(endpointPath); final QueryParams requestParams = oAuth2Request.bodyParams(); - return endpoint.execute(request) + return endpoint.execute(request, UNARY_REQUEST_OPTIONS) .aggregate() .thenApply(response -> responseHandler.handle(response, requestParams)); } diff --git a/oauth2/src/test/java/com/linecorp/armeria/common/auth/oauth2/GrantedOAuth2AccessTokenTest.java b/oauth2/src/test/java/com/linecorp/armeria/common/auth/oauth2/GrantedOAuth2AccessTokenTest.java index 5b378bdb84d..91627b48238 100644 --- a/oauth2/src/test/java/com/linecorp/armeria/common/auth/oauth2/GrantedOAuth2AccessTokenTest.java +++ b/oauth2/src/test/java/com/linecorp/armeria/common/auth/oauth2/GrantedOAuth2AccessTokenTest.java @@ -173,7 +173,6 @@ void testToString() throws Exception { .scope(scope) .build(); - System.out.println(token); assertThat(token.toString()).isEqualTo(toString); } diff --git a/oauth2/src/test/java/com/linecorp/armeria/common/auth/oauth2/OAuth2TokenDescriptorTest.java b/oauth2/src/test/java/com/linecorp/armeria/common/auth/oauth2/OAuth2TokenDescriptorTest.java index 9152f2f02b6..0d0f35bbed1 100644 --- a/oauth2/src/test/java/com/linecorp/armeria/common/auth/oauth2/OAuth2TokenDescriptorTest.java +++ b/oauth2/src/test/java/com/linecorp/armeria/common/auth/oauth2/OAuth2TokenDescriptorTest.java @@ -197,8 +197,6 @@ void testToString() throws Exception { .extras(extras) .scope(scope) .build(); - - System.out.println(descriptor); assertThat(descriptor.toString()).isEqualTo(rawResponse); } diff --git a/oauth2/src/test/java/com/linecorp/armeria/internal/common/auth/oauth2/TokenRevocationRequestTest.java b/oauth2/src/test/java/com/linecorp/armeria/internal/common/auth/oauth2/TokenRevocationRequestTest.java index 6feb16288d8..3706e49158b 100644 --- a/oauth2/src/test/java/com/linecorp/armeria/internal/common/auth/oauth2/TokenRevocationRequestTest.java +++ b/oauth2/src/test/java/com/linecorp/armeria/internal/common/auth/oauth2/TokenRevocationRequestTest.java @@ -69,7 +69,6 @@ public void testRevoke() throws Exception { requestHeaders1, "token=" + token.grantedToken().accessToken() + "&token_type_hint=access_token").aggregate().join(); assertThat(response1.status()).isEqualTo(HttpStatus.OK); - System.out.println(response1.contentUtf8()); assertThat(response1.contentUtf8()).isEqualTo(HttpStatus.OK.toString()); final RequestHeaders requestHeaders2 = RequestHeaders.of( diff --git a/resteasy/src/test/java/com/linecorp/armeria/server/resteasy/BookServiceClientServerTest.java b/resteasy/src/test/java/com/linecorp/armeria/server/resteasy/BookServiceClientServerTest.java index eaf5c23aaf3..993ccb77637 100644 --- a/resteasy/src/test/java/com/linecorp/armeria/server/resteasy/BookServiceClientServerTest.java +++ b/resteasy/src/test/java/com/linecorp/armeria/server/resteasy/BookServiceClientServerTest.java @@ -112,7 +112,6 @@ void testBooks() throws Exception { assertThatThrownBy(getBooks::hasEntity) .isInstanceOf(IllegalStateException.class) .hasMessage("RESTEASY003765: Response is closed."); - System.out.println(getBooksEntry); assertThat(getBooksEntry).contains("John Doe"); assertThat(getBooksEntry).contains("Java"); final Book[] getBooksEntryArray = JSON.readValue(getBooksEntry, Book[].class); @@ -134,7 +133,6 @@ void testBooks() throws Exception { .isInstanceOf(IllegalStateException.class) .hasMessage("RESTEASY003765: Response is closed."); final String getBooksEntry2 = JSON.writeValueAsString(getBooksEntryArray2); - System.out.println(getBooksEntry2); assertThat(getBooksEntry2).isEqualTo(getBooksEntry); final Response getAllBooks = webTarget.path(booksPath) @@ -150,7 +148,6 @@ void testBooks() throws Exception { assertThat(getAllBooksEntryArray[0]).isInstanceOf(Book.class); assertThat(getAllBooksEntryArray).contains(getBooksEntryArray2); final String getAllBooksEntry = JSON.writeValueAsString(getBooksEntryArray2); - System.out.println(getAllBooksEntry); final String getBookPath = "/resteasy/app/books/978-3-16-148410-0"; final Response getBook = webTarget.path(getBookPath) diff --git a/sangria/sangria_2.13/src/test/scala/com/linecorp/armeria/server/sangria/ServerSuite.scala b/sangria/sangria_2.13/src/test/scala/com/linecorp/armeria/server/sangria/ServerSuite.scala index fd5139128df..87f192d696f 100644 --- a/sangria/sangria_2.13/src/test/scala/com/linecorp/armeria/server/sangria/ServerSuite.scala +++ b/sangria/sangria_2.13/src/test/scala/com/linecorp/armeria/server/sangria/ServerSuite.scala @@ -16,8 +16,8 @@ package com.linecorp.armeria.server.sangria -import com.linecorp.armeria.client.{WebClient, WebClientBuilder} -import com.linecorp.armeria.client.logging.LoggingClient +import com.linecorp.armeria.client.WebClientBuilder +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder import com.linecorp.armeria.internal.testing.ServerRuleDelegate import com.linecorp.armeria.server.ServerBuilder import munit.Suite @@ -33,6 +33,8 @@ trait ServerSuite { protected def configureWebClient: WebClientBuilder => Unit = _ => () + protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => () + protected def server: ServerRuleDelegate = delegate /** @@ -46,6 +48,9 @@ trait ServerSuite { override def configure(sb: ServerBuilder): Unit = configureServer(sb) override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb) + + override def configureWebSocketClient(webSocketClientBuilder: WebSocketClientBuilder): Unit = + self.configureWebSocketClient(webSocketClientBuilder) } if (!runServerForEachTest) { diff --git a/scala/scala_2.13/src/test/scala/com/linecorp/armeria/server/ServerSuite.scala b/scala/scala_2.13/src/test/scala/com/linecorp/armeria/server/ServerSuite.scala index 5f95891f4e2..ca78316f58d 100644 --- a/scala/scala_2.13/src/test/scala/com/linecorp/armeria/server/ServerSuite.scala +++ b/scala/scala_2.13/src/test/scala/com/linecorp/armeria/server/ServerSuite.scala @@ -17,6 +17,7 @@ package com.linecorp.armeria.server import com.linecorp.armeria.client.WebClientBuilder +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder import com.linecorp.armeria.internal.testing.ServerRuleDelegate import munit.Suite @@ -29,6 +30,8 @@ trait ServerSuite { protected def configureWebClient: WebClientBuilder => Unit = _ => () + protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => () + protected def server: ServerRuleDelegate = delegate /** @@ -42,6 +45,9 @@ trait ServerSuite { override def configure(sb: ServerBuilder): Unit = configureServer(sb) override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb) + + override def configureWebSocketClient(wscb: WebSocketClientBuilder): Unit = + self.configureWebSocketClient(wscb) } if (!runServerForEachTest) { diff --git a/scalapb/scalapb_2.13/src/test/scala/com/linecorp/armeria/server/scalapb/ServerSuite.scala b/scalapb/scalapb_2.13/src/test/scala/com/linecorp/armeria/server/scalapb/ServerSuite.scala index 509e0a4c0e7..d6965bcc07a 100644 --- a/scalapb/scalapb_2.13/src/test/scala/com/linecorp/armeria/server/scalapb/ServerSuite.scala +++ b/scalapb/scalapb_2.13/src/test/scala/com/linecorp/armeria/server/scalapb/ServerSuite.scala @@ -17,6 +17,7 @@ package com.linecorp.armeria.server.scalapb import com.linecorp.armeria.client.WebClientBuilder +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder import com.linecorp.armeria.internal.testing.ServerRuleDelegate import com.linecorp.armeria.server.ServerBuilder import munit.Suite @@ -30,6 +31,8 @@ trait ServerSuite { protected def configureWebClient: WebClientBuilder => Unit = _ => () + protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => () + protected def server: ServerRuleDelegate = delegate /** @@ -43,6 +46,9 @@ trait ServerSuite { override def configure(sb: ServerBuilder): Unit = configureServer(sb) override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb) + + override def configureWebSocketClient(webSocketClientBuilder: WebSocketClientBuilder): Unit = + self.configureWebSocketClient(webSocketClientBuilder) } if (!runServerForEachTest) { diff --git a/settings.gradle b/settings.gradle index 749297933e3..72a889f1cfb 100644 --- a/settings.gradle +++ b/settings.gradle @@ -6,9 +6,9 @@ plugins { // automatically download one based on the foojay Disco API. // https://docs.gradle.org/8.1.1/userguide/toolchains.html#sec:provisioning id 'org.gradle.toolchains.foojay-resolver-convention' version '0.6.0' - id 'com.gradle.develocity' version '3.17.4' + id 'com.gradle.develocity' version '3.17.5' // adds additional metadata to build scans - id 'com.gradle.common-custom-user-data-gradle-plugin' version '2.0.1' + id 'com.gradle.common-custom-user-data-gradle-plugin' version '2.0.2' } import com.gradle.develocity.agent.gradle.scan.PublishedBuildScan @@ -71,10 +71,22 @@ rootProject.name = 'armeria' apply from: "${rootDir}/gradle/scripts/settings-flags.gradle" +def virtualProjectsPath = "${rootDir}/build/virtual-projects" + // Published BOM projects includeWithFlags ':bom', 'bom' +project(':bom').with { + projectDir = file("${virtualProjectsPath}/bom") + projectDir.mkdirs() + buildFileName = 'gradle/scripts/lib/bom.gradle' +} // Published version catalog projects includeWithFlags ':version-catalog', 'version-catalog' +project(':version-catalog').with { + projectDir = file("${virtualProjectsPath}/version-catalog") + projectDir.mkdirs() + buildFileName = 'gradle/scripts/version-catalog.gradle' +} // Published Java projects includeWithFlags ':annotation-processor', 'java', 'publish', 'relocate' @@ -253,6 +265,7 @@ includeWithFlags ':examples:graphql-kotlin-example', 'java17', 'ko project(':examples:graphql-kotlin-example').projectDir = file('examples/graphql-kotlin') includeWithFlags ':examples:graphql-sangria-example', 'java11', 'scala_2.13' project(':examples:graphql-sangria-example').projectDir = file('examples/graphql-sangria') +includeWithFlags ':examples:grpc-envoy', 'java11' includeWithFlags ':examples:grpc-example', 'java11' project(':examples:grpc-example').projectDir = file('examples/grpc') includeWithFlags ':examples:grpc-kotlin', 'java11', 'kotlin-grpc', 'kotlin' diff --git a/site/src/pages/community/articles.mdx b/site/src/pages/community/articles.mdx index 7fc7412d491..5fcfd73c752 100644 --- a/site/src/pages/community/articles.mdx +++ b/site/src/pages/community/articles.mdx @@ -13,6 +13,19 @@ Send a pull request by editing [this page](https://github.com/line/armeria/edit/ ### Slides and videos +

    +
    +
    + +
    + +
    + +
    + +
    + +
    + +
    + +
    + +
    +
    . #5652 #5653 + ```java + Server + .builder() + .multipartRemovalStrategy( + MultipartRemovalStrategy.ON_RESPONSE_COMPLETION) + ... + ``` + - The default value is now which removes + the temporary files when the response is completed. +- A now includes the time spent on the TLS handshake. #3647 #5647 + ```java + ClientConnectionTimings timings = ... + assert timings.tlsHandshakeDurationNanos() > 0; + ``` +- You can now configure using or . #4962 + ```java + ClientFactory + .builder() + .tlsEngineType(TlsEngineType.OPENSSL) // 👈👈👈 + .build(); + ``` +- You can now easily find both dynamic and static decorators that handle a request + via . #5670 + ```java + ServerBuilder sb = ... + sb.decorator(CorsService.builderForAnyOrigin().newDecorator()); + ... + ServiceRequestContext ctx = ... + CorsService corsService = ctx.findService(CorsService.class); + assert corsService.config().isAnyOriginSupported(); + ``` +- You can now use the marshaller specified by gRPC `MethodDescriptor` by setting + and + . #5103 #5630 + ```java + GrpcClientBuilder builder = ... + builder.useMethodMarshaller(true); // 👈👈👈 + + GrpcServiceBuilder builder = ... + builder.useMethodMarshaller(true); // 👈👈👈 + ``` +- You can now programmatically retrieve the server side metrics via . #4992 #5627 + ```java + ServerMetrics metrics = serverConfig.serverMetrics(); + metrics.activeConnections(); + metrics.pendingRequests(); + metrics.activeRequests(); + ``` +- You can now use a that runs your service in a coroutine scope. #5442 #5603 + ```kotlin + ServerBuilder sb = ... + sb.service( + "/hello", + CoroutineHttpService { ctx, req -> + HttpResponse.of("hello world") + }) + ``` +- You can now specify options for an via . #5071 #5574 + ```java + static final ServiceOptions SERVICE_OPTIONS = + ServiceOptions + .builder() + .requestTimeoutMillis(5000) + .maxRequestLength(1024) + .requestAutoAbortDelayMillis(1000) + .build(); + + HttpService httpService = new HttpService() { + ... + @Override + public ServiceOptions options() { + return SERVICE_OPTIONS; + } + }; + + // Or use annotation for an annotated service. + class MyService { + @ServiceOption(requestTimeoutMillis = 5000, maxRequestLength = 1024) + @Get("/hello") + public HttpResponse hello() { + return HttpResponse.of("Hello!"); + } + } + ``` +- You can now inject a custom attribute from a to an annotated service + using annotation. #5514 #5547 + ```java + class MyAttributes { + public static final AttributeKey USERNAME = + AttributeKey.valueOf(MyAttributes.class, "USERNAME"); + } + + class MyAnnotatedService { + + @Get("/hello") + public HttpResponse hello( + @Attribute(prefix = MyAttributes.class, value = "USERNAME") // 👈👈👈 + String username) { ... } + } + ``` +- You can now specify a graceful shutdown timeout for an HTTP/2 connection in . #5470 #5489 + ```java + ClientFactory + .builder() + .http2GracefulShutdownTimeoutMillis(1000) + .build(); + ``` +- You can now specify an when building a . #5292 #5298 + ```java + ClientRequestContext ctx = + ClientRequestContext + .builder(request) + .endpointGroup(endpointGroup) + .build(); + ``` +- You can now set a `GraphQL` instance directly to the . #5269 + ```java + GraphQL graphQL = new GraphQL.Builder(graphqlSchema).build(); + GraphqlServiceBuilder builder = ... + builder.graphql(graphQL); +- You can now specify a delay to close an HTTP/1.1 connection from the server side, allowing the client an + opportunity for active closing. #4849 #5616 +- You can now configure the maximum length of a TLS client hello message that a server allows by using the + `-Dcom.linecorp.armeria.defaultMaxClientHelloLength=` JVM option. #5747 + +## 📈 Improvements + +- now supports [Slow start mode](https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/upstream/load_balancing/slow_start). #5688 #5693 +- Priority and locality based load balancing is now available for an . #5610 +- You can now view connection related information in the request logs. #5730 +- CORS headers are added to responses even if an exception is raised while handling a request. #5632 +- is now available through a public API. #5382 #5628 +- The metrics for failed requests while retrying now include the cause of the failure. #5405 #5583 +- The from a retrying client now includes the current attempt number. #5716 #5719 +- A colon is now allowed in the path for binding a service. #4577 #5676 +- Timer histogram metrics now repects user-provided `serviceLevelObjectives`. #5661 +- + can now access the current if available. #5634 +- + can now access the for error handling if available. #5401 #5622 +- The copied cURL command from a doc service, is now quoted correctly. #5566 + +## 🛠️ Bug fixes + +- Fixed a bug where gets to deadlock when fails to obtain + an access token. #5715 +- Armeria client always uses HTTP/2 connection preface for `h2c`, regardless of + the value of . #5706 +- now properly overrides the properties set by . #5009 #5692 +- Armeria server now returns a 408 status if a service didn't receive the request fully and + the request times out. #5579 #5680 +- A `NullPointerException` isn't raised anymore when running Armeria with a self-signed certificate. #5669 + It happened when the following conditions are all met: + - DEBUG-level logging of JDK security event is enabled. + - A user has no Bouncy Castle dependency. +- annotations at the super class or interface are now used for the parameter description. #5195 #5562 +- The specfied `ObjectMapper` to isn't + ignored anymore. #5454 #5512 +- is now applied to convert a gRPC cancellations. #5329 #5398 +- In Spring integration, the default is now gracefully closed after the server shuts down. #5742 +- A connection is not reused anymore after a is raised. #5738 +- is now correctly propagated when + is invoked. #5746 + +## 🏚️ Deprecations + +- and are deprecated. #5698 + - Use the same classes in the `armeria-prometheus1` module. +- is deprecated. Use instead. #4910 #5586 +- and various setter methods for building a `GraphQL` instance are deprecated. + - Use instead. #5269 + +## ☢️ Breaking changes + +- We updated Micrometer to 1.13.0 that has breaking changes in its Prometheus support. #5698 + - If you want to keep the old behavior that uses Prometheus Java client 0.x, + use the `io.micrometer:micrometer-registry-prometheus-simpleclient:1.13.0` module. + - If you want to use Prometheus Java client 1.x, add `com.linecorp.armeria:armeria-prometheus1` module. + - More details can be found in the [Micrometer migration guide](https://github.com/micrometer-metrics/micrometer/wiki/1.13-Migration-Guide). +- now returns a `CompletableFuture`. #5752 +- The following builder classes now have the `SELF` type parameter. #5733 + - + - + - + - + - + - + - + +## ⛓ Dependencies + +- Blockhound 1.0.8.RELEASE → 1.0.9.RELEASE +- Control Plane 1.0.44 → 1.0.45 +- GraphQL Kotlin 7.0.2 → 7.1.1 +- gRPC Java 1.63.0 → 1.64.0 +- Jackson 2.17.0 → 2.17.1 +- Kotlin Coroutine 1.8.0 → 1.8.1 +- Kubernetes client 6.11.0 → 6.12.1 +- Mircometer 1.12.4 → 1.13.0 +- Netty 4.1.108.Final → 4.1.110.Final +- Reactor 3.6.4 → 3.6.6 +- Scala2.13 2.13.13 → 2.13.14 +- Scala Collection compat 2.11.0 → 2.12.0 +- Spring 6.1.5 → 6.1.8 +- Spring Boot 3.2.4 → 3.3.0 + +## 🙇 Thank you + + \ No newline at end of file diff --git a/site/src/pages/release-notes/1.29.1.mdx b/site/src/pages/release-notes/1.29.1.mdx new file mode 100644 index 00000000000..bd61e5404fe --- /dev/null +++ b/site/src/pages/release-notes/1.29.1.mdx @@ -0,0 +1,27 @@ +--- +date: 2024-06-28 +--- + +## 🛠️ Bug fixes + +- The default now properly handles and returns + the correct gRPC `Status`. #5786 +- The is now correctly detected for the default . #5787 +- A duplicate key exception isn't raised anymore when building a JSON schema. #5788 +- The duplicator's child stream subscriber methods are now called by the correct executor. #5783 + +## ☢️ Breaking changes + +- The signatures of + has been changed. #5786 + +## 🙇 Thank you + + diff --git a/site/src/pages/release-notes/1.29.2.mdx b/site/src/pages/release-notes/1.29.2.mdx new file mode 100644 index 00000000000..0c8e562d387 --- /dev/null +++ b/site/src/pages/release-notes/1.29.2.mdx @@ -0,0 +1,22 @@ +--- +date: 2024-07-10 +--- + +## 🛠️ Bug fixes + +- Service binding builder methods now correctly return the self-types. #5797 +- The peeled exception is now correctly passed to + method. #5796 +- selection calls to the ramping up succeeds + for . #5799 +- The meter names in no longer conflict with . #5804 + +## 🙇 Thank you + + \ No newline at end of file diff --git a/site/src/pages/release-notes/1.29.3.mdx b/site/src/pages/release-notes/1.29.3.mdx new file mode 100644 index 00000000000..cec39e7f221 --- /dev/null +++ b/site/src/pages/release-notes/1.29.3.mdx @@ -0,0 +1,19 @@ +--- +date: 2024-07-19 +--- + +## 🛠️ Bug fixes + +- `NullPointerException` is no longer raised when handles errors. #5815 #5816 +- Fixed a regression where a protocol violation error is not handled + by #5811 + +## 🙇 Thank you + + diff --git a/spring/boot3-autoconfigure/src/main/java/com/linecorp/armeria/spring/InternalServices.java b/spring/boot3-autoconfigure/src/main/java/com/linecorp/armeria/spring/InternalServices.java index ee5134be816..e046de8b0fd 100644 --- a/spring/boot3-autoconfigure/src/main/java/com/linecorp/armeria/spring/InternalServices.java +++ b/spring/boot3-autoconfigure/src/main/java/com/linecorp/armeria/spring/InternalServices.java @@ -23,10 +23,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.boot.SpringApplication; import com.google.common.base.MoreObjects; import com.google.common.base.Strings; +import com.linecorp.armeria.client.ClientFactory; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.internal.common.util.PortUtil; @@ -60,6 +62,18 @@ private static boolean hasAllClasses(String... classNames) { return true; } + static { + // InternalServices is the only class that both boot-starter and boot-webflux-starter always depend on. + + // Disable the default shutdown hook to gracefully close the client factory after the server is + // shut down. + ClientFactory.disableShutdownHook(); + // The shutdown hooks are invoked after all other contexts are closed. + // The server is closed by ConfigurableApplicationContext.closeAndWait(). + // https://github.com/spring-projects/spring-boot/blame/781d7b0394c71e20f098f64a3261a18346ccd915/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplicationShutdownHook.java#L114-L116 + SpringApplication.getShutdownHandlers().add(ClientFactory::closeDefault); + } + /** * Returns a newly created {@link InternalServices} from the specified properties. */ diff --git a/thrift/thrift0.13/build.gradle b/thrift/thrift0.13/build.gradle index 5eb0d477e92..7bfd1e1ffd6 100644 --- a/thrift/thrift0.13/build.gradle +++ b/thrift/thrift0.13/build.gradle @@ -35,7 +35,7 @@ ext { // NB: Keep this same with 'armeria-thrift0.9'. tasks.shadedJar.exclude 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*' tasks.shadedJar.doLast { - ant.jar(update: true, destfile: tasks.shadedJar.archivePath) { + ant.jar(update: true, destfile: tasks.shadedJar.archiveFile.get().asFile) { sourceSets.main.output.classesDirs.each { classesDir -> fileset(dir: "${classesDir}", includes: 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*') diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/THttpServiceBuilder.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/THttpServiceBuilder.java index f5e27a5935d..39f3acf156d 100644 --- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/THttpServiceBuilder.java +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/THttpServiceBuilder.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Executors; import java.util.function.BiFunction; import java.util.function.Function; @@ -84,6 +85,7 @@ public final class THttpServiceBuilder { // -1 means to use the default request length of the Server. private int maxRequestStringLength = -1; private int maxRequestContainerLength = -1; + private boolean useBlockingTaskExecutor; THttpServiceBuilder() {} @@ -190,6 +192,17 @@ public THttpServiceBuilder maxRequestContainerLength(int maxRequestContainerLeng return this; } + /** + * Sets whether the service executes service methods using the blocking executor. By default, service + * methods are executed directly on the event loop for implementing fully asynchronous services. If your + * service uses blocking logic, you should either execute such logic in a separate thread using something + * like {@link Executors#newCachedThreadPool()} or enable this setting. + */ + public THttpServiceBuilder useBlockingTaskExecutor(boolean useBlockingTaskExecutor) { + this.useBlockingTaskExecutor = useBlockingTaskExecutor; + return this; + } + /** * Sets the {@link BiFunction} that returns an {@link RpcResponse} using the given {@link Throwable} * and {@link ServiceRequestContext}. @@ -225,10 +238,11 @@ private RpcService decorate(RpcService service) { * Builds a new instance of {@link THttpService}. */ public THttpService build() { - @SuppressWarnings("UnstableApiUsage") final Map> implementations = Multimaps.asMap(implementationsBuilder.build()); - - final ThriftCallService tcs = ThriftCallService.of(implementations); + final ThriftCallService tcs = new ThriftCallServiceBuilder() + .addServices(implementations) + .useBlockingTaskExecutor(useBlockingTaskExecutor) + .build(); return build0(tcs); } @@ -244,7 +258,9 @@ private THttpService build0(RpcService tcs) { builder.add(defaultSerializationFormat); builder.addAll(otherSerializationFormats); - return new THttpService(decorate(tcs), defaultSerializationFormat, builder.build(), - maxRequestStringLength, maxRequestContainerLength, exceptionHandler); + return new THttpService( + decorate(tcs), defaultSerializationFormat, builder.build(), + maxRequestStringLength, maxRequestContainerLength, exceptionHandler + ); } } diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallService.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallService.java index b3cd8d1c1dc..76241c09f17 100644 --- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallService.java +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallService.java @@ -16,7 +16,7 @@ package com.linecorp.armeria.server.thrift; -import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; import java.util.List; @@ -31,14 +31,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; - import com.linecorp.armeria.common.CompletableRpcResponse; import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.RpcRequest; import com.linecorp.armeria.common.RpcResponse; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.internal.common.thrift.ThriftFunction; import com.linecorp.armeria.server.RpcService; import com.linecorp.armeria.server.ServiceRequestContext; @@ -70,7 +68,7 @@ public void onError(Exception e) { */ public static ThriftCallService of(Object implementation) { requireNonNull(implementation, "implementation"); - return new ThriftCallService(ImmutableMap.of("", ImmutableList.of(implementation))); + return builder().addService(implementation).build(); } /** @@ -82,19 +80,27 @@ public static ThriftCallService of(Object implementation) { */ public static ThriftCallService of(Map> implementations) { requireNonNull(implementations, "implementations"); - return new ThriftCallService(implementations); + checkArgument(!implementations.isEmpty(), "implementations is empty"); + + return builder().addServices(implementations).build(); + } + + /** + * Creates a new instance of {@link ThriftCallServiceBuilder} which can build + * an instance of {@link ThriftCallService} fluently. + */ + @UnstableApi + public static ThriftCallServiceBuilder builder() { + return new ThriftCallServiceBuilder(); } private final Map entries; - private ThriftCallService(Map> implementations) { - requireNonNull(implementations, "implementations"); - if (implementations.isEmpty()) { - throw new IllegalArgumentException("empty implementations"); - } + private final boolean useBlockingTaskExecutor; - entries = implementations.entrySet().stream().collect( - toImmutableMap(Map.Entry::getKey, ThriftServiceEntry::new)); + ThriftCallService(Map entries, boolean useBlockingTaskExecutor) { + this.entries = entries; + this.useBlockingTaskExecutor = useBlockingTaskExecutor; } /** @@ -140,14 +146,24 @@ public RpcResponse serve(ServiceRequestContext ctx, RpcRequest call) throws Exce TApplicationException.UNKNOWN_METHOD, "unknown method: " + call.method())); } - private static void invoke( + private void invoke( ServiceRequestContext ctx, Object impl, ThriftFunction func, List args, CompletableRpcResponse reply) { try { final TBase tArgs = func.newArgs(args); if (func.isAsync()) { - invokeAsynchronously(impl, func, tArgs, reply); + if (useBlockingTaskExecutor) { + ctx.blockingTaskExecutor().execute(() -> { + try { + invokeAsynchronously(impl, func, tArgs, reply); + } catch (Throwable t) { + reply.completeExceptionally(t); + } + }); + } else { + invokeAsynchronously(impl, func, tArgs, reply); + } } else { invokeSynchronously(ctx, impl, func, tArgs, reply); } diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilder.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilder.java new file mode 100644 index 00000000000..ff2df8be448 --- /dev/null +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilder.java @@ -0,0 +1,153 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.thrift; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +import java.util.Map; +import java.util.concurrent.Executors; + +import com.google.common.collect.ImmutableListMultimap; + +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * A fluent builder to build an instance of {@link ThriftCallService}. + * + *

    Example

    + *
    {@code
    + * ThriftCallService service = ThriftCallService
    + *                 .builder()
    + *                 .addService(defaultServiceImpl) // Adds a service
    + *                 .addService("foo", fooServiceImpl) // Adds a service with a key
    + *                 .addService("foobar", fooServiceImpl)  // Adds multiple services to the same key
    + *                 .addService("foobar", barServiceImpl)
    + *                  // Adds multiple services at once
    + *                 .addServices("foobarOnce", fooServiceImpl, barServiceImpl)
    + *                  // Adds multiple services by list
    + *                 .addServices("foobarList", ImmutableList.of(fooServiceImpl, barServiceImpl))
    + *                  // Adds multiple services by map
    + *                 .addServices(ImmutableMap.of("fooMap", fooServiceImpl, "barMap", barServiceImpl))
    + *                  // Adds multiple services by map
    + *                 .addServices(ImmutableMap.of("fooIterableMap",
    + *                                              ImmutableList.of(fooServiceImpl, barServiceImpl)))
    + *                 .build();
    + * }
    + * + * @see ThriftCallService + */ +@UnstableApi +public final class ThriftCallServiceBuilder { + private final ImmutableListMultimap.Builder servicesBuilder = + ImmutableListMultimap.builder(); + + private boolean useBlockingTaskExecutor; + + ThriftCallServiceBuilder() {} + + /** + * Adds a service for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addService(Object service) { + requireNonNull(service, "service"); + servicesBuilder.put("", service); + return this; + } + + /** + * Adds a service with a key for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addService(String key, Object service) { + requireNonNull(key, "key"); + requireNonNull(service, "service"); + servicesBuilder.put(key, service); + return this; + } + + /** + * Adds a service for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addServices(Object... services) { + requireNonNull(services, "services"); + checkArgument(services.length != 0, "service should not be empty"); + servicesBuilder.putAll("", services); + return this; + } + + /** + * Adds a service with a key for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addServices(String key, Object... services) { + requireNonNull(key, "key"); + requireNonNull(services, "service"); + checkArgument(services.length != 0, "service should not be empty"); + servicesBuilder.putAll(key, services); + return this; + } + + /** + * Adds services with key by iterable for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addServices(String key, Iterable services) { + requireNonNull(key, "key"); + requireNonNull(services, "services"); + checkArgument(services.iterator().hasNext(), "service should not be empty"); + servicesBuilder.putAll(key, services); + return this; + } + + /** + * Adds multiple services by map for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addServices(Map services) { + requireNonNull(services, "services"); + checkArgument(!services.isEmpty(), "service should not be empty"); + + services.forEach((k, v) -> { + if (v instanceof Iterable) { + servicesBuilder.putAll(k, (Iterable) v); + } else { + servicesBuilder.put(k, v); + } + }); + return this; + } + + /** + * Sets whether the service executes service methods using the blocking executor. By default, service + * methods are executed directly on the event loop for implementing fully asynchronous services. If your + * service uses blocking logic, you should either execute such logic in a separate thread using something + * like {@link Executors#newCachedThreadPool()} or enable this setting. + */ + public ThriftCallServiceBuilder useBlockingTaskExecutor(boolean useBlockingTaskExecutor) { + this.useBlockingTaskExecutor = useBlockingTaskExecutor; + return this; + } + + /** + * Builds a new instance of {@link ThriftCallService}. + */ + public ThriftCallService build() { + return new ThriftCallService( + servicesBuilder.build().asMap().entrySet().stream().collect( + toImmutableMap(Map.Entry::getKey, ThriftServiceEntry::new)), + useBlockingTaskExecutor + ); + } +} diff --git a/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/THttpServiceBlockingTest.java b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/THttpServiceBlockingTest.java new file mode 100644 index 00000000000..5c53eca17f7 --- /dev/null +++ b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/THttpServiceBlockingTest.java @@ -0,0 +1,127 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.thrift; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.thrift.TException; +import org.apache.thrift.async.AsyncMethodCallback; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.client.thrift.ThriftClients; +import com.linecorp.armeria.common.util.ThreadFactories; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import testing.thrift.main.HelloService; + +class THttpServiceBlockingTest { + private static final AtomicReference currentThreadName = new AtomicReference<>(""); + + private static final String BLOCKING_EXECUTOR_PREFIX = "blocking-test"; + private static final ScheduledExecutorService executor = + new ScheduledThreadPoolExecutor(1, + ThreadFactories.newThreadFactory(BLOCKING_EXECUTOR_PREFIX, true)); + + @BeforeEach + void clearDetector() { + currentThreadName.set(""); + } + + @AfterAll + public static void shutdownExecutor() { + executor.shutdown(); + } + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + + sb.service("/", + THttpService.builder().addService(new HelloServiceAsyncImpl()).build()); + sb.service("/blocking", THttpService.builder() + .useBlockingTaskExecutor(true) + .addService(new HelloServiceAsyncImpl()) + .build()); + sb.service("/blocking-iface", + THttpService.builder().addService(new HelloServiceImpl()).build()); + + sb.blockingTaskExecutor(executor, true); + } + }; + + @Test + void nonBlocking() throws Exception { + final HelloService.Iface client = ThriftClients.newClient(server.httpUri(), HelloService.Iface.class); + + final String message = "nonBlockingTest"; + final String response = client.hello(message); + + assertThat(response).isEqualTo(message); + assertThat(currentThreadName.get().startsWith(BLOCKING_EXECUTOR_PREFIX)).isFalse(); + } + + @Test + void blocking() throws Exception { + final HelloService.Iface client = + ThriftClients.builder(server.httpUri()) + .path("/blocking") + .build(HelloService.Iface.class); + final String message = "blockingTest"; + final String response = client.hello(message); + + assertThat(response).isEqualTo(message); + assertThat(currentThreadName.get().startsWith(BLOCKING_EXECUTOR_PREFIX)).isTrue(); + } + + @Test + void blockingIface() throws Exception { + final HelloService.Iface client = + ThriftClients.builder(server.httpUri()) + .path("/blocking-iface") + .build(HelloService.Iface.class); + final String message = "blockingTest"; + final String response = client.hello(message); + + assertThat(response).isEqualTo(message); + assertThat(currentThreadName.get().startsWith(BLOCKING_EXECUTOR_PREFIX)).isTrue(); + } + + static class HelloServiceAsyncImpl implements HelloService.AsyncIface { + @Override + public void hello(String name, AsyncMethodCallback resultHandler) throws TException { + currentThreadName.set(Thread.currentThread().getName()); + resultHandler.onComplete(name); + } + } + + static class HelloServiceImpl implements HelloService.Iface { + @Override + public String hello(String name) throws TException { + currentThreadName.set(Thread.currentThread().getName()); + return name; + } + } +} diff --git a/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilderTest.java b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilderTest.java new file mode 100644 index 00000000000..419629ebfe4 --- /dev/null +++ b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilderTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.thrift; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import testing.thrift.main.FooService.AsyncIface; + +/** + * Test for {@link ThriftCallServiceBuilder}. + */ +class ThriftCallServiceBuilderTest { + @Test + void nullAndEmptyCases() { + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addService(null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addService("", null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addServices("", (Object) null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addServices("", null, null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addServices("", null, null, null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addServices("", (Iterable) null) + ); + assertThrows(IllegalArgumentException.class, () -> + ThriftCallService.builder().addServices("", new ArrayList<>()) + ); + } + + @Test + void testBuilder() { + final AsyncIface defaultServiceImpl = mock(AsyncIface.class); + final AsyncIface fooServiceImpl = mock(AsyncIface.class); + final AsyncIface barServiceImpl = mock(AsyncIface.class); + final ThriftCallService service = ThriftCallService + .builder() + .addService(defaultServiceImpl) + .addService("foo", fooServiceImpl) + .addService("foobar", fooServiceImpl) + .addService("foobar", barServiceImpl) + .addServices("foobarOnce", fooServiceImpl, barServiceImpl) + .addServices("foobarList", ImmutableList.of(fooServiceImpl, barServiceImpl)) + .addServices(ImmutableMap.of("fooMap", fooServiceImpl, "barMap", barServiceImpl)) + .addServices(ImmutableMap.of("fooIterableMap", + ImmutableList.of(fooServiceImpl, barServiceImpl))) + .build(); + final Map> actualEntries = + service.entries().entrySet().stream() + .collect(ImmutableMap.toImmutableMap( + Map.Entry::getKey, + e -> ImmutableList.copyOf(e.getValue().implementations))); + + final Map> expectedEntries = ImmutableMap.of( + "", ImmutableList.of(defaultServiceImpl), + "foo", ImmutableList.of(fooServiceImpl), + "foobar", ImmutableList.of(fooServiceImpl, barServiceImpl), + "foobarOnce", ImmutableList.of(fooServiceImpl, barServiceImpl), + "foobarList", ImmutableList.of(fooServiceImpl, barServiceImpl), + "fooMap", ImmutableList.of(fooServiceImpl), + "barMap", ImmutableList.of(barServiceImpl), + "fooIterableMap", ImmutableList.of(fooServiceImpl, barServiceImpl)); + + assertThat(actualEntries).isEqualTo(expectedEntries); + } +} diff --git a/thrift/thrift0.14/build.gradle b/thrift/thrift0.14/build.gradle index a93c09493cd..54050bcd18d 100644 --- a/thrift/thrift0.14/build.gradle +++ b/thrift/thrift0.14/build.gradle @@ -66,7 +66,7 @@ tasks.sourcesJar.from "${thrift013ProjectDir}/src/main/resources" // NB: Keep this same with ':thrift0.13'. tasks.shadedJar.exclude 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*' tasks.shadedJar.doLast { - ant.jar(update: true, destfile: tasks.shadedJar.archivePath) { + ant.jar(update: true, destfile: tasks.shadedJar.archiveFile.get().asFile) { sourceSets.main.output.classesDirs.each { classesDir -> fileset(dir: "$classesDir", includes: 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*') diff --git a/thrift/thrift0.15/build.gradle b/thrift/thrift0.15/build.gradle index 5b4dde0ede3..7d8af98451e 100644 --- a/thrift/thrift0.15/build.gradle +++ b/thrift/thrift0.15/build.gradle @@ -72,7 +72,7 @@ tasks.sourcesJar.from "${thrift013ProjectDir}/src/main/resources" // NB: Keep this same with ':thrift0.13'. tasks.shadedJar.exclude 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*' tasks.shadedJar.doLast { - ant.jar(update: true, destfile: tasks.shadedJar.archivePath) { + ant.jar(update: true, destfile: tasks.shadedJar.archiveFile.get().asFile) { sourceSets.main.output.classesDirs.each { classesDir -> fileset(dir: "$classesDir", includes: 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*') diff --git a/thrift/thrift0.16/build.gradle b/thrift/thrift0.16/build.gradle index 9d73fcfa1f2..3be0448c7f7 100644 --- a/thrift/thrift0.16/build.gradle +++ b/thrift/thrift0.16/build.gradle @@ -73,7 +73,7 @@ tasks.sourcesJar.from "${thrift013ProjectDir}/src/main/resources" // NB: Keep this same with ':thrift0.13'. tasks.shadedJar.exclude 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*' tasks.shadedJar.doLast { - ant.jar(update: true, destfile: tasks.shadedJar.archivePath) { + ant.jar(update: true, destfile: tasks.shadedJar.archiveFile.get().asFile) { sourceSets.main.output.classesDirs.each { classesDir -> fileset(dir: "$classesDir", includes: 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*') diff --git a/thrift/thrift0.17/build.gradle b/thrift/thrift0.17/build.gradle index 9d0c8053322..b40b1d48741 100644 --- a/thrift/thrift0.17/build.gradle +++ b/thrift/thrift0.17/build.gradle @@ -74,7 +74,7 @@ tasks.sourcesJar.from "${thrift013ProjectDir}/src/main/resources" // NB: Keep this same with ':thrift0.13'. tasks.shadedJar.exclude 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*' tasks.shadedJar.doLast { - ant.jar(update: true, destfile: tasks.shadedJar.archivePath) { + ant.jar(update: true, destfile: tasks.shadedJar.archiveFile.get().asFile) { sourceSets.main.output.classesDirs.each { classesDir -> fileset(dir: "$classesDir", includes: 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*') diff --git a/thrift/thrift0.18/build.gradle b/thrift/thrift0.18/build.gradle index 89bf208388e..28b139ab86a 100644 --- a/thrift/thrift0.18/build.gradle +++ b/thrift/thrift0.18/build.gradle @@ -76,7 +76,7 @@ tasks.sourcesJar.from "${thrift013ProjectDir}/src/main/resources" // NB: Keep this same with ':thrift0.13'. tasks.shadedJar.exclude 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*' tasks.shadedJar.doLast { - ant.jar(update: true, destfile: tasks.shadedJar.archivePath) { + ant.jar(update: true, destfile: tasks.shadedJar.archiveFile.get().asFile) { sourceSets.main.output.classesDirs.each { classesDir -> fileset(dir: "$classesDir", includes: 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*') diff --git a/thrift/thrift0.9/build.gradle b/thrift/thrift0.9/build.gradle index 3af131edb65..b87c27ffc23 100644 --- a/thrift/thrift0.9/build.gradle +++ b/thrift/thrift0.9/build.gradle @@ -56,7 +56,7 @@ ext { // NB: Keep this same with ':thrift0.13'. tasks.shadedJar.exclude 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*' tasks.shadedJar.doLast { - ant.jar(update: true, destfile: tasks.shadedJar.archivePath) { + ant.jar(update: true, destfile: tasks.shadedJar.archiveFile.get().asFile) { sourceSets.main.output.classesDirs.each { classesDir -> fileset(dir: "$classesDir", includes: 'com/linecorp/armeria/common/thrift/ThriftListenableFuture*') diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/ClusterEntry.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/ClusterEntry.java index cc46d949662..ea09f43ca4a 100644 --- a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/ClusterEntry.java +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/ClusterEntry.java @@ -21,6 +21,9 @@ import java.util.List; import java.util.concurrent.CompletableFuture; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableList; @@ -35,8 +38,11 @@ final class ClusterEntry implements AsyncCloseable { + private static final Logger logger = LoggerFactory.getLogger(ClusterEntry.class); + private final EndpointsPool endpointsPool; - private final LoadBalancer loadBalancer = new SubsetLoadBalancer(); + @Nullable + private volatile LoadBalancer loadBalancer; private final ClusterManager clusterManager; private final EventExecutor eventExecutor; private List endpoints = ImmutableList.of(); @@ -51,6 +57,10 @@ final class ClusterEntry implements AsyncCloseable { @Nullable Endpoint selectNow(ClientRequestContext ctx) { + final LoadBalancer loadBalancer = this.loadBalancer; + if (loadBalancer == null) { + return null; + } return loadBalancer.selectNow(ctx); } @@ -64,9 +74,16 @@ void updateClusterSnapshot(ClusterSnapshot clusterSnapshot) { void accept(ClusterSnapshot clusterSnapshot, List endpoints) { assert eventExecutor.inEventLoop(); - this.endpoints = endpoints; - final PrioritySet prioritySet = new PrioritySet(endpoints, clusterSnapshot); - loadBalancer.prioritySetUpdated(prioritySet); + this.endpoints = ImmutableList.copyOf(endpoints); + final PrioritySet prioritySet = new PriorityStateManager(clusterSnapshot, endpoints).build(); + if (logger.isTraceEnabled()) { + logger.trace("XdsEndpointGroup is using a new PrioritySet({})", prioritySet); + } + if (clusterSnapshot.xdsResource().resource().hasLbSubsetConfig()) { + loadBalancer = new SubsetLoadBalancer(prioritySet); + } else { + loadBalancer = new DefaultLoadBalancer(prioritySet); + } clusterManager.notifyListeners(); } diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/DefaultLbStateFactory.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/DefaultLbStateFactory.java new file mode 100644 index 00000000000..a7eb723996c --- /dev/null +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/DefaultLbStateFactory.java @@ -0,0 +1,396 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import java.util.Collections; +import java.util.Map; +import java.util.SortedSet; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; +import com.google.common.math.IntMath; +import com.google.common.math.LongMath; + +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.xds.client.endpoint.DefaultLoadBalancer.DistributeLoadState; +import com.linecorp.armeria.xds.client.endpoint.DefaultLoadBalancer.HostAvailability; +import com.linecorp.armeria.xds.client.endpoint.DefaultLoadBalancer.PriorityAndAvailability; + +import it.unimi.dsi.fastutil.ints.Int2IntMap; +import it.unimi.dsi.fastutil.ints.Int2IntMaps; +import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; + +final class DefaultLbStateFactory { + + private static final Logger logger = LoggerFactory.getLogger(DefaultLbStateFactory.class); + + static DefaultLbState newInstance(PrioritySet prioritySet) { + PerPriorityLoad perPriorityLoad = calculatePerPriorityLoad(prioritySet); + final PerPriorityPanic perPriorityPanic = + recalculatePerPriorityPanic(prioritySet, + perPriorityLoad.normalizedTotalAvailability()); + + logger.debug("XdsEndpointGroup load balancer priorities for cluster({}) has been updated with" + + " perPriorityLoad({}), perPriorityPanic({}).", + prioritySet.cluster().getName(), perPriorityLoad, perPriorityPanic); + + if (perPriorityPanic.totalPanic()) { + perPriorityLoad = recalculateLoadInTotalPanic(prioritySet); + logger.debug("XdsEndpointGroup load balancer in panic for cluster({}) with perPriorityLoad({}).", + prioritySet.cluster().getName(), perPriorityLoad); + } + return new DefaultLbState(prioritySet, perPriorityLoad, perPriorityPanic); + } + + private static PerPriorityLoad calculatePerPriorityLoad(PrioritySet prioritySet) { + final Int2IntMap perPriorityHealth = new Int2IntOpenHashMap(prioritySet.priorities().size()); + final Int2IntMap perPriorityDegraded = new Int2IntOpenHashMap(prioritySet.priorities().size()); + for (int priority: prioritySet.priorities()) { + final HealthAndDegraded healthAndDegraded = + recalculatePerPriorityState(priority, prioritySet); + perPriorityHealth.put(priority, healthAndDegraded.healthWeight); + perPriorityDegraded.put(priority, healthAndDegraded.degradedWeight); + } + return buildLoads(prioritySet, + Int2IntMaps.unmodifiable(perPriorityHealth), + Int2IntMaps.unmodifiable(perPriorityDegraded)); + } + + private static HealthAndDegraded recalculatePerPriorityState( + int priority, PrioritySet prioritySet) { + final HostSet hostSet = prioritySet.hostSets().get(priority); + final int hostCount = hostSet.hosts().size(); + + if (hostCount <= 0) { + return HealthAndDegraded.ZERO; + } + + long healthyWeight = 0; + long degradedWeight = 0; + long totalWeight = 0; + if (hostSet.weightedPriorityHealth()) { + for (Endpoint host : hostSet.healthyHosts()) { + healthyWeight += host.weight(); + } + for (Endpoint host : hostSet.degradedHosts()) { + degradedWeight += host.weight(); + } + for (Endpoint host : hostSet.hosts()) { + totalWeight += host.weight(); + } + } else { + healthyWeight = hostSet.healthyHosts().size(); + degradedWeight = hostSet.degradedHosts().size(); + totalWeight = hostCount; + } + final int health = (int) Math.min(100L, LongMath.saturatedMultiply( + hostSet.overProvisioningFactor(), healthyWeight) / totalWeight); + final int degraded = (int) Math.min(100L, LongMath.saturatedMultiply( + hostSet.overProvisioningFactor(), degradedWeight) / totalWeight); + return new HealthAndDegraded(health, degraded); + } + + private static PerPriorityLoad buildLoads(PrioritySet prioritySet, + Map perPriorityHealth, + Map perPriorityDegraded) { + final int normalizedTotalAvailability = + normalizedTotalAvailability(perPriorityHealth, perPriorityDegraded); + if (normalizedTotalAvailability == 0) { + return PerPriorityLoad.INVALID; + } + + final Map healthyPriorityLoad = new Int2IntOpenHashMap(); + final Map degradedPriorityLoad = new Int2IntOpenHashMap(); + final DistributeLoadState firstHealthyAndRemaining = + distributeLoad(prioritySet.priorities(), healthyPriorityLoad, perPriorityHealth, + 100, normalizedTotalAvailability); + final DistributeLoadState firstDegradedAndRemaining = + distributeLoad(prioritySet.priorities(), degradedPriorityLoad, perPriorityDegraded, + firstHealthyAndRemaining.totalLoad, normalizedTotalAvailability); + final int remainingLoad = firstDegradedAndRemaining.totalLoad; + if (remainingLoad > 0) { + final int firstHealthy = firstHealthyAndRemaining.firstAvailablePriority; + final int firstDegraded = firstDegradedAndRemaining.firstAvailablePriority; + if (firstHealthy != -1) { + healthyPriorityLoad.computeIfPresent(firstHealthy, (k, v) -> v + remainingLoad); + } else { + assert firstDegraded != -1; + degradedPriorityLoad.computeIfPresent(firstDegraded, (k, v) -> v + remainingLoad); + } + } + + assert priorityLoadSum(healthyPriorityLoad, degradedPriorityLoad) == 100; + return new PerPriorityLoad(healthyPriorityLoad, degradedPriorityLoad, + normalizedTotalAvailability); + } + + private static int normalizedTotalAvailability(Map perPriorityHealth, + Map perPriorityDegraded) { + final int totalAvailability = Streams.concat(perPriorityHealth.values().stream(), + perPriorityDegraded.values().stream()) + .reduce(0, IntMath::saturatedAdd).intValue(); + return Math.min(totalAvailability, 100); + } + + private static int priorityLoadSum(Map healthyPriorityLoad, + Map degradedPriorityLoad) { + return Streams.concat(healthyPriorityLoad.values().stream(), + degradedPriorityLoad.values().stream()) + .reduce(0, IntMath::saturatedAdd).intValue(); + } + + private static DistributeLoadState distributeLoad(SortedSet priorities, + Map perPriorityLoad, + Map perPriorityAvailability, + int totalLoad, int normalizedTotalAvailability) { + int firstAvailablePriority = -1; + for (Integer priority: priorities) { + final long availability = perPriorityAvailability.getOrDefault(priority, 0); + if (firstAvailablePriority < 0 && availability > 0) { + firstAvailablePriority = priority; + } + final int load = (int) Math.min(totalLoad, availability * 100 / normalizedTotalAvailability); + perPriorityLoad.put(priority, load); + totalLoad -= load; + } + return new DistributeLoadState(totalLoad, firstAvailablePriority); + } + + private static PerPriorityPanic recalculatePerPriorityPanic(PrioritySet prioritySet, + int normalizedTotalAvailability) { + final int panicThreshold = prioritySet.panicThreshold(); + if (normalizedTotalAvailability == 0 && panicThreshold == 0) { + // there are no hosts available and panic mode is disabled. + // we should always return a null Endpoint for this case. + return PerPriorityPanic.INVALID; + } + boolean totalPanic = true; + final ImmutableMap.Builder perPriorityPanicBuilder = ImmutableMap.builder(); + for (Integer priority : prioritySet.priorities()) { + final HostSet hostSet = prioritySet.hostSets().get(priority); + final boolean isPanic = + normalizedTotalAvailability == 100 ? false : isHostSetInPanic(hostSet, panicThreshold); + perPriorityPanicBuilder.put(priority, isPanic); + totalPanic &= isPanic; + } + return new PerPriorityPanic(perPriorityPanicBuilder.build(), totalPanic); + } + + private static PerPriorityLoad recalculateLoadInTotalPanic(PrioritySet prioritySet) { + final int totalHostsCount = prioritySet.hostSets().values().stream() + .map(hostSet -> hostSet.hosts().size()) + .reduce(0, IntMath::saturatedAdd) + .intValue(); + if (totalHostsCount == 0) { + return PerPriorityLoad.INVALID; + } + int totalLoad = 100; + int firstNoEmpty = -1; + final Map healthyPriorityLoad = + new Int2IntOpenHashMap(prioritySet.priorities().size()); + final Map degradedPriorityLoad = + new Int2IntOpenHashMap(prioritySet.priorities().size()); + for (Integer priority: prioritySet.priorities()) { + final HostSet hostSet = prioritySet.hostSets().get(priority); + final int hostsSize = hostSet.hosts().size(); + if (firstNoEmpty == -1 && hostsSize > 0) { + firstNoEmpty = priority; + } + final int load = 100 * hostsSize / totalHostsCount; + healthyPriorityLoad.put(priority, load); + degradedPriorityLoad.put(priority, 0); + totalLoad -= load; + } + final int remainingLoad = totalLoad; + healthyPriorityLoad.computeIfPresent(firstNoEmpty, (k, v) -> v + remainingLoad); + final int priorityLoadSum = priorityLoadSum(healthyPriorityLoad, degradedPriorityLoad); + assert priorityLoadSum == 100 : "The priority loads not summing up to 100 (" + priorityLoadSum + + ") for cluster (" + prioritySet.cluster().getName() + ')'; + return new PerPriorityLoad(healthyPriorityLoad, degradedPriorityLoad, 100); + } + + private static boolean isHostSetInPanic(HostSet hostSet, int panicThreshold) { + final int hostCount = hostSet.hosts().size(); + final double healthyPercent = + hostCount == 0 ? 0 : 100.0 * hostSet.healthyHosts().size() / hostCount; + final double degradedPercent = + hostCount == 0 ? 0 : 100.0 * hostSet.degradedHosts().size() / hostCount; + return healthyPercent + degradedPercent < panicThreshold; + } + + static class PerPriorityLoad { + final Map healthyPriorityLoad; + final Map degradedPriorityLoad; + private final int normalizedTotalAvailability; + private final boolean forceEmptyEndpoint; + + private static final PerPriorityLoad INVALID = new PerPriorityLoad(); + + private PerPriorityLoad() { + healthyPriorityLoad = Collections.emptyMap(); + degradedPriorityLoad = Collections.emptyMap(); + normalizedTotalAvailability = 0; + forceEmptyEndpoint = true; + } + + PerPriorityLoad(Map healthyPriorityLoad, + Map degradedPriorityLoad, + int normalizedTotalAvailability) { + this.healthyPriorityLoad = ImmutableMap.copyOf(healthyPriorityLoad); + this.degradedPriorityLoad = ImmutableMap.copyOf(degradedPriorityLoad); + this.normalizedTotalAvailability = normalizedTotalAvailability; + forceEmptyEndpoint = false; + } + + int normalizedTotalAvailability() { + return normalizedTotalAvailability; + } + + int getHealthy(int priority) { + return healthyPriorityLoad.getOrDefault(priority, 0); + } + + int getDegraded(int priority) { + return degradedPriorityLoad.getOrDefault(priority, 0); + } + + boolean forceEmptyEndpoint() { + return forceEmptyEndpoint; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("healthyPriorityLoad", healthyPriorityLoad) + .add("degradedPriorityLoad", degradedPriorityLoad) + .add("normalizedTotalAvailability", normalizedTotalAvailability) + .add("forceEmptyEndpoint", forceEmptyEndpoint) + .toString(); + } + } + + static class PerPriorityPanic { + final Map perPriorityPanic; + private final boolean totalPanic; + private final boolean forceEmptyEndpoint; + + static final PerPriorityPanic INVALID = new PerPriorityPanic(); + + private PerPriorityPanic() { + perPriorityPanic = Collections.emptyMap(); + forceEmptyEndpoint = true; + totalPanic = false; + } + + PerPriorityPanic(Map perPriorityPanic, boolean totalPanic) { + this.perPriorityPanic = ImmutableMap.copyOf(perPriorityPanic); + this.totalPanic = totalPanic; + forceEmptyEndpoint = false; + } + + boolean get(int priority) { + return perPriorityPanic.getOrDefault(priority, true); + } + + boolean totalPanic() { + return totalPanic; + } + + boolean forceEmptyEndpoint() { + return forceEmptyEndpoint; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("perPriorityPanic", perPriorityPanic) + .add("totalPanic", totalPanic) + .add("forceEmptyEndpoint", forceEmptyEndpoint) + .toString(); + } + } + + static class DefaultLbState { + private final PrioritySet prioritySet; + private final PerPriorityLoad perPriorityLoad; + private final PerPriorityPanic perPriorityPanic; + + DefaultLbState(PrioritySet prioritySet, + PerPriorityLoad perPriorityLoad, PerPriorityPanic perPriorityPanic) { + this.prioritySet = prioritySet; + this.perPriorityLoad = perPriorityLoad; + this.perPriorityPanic = perPriorityPanic; + } + + PerPriorityPanic perPriorityPanic() { + return perPriorityPanic; + } + + PrioritySet prioritySet() { + return prioritySet; + } + + PerPriorityLoad perPriorityLoad() { + return perPriorityLoad; + } + + @Nullable + PriorityAndAvailability choosePriority(int hash) { + if (perPriorityLoad.forceEmptyEndpoint() || perPriorityPanic.forceEmptyEndpoint()) { + return null; + } + hash = hash % 100 + 1; + int aggregatePercentageLoad = 0; + final PerPriorityLoad perPriorityLoad = perPriorityLoad(); + for (Integer priority: prioritySet.priorities()) { + aggregatePercentageLoad += perPriorityLoad.getHealthy(priority); + if (hash <= aggregatePercentageLoad) { + return new PriorityAndAvailability(priority, HostAvailability.HEALTHY); + } + } + for (Integer priority: prioritySet.priorities()) { + aggregatePercentageLoad += perPriorityLoad.getDegraded(priority); + if (hash <= aggregatePercentageLoad) { + return new PriorityAndAvailability(priority, HostAvailability.DEGRADED); + } + } + // Shouldn't reach here + throw new IllegalStateException("Unable to select a priority for cluster(" + + prioritySet.cluster().getName() + "), hash(" + hash + ')'); + } + } + + private static class HealthAndDegraded { + + static final HealthAndDegraded ZERO = new HealthAndDegraded(0, 0); + + private final int healthWeight; + private final int degradedWeight; + + HealthAndDegraded(int healthWeight, int degradedWeight) { + this.healthWeight = healthWeight; + this.degradedWeight = degradedWeight; + } + } + + private DefaultLbStateFactory() {} +} diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/DefaultLoadBalancer.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/DefaultLoadBalancer.java new file mode 100644 index 00000000000..0176817bdcf --- /dev/null +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/DefaultLoadBalancer.java @@ -0,0 +1,215 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import static com.google.common.base.Preconditions.checkArgument; + +import java.util.Map; + +import com.google.common.base.MoreObjects; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.xds.client.endpoint.DefaultLbStateFactory.DefaultLbState; + +import io.envoyproxy.envoy.config.core.v3.Locality; + +final class DefaultLoadBalancer implements LoadBalancer { + + private final DefaultLbStateFactory.DefaultLbState lbState; + + DefaultLoadBalancer(PrioritySet prioritySet) { + lbState = DefaultLbStateFactory.newInstance(prioritySet); + } + + @Override + @Nullable + public Endpoint selectNow(ClientRequestContext ctx) { + final PrioritySet prioritySet = lbState.prioritySet(); + if (prioritySet.priorities().isEmpty()) { + return null; + } + final int hash = EndpointUtil.hash(ctx); + final HostsSource hostsSource = hostSourceToUse(lbState, hash); + if (hostsSource == null) { + return null; + } + final HostSet hostSet = prioritySet.hostSets().get(hostsSource.priority); + if (hostSet == null) { + // shouldn't reach here + throw new IllegalStateException("Unable to select a priority for cluster(" + + prioritySet.cluster().getName() + "), hostsSource(" + + hostsSource + ')'); + } + switch (hostsSource.sourceType) { + case ALL_HOSTS: + return hostSet.hostsEndpointGroup().selectNow(ctx); + case HEALTHY_HOSTS: + return hostSet.healthyHostsEndpointGroup().selectNow(ctx); + case DEGRADED_HOSTS: + return hostSet.degradedHostsEndpointGroup().selectNow(ctx); + case LOCALITY_HEALTHY_HOSTS: + final Map healthyLocalities = + hostSet.healthyEndpointGroupPerLocality(); + final EndpointGroup healthyEndpointGroup = healthyLocalities.get(hostsSource.locality); + if (healthyEndpointGroup != null) { + return healthyEndpointGroup.selectNow(ctx); + } + break; + case LOCALITY_DEGRADED_HOSTS: + final Map degradedLocalities = + hostSet.degradedEndpointGroupPerLocality(); + final EndpointGroup degradedEndpointGroup = degradedLocalities.get(hostsSource.locality); + if (degradedEndpointGroup != null) { + return degradedEndpointGroup.selectNow(ctx); + } + break; + default: + throw new Error(); + } + return null; + } + + @Nullable + HostsSource hostSourceToUse(DefaultLbState lbState, int hash) { + final PriorityAndAvailability priorityAndAvailability = lbState.choosePriority(hash); + if (priorityAndAvailability == null) { + return null; + } + final PrioritySet prioritySet = lbState.prioritySet(); + final int priority = priorityAndAvailability.priority; + final HostSet hostSet = prioritySet.hostSets().get(priority); + final HostAvailability hostAvailability = priorityAndAvailability.hostAvailability; + if (lbState.perPriorityPanic().get(priority)) { + if (prioritySet.failTrafficOnPanic()) { + return null; + } else { + return new HostsSource(priority, SourceType.ALL_HOSTS); + } + } + + if (prioritySet.localityWeightedBalancing()) { + final Locality locality; + if (hostAvailability == HostAvailability.DEGRADED) { + locality = hostSet.chooseDegradedLocality(); + } else { + locality = hostSet.chooseHealthyLocality(); + } + if (locality != null) { + return new HostsSource(priority, localitySourceType(hostAvailability), locality); + } + } + + // don't do zone aware routing for now + return new HostsSource(priority, sourceType(hostAvailability), null); + } + + private static SourceType localitySourceType(HostAvailability hostAvailability) { + final SourceType sourceType; + switch (hostAvailability) { + case HEALTHY: + sourceType = SourceType.LOCALITY_HEALTHY_HOSTS; + break; + case DEGRADED: + sourceType = SourceType.LOCALITY_DEGRADED_HOSTS; + break; + default: + throw new Error(); + } + return sourceType; + } + + private static SourceType sourceType(HostAvailability hostAvailability) { + final SourceType sourceType; + switch (hostAvailability) { + case HEALTHY: + sourceType = SourceType.HEALTHY_HOSTS; + break; + case DEGRADED: + sourceType = SourceType.DEGRADED_HOSTS; + break; + default: + throw new Error(); + } + return sourceType; + } + + static class PriorityAndAvailability { + final int priority; + final HostAvailability hostAvailability; + + PriorityAndAvailability(int priority, HostAvailability hostAvailability) { + this.priority = priority; + this.hostAvailability = hostAvailability; + } + } + + static class HostsSource { + final int priority; + final SourceType sourceType; + @Nullable + final Locality locality; + + HostsSource(int priority, SourceType sourceType) { + this(priority, sourceType, null); + } + + HostsSource(int priority, SourceType sourceType, @Nullable Locality locality) { + if (sourceType == SourceType.LOCALITY_HEALTHY_HOSTS || + sourceType == SourceType.LOCALITY_DEGRADED_HOSTS) { + checkArgument(locality != null, "Locality must be non-null for %s", sourceType); + } + this.priority = priority; + this.sourceType = sourceType; + this.locality = locality; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("priority", priority) + .add("sourceType", sourceType) + .add("locality", locality) + .toString(); + } + } + + enum SourceType { + ALL_HOSTS, + HEALTHY_HOSTS, + DEGRADED_HOSTS, + LOCALITY_HEALTHY_HOSTS, + LOCALITY_DEGRADED_HOSTS, + } + + enum HostAvailability { + HEALTHY, + DEGRADED, + } + + static class DistributeLoadState { + final int totalLoad; + final int firstAvailablePriority; + + DistributeLoadState(int totalLoad, int firstAvailablePriority) { + this.totalLoad = totalLoad; + this.firstAvailablePriority = firstAvailablePriority; + } + } +} diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/EndpointGroupUtil.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/EndpointGroupUtil.java new file mode 100644 index 00000000000..049b75b77a3 --- /dev/null +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/EndpointGroupUtil.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import com.google.common.collect.ImmutableMap; + +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.client.endpoint.EndpointSelectionStrategy; + +import io.envoyproxy.envoy.config.core.v3.Locality; + +final class EndpointGroupUtil { + + static Map> endpointsByLocality(List endpoints) { + return endpoints.stream().collect(Collectors.groupingBy(EndpointUtil::locality)); + } + + static EndpointGroup filter(List endpoints, EndpointSelectionStrategy strategy, + Predicate predicate) { + final List filteredEndpoints = + endpoints.stream().filter(predicate).collect(Collectors.toList()); + return EndpointGroup.of(strategy, filteredEndpoints); + } + + static EndpointGroup filter(EndpointGroup origEndpointGroup, Predicate predicate) { + return filter(origEndpointGroup.endpoints(), origEndpointGroup.selectionStrategy(), predicate); + } + + static Map filterByLocality(Map> endpointsMap, + EndpointSelectionStrategy strategy, + Predicate predicate) { + final ImmutableMap.Builder filteredLocality = ImmutableMap.builder(); + for (Entry> entry: endpointsMap.entrySet()) { + final EndpointGroup endpointGroup = filter(entry.getValue(), strategy, predicate); + if (endpointGroup.endpoints().isEmpty()) { + continue; + } + filteredLocality.put(entry.getKey(), endpointGroup); + } + return filteredLocality.build(); + } + + static Map filterByLocality(Map origLocality, + Predicate predicate) { + final ImmutableMap.Builder filteredLocality = ImmutableMap.builder(); + for (Entry entry: origLocality.entrySet()) { + final EndpointGroup endpointGroup = filter(entry.getValue(), predicate); + if (endpointGroup.endpoints().isEmpty()) { + continue; + } + filteredLocality.put(entry.getKey(), endpointGroup); + } + return filteredLocality.build(); + } + + private EndpointGroupUtil() {} +} diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/EndpointUtil.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/EndpointUtil.java index e2c96bd7a1e..af42ad1cc44 100644 --- a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/EndpointUtil.java +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/EndpointUtil.java @@ -16,19 +16,29 @@ package com.linecorp.armeria.xds.client.endpoint; +import java.util.concurrent.ThreadLocalRandom; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.protobuf.Duration; +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.endpoint.EndpointSelectionStrategy; import com.linecorp.armeria.client.endpoint.EndpointWeightTransition; import com.linecorp.armeria.client.endpoint.WeightRampingUpStrategyBuilder; import com.linecorp.armeria.common.annotation.Nullable; import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.CommonLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; import io.envoyproxy.envoy.config.cluster.v3.Cluster.SlowStartConfig; +import io.envoyproxy.envoy.config.core.v3.Locality; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment.Policy; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; final class EndpointUtil { @@ -92,5 +102,89 @@ private static EndpointSelectionStrategy rampingUpSelectionStrategy(SlowStartCon return builder.build(); } + static Locality locality(Endpoint endpoint) { + final LocalityLbEndpoints localityLbEndpoints = localityLbEndpoints(endpoint); + return localityLbEndpoints.hasLocality() ? localityLbEndpoints.getLocality() + : Locality.getDefaultInstance(); + } + + static CoarseHealth coarseHealth(Endpoint endpoint) { + final LbEndpoint lbEndpoint = lbEndpoint(endpoint); + switch (lbEndpoint.getHealthStatus()) { + // Assume UNKNOWN means health check wasn't performed + case UNKNOWN: + case HEALTHY: + return CoarseHealth.HEALTHY; + case DEGRADED: + return CoarseHealth.DEGRADED; + default: + return CoarseHealth.UNHEALTHY; + } + } + + static int hash(ClientRequestContext ctx) { + if (ctx.hasAttr(XdsAttributeKeys.SELECTION_HASH)) { + final Integer selectionHash = ctx.attr(XdsAttributeKeys.SELECTION_HASH); + assert selectionHash != null; + return Math.max(0, selectionHash); + } + return ThreadLocalRandom.current().nextInt(0, Integer.MAX_VALUE); + } + + static int priority(Endpoint endpoint) { + return localityLbEndpoints(endpoint).getPriority(); + } + + static boolean hasLocalityLoadBalancingWeight(Endpoint endpoint) { + return localityLbEndpoints(endpoint).hasLoadBalancingWeight(); + } + + static int localityLoadBalancingWeight(Endpoint endpoint) { + return localityLbEndpoints(endpoint).getLoadBalancingWeight().getValue(); + } + + private static LbEndpoint lbEndpoint(Endpoint endpoint) { + final LbEndpoint lbEndpoint = endpoint.attr(XdsAttributeKeys.LB_ENDPOINT_KEY); + assert lbEndpoint != null; + return lbEndpoint; + } + + private static LocalityLbEndpoints localityLbEndpoints(Endpoint endpoint) { + final LocalityLbEndpoints localityLbEndpoints = endpoint.attr( + XdsAttributeKeys.LOCALITY_LB_ENDPOINTS_KEY); + assert localityLbEndpoints != null; + return localityLbEndpoints; + } + + static int overProvisionFactor(ClusterLoadAssignment clusterLoadAssignment) { + if (!clusterLoadAssignment.hasPolicy()) { + return 140; + } + final Policy policy = clusterLoadAssignment.getPolicy(); + return policy.hasOverprovisioningFactor() ? policy.getOverprovisioningFactor().getValue() : 140; + } + + static boolean weightedPriorityHealth(ClusterLoadAssignment clusterLoadAssignment) { + return clusterLoadAssignment.hasPolicy() ? + clusterLoadAssignment.getPolicy().getWeightedPriorityHealth() : false; + } + + static int panicThreshold(Cluster cluster) { + if (!cluster.hasCommonLbConfig()) { + return 50; + } + final CommonLbConfig commonLbConfig = cluster.getCommonLbConfig(); + if (!commonLbConfig.hasHealthyPanicThreshold()) { + return 50; + } + return Math.min((int) Math.round(commonLbConfig.getHealthyPanicThreshold().getValue()), 100); + } + + enum CoarseHealth { + HEALTHY, + DEGRADED, + UNHEALTHY, + } + private EndpointUtil() {} } diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/HostSet.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/HostSet.java new file mode 100644 index 00000000000..a21993a59e0 --- /dev/null +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/HostSet.java @@ -0,0 +1,190 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import java.util.List; +import java.util.Map; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; + +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.client.endpoint.WeightedRandomDistributionSelector; + +import io.envoyproxy.envoy.config.core.v3.Locality; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; + +final class HostSet { + + private final boolean weightedPriorityHealth; + private final int overProvisioningFactor; + + private final WeightedRandomDistributionSelector healthyLocalitySelector; + private final WeightedRandomDistributionSelector degradedLocalitySelector; + + private final EndpointGroup hostsEndpointGroup; + private final EndpointGroup healthyHostsEndpointGroup; + private final Map healthyEndpointGroupPerLocality; + private final EndpointGroup degradedHostsEndpointGroup; + private final Map degradedEndpointGroupPerLocality; + + HostSet(UpdateHostsParam params, ClusterLoadAssignment clusterLoadAssignment) { + weightedPriorityHealth = EndpointUtil.weightedPriorityHealth(clusterLoadAssignment); + overProvisioningFactor = EndpointUtil.overProvisionFactor(clusterLoadAssignment); + + healthyLocalitySelector = rebuildLocalityScheduler( + params.healthyHostsPerLocality(), params.hostsPerLocality(), + params.localityWeightsMap(), overProvisioningFactor); + degradedLocalitySelector = rebuildLocalityScheduler( + params.degradedHostsPerLocality(), params.hostsPerLocality(), + params.localityWeightsMap(), overProvisioningFactor); + + hostsEndpointGroup = params.hosts(); + healthyHostsEndpointGroup = params.healthyHosts(); + degradedHostsEndpointGroup = params.degradedHosts(); + healthyEndpointGroupPerLocality = params.healthyHostsPerLocality(); + degradedEndpointGroupPerLocality = params.degradedHostsPerLocality(); + } + + List hosts() { + return hostsEndpointGroup.endpoints(); + } + + EndpointGroup hostsEndpointGroup() { + return hostsEndpointGroup; + } + + List healthyHosts() { + return healthyHostsEndpointGroup.endpoints(); + } + + EndpointGroup healthyHostsEndpointGroup() { + return healthyHostsEndpointGroup; + } + + Map healthyEndpointGroupPerLocality() { + return healthyEndpointGroupPerLocality; + } + + List degradedHosts() { + return degradedHostsEndpointGroup.endpoints(); + } + + EndpointGroup degradedHostsEndpointGroup() { + return degradedHostsEndpointGroup; + } + + Map degradedEndpointGroupPerLocality() { + return degradedEndpointGroupPerLocality; + } + + boolean weightedPriorityHealth() { + return weightedPriorityHealth; + } + + int overProvisioningFactor() { + return overProvisioningFactor; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("hostsEndpointGroup", hostsEndpointGroup) + .add("healthyHostsEndpointGroup", healthyHostsEndpointGroup) + .add("healthyEndpointGroupPerLocality", healthyEndpointGroupPerLocality) + .add("degradedHostsEndpointGroup", degradedHostsEndpointGroup) + .add("degradedEndpointGroupPerLocality", degradedEndpointGroupPerLocality) + .add("weightedPriorityHealth", weightedPriorityHealth) + .add("overProvisioningFactor", overProvisioningFactor) + .toString(); + } + + private static WeightedRandomDistributionSelector rebuildLocalityScheduler( + Map eligibleHostsPerLocality, + Map allHostsPerLocality, + Map localityWeightsMap, + int overProvisioningFactor) { + final ImmutableList.Builder localityWeightsBuilder = ImmutableList.builder(); + for (Locality locality : allHostsPerLocality.keySet()) { + final double effectiveWeight = + effectiveLocalityWeight(locality, eligibleHostsPerLocality, allHostsPerLocality, + localityWeightsMap, overProvisioningFactor); + if (effectiveWeight > 0) { + localityWeightsBuilder.add(new LocalityEntry(locality, effectiveWeight)); + } + } + return new WeightedRandomDistributionSelector<>(localityWeightsBuilder.build()); + } + + static double effectiveLocalityWeight(Locality locality, + Map eligibleHostsPerLocality, + Map allHostsPerLocality, + Map localityWeightsMap, + int overProvisioningFactor) { + final EndpointGroup localityEligibleHosts = + eligibleHostsPerLocality.getOrDefault(locality, EndpointGroup.of()); + final int hostCount = allHostsPerLocality.getOrDefault(locality, EndpointGroup.of()).endpoints().size(); + if (hostCount == 0) { + return 0; + } + // We compute the availability of a locality via: + // (overProvisioningFactor) * (# healthy/degraded of hosts) / (# total hosts) + // https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/upstream/load_balancing/locality_weight.html + final double localityAvailabilityRatio = (double) localityEligibleHosts.endpoints().size() / hostCount; + final int weight = localityWeightsMap.getOrDefault(locality, 0); + final double effectiveLocalityAvailabilityRatio = + Math.min(1.0, (overProvisioningFactor / 100.0) * localityAvailabilityRatio); + return weight * effectiveLocalityAvailabilityRatio; + } + + @Nullable + Locality chooseDegradedLocality() { + final LocalityEntry localityEntry = degradedLocalitySelector.select(); + if (localityEntry == null) { + return null; + } + return localityEntry.locality; + } + + @Nullable + Locality chooseHealthyLocality() { + final LocalityEntry localityEntry = healthyLocalitySelector.select(); + if (localityEntry == null) { + return null; + } + return localityEntry.locality; + } + + static class LocalityEntry extends WeightedRandomDistributionSelector.AbstractEntry { + + private final Locality locality; + private final int weight; + + LocalityEntry(Locality locality, double weight) { + this.locality = locality; + this.weight = Ints.saturatedCast(Math.round(weight)); + } + + @Override + public int weight() { + return weight; + } + } +} diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/LoadBalancer.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/LoadBalancer.java index 1b382a33764..8c55870c63a 100644 --- a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/LoadBalancer.java +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/LoadBalancer.java @@ -24,6 +24,4 @@ interface LoadBalancer { @Nullable Endpoint selectNow(ClientRequestContext ctx); - - void prioritySetUpdated(PrioritySet prioritySet); } diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PrioritySet.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PrioritySet.java index c3944fb7bd6..a62ca703187 100644 --- a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PrioritySet.java +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PrioritySet.java @@ -17,26 +17,126 @@ package com.linecorp.armeria.xds.client.endpoint; import java.util.List; +import java.util.Map; +import java.util.SortedSet; +import java.util.TreeSet; -import com.google.common.collect.ImmutableList; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.xds.ClusterSnapshot; +import com.linecorp.armeria.xds.EndpointSnapshot; + +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.CommonLbConfig; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; final class PrioritySet { - private final List endpoints; + private final Map hostSets; + private final SortedSet priorities; + private final List origEndpoints; private final ClusterSnapshot clusterSnapshot; + private final Cluster cluster; + private final int panicThreshold; - PrioritySet(List endpoints, ClusterSnapshot clusterSnapshot) { - this.endpoints = ImmutableList.copyOf(endpoints); + PrioritySet(ClusterSnapshot clusterSnapshot, Map hostSets, List origEndpoints) { this.clusterSnapshot = clusterSnapshot; + cluster = clusterSnapshot.xdsResource().resource(); + panicThreshold = EndpointUtil.panicThreshold(cluster); + this.hostSets = hostSets; + priorities = new TreeSet<>(hostSets.keySet()); + this.origEndpoints = origEndpoints; + } + + boolean failTrafficOnPanic() { + final CommonLbConfig commonLbConfig = commonLbConfig(); + if (commonLbConfig == null) { + return false; + } + if (!commonLbConfig.hasZoneAwareLbConfig()) { + return false; + } + return commonLbConfig.getZoneAwareLbConfig().getFailTrafficOnPanic(); + } + + @Nullable + private CommonLbConfig commonLbConfig() { + if (!cluster.hasCommonLbConfig()) { + return null; + } + return cluster.getCommonLbConfig(); + } + + boolean localityWeightedBalancing() { + final CommonLbConfig commonLbConfig = commonLbConfig(); + if (commonLbConfig == null) { + return false; + } + return commonLbConfig.hasLocalityWeightedLbConfig(); + } + + int panicThreshold() { + return panicThreshold; } + SortedSet priorities() { + return priorities; + } + + Map hostSets() { + return hostSets; + } + + /** + * Returns the original list of endpoints this priority set was created with. + * This method acts as a temporary measure to keep backwards compatibility with + * {@link SubsetLoadBalancer}. It will be removed once {@link SubsetLoadBalancer} + * is fully implemented. + */ List endpoints() { - return endpoints; + return origEndpoints; + } + + Cluster cluster() { + return cluster; } ClusterSnapshot clusterSnapshot() { return clusterSnapshot; } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("hostSets", hostSets) + .add("cluster", cluster) + .toString(); + } + + static final class PrioritySetBuilder { + + private final ImmutableMap.Builder hostSetsBuilder = ImmutableMap.builder(); + private final ClusterSnapshot clusterSnapshot; + private final List origEndpoints; + private final ClusterLoadAssignment clusterLoadAssignment; + + PrioritySetBuilder(ClusterSnapshot clusterSnapshot, List origEndpoints) { + this.clusterSnapshot = clusterSnapshot; + this.origEndpoints = origEndpoints; + final EndpointSnapshot endpointSnapshot = clusterSnapshot.endpointSnapshot(); + assert endpointSnapshot != null; + clusterLoadAssignment = endpointSnapshot.xdsResource().resource(); + } + + void createHostSet(int priority, UpdateHostsParam params) { + final HostSet hostSet = new HostSet(params, clusterLoadAssignment); + hostSetsBuilder.put(priority, hostSet); + } + + PrioritySet build() { + return new PrioritySet(clusterSnapshot, hostSetsBuilder.build(), origEndpoints); + } + } } diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PriorityState.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PriorityState.java new file mode 100644 index 00000000000..4194d7739ff --- /dev/null +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PriorityState.java @@ -0,0 +1,73 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import static com.linecorp.armeria.xds.client.endpoint.EndpointGroupUtil.endpointsByLocality; +import static com.linecorp.armeria.xds.client.endpoint.EndpointUtil.locality; +import static com.linecorp.armeria.xds.client.endpoint.EndpointUtil.localityLoadBalancingWeight; +import static com.linecorp.armeria.xds.client.endpoint.EndpointUtil.selectionStrategy; + +import java.util.List; +import java.util.Map; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.xds.ClusterSnapshot; + +import io.envoyproxy.envoy.config.core.v3.Locality; + +final class PriorityState { + private final UpdateHostsParam param; + + PriorityState(List hosts, Map localityWeightsMap, + ClusterSnapshot clusterSnapshot) { + final Map> endpointsPerLocality = endpointsByLocality(hosts); + param = new UpdateHostsParam(hosts, endpointsPerLocality, localityWeightsMap, + selectionStrategy(clusterSnapshot.xdsResource().resource())); + } + + UpdateHostsParam param() { + return param; + } + + static final class PriorityStateBuilder { + + private final ImmutableList.Builder hostsBuilder = ImmutableList.builder(); + private final ImmutableMap.Builder localityWeightsBuilder = + ImmutableMap.builder(); + private final ClusterSnapshot clusterSnapshot; + + PriorityStateBuilder(ClusterSnapshot clusterSnapshot) { + this.clusterSnapshot = clusterSnapshot; + } + + void addEndpoint(Endpoint endpoint) { + hostsBuilder.add(endpoint); + if (locality(endpoint) != Locality.getDefaultInstance() && + EndpointUtil.hasLocalityLoadBalancingWeight(endpoint)) { + localityWeightsBuilder.put(locality(endpoint), localityLoadBalancingWeight(endpoint)); + } + } + + PriorityState build() { + return new PriorityState(hostsBuilder.build(), localityWeightsBuilder.buildKeepingLast(), + clusterSnapshot); + } + } +} diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PriorityStateManager.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PriorityStateManager.java new file mode 100644 index 00000000000..30a94e0123f --- /dev/null +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/PriorityStateManager.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import static com.linecorp.armeria.xds.client.endpoint.EndpointUtil.priority; + +import java.util.List; +import java.util.Map.Entry; +import java.util.SortedMap; + +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.xds.ClusterSnapshot; + +import it.unimi.dsi.fastutil.ints.Int2ReferenceAVLTreeMap; + +final class PriorityStateManager { + + private final SortedMap priorityStateMap = + new Int2ReferenceAVLTreeMap<>(); + private final ClusterSnapshot clusterSnapshot; + private final List origEndpoints; + + PriorityStateManager(ClusterSnapshot clusterSnapshot, List origEndpoints) { + this.clusterSnapshot = clusterSnapshot; + this.origEndpoints = origEndpoints; + for (Endpoint endpoint : origEndpoints) { + registerEndpoint(endpoint); + } + } + + private void registerEndpoint(Endpoint endpoint) { + final int priority = priority(endpoint); + PriorityState.PriorityStateBuilder builder = priorityStateMap.get(priority); + if (builder == null) { + builder = priorityStateMap.computeIfAbsent( + priority(endpoint), + ignored -> new PriorityState.PriorityStateBuilder(clusterSnapshot)); + } + builder.addEndpoint(endpoint); + } + + PrioritySet build() { + final PrioritySet.PrioritySetBuilder prioritySetBuilder = + new PrioritySet.PrioritySetBuilder(clusterSnapshot, origEndpoints); + for (Entry entry: priorityStateMap.entrySet()) { + final Integer priority = entry.getKey(); + final PriorityState priorityState = entry.getValue().build(); + prioritySetBuilder.createHostSet(priority, priorityState.param()); + } + return prioritySetBuilder.build(); + } +} diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/SubsetLoadBalancer.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/SubsetLoadBalancer.java index 74bbfadd07e..3aa92b4534a 100644 --- a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/SubsetLoadBalancer.java +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/SubsetLoadBalancer.java @@ -43,24 +43,18 @@ final class SubsetLoadBalancer implements LoadBalancer { private static final Logger logger = LoggerFactory.getLogger(SubsetLoadBalancer.class); - @Nullable - private volatile EndpointGroup endpointGroup; + private final EndpointGroup endpointGroup; + + SubsetLoadBalancer(PrioritySet prioritySet) { + endpointGroup = createEndpointGroup(prioritySet); + } @Override @Nullable public Endpoint selectNow(ClientRequestContext ctx) { - final EndpointGroup endpointGroup = this.endpointGroup; - if (endpointGroup == null) { - return null; - } return endpointGroup.selectNow(ctx); } - @Override - public void prioritySetUpdated(PrioritySet prioritySet) { - endpointGroup = createEndpointGroup(prioritySet); - } - private static EndpointGroup createEndpointGroup(PrioritySet prioritySet) { final ClusterSnapshot clusterSnapshot = prioritySet.clusterSnapshot(); final Struct filterMetadata = filterMetadata(clusterSnapshot); diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/UpdateHostsParam.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/UpdateHostsParam.java new file mode 100644 index 00000000000..77f62ed639d --- /dev/null +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/UpdateHostsParam.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import static com.linecorp.armeria.xds.client.endpoint.EndpointGroupUtil.filter; +import static com.linecorp.armeria.xds.client.endpoint.EndpointGroupUtil.filterByLocality; +import static com.linecorp.armeria.xds.client.endpoint.EndpointUtil.coarseHealth; + +import java.util.List; +import java.util.Map; + +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.client.endpoint.EndpointSelectionStrategy; +import com.linecorp.armeria.xds.client.endpoint.EndpointUtil.CoarseHealth; + +import io.envoyproxy.envoy.config.core.v3.Locality; + +/** + * Hosts per partition. + */ +final class UpdateHostsParam { + + private final EndpointGroup hosts; + private final EndpointGroup healthyHosts; + private final EndpointGroup degradedHosts; + private final Map hostsPerLocality; + private final Map healthyHostsPerLocality; + private final Map degradedHostsPerLocality; + private final Map localityWeightsMap; + + UpdateHostsParam(List endpoints, + Map> endpointsPerLocality, + Map localityWeightsMap, + EndpointSelectionStrategy strategy) { + hosts = EndpointGroup.of(strategy, endpoints); + hostsPerLocality = filterByLocality(endpointsPerLocality, strategy, ignored -> true); + healthyHosts = filter(endpoints, strategy, + endpoint -> coarseHealth(endpoint) == CoarseHealth.HEALTHY); + healthyHostsPerLocality = filterByLocality(endpointsPerLocality, strategy, + endpoint -> coarseHealth(endpoint) == CoarseHealth.HEALTHY); + degradedHosts = filter(endpoints, strategy, + endpoint -> coarseHealth(endpoint) == CoarseHealth.DEGRADED); + degradedHostsPerLocality = filterByLocality( + endpointsPerLocality, strategy, + endpoint -> coarseHealth(endpoint) == CoarseHealth.DEGRADED); + this.localityWeightsMap = localityWeightsMap; + } + + EndpointGroup hosts() { + return hosts; + } + + Map hostsPerLocality() { + return hostsPerLocality; + } + + EndpointGroup healthyHosts() { + return healthyHosts; + } + + Map healthyHostsPerLocality() { + return healthyHostsPerLocality; + } + + EndpointGroup degradedHosts() { + return degradedHosts; + } + + Map degradedHostsPerLocality() { + return degradedHostsPerLocality; + } + + Map localityWeightsMap() { + return localityWeightsMap; + } +} diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributeAssigningEndpointGroup.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributeAssigningEndpointGroup.java index 8c9316e9986..5b3b27d2192 100644 --- a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributeAssigningEndpointGroup.java +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributeAssigningEndpointGroup.java @@ -16,8 +16,8 @@ package com.linecorp.armeria.xds.client.endpoint; -import static com.linecorp.armeria.xds.client.endpoint.XdsAttributesKeys.LB_ENDPOINT_KEY; -import static com.linecorp.armeria.xds.client.endpoint.XdsAttributesKeys.LOCALITY_LB_ENDPOINTS_KEY; +import static com.linecorp.armeria.xds.client.endpoint.XdsAttributeKeys.LB_ENDPOINT_KEY; +import static com.linecorp.armeria.xds.client.endpoint.XdsAttributeKeys.LOCALITY_LB_ENDPOINTS_KEY; import java.util.List; import java.util.function.Consumer; diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributesKeys.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributeKeys.java similarity index 73% rename from xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributesKeys.java rename to xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributeKeys.java index 9b2a531c861..ee4c507590c 100644 --- a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributesKeys.java +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsAttributeKeys.java @@ -20,12 +20,14 @@ import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; import io.netty.util.AttributeKey; -final class XdsAttributesKeys { +final class XdsAttributeKeys { static final AttributeKey LB_ENDPOINT_KEY = - AttributeKey.valueOf(XdsAttributesKeys.class, "LB_ENDPOINT_KEY"); + AttributeKey.valueOf(XdsAttributeKeys.class, "LB_ENDPOINT_KEY"); static final AttributeKey LOCALITY_LB_ENDPOINTS_KEY = - AttributeKey.valueOf(XdsAttributesKeys.class, "LOCALITY_LB_ENDPOINTS_KEY"); + AttributeKey.valueOf(XdsAttributeKeys.class, "LOCALITY_LB_ENDPOINTS_KEY"); + static final AttributeKey SELECTION_HASH = + AttributeKey.valueOf(XdsAttributeKeys.class, "SELECTION_HASH"); - private XdsAttributesKeys() {} + private XdsAttributeKeys() {} } diff --git a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsEndpointUtil.java b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsEndpointUtil.java index 3b908bb1b42..34692f97cdd 100644 --- a/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsEndpointUtil.java +++ b/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/XdsEndpointUtil.java @@ -57,7 +57,7 @@ static List convertEndpoints(List endpoints, Struct filterMe checkArgument(filterMetadata.getFieldsCount() > 0, "filterMetadata.getFieldsCount(): %s (expected: > 0)", filterMetadata.getFieldsCount()); final Predicate lbEndpointPredicate = endpoint -> { - final LbEndpoint lbEndpoint = endpoint.attr(XdsAttributesKeys.LB_ENDPOINT_KEY); + final LbEndpoint lbEndpoint = endpoint.attr(XdsAttributeKeys.LB_ENDPOINT_KEY); assert lbEndpoint != null; final Struct endpointMetadata = lbEndpoint.getMetadata().getFilterMetadataOrDefault( SUBSET_LOAD_BALANCING_FILTER_NAME, Struct.getDefaultInstance()); @@ -201,13 +201,13 @@ private static Endpoint convertToEndpoint(LocalityLbEndpoints localityLbEndpoint if (!Strings.isNullOrEmpty(hostname)) { endpoint = Endpoint.of(hostname) .withIpAddr(socketAddress.getAddress()) - .withAttr(XdsAttributesKeys.LB_ENDPOINT_KEY, lbEndpoint) - .withAttr(XdsAttributesKeys.LOCALITY_LB_ENDPOINTS_KEY, localityLbEndpoints) + .withAttr(XdsAttributeKeys.LB_ENDPOINT_KEY, lbEndpoint) + .withAttr(XdsAttributeKeys.LOCALITY_LB_ENDPOINTS_KEY, localityLbEndpoints) .withWeight(weight); } else { endpoint = Endpoint.of(socketAddress.getAddress()) - .withAttr(XdsAttributesKeys.LB_ENDPOINT_KEY, lbEndpoint) - .withAttr(XdsAttributesKeys.LOCALITY_LB_ENDPOINTS_KEY, localityLbEndpoints) + .withAttr(XdsAttributeKeys.LB_ENDPOINT_KEY, lbEndpoint) + .withAttr(XdsAttributeKeys.LOCALITY_LB_ENDPOINTS_KEY, localityLbEndpoints) .withWeight(weight); } if (socketAddress.hasPortValue()) { diff --git a/xds/src/test/java/com/linecorp/armeria/xds/XdsTestResources.java b/xds/src/test/java/com/linecorp/armeria/xds/XdsTestResources.java index 8dca7588ee3..0355afbe801 100644 --- a/xds/src/test/java/com/linecorp/armeria/xds/XdsTestResources.java +++ b/xds/src/test/java/com/linecorp/armeria/xds/XdsTestResources.java @@ -66,6 +66,7 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; +import io.envoyproxy.envoy.type.v3.Percent; public final class XdsTestResources { @@ -74,7 +75,8 @@ public final class XdsTestResources { private XdsTestResources() {} public static LbEndpoint endpoint(String address, int port) { - return endpoint(address, port, Metadata.getDefaultInstance()); + return endpoint(address, port, Metadata.getDefaultInstance(), 1, + HealthStatus.HEALTHY); } public static LbEndpoint endpoint(String address, int port, int weight) { @@ -82,6 +84,15 @@ public static LbEndpoint endpoint(String address, int port, int weight) { HealthStatus.HEALTHY); } + public static LbEndpoint endpoint(String address, int port, HealthStatus healthStatus) { + return endpoint(address, port, Metadata.getDefaultInstance(), 1, healthStatus); + } + + public static LbEndpoint endpoint(String address, int port, HealthStatus healthStatus, + int weight) { + return endpoint(address, port, Metadata.getDefaultInstance(), weight, healthStatus); + } + public static LbEndpoint endpoint(String address, int port, Metadata metadata) { return endpoint(address, port, metadata, 1, HealthStatus.HEALTHY); } @@ -104,6 +115,16 @@ public static LbEndpoint endpoint(String address, int port, Metadata metadata, i .build()).build(); } + public static Locality locality(String region) { + return Locality.newBuilder() + .setRegion(region) + .build(); + } + + public static Percent percent(int percent) { + return Percent.newBuilder().setValue(percent).build(); + } + public static ClusterLoadAssignment loadAssignment(String clusterName, URI uri) { return loadAssignment(clusterName, uri.getHost(), uri.getPort()); } @@ -385,10 +406,30 @@ public static Bootstrap staticBootstrap(Listener listener, Cluster cluster) { public static LocalityLbEndpoints localityLbEndpoints(Locality locality, Collection endpoints) { - return LocalityLbEndpoints.newBuilder() - .addAllLbEndpoints(endpoints) - .setLocality(locality) - .build(); + return localityLbEndpoints(locality, endpoints, -1, 0); + } + + public static LocalityLbEndpoints localityLbEndpoints(Locality locality, + Collection endpoints, + Integer priority) { + return localityLbEndpoints(locality, endpoints, priority, 0); + } + + public static LocalityLbEndpoints localityLbEndpoints(Locality locality, + Collection endpoints, + int priority, + int loadBalancingWeight) { + final LocalityLbEndpoints.Builder builder = LocalityLbEndpoints.newBuilder() + .addAllLbEndpoints(endpoints) + .setLocality(locality); + if (priority >= 0) { + builder.setPriority(priority); + } + if (loadBalancingWeight > 0) { + builder.setLoadBalancingWeight(UInt32Value.of(loadBalancingWeight)); + } + + return builder.build(); } public static LocalityLbEndpoints localityLbEndpoints(Locality locality, LbEndpoint... endpoints) { diff --git a/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/LocalityTest.java b/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/LocalityTest.java new file mode 100644 index 00000000000..68e903ac739 --- /dev/null +++ b/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/LocalityTest.java @@ -0,0 +1,170 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import static com.linecorp.armeria.xds.XdsTestResources.createStaticCluster; +import static com.linecorp.armeria.xds.XdsTestResources.endpoint; +import static com.linecorp.armeria.xds.XdsTestResources.locality; +import static com.linecorp.armeria.xds.XdsTestResources.localityLbEndpoints; +import static com.linecorp.armeria.xds.XdsTestResources.staticBootstrap; +import static com.linecorp.armeria.xds.XdsTestResources.staticResourceListener; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.xds.ListenerRoot; +import com.linecorp.armeria.xds.XdsBootstrap; + +import io.envoyproxy.envoy.config.bootstrap.v3.Bootstrap; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.CommonLbConfig; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.CommonLbConfig.Builder; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.CommonLbConfig.LocalityWeightedLbConfig; +import io.envoyproxy.envoy.config.core.v3.HealthStatus; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.listener.v3.Listener; + +class LocalityTest { + + private static final Builder LOCALITY_LB_CONFIG = + CommonLbConfig.newBuilder() + .setLocalityWeightedLbConfig(LocalityWeightedLbConfig.getDefaultInstance()); + + @Test + void basicCase() { + final Listener listener = staticResourceListener(); + + final List lbEndpointsA = + ImmutableList.of(endpoint("127.0.0.1", 8080, 1000)); + final List lbEndpointsB = + ImmutableList.of(endpoint("127.0.0.1", 8081, 1)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(locality("regionA"), lbEndpointsA, 0, 9)) + .addEndpoints(localityLbEndpoints(locality("regionB"), lbEndpointsB, 0, 1)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder().setCommonLbConfig(LOCALITY_LB_CONFIG).build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + final Map countsMap = new HashMap<>(); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + // Regardless of the endpoint weight, the locality weight will be used + // to determine which endpoint group to use + final ClientRequestContext ctx = + ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + for (int i = 0; i < 10; i++) { + final Endpoint selected = endpointGroup.selectNow(ctx); + assertThat(selected).isNotNull(); + countsMap.compute(selected, (k, v) -> v == null ? 1 : v + 1); + } + assertThat(countsMap) + .containsExactlyInAnyOrderEntriesOf( + ImmutableMap.of(Endpoint.of("127.0.0.1", 8080).withWeight(1000), 9, + Endpoint.of("127.0.0.1", 8081).withWeight(1), 1)); + } + } + + @Test + void emptyLocality() { + final Listener listener = staticResourceListener(); + + final List lbEndpointsA = ImmutableList.of(); + final List lbEndpointsB = + ImmutableList.of(endpoint("127.0.0.1", 8081), + endpoint("127.0.0.1", 8081)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment.newBuilder() + .addEndpoints(localityLbEndpoints(locality("regionA"), lbEndpointsA)) + .addEndpoints(localityLbEndpoints(locality("regionB"), lbEndpointsB)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder().setCommonLbConfig(LOCALITY_LB_CONFIG).build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + // regionA won't be selected at all since it is empty + final ClientRequestContext ctx = + ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8081)); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8081)); + } + } + + @Test + void multiPriorityAndLocality() { + final Listener listener = staticResourceListener(); + + final List lbEndpointsA = + ImmutableList.of(endpoint("127.0.0.1", 8080, HealthStatus.HEALTHY)); + // the unhealthy endpoint won't be selected due to priority selection + final List lbEndpointsB = + ImmutableList.of(endpoint("127.0.0.1", 8081, HealthStatus.UNHEALTHY)); + final List lbEndpointsC = + ImmutableList.of(endpoint("127.0.0.1", 8082, HealthStatus.HEALTHY)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(locality("regionA"), lbEndpointsA, 0, 9)) + .addEndpoints(localityLbEndpoints(locality("regionB"), lbEndpointsB, 0, 1000)) + .addEndpoints(localityLbEndpoints(locality("regionC"), lbEndpointsC, 0, 1)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder().setCommonLbConfig(LOCALITY_LB_CONFIG).build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + + final ClientRequestContext ctx = + ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + final Map countsMap = new HashMap<>(); + for (int i = 0; i < 10; i++) { + final Endpoint selected = endpointGroup.selectNow(ctx); + assertThat(selected).isNotNull(); + countsMap.compute(selected, (k, v) -> v == null ? 1 : v + 1); + } + assertThat(countsMap) + .containsExactlyInAnyOrderEntriesOf( + ImmutableMap.of(Endpoint.of("127.0.0.1", 8080), 9, + Endpoint.of("127.0.0.1", 8082), 1)); + } + } +} diff --git a/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/PriorityTest.java b/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/PriorityTest.java new file mode 100644 index 00000000000..9921e008fc7 --- /dev/null +++ b/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/PriorityTest.java @@ -0,0 +1,388 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.xds.client.endpoint; + +import static com.linecorp.armeria.xds.XdsTestResources.createStaticCluster; +import static com.linecorp.armeria.xds.XdsTestResources.endpoint; +import static com.linecorp.armeria.xds.XdsTestResources.localityLbEndpoints; +import static com.linecorp.armeria.xds.XdsTestResources.percent; +import static com.linecorp.armeria.xds.XdsTestResources.staticBootstrap; +import static com.linecorp.armeria.xds.XdsTestResources.staticResourceListener; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.UInt32Value; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.xds.ListenerRoot; +import com.linecorp.armeria.xds.XdsBootstrap; + +import io.envoyproxy.envoy.config.bootstrap.v3.Bootstrap; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.CommonLbConfig; +import io.envoyproxy.envoy.config.core.v3.HealthStatus; +import io.envoyproxy.envoy.config.core.v3.Locality; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment.Policy; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.envoyproxy.envoy.type.v3.Percent; + +class PriorityTest { + + @Test + void basicCase() { + final Listener listener = staticResourceListener(); + + final List lbEndpoints = + ImmutableList.of(endpoint("127.0.0.1", 8080), + endpoint("127.0.0.1", 8081), + endpoint("127.0.0.1", 8082)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment.newBuilder() + .addEndpoints(localityLbEndpoints( + Locality.getDefaultInstance(), lbEndpoints)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + final ClientRequestContext ctx = ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8080)); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8081)); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8082)); + } + } + + @Test + void differentWeights() { + final Listener listener = staticResourceListener(); + + final List lbEndpoints = + ImmutableList.of(endpoint("127.0.0.1", 8080, 1), + endpoint("127.0.0.1", 8081, 1), + endpoint("127.0.0.1", 8082, 2)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment.newBuilder() + .addEndpoints(localityLbEndpoints( + Locality.getDefaultInstance(), lbEndpoints)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + final ClientRequestContext ctx = ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8080)); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8081)); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8082)); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8082)); + } + } + + @Test + void differentPriorities() { + final Listener listener = staticResourceListener(); + + final List lbEndpoints0 = + ImmutableList.of(endpoint("127.0.0.1", 8080, HealthStatus.HEALTHY), + endpoint("127.0.0.1", 8081, HealthStatus.DEGRADED)); + final List lbEndpoints1 = + ImmutableList.of(endpoint("127.0.0.1", 8082, HealthStatus.HEALTHY), + endpoint("127.0.0.1", 8083, HealthStatus.DEGRADED)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints0, 0)) + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints1, 1)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + final ClientRequestContext ctx = + ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + + // default overprovisioning factor (140) * 0.5 = 70 will be routed + // to healthy endpoints for priority 0 + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 0); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8080)); + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 68); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8080)); + + // 100 - 70 (priority 0) = 30 will be routed to healthy endpoints for priority 1 + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 70); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8082)); + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 99); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8082)); + } + } + + @Test + void degradedEndpoints() { + final Listener listener = staticResourceListener(); + + final List lbEndpoints0 = + ImmutableList.of(endpoint("127.0.0.1", 8080, HealthStatus.HEALTHY, 1), + endpoint("127.0.0.1", 8081, HealthStatus.UNHEALTHY, 9)); + final List lbEndpoints1 = + ImmutableList.of(endpoint("127.0.0.1", 8082, HealthStatus.HEALTHY, 1), + endpoint("127.0.0.1", 8083, HealthStatus.DEGRADED, 9)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints0, 0)) + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints1, 1)) + // set overprovisioning factor to 100 for simpler calculation + .setPolicy(Policy.newBuilder() + .setOverprovisioningFactor(UInt32Value.of(100)) + .setWeightedPriorityHealth(true)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder() + .setCommonLbConfig(CommonLbConfig.newBuilder() + .setHealthyPanicThreshold(Percent.newBuilder() + .setValue(0))) + .build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + final ClientRequestContext ctx = + ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + + // 0 ~ 9 for priority 0 HEALTHY + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 0); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8080)); + + // 10 ~ 19 for priority 1 HEALTHY + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 10); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8082)); + + // 20 ~ 99 for priority 1 DEGRADED + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 20); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8083)); + } + } + + @Test + void noHosts() { + final Listener listener = staticResourceListener(); + final List lbEndpoints0 = ImmutableList.of(); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints0, 0)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder() + .setCommonLbConfig(CommonLbConfig.newBuilder() + .setHealthyPanicThreshold(Percent.newBuilder() + .setValue(50))) + .build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot, true); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + + final ClientRequestContext ctx = + ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + await().pollDelay(3, TimeUnit.SECONDS) + .untilAsserted(() -> assertThat(endpointGroup.selectNow(ctx)).isNull()); + } + } + + @Test + void partialPanic() { + final Listener listener = staticResourceListener(); + + // there are no healthy endpoints in priority0 + final List lbEndpoints0 = + ImmutableList.of(endpoint("127.0.0.1", 8080, HealthStatus.UNHEALTHY), + endpoint("127.0.0.1", 8081, HealthStatus.UNHEALTHY), + endpoint("127.0.0.1", 8082, HealthStatus.UNHEALTHY)); + final List lbEndpoints1 = + ImmutableList.of(endpoint("127.0.0.1", 8083, HealthStatus.HEALTHY)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints0, 0)) + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints1, 1)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder().setCommonLbConfig(CommonLbConfig.newBuilder() + .setHealthyPanicThreshold(Percent.newBuilder() + .setValue(50))) + .build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + + final ClientRequestContext ctx = + ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 0); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8083)); + ctx.setAttr(XdsAttributeKeys.SELECTION_HASH, 99); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8083)); + } + } + + @Test + void totalPanic() { + final Listener listener = staticResourceListener(); + + // 0.33 (healthy) * 140 (overprovisioning factor) < 50 (healthyPanicThreshold) + final List lbEndpoints0 = + ImmutableList.of(endpoint("127.0.0.1", 8080, HealthStatus.HEALTHY), + endpoint("127.0.0.1", 8081, HealthStatus.UNHEALTHY), + endpoint("127.0.0.1", 8082, HealthStatus.UNHEALTHY)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints0, 0)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder().setCommonLbConfig(CommonLbConfig.newBuilder() + .setHealthyPanicThreshold(Percent.newBuilder() + .setValue(50))) + .build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + + // When in panic mode, all endpoints are selected regardless of health status + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + final ClientRequestContext ctx = ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8080)); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8081)); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(Endpoint.of("127.0.0.1", 8082)); + } + } + + @Test + void onlyUnhealthyPanicDisabled() { + final Listener listener = staticResourceListener(); + + final List lbEndpoints0 = + ImmutableList.of(endpoint("127.0.0.1", 8080, HealthStatus.UNHEALTHY), + endpoint("127.0.0.1", 8081, HealthStatus.UNHEALTHY), + endpoint("127.0.0.1", 8082, HealthStatus.UNHEALTHY)); + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints0)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder() + .setCommonLbConfig(CommonLbConfig.newBuilder().setHealthyPanicThreshold(percent(0))) + .build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + + // When in panic mode, all endpoints are selected regardless of health status + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + final ClientRequestContext ctx = ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + assertThat(endpointGroup.selectNow(ctx)).isNull(); + assertThat(endpointGroup.selectNow(ctx)).isNull(); + assertThat(endpointGroup.selectNow(ctx)).isNull(); + } + } + + private static Stream healthyLoadZeroArgs() { + return Stream.of( + // panic mode routes traffic to all endpoints + Arguments.of(51, Endpoint.of("127.0.0.1", 8080), Endpoint.of("127.0.0.1", 8081)), + // non-panic mode doesn't route traffic + Arguments.of(49, null, null) + ); + } + + @ParameterizedTest + @MethodSource("healthyLoadZeroArgs") + void healthyLoadZero(int healthyPanicThreshold, @Nullable Endpoint endpoint1, + @Nullable Endpoint endpoint2) { + final Listener listener = staticResourceListener(); + final List lbEndpoints0 = + ImmutableList.of(endpoint("127.0.0.1", 8080, HealthStatus.HEALTHY, 1), + endpoint("127.0.0.1", 8081, HealthStatus.UNHEALTHY, 10000)); + + final ClusterLoadAssignment loadAssignment = + ClusterLoadAssignment + .newBuilder() + .addEndpoints(localityLbEndpoints(Locality.getDefaultInstance(), lbEndpoints0)) + .setPolicy(Policy.newBuilder() + .setWeightedPriorityHealth(true)) + .build(); + final Cluster cluster = createStaticCluster("cluster", loadAssignment) + .toBuilder() + .setCommonLbConfig(CommonLbConfig.newBuilder() + .setHealthyPanicThreshold(percent(healthyPanicThreshold))) + .build(); + + final Bootstrap bootstrap = staticBootstrap(listener, cluster); + try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap)) { + final ListenerRoot listenerRoot = xdsBootstrap.listenerRoot("listener"); + final EndpointGroup endpointGroup = XdsEndpointGroup.of(listenerRoot); + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + + // When in panic mode, all endpoints are selected regardless of health status + await().untilAsserted(() -> assertThat(endpointGroup.whenReady()).isDone()); + final ClientRequestContext ctx = ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(endpoint1); + assertThat(endpointGroup.selectNow(ctx)).isEqualTo(endpoint2); + } + } +} diff --git a/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/RampingUpTest.java b/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/RampingUpTest.java index b87f39a9dc6..21c4b7bb3d6 100644 --- a/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/RampingUpTest.java +++ b/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/RampingUpTest.java @@ -41,6 +41,7 @@ import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.CommonPools; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; @@ -176,8 +177,8 @@ void checkEndpointsAreRampedUp() throws Exception { private static Set selectEndpoints(int weight, EndpointGroup xdsEndpointGroup) { final Set selectedEndpoints = new HashSet<>(); for (int i = 0; i < weight * 2; i++) { - selectedEndpoints.add(xdsEndpointGroup.selectNow(ctx())); - selectedEndpoints.add(xdsEndpointGroup.selectNow(ctx())); + selectedEndpoints.add(xdsEndpointGroup.select(ctx(), CommonPools.workerGroup()).join()); + selectedEndpoints.add(xdsEndpointGroup.select(ctx(), CommonPools.workerGroup()).join()); } return selectedEndpoints; } diff --git a/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/XdsConverterUtilTest.java b/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/XdsConverterUtilTest.java index cb7a4621145..b4ead223195 100644 --- a/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/XdsConverterUtilTest.java +++ b/xds/src/test/java/com/linecorp/armeria/xds/client/endpoint/XdsConverterUtilTest.java @@ -48,15 +48,15 @@ void convertEndpointsWithFilterMetadata() { final Metadata metadata1 = metadata(ImmutableMap.of("foo", "foo1")); final LbEndpoint lbEndpoint1 = endpoint("127.0.0.1", 8080, metadata1); final Endpoint endpoint1 = Endpoint.of("127.0.0.1", 8080) - .withAttr(XdsAttributesKeys.LB_ENDPOINT_KEY, lbEndpoint1); + .withAttr(XdsAttributeKeys.LB_ENDPOINT_KEY, lbEndpoint1); final Metadata metadata2 = metadata(ImmutableMap.of("foo", "foo1", "bar", "bar2")); final LbEndpoint lbEndpoint2 = endpoint("127.0.0.1", 8081, metadata2); final Endpoint endpoint2 = Endpoint.of("127.0.0.1", 8081) - .withAttr(XdsAttributesKeys.LB_ENDPOINT_KEY, lbEndpoint2); + .withAttr(XdsAttributeKeys.LB_ENDPOINT_KEY, lbEndpoint2); final Metadata metadata3 = metadata(ImmutableMap.of("foo", "foo1", "bar", "bar1", "baz", "baz1")); final LbEndpoint lbEndpoint3 = endpoint("127.0.0.1", 8082, metadata3); final Endpoint endpoint3 = Endpoint.of("127.0.0.1", 8082) - .withAttr(XdsAttributesKeys.LB_ENDPOINT_KEY, lbEndpoint3); + .withAttr(XdsAttributeKeys.LB_ENDPOINT_KEY, lbEndpoint3); final List endpoints = convertEndpoints(ImmutableList.of(endpoint1, endpoint2, endpoint3), Struct.newBuilder() .putFields("foo", stringValue("foo1"))