From ce03ff4333de56c3facf515bf0587245ca36aa4d Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Fri, 7 Apr 2023 14:21:48 +0900 Subject: [PATCH] Handle a fragment in a client-side URI properly (#4789) Motivation: When a user sends a request whose path contains a fragment (e.g. `#foo`), Armeria behaves inconsistently depending on whether a user specified an absolute URI in `:path` or not. On an absolute URI, we rely on `URI` for parsing, which takes a fragment into account. Otherwise, we use `PathAndQuery`, which doesn't treat a fragment as a fragment and just normalizes `#` into `%2A` as a part of path and query. Modifications: - Evolve `PathAndQuery` into `RequestTarget` that is capable of parsing and normalizing a `:path` header. - `RequestTarget` now understands a fragment as well as an absolute URI. - Added `RequestTargetForm` - When normalizing a client-side path, `RequestTarget` doesn't clean up consecutive slashes anymore, e.g. `foo///bar` is *not* normalized into `foo/bar` on the client side. - Replaced `path`, `query` and `fragment` fields and parameters with `RequestTarget` where applicable, including: - `RequestContext` implementations - `AbstractRequestContextBuilder` and its subtypes - `UserClient` and its subtypes - `RoutingContext` implementations - Removed `RoutingStatus.INVALID_PATH` because `RequestTarget` always ensures that the path is valid now. - Split the path cache metrics into client-side and server-side ones - Old meter name: `armeria.server.parsed.path.cache` - New meter names: - `armeria.path.cache{type=client}` - `armeria.path.cache{type=server}` - `HttpClientDelegate` now makes sure `ctx.request() == req` to prevent the loophole where a decorator can send a request with an invalid path. - A user must call `ctx.updateRequest()` to validate the request first. - Renamed `PathParsingBenchmark` to `RequestTargetBenchmark` - Added client-side benchmarks - Fixed the incorrect JVM system property name Result: - Armeria client now handles the fragment part of URI consistently. - (Defect) Closed the loophole that allowed a decorator to send a different request than `ctx.request()`. - A decorator now must call `ctx.updateRequest()` when it replaces the current request. - (Improvement) Armeria client doesn't normalize consecutive slashes in a client request path anymore, giving a user freedom to send such a request. - Note: Please make sure that your server or service handles consecutive slashes (e.g. `foo//bar`) properly before an upgrade. Armeria server always cleans up such a path for you, so you don't need to worry. - (Breaking) `RoutingStatus.INVALID_PATH` has been removed because Armeria doesn't leak a request with an invalid state into router. - (Breaking) The signatures of `UserClient.execute()` have been changed. - (Breaking) The names of path cache meters have been changed. - Old meter name: `armeria.server.parsed.path.cache` - New meter names: - `armeria.path.cache{type=client}` - `armeria.path.cache{type=server}` --- .../internal/common/PathParsingBenchmark.java | 79 -- .../common/RequestTargetBenchmark.java | 119 +++ .../armeria/server/RoutersBenchmark.java | 19 +- .../client/ClientRequestContextBuilder.java | 6 +- .../armeria/client/DefaultWebClient.java | 104 +- .../armeria/client/HttpClientDelegate.java | 15 +- .../armeria/client/HttpClientFactory.java | 2 + .../linecorp/armeria/client/UserClient.java | 39 +- .../logging/ContentPreviewingClient.java | 1 + .../common/AbstractRequestContextBuilder.java | 60 +- .../armeria/common/RequestTarget.java | 150 +++ .../armeria/common/RequestTargetForm.java | 41 + .../armeria/internal/client/ClientUtil.java | 16 - .../client/DefaultClientRequestContext.java | 116 +-- .../internal/common/ArmeriaHttpUtil.java | 117 ++- .../internal/common/DefaultRequestTarget.java | 964 ++++++++++++++++++ .../common/NonWrappingRequestContext.java | 63 +- .../armeria/internal/common/PathAndQuery.java | 654 ------------ .../internal/common/RequestTargetCache.java | 131 +++ .../server/DefaultServiceRequestContext.java | 13 +- .../armeria/server/DefaultRoutingContext.java | 50 +- .../armeria/server/Http1RequestDecoder.java | 80 +- .../armeria/server/Http2RequestDecoder.java | 93 +- .../armeria/server/HttpServerHandler.java | 25 +- .../armeria/server/RoutingContext.java | 40 +- .../armeria/server/RoutingContextWrapper.java | 15 +- .../armeria/server/RoutingStatus.java | 7 +- .../com/linecorp/armeria/server/Server.java | 9 +- .../server/ServiceRequestContextBuilder.java | 12 +- .../armeria/server/ServiceRouteUtil.java | 25 +- .../armeria/server/docs/MethodInfo.java | 14 +- .../client/ClientRequestContextTest.java | 28 +- .../client/HttpClientContextCaptorTest.java | 2 +- .../client/HttpClientWithRequestLogTest.java | 8 +- .../DefaultClientRequestContextTest.java | 31 +- .../internal/common/ArmeriaHttpUtilTest.java | 25 +- .../common/DefaultRequestTargetTest.java | 601 +++++++++++ .../internal/common/PathAndQueryTest.java | 572 ----------- .../server/CachingRoutingContextTest.java | 46 +- .../armeria/server/GlobPathMappingTest.java | 8 - .../armeria/server/HttpServerTest.java | 10 +- .../linecorp/armeria/server/RouteTest.java | 23 +- .../linecorp/armeria/server/RouterTest.java | 3 +- .../armeria/server/RoutingContextTest.java | 21 +- .../armeria/server/ServiceRouteUtilTest.java | 18 +- .../server/VirtualHostBuilderTest.java | 7 +- .../armeria/server/file/FileServiceTest.java | 12 +- .../server/InvalidPathWithDataTest.java | 9 +- .../client/grpc/protocol/UnaryGrpcClient.java | 4 +- .../internal/client/grpc/ArmeriaChannel.java | 11 +- .../armeria/client/grpc/GrpcClientTest.java | 12 + .../server/grpc/GrpcServiceServerTest.java | 6 +- .../client/thrift/DefaultTHttpClient.java | 17 +- 53 files changed, 2612 insertions(+), 1941 deletions(-) delete mode 100644 benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/PathParsingBenchmark.java create mode 100644 benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/RequestTargetBenchmark.java create mode 100644 core/src/main/java/com/linecorp/armeria/common/RequestTarget.java create mode 100644 core/src/main/java/com/linecorp/armeria/common/RequestTargetForm.java create mode 100644 core/src/main/java/com/linecorp/armeria/internal/common/DefaultRequestTarget.java delete mode 100644 core/src/main/java/com/linecorp/armeria/internal/common/PathAndQuery.java create mode 100644 core/src/main/java/com/linecorp/armeria/internal/common/RequestTargetCache.java create mode 100644 core/src/test/java/com/linecorp/armeria/internal/common/DefaultRequestTargetTest.java delete mode 100644 core/src/test/java/com/linecorp/armeria/internal/common/PathAndQueryTest.java diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/PathParsingBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/PathParsingBenchmark.java deleted file mode 100644 index afb24c76562..00000000000 --- a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/PathParsingBenchmark.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2017 LINE Corporation - * - * LINE Corporation licenses this file to you under the Apache License, - * version 2.0 (the "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at: - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ - -package com.linecorp.armeria.internal.common; - -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Level; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.infra.Blackhole; - -/** - * Microbenchmarks for the {@link PathAndQuery#parse(String)} method. - */ -@State(Scope.Thread) -public class PathParsingBenchmark { - - private String path1; - private String path2; - - @Setup(Level.Invocation) - @SuppressWarnings("StringOperationCanBeSimplified") - public void setUp() { - // Create a new String for paths every time to avoid constant folding. - path1 = new String("/armeria/services/hello-world"); - path2 = new String("/armeria/services/goodbye-world"); - } - - @Benchmark - public PathAndQuery normal() { - return doNormal(); - } - - @Benchmark - @Fork(jvmArgsAppend = "-Dcom.linecorp.armeria.parsedPathCache=off") - public PathAndQuery normal_cacheDisabled() { - return doNormal(); - } - - private PathAndQuery doNormal() { - final PathAndQuery parsed = PathAndQuery.parse(path1); - parsed.storeInCache(path1); - return parsed; - } - - @Benchmark - public PathAndQuery cachedAndNotCached(Blackhole bh) { - return doCachedAndNotCached(bh); - } - - @Benchmark - @Fork(jvmArgsAppend = "-Dcom.linecorp.armeria.parsedPathCache=off") - public PathAndQuery cachedAndNotCached_cacheDisabled(Blackhole bh) { - return doCachedAndNotCached(bh); - } - - private PathAndQuery doCachedAndNotCached(Blackhole bh) { - final PathAndQuery parsed = PathAndQuery.parse(path1); - parsed.storeInCache(path1); - final PathAndQuery parsed2 = PathAndQuery.parse(path2); - bh.consume(parsed2); - return parsed; - } -} diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/RequestTargetBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/RequestTargetBenchmark.java new file mode 100644 index 00000000000..a8efe75766b --- /dev/null +++ b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/internal/common/RequestTargetBenchmark.java @@ -0,0 +1,119 @@ +/* + * Copyright 2017 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.infra.Blackhole; + +import com.linecorp.armeria.common.RequestTarget; + +/** + * Microbenchmarks for {@link RequestTarget}. + */ +@State(Scope.Thread) +public class RequestTargetBenchmark { + + private static final String NO_CACHE_JVM_OPTS = "-Dcom.linecorp.armeria.parsedPathCacheSpec=off"; + + private String path1; + private String path2; + + @Setup(Level.Invocation) + @SuppressWarnings("StringOperationCanBeSimplified") + public void setUp() { + // Create a new String for paths every time to avoid constant folding. + path1 = new String("/armeria/services/hello-world"); + path2 = new String("/armeria/services/goodbye-world"); + } + + @Benchmark + public RequestTarget serverCached() { + return doServer(); + } + + @Benchmark + @Fork(jvmArgsAppend = NO_CACHE_JVM_OPTS) + public RequestTarget serverUncached() { + return doServer(); + } + + private RequestTarget doServer() { + final RequestTarget parsed = RequestTarget.forServer(path1); + RequestTargetCache.putForServer(path1, parsed); + return parsed; + } + + @Benchmark + public RequestTarget serverCachedAndUncached(Blackhole bh) { + return doServerCachedAndUncached(bh); + } + + @Benchmark + @Fork(jvmArgsAppend = NO_CACHE_JVM_OPTS) + public RequestTarget serverUncachedAndUncached(Blackhole bh) { + return doServerCachedAndUncached(bh); + } + + private RequestTarget doServerCachedAndUncached(Blackhole bh) { + final RequestTarget parsed = RequestTarget.forServer(path1); + RequestTargetCache.putForServer(path1, parsed); + final RequestTarget parsed2 = RequestTarget.forServer(path2); + bh.consume(parsed2); + return parsed; + } + + @Benchmark + public RequestTarget clientCached() { + return doServer(); + } + + @Benchmark + @Fork(jvmArgsAppend = NO_CACHE_JVM_OPTS) + public RequestTarget clientUncached() { + return doClient(); + } + + private RequestTarget doClient() { + final RequestTarget parsed = RequestTarget.forClient(path1); + RequestTargetCache.putForClient(path1, parsed); + return parsed; + } + + @Benchmark + public RequestTarget clientCachedAndUncached(Blackhole bh) { + return doClientCachedAndUncached(bh); + } + + @Benchmark + @Fork(jvmArgsAppend = NO_CACHE_JVM_OPTS) + public RequestTarget clientUncachedAndUncached(Blackhole bh) { + return doClientCachedAndUncached(bh); + } + + private RequestTarget doClientCachedAndUncached(Blackhole bh) { + final RequestTarget parsed = RequestTarget.forClient(path1); + RequestTargetCache.putForClient(path1, parsed); + final RequestTarget parsed2 = RequestTarget.forClient(path2); + bh.consume(parsed2); + return parsed; + } +} diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java index e2a8453bc8c..c284d155d75 100644 --- a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java +++ b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java @@ -32,6 +32,7 @@ import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestId; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.SuccessFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; @@ -48,6 +49,8 @@ public class RoutersBenchmark { private static final RequestHeaders METHOD1_HEADERS = RequestHeaders.of(HttpMethod.POST, "/grpc.package.Service/Method1"); + private static final RequestTarget METHOD1_REQ_TARGET = RequestTarget.forServer(METHOD1_HEADERS.path()); + static { final String defaultLogName = null; final String defaultServiceName = null; @@ -61,32 +64,32 @@ public class RoutersBenchmark { SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), multipartUploadsLocation, ImmutableList.of(), - HttpHeaders.of(), (ctx) -> RequestId.random(), serviceErrorHandler), + HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler), new ServiceConfig(route2, route2, SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), multipartUploadsLocation, ImmutableList.of(), - HttpHeaders.of(), (ctx) -> RequestId.random(), serviceErrorHandler)); + HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler)); FALLBACK_SERVICE = new ServiceConfig(Route.ofCatchAll(), Route.ofCatchAll(), SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), multipartUploadsLocation, - ImmutableList.of(), HttpHeaders.of(), (ctx) -> RequestId.random(), + ImmutableList.of(), HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler); HOST = new VirtualHost( "localhost", "localhost", 0, null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED, unused -> NOPLogger.NOP_LOGGER, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), multipartUploadsLocation, ImmutableList.of(), - (ctx) -> RequestId.random()); + ctx -> RequestId.random()); ROUTER = Routers.ofVirtualHost(HOST, SERVICES, RejectedRouteHandler.DISABLED); } @Benchmark public Routed exactMatch() { - final RoutingContext ctx = DefaultRoutingContext.of(HOST, "localhost", METHOD1_HEADERS.path(), - null, METHOD1_HEADERS, RoutingStatus.OK); + final RoutingContext ctx = DefaultRoutingContext.of(HOST, "localhost", METHOD1_REQ_TARGET, + METHOD1_HEADERS, RoutingStatus.OK); final Routed routed = ROUTER.find(ctx); if (routed.value() != SERVICES.get(0)) { throw new IllegalStateException("Routing error"); @@ -97,8 +100,8 @@ public Routed exactMatch() { @Benchmark public Routed exactMatch_wrapped() { final RoutingContext ctx = new RoutingContextWrapper( - DefaultRoutingContext.of(HOST, "localhost", METHOD1_HEADERS.path(), - null, METHOD1_HEADERS, RoutingStatus.OK)); + DefaultRoutingContext.of(HOST, "localhost", METHOD1_REQ_TARGET, + METHOD1_HEADERS, RoutingStatus.OK)); final Routed routed = ROUTER.find(ctx); if (routed.value() != SERVICES.get(0)) { throw new IllegalStateException("Routing error"); diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextBuilder.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextBuilder.java index 0ded24cac8e..d8ffe00b81c 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextBuilder.java @@ -70,8 +70,6 @@ public void run(Throwable cause) { /* no-op */ } noopResponseCancellationScheduler.finishNow(); } - @Nullable - private final String fragment; @Nullable private Endpoint endpoint; private ClientOptions options = ClientOptions.of(); @@ -81,12 +79,10 @@ public void run(Throwable cause) { /* no-op */ } ClientRequestContextBuilder(HttpRequest request) { super(false, request); - fragment = null; } ClientRequestContextBuilder(RpcRequest request, URI uri) { super(false, request, uri); - fragment = uri.getRawFragment(); } @Override @@ -157,7 +153,7 @@ public ClientRequestContext build() { final DefaultClientRequestContext ctx = new DefaultClientRequestContext( eventLoop(), meterRegistry(), sessionProtocol(), - id(), method(), path(), query(), fragment, options, request(), rpcRequest(), + id(), method(), requestTarget(), options, request(), rpcRequest(), requestOptions, responseCancellationScheduler, isRequestStartTimeSet() ? requestStartTimeNanos() : System.nanoTime(), isRequestStartTimeSet() ? requestStartTimeMicros() : SystemInfo.currentTimeMicros()); diff --git a/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java b/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java index 767a0737042..d2c222ee4ea 100644 --- a/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java @@ -16,21 +16,19 @@ package com.linecorp.armeria.client; -import static com.linecorp.armeria.internal.client.ClientUtil.pathWithQuery; -import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.concatPaths; -import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.isAbsoluteUri; import static java.util.Objects.requireNonNull; -import java.net.URI; +import com.google.common.base.Strings; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.common.ExchangeType; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.RequestTargetForm; import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.internal.common.PathAndQuery; import io.micrometer.core.instrument.MeterRegistry; @@ -58,69 +56,69 @@ public HttpResponse execute(HttpRequest req, RequestOptions requestOptions) { requireNonNull(req, "req"); requireNonNull(requestOptions, "requestOptions"); + final String originalPath = req.path(); + final String prefix = Strings.emptyToNull(uri().getRawPath()); + final RequestTarget reqTarget = RequestTarget.forClient(originalPath, prefix); + if (reqTarget == null) { + return abortRequestAndReturnFailureResponse( + req, new IllegalArgumentException("Invalid path: " + originalPath)); + } + + final EndpointGroup endpointGroup; + final SessionProtocol protocol; + if (Clients.isUndefinedUri(uri())) { - final URI uri; - if (isAbsoluteUri(req.path())) { - try { - uri = URI.create(req.path()); - } catch (Exception ex) { + final String scheme; + final String authority; + if (reqTarget.form() == RequestTargetForm.ABSOLUTE) { + scheme = reqTarget.scheme(); + authority = reqTarget.authority(); + assert scheme != null; + assert authority != null; + } else { + scheme = req.scheme(); + authority = req.authority(); + + if (scheme == null || authority == null) { return abortRequestAndReturnFailureResponse(req, new IllegalArgumentException( - "Failed to create a URI: " + req.path(), ex)); + "Scheme and authority must be specified in \":path\" or " + + "in \":scheme\" and \":authority\". :path=" + + originalPath + ", :scheme=" + req.scheme() + ", :authority=" + req.authority())); } - } else if (req.scheme() != null && req.authority() != null) { - uri = req.uri(); - } else { - return abortRequestAndReturnFailureResponse(req, new IllegalArgumentException( - "Scheme and authority must be specified in \":path\" or " + - "in \":scheme\" and \":authority\". :path=" + - req.path() + ", :scheme=" + req.scheme() + ", :authority=" + req.authority())); } - final SessionProtocol protocol; + + endpointGroup = Endpoint.parse(authority); try { - protocol = Scheme.parse(uri.getScheme()).sessionProtocol(); + protocol = Scheme.parse(scheme).sessionProtocol(); } catch (Exception e) { return abortRequestAndReturnFailureResponse(req, new IllegalArgumentException( - "Failed to parse a scheme: " + uri.getScheme(), e)); + "Failed to parse a scheme: " + reqTarget.scheme(), e)); + } + } else { + if (reqTarget.form() == RequestTargetForm.ABSOLUTE) { + return abortRequestAndReturnFailureResponse(req, new IllegalArgumentException( + "Cannot send a request with a \":path\" header that contains an authority, " + + "because the client was created with a base URI. path: " + originalPath)); } - final Endpoint endpoint = Endpoint.parse(uri.getAuthority()); - final String query = uri.getRawQuery(); - final String path = pathWithQuery(uri, query); - final HttpRequest newReq = req.withHeaders(req.headers().toBuilder().path(path)); - return execute(endpoint, newReq, protocol, requestOptions); - } - - if (isAbsoluteUri(req.path())) { - return abortRequestAndReturnFailureResponse(req, new IllegalArgumentException( - "Cannot send a request with a \":path\" header that contains a URI with the authority, " + - "because the client was created with a base URI. path: " + req.path())); + endpointGroup = endpointGroup(); + protocol = scheme().sessionProtocol(); } - final String originalPath = req.path(); - final String newPath = concatPaths(uri().getRawPath(), originalPath); + final String newPath = reqTarget.pathAndQuery(); final HttpRequest newReq; - // newPath and originalPath should be the same reference if uri().getRawPath() can be ignorable - if (newPath != originalPath) { - newReq = req.withHeaders(req.headers().toBuilder().path(newPath)); - } else { + if (newPath.equals(originalPath)) { newReq = req; + } else { + newReq = req.withHeaders(req.headers().toBuilder().path(newPath)); } - return execute(endpointGroup(), newReq, scheme().sessionProtocol(), requestOptions); - } - private HttpResponse execute(EndpointGroup endpointGroup, HttpRequest req, SessionProtocol protocol, - RequestOptions requestOptions) { - final PathAndQuery pathAndQuery = PathAndQuery.parse(req.path()); - if (pathAndQuery == null) { - final IllegalArgumentException cause = new IllegalArgumentException("invalid path: " + req.path()); - return abortRequestAndReturnFailureResponse(req, cause); - } - final String newPath = pathAndQuery.toString(); - if (!newPath.equals(req.path())) { - req = req.withHeaders(req.headers().toBuilder().path(newPath)); - } - return execute(protocol, endpointGroup, req.method(), - pathAndQuery.path(), pathAndQuery.query(), null, req, requestOptions); + return execute(protocol, + endpointGroup, + newReq.method(), + reqTarget, + newReq, + requestOptions); } private static HttpResponse abortRequestAndReturnFailureResponse( diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java index 48b51d65cc8..a255b61dbae 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java @@ -38,7 +38,6 @@ import com.linecorp.armeria.internal.client.DecodedHttpResponse; import com.linecorp.armeria.internal.client.HttpSession; import com.linecorp.armeria.internal.client.PooledChannel; -import com.linecorp.armeria.internal.common.PathAndQuery; import com.linecorp.armeria.internal.common.RequestContextUtil; import com.linecorp.armeria.server.ProxiedAddresses; import com.linecorp.armeria.server.ServiceRequestContext; @@ -65,6 +64,12 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex if (throwable != null) { return earlyFailedResponse(throwable, ctx, req); } + if (req != ctx.request()) { + return earlyFailedResponse( + new IllegalStateException("ctx.request() does not match the actual request; " + + "did you forget to call ctx.updateRequest() in your decorator?"), + ctx, req); + } final Endpoint endpoint = ctx.endpoint(); if (endpoint == null) { @@ -81,10 +86,6 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex return earlyFailedResponse(EmptyEndpointGroupException.get(ctx.endpointGroup()), ctx, req); } - if (!isValidPath(req)) { - return earlyFailedResponse(new IllegalArgumentException("invalid path: " + req.path()), ctx, req); - } - final SessionProtocol protocol = ctx.sessionProtocol(); final ProxyConfig proxyConfig; try { @@ -221,10 +222,6 @@ private static void logSession(ClientRequestContext ctx, @Nullable PooledChannel } } - private static boolean isValidPath(HttpRequest req) { - return PathAndQuery.parse(req.path()) != null; - } - private static HttpResponse earlyFailedResponse(Throwable t, ClientRequestContext ctx, HttpRequest req) { final UnprocessedRequestException cause = UnprocessedRequestException.of(t); handleEarlyRequestException(ctx, req, cause); diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java index f124b0d0e93..64e7ed4c16d 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java @@ -50,6 +50,7 @@ import com.linecorp.armeria.common.util.ReleasableHolder; import com.linecorp.armeria.common.util.ShutdownHooks; import com.linecorp.armeria.common.util.TransportType; +import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.internal.common.util.SslContextUtil; import io.micrometer.core.instrument.MeterRegistry; @@ -166,6 +167,7 @@ final class HttpClientFactory implements ClientFactory { this.options = options; clientDelegate = new HttpClientDelegate(this, addressResolverGroup); + RequestTargetCache.registerClientMetrics(meterRegistry); } /** diff --git a/core/src/main/java/com/linecorp/armeria/client/UserClient.java b/core/src/main/java/com/linecorp/armeria/client/UserClient.java index dac77d9921c..dc87b286849 100644 --- a/core/src/main/java/com/linecorp/armeria/client/UserClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/UserClient.java @@ -32,12 +32,12 @@ import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.Request; import com.linecorp.armeria.common.RequestId; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.Response; import com.linecorp.armeria.common.RpcRequest; import com.linecorp.armeria.common.RpcResponse; import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SessionProtocol; -import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.AbstractUnwrappable; import com.linecorp.armeria.common.util.SystemInfo; import com.linecorp.armeria.internal.client.DefaultClientRequestContext; @@ -123,14 +123,11 @@ public final ClientOptions options() { * * @param protocol the {@link SessionProtocol} to use * @param method the method of the {@link Request} - * @param path the path part of the {@link Request} URI - * @param query the query part of the {@link Request} URI - * @param fragment the fragment part of the {@link Request} URI + * @param reqTarget the {@link RequestTarget} of the {@link Request} * @param req the {@link Request} */ - protected final O execute(SessionProtocol protocol, HttpMethod method, String path, - @Nullable String query, @Nullable String fragment, I req) { - return execute(protocol, method, path, query, fragment, req, RequestOptions.of()); + protected final O execute(SessionProtocol protocol, HttpMethod method, RequestTarget reqTarget, I req) { + return execute(protocol, method, reqTarget, req, RequestOptions.of()); } /** @@ -138,16 +135,13 @@ protected final O execute(SessionProtocol protocol, HttpMethod method, String pa * * @param protocol the {@link SessionProtocol} to use * @param method the method of the {@link Request} - * @param path the path part of the {@link Request} URI - * @param query the query part of the {@link Request} URI - * @param fragment the fragment part of the {@link Request} URI + * @param reqTarget the {@link RequestTarget} of the {@link Request} * @param req the {@link Request} * @param requestOptions the {@link RequestOptions} of the {@link Request} */ - protected final O execute(SessionProtocol protocol, HttpMethod method, String path, - @Nullable String query, @Nullable String fragment, I req, - RequestOptions requestOptions) { - return execute(protocol, endpointGroup(), method, path, query, fragment, req, requestOptions); + protected final O execute(SessionProtocol protocol, HttpMethod method, RequestTarget reqTarget, + I req, RequestOptions requestOptions) { + return execute(protocol, endpointGroup(), method, reqTarget, req, requestOptions); } /** @@ -156,14 +150,12 @@ protected final O execute(SessionProtocol protocol, HttpMethod method, String pa * @param protocol the {@link SessionProtocol} to use * @param endpointGroup the {@link EndpointGroup} of the {@link Request} * @param method the method of the {@link Request} - * @param path the path part of the {@link Request} URI - * @param query the query part of the {@link Request} URI - * @param fragment the fragment part of the {@link Request} URI + * @param reqTarget the {@link RequestTarget} of the {@link Request} * @param req the {@link Request} */ protected final O execute(SessionProtocol protocol, EndpointGroup endpointGroup, HttpMethod method, - String path, @Nullable String query, @Nullable String fragment, I req) { - return execute(protocol, endpointGroup, method, path, query, fragment, req, RequestOptions.of()); + RequestTarget reqTarget, I req) { + return execute(protocol, endpointGroup, method, reqTarget, req, RequestOptions.of()); } /** @@ -172,15 +164,12 @@ protected final O execute(SessionProtocol protocol, EndpointGroup endpointGroup, * @param protocol the {@link SessionProtocol} to use * @param endpointGroup the {@link EndpointGroup} of the {@link Request} * @param method the method of the {@link Request} - * @param path the path part of the {@link Request} URI - * @param query the query part of the {@link Request} URI - * @param fragment the fragment part of the {@link Request} URI + * @param reqTarget the {@link RequestTarget} of the {@link Request} * @param req the {@link Request} * @param requestOptions the {@link RequestOptions} of the {@link Request} */ protected final O execute(SessionProtocol protocol, EndpointGroup endpointGroup, HttpMethod method, - String path, @Nullable String query, @Nullable String fragment, I req, - RequestOptions requestOptions) { + RequestTarget reqTarget, I req, RequestOptions requestOptions) { final HttpRequest httpReq; final RpcRequest rpcReq; @@ -195,7 +184,7 @@ protected final O execute(SessionProtocol protocol, EndpointGroup endpointGroup, } final DefaultClientRequestContext ctx = new DefaultClientRequestContext( - meterRegistry, protocol, id, method, path, query, fragment, options(), httpReq, rpcReq, + meterRegistry, protocol, id, method, reqTarget, options(), httpReq, rpcReq, requestOptions, System.nanoTime(), SystemInfo.currentTimeMicros()); return initContextAndExecuteWithFallback(unwrap(), ctx, endpointGroup, diff --git a/core/src/main/java/com/linecorp/armeria/client/logging/ContentPreviewingClient.java b/core/src/main/java/com/linecorp/armeria/client/logging/ContentPreviewingClient.java index dd53c48ea3f..faec01dc62b 100644 --- a/core/src/main/java/com/linecorp/armeria/client/logging/ContentPreviewingClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/logging/ContentPreviewingClient.java @@ -151,6 +151,7 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex final ContentPreviewer requestContentPreviewer = contentPreviewerFactory.requestContentPreviewer(ctx, req.headers()); req = setUpRequestContentPreviewer(ctx, req, requestContentPreviewer, requestPreviewSanitizer); + ctx.updateRequest(req); } else { // Set empty String. ctx.logBuilder().requestContentPreview(""); diff --git a/core/src/main/java/com/linecorp/armeria/common/AbstractRequestContextBuilder.java b/core/src/main/java/com/linecorp/armeria/common/AbstractRequestContextBuilder.java index 1286ed60605..2e6343dd545 100644 --- a/core/src/main/java/com/linecorp/armeria/common/AbstractRequestContextBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/common/AbstractRequestContextBuilder.java @@ -31,7 +31,7 @@ import com.linecorp.armeria.client.ClientRequestContextBuilder; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.metric.NoopMeterRegistry; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.internal.common.DefaultRequestTarget; import com.linecorp.armeria.server.Service; import com.linecorp.armeria.server.ServiceRequestContextBuilder; @@ -71,9 +71,7 @@ public abstract class AbstractRequestContextBuilder { private RequestId id; private HttpMethod method; private final String authority; - private final String path; - @Nullable - private final String query; + private final RequestTarget reqTarget; private MeterRegistry meterRegistry = NoopMeterRegistry.get(); @Nullable @@ -99,19 +97,31 @@ public abstract class AbstractRequestContextBuilder { * @param req the {@link HttpRequest}. */ protected AbstractRequestContextBuilder(boolean server, HttpRequest req) { + requireNonNull(req, "req"); this.server = server; - this.req = requireNonNull(req, "req"); rpcReq = null; sessionProtocol = SessionProtocol.H2C; method = req.headers().method(); authority = firstNonNull(req.headers().authority(), FALLBACK_AUTHORITY); - final String pathAndQueryStr = req.headers().path(); - final PathAndQuery pathAndQuery = PathAndQuery.parse(pathAndQueryStr); - checkArgument(pathAndQuery != null, "request.path is not valid: %s", req); - path = pathAndQuery.path(); - query = pathAndQuery.query(); + final String rawPath = req.headers().path(); + final RequestTarget reqTarget = server ? RequestTarget.forServer(rawPath) + : RequestTarget.forClient(rawPath); + checkArgument(reqTarget != null, "request.path is not valid: %s", rawPath); + checkArgument(reqTarget.form() != RequestTargetForm.ABSOLUTE, + "request.path must not contain scheme or authority: %s", rawPath); + + final String newRawPath = reqTarget.pathAndQuery(); + if (newRawPath.equals(rawPath)) { + this.req = req; + } else { + this.req = req.withHeaders(req.headers() + .toBuilder() + .path(newRawPath)); + } + + this.reqTarget = reqTarget; } /** @@ -131,15 +141,15 @@ protected AbstractRequestContextBuilder(boolean server, RpcRequest rpcReq, URI u authority = firstNonNull(uri.getRawAuthority(), FALLBACK_AUTHORITY); sessionProtocol = getSessionProtocol(uri); - final PathAndQuery pathAndQuery; - if (uri.getRawQuery() != null) { - pathAndQuery = PathAndQuery.parse(uri.getRawPath() + '?' + uri.getRawQuery()); + if (server) { + reqTarget = DefaultRequestTarget.createWithoutValidation( + RequestTargetForm.ORIGIN, null, null, + uri.getRawPath(), uri.getRawQuery(), null); } else { - pathAndQuery = PathAndQuery.parse(uri.getRawPath()); + reqTarget = DefaultRequestTarget.createWithoutValidation( + RequestTargetForm.ORIGIN, null, null, + uri.getRawPath(), uri.getRawQuery(), uri.getRawFragment()); } - checkArgument(pathAndQuery != null, "uri.path or uri.query is not valid: %s", uri); - path = pathAndQuery.path(); - query = pathAndQuery.query(); } private static SessionProtocol getSessionProtocol(URI uri) { @@ -443,10 +453,10 @@ protected final String authority() { } /** - * Returns the path of the request, excluding the query part. + * Returns the {@link RequestTarget}. */ - protected final String path() { - return path; + protected final RequestTarget requestTarget() { + return reqTarget; } /** @@ -468,16 +478,6 @@ protected final RequestId id() { return id; } - /** - * Returns the query part of the request, excluding the leading question mark ({@code '?'}). - * - * @return the query string, or {@code null} if there is no query. - */ - @Nullable - protected final String query() { - return query; - } - /** * Returns a fake {@link Channel} which is required internally when creating a context. */ diff --git a/core/src/main/java/com/linecorp/armeria/common/RequestTarget.java b/core/src/main/java/com/linecorp/armeria/common/RequestTarget.java new file mode 100644 index 00000000000..3dbd9d1be7b --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/RequestTarget.java @@ -0,0 +1,150 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.common; + +import static java.util.Objects.requireNonNull; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.internal.common.DefaultRequestTarget; +import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals; + +/** + * An HTTP request target, as defined in + * Section 3.2, RFC 9112. + * + *

Note: This interface doesn't support the + * authority form. + */ +@UnstableApi +public interface RequestTarget { + + /** + * Returns a {@link RequestTarget} parsed and normalized from the specified request target string + * in the context of server-side application. It rejects an absolute or authority form request target. + * It also normalizes {@code '#'} into {@code '%2A'} instead of parsing a fragment. + * Use {@link #forClient(String)} if you want to parse an absolute form request target or a fragment. + * + * @param reqTarget the request target string + * @return a {@link RequestTarget} if parsed and normalized successfully, or {@code null} otherwise. + */ + @Nullable + static RequestTarget forServer(String reqTarget) { + requireNonNull(reqTarget, "reqTarget"); + return DefaultRequestTarget.forServer(reqTarget, Flags.allowDoubleDotsInQueryString()); + } + + /** + * Returns a {@link RequestTarget} parsed and normalized from the specified request target string + * in the context of client-side application. It rejects an authority form request target. + * + * @param reqTarget the request target string + * @return a {@link RequestTarget} if parsed and normalized successfully, or {@code null} otherwise. + * @see #forServer(String) + */ + @Nullable + static RequestTarget forClient(String reqTarget) { + return forClient(reqTarget, null); + } + + /** + * Returns a {@link RequestTarget} parsed and normalized from the specified request target string + * in the context of client-side application. It rejects an authority form request target. + * + * @param reqTarget the request target string + * @param prefix the prefix to add to {@code reqTarget}. No prefix is added if {@code null} or empty. + * @return a {@link RequestTarget} if parsed and normalized successfully, or {@code null} otherwise. + * @see #forServer(String) + */ + @Nullable + static RequestTarget forClient(String reqTarget, @Nullable String prefix) { + return DefaultRequestTarget.forClient(reqTarget, prefix); + } + + /** + * Returns the form of this {@link RequestTarget}. + */ + RequestTargetForm form(); + + /** + * Returns the scheme of this {@link RequestTarget}. + * + * @return a non-empty string if {@link #form()} is {@link RequestTargetForm#ABSOLUTE}. + * {@code null} otherwise. + */ + @Nullable + String scheme(); + + /** + * Returns the authority of this {@link RequestTarget}. + * + * @return a non-empty string if {@link #form()} is {@link RequestTargetForm#ABSOLUTE}. + * {@code null} otherwise. + */ + @Nullable + String authority(); + + /** + * Returns the path of this {@link RequestTarget}, which always starts with {@code '/'}. + */ + String path(); + + /** + * Returns the query of this {@link RequestTarget}. + */ + @Nullable + String query(); + + /** + * Returns the string that combines {@link #path()} and {@link #query()}. + * + * @return {@link #path()} + '?' + {@link #query()} if {@link #query()} is non-{@code null}. + * {@link #path()} if {@link #query()} is {@code null}. + */ + default String pathAndQuery() { + if (query() == null) { + return path(); + } + + try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) { + return tmp.stringBuilder() + .append(path()) + .append('?') + .append(query()) + .toString(); + } + } + + /** + * Returns the fragment of this {@link RequestTarget}. + */ + @Nullable + String fragment(); + + /** + * Returns the string representation of this {@link RequestTarget}. + * + * @return One of the following:

    + *
  • An absolute URI if {@link #form()} is {@link RequestTargetForm#ABSOLUTE}, e.g. + * {@code "https://example.com/foo?bar#baz}
  • + *
  • Path with query and fragment if {@link #form()} is {@link RequestTargetForm#ORIGIN}, e.g. + * {@code "/foo?bar#baz"}
  • + *
  • {@code "*"} if {@link #form()} is {@link RequestTargetForm#ASTERISK}
  • + *
+ */ + @Override + String toString(); +} diff --git a/core/src/main/java/com/linecorp/armeria/common/RequestTargetForm.java b/core/src/main/java/com/linecorp/armeria/common/RequestTargetForm.java new file mode 100644 index 00000000000..2fb2d177dd6 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/RequestTargetForm.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.common; + +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * {@link RequestTarget} form, as defined in + * Section 3.2, RFC 9112. + * + *

Note: This enum doesn't support the + * authority form. + */ +@UnstableApi +public enum RequestTargetForm { + /** + * An absolute path followed by a query and a fragment. + */ + ORIGIN, + /** + * An absolute URI that has scheme, authority and absolute path followed by a query and a fragment. + */ + ABSOLUTE, + /** + * {@code "*"}, used for a server-side {@code OPTIONS} request. + */ + ASTERISK +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java b/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java index f9dd9156e0c..661fd471326 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java @@ -18,13 +18,10 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static java.util.Objects.requireNonNull; -import java.net.URI; import java.util.concurrent.CompletableFuture; import java.util.function.BiFunction; import java.util.function.Function; -import com.google.common.base.Strings; - import com.linecorp.armeria.client.Client; import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.Endpoint; @@ -231,18 +228,5 @@ public static ClientRequestContext newDerivedContext(ClientRequestContext ctx, return derived; } - public static String pathWithQuery(URI uri, @Nullable String query) { - return pathWithQuery(uri.getRawPath(), query); - } - - public static String pathWithQuery(String path, @Nullable String query) { - if (Strings.isNullOrEmpty(path)) { - path = query == null ? "/" : "/?" + query; - } else if (query != null) { - path = path + '?' + query; - } - return path; - } - private ClientUtil() {} } diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java index e7c6e5b9ef7..356819f9b97 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java @@ -18,8 +18,6 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.linecorp.armeria.internal.client.ClientUtil.pathWithQuery; -import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.isAbsoluteUri; import static com.linecorp.armeria.internal.common.HttpHeadersUtil.getScheme; import static java.util.Objects.requireNonNull; @@ -57,6 +55,8 @@ import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestId; +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.RequestTargetForm; import com.linecorp.armeria.common.Response; import com.linecorp.armeria.common.RpcRequest; import com.linecorp.armeria.common.Scheme; @@ -72,7 +72,6 @@ import com.linecorp.armeria.common.util.UnmodifiableFuture; import com.linecorp.armeria.internal.common.CancellationScheduler; import com.linecorp.armeria.internal.common.NonWrappingRequestContext; -import com.linecorp.armeria.internal.common.PathAndQuery; import com.linecorp.armeria.internal.common.RequestContextExtension; import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals; import com.linecorp.armeria.server.ServiceRequestContext; @@ -115,8 +114,6 @@ public final class DefaultClientRequestContext @Nullable private ContextAwareEventLoop contextAwareEventLoop; @Nullable - private final String fragment; - @Nullable private final ServiceRequestContext root; private final ClientOptions options; @@ -161,12 +158,12 @@ public final class DefaultClientRequestContext */ public DefaultClientRequestContext( EventLoop eventLoop, MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, HttpMethod method, String path, @Nullable String query, @Nullable String fragment, + RequestId id, HttpMethod method, RequestTarget reqTarget, ClientOptions options, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, RequestOptions requestOptions, CancellationScheduler responseCancellationScheduler, long requestStartTimeNanos, long requestStartTimeMicros) { this(eventLoop, meterRegistry, sessionProtocol, - id, method, path, query, fragment, options, req, rpcReq, requestOptions, serviceRequestContext(), + id, method, reqTarget, options, req, rpcReq, requestOptions, serviceRequestContext(), responseCancellationScheduler, requestStartTimeNanos, requestStartTimeMicros); } @@ -185,30 +182,29 @@ id, method, path, query, fragment, options, req, rpcReq, requestOptions, service */ public DefaultClientRequestContext( MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, HttpMethod method, String path, @Nullable String query, @Nullable String fragment, + RequestId id, HttpMethod method, RequestTarget reqTarget, ClientOptions options, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, RequestOptions requestOptions, long requestStartTimeNanos, long requestStartTimeMicros) { this(null, meterRegistry, sessionProtocol, - id, method, path, query, fragment, options, req, rpcReq, requestOptions, + id, method, reqTarget, options, req, rpcReq, requestOptions, serviceRequestContext(), /* responseCancellationScheduler */ null, requestStartTimeNanos, requestStartTimeMicros); } private DefaultClientRequestContext( @Nullable EventLoop eventLoop, MeterRegistry meterRegistry, - SessionProtocol sessionProtocol, RequestId id, HttpMethod method, String path, - @Nullable String query, @Nullable String fragment, ClientOptions options, + SessionProtocol sessionProtocol, RequestId id, HttpMethod method, + RequestTarget reqTarget, ClientOptions options, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, RequestOptions requestOptions, @Nullable ServiceRequestContext root, @Nullable CancellationScheduler responseCancellationScheduler, long requestStartTimeNanos, long requestStartTimeMicros) { - super(meterRegistry, sessionProtocol, id, method, path, query, + super(meterRegistry, sessionProtocol, id, method, reqTarget, firstNonNull(requestOptions.exchangeType(), ExchangeType.BIDI_STREAMING), req, rpcReq, getAttributes(root)); this.eventLoop = eventLoop; this.options = requireNonNull(options, "options"); - this.fragment = fragment; this.root = root; log = RequestLog.builder(this); @@ -460,8 +456,8 @@ private DefaultClientRequestContext(DefaultClientRequestContext ctx, @Nullable RpcRequest rpcReq, @Nullable Endpoint endpoint, @Nullable EndpointGroup endpointGroup, SessionProtocol sessionProtocol, HttpMethod method, - String path, @Nullable String query, @Nullable String fragment) { - super(ctx.meterRegistry(), sessionProtocol, id, method, path, query, ctx.exchangeType(), + RequestTarget reqTarget) { + super(ctx.meterRegistry(), sessionProtocol, id, method, reqTarget, ctx.exchangeType(), req, rpcReq, getAttributes(ctx.root())); // The new requests cannot be null if it was previously non-null. @@ -475,7 +471,6 @@ private DefaultClientRequestContext(DefaultClientRequestContext ctx, eventLoop = ctx.eventLoop().withoutContext(); options = ctx.options(); - this.fragment = fragment; root = ctx.root(); log = RequestLog.builder(this); @@ -533,81 +528,44 @@ public ClientRequestContext newDerivedContext(RequestId id, @Nullable Endpoint endpoint) { if (req != null) { final RequestHeaders newHeaders = req.headers(); + final String oldPath = requestTarget().pathAndQuery(); final String newPath = newHeaders.path(); - if (!path().equals(newPath)) { + if (!oldPath.equals(newPath)) { // path is changed. + final RequestTarget reqTarget = RequestTarget.forClient(newPath); + checkArgument(reqTarget != null, "invalid path: %s", newPath); - if (!isAbsoluteUri(newPath)) { - return newDerivedContext(id, req, rpcReq, newHeaders, sessionProtocol(), endpoint, newPath); + if (reqTarget.form() != RequestTargetForm.ABSOLUTE) { + // Not an absolute URI. + return new DefaultClientRequestContext(this, id, req, rpcReq, endpoint, null, + sessionProtocol(), newHeaders.method(), reqTarget); } - final URI uri = URI.create(req.path()); - final Scheme scheme = Scheme.parse(uri.getScheme()); - final SessionProtocol protocol = scheme.sessionProtocol(); - final Endpoint newEndpoint = Endpoint.parse(uri.getAuthority()); - final String rawQuery = uri.getRawQuery(); - final String pathWithQuery = pathWithQuery(uri, rawQuery); - final HttpRequest newReq = req.withHeaders(req.headers().toBuilder().path(pathWithQuery)); - return newDerivedContext(id, newReq, rpcReq, newHeaders, protocol, - newEndpoint, pathWithQuery); + + // Recalculate protocol and endpoint from the absolute URI. + final String scheme = reqTarget.scheme(); + final String authority = reqTarget.authority(); + assert scheme != null; + assert authority != null; + + final SessionProtocol protocol = Scheme.parse(scheme).sessionProtocol(); + final Endpoint newEndpoint = Endpoint.parse(authority); + final HttpRequest newReq = req.withHeaders(req.headers() + .toBuilder() + .path(reqTarget.pathAndQuery())); + return new DefaultClientRequestContext(this, id, newReq, rpcReq, newEndpoint, null, + protocol, newHeaders.method(), reqTarget); } } return new DefaultClientRequestContext(this, id, req, rpcReq, endpoint, endpointGroup(), - sessionProtocol(), method(), path(), query(), fragment()); - } - - private ClientRequestContext newDerivedContext(RequestId id, HttpRequest req, @Nullable RpcRequest rpcReq, - RequestHeaders newHeaders, SessionProtocol protocol, - @Nullable Endpoint endpoint, String pathWithQuery) { - final PathAndQuery pathAndQuery = PathAndQuery.parse(pathWithQuery); - if (pathAndQuery == null) { - throw new IllegalArgumentException("invalid path: " + req.path()); - } - return new DefaultClientRequestContext(this, id, req, rpcReq, endpoint, null, - protocol, newHeaders.method(), pathAndQuery.path(), - pathAndQuery.query(), null); + sessionProtocol(), method(), requestTarget()); } @Override - protected void validateHeaders(RequestHeaders headers) { + protected RequestTarget validateHeaders(RequestHeaders headers) { // no need to validate since internal headers will contain // the default host and session protocol headers set by endpoints. - } - - @Override - protected void unsafeUpdateRequest(HttpRequest req) { - final PathAndQuery pathAndQuery; - final SessionProtocol sessionProtocol; - final String authority; - if (isAbsoluteUri(req.path())) { - final URI uri = URI.create(req.path()); - checkArgument(uri.getScheme() != null, "missing scheme"); - checkArgument(uri.getAuthority() != null, "missing authority"); - checkArgument(!uri.getAuthority().isEmpty(), "empty authority"); - final String rawQuery = uri.getRawQuery(); - final String pathWithQuery = pathWithQuery(uri, rawQuery); - pathAndQuery = PathAndQuery.parse(pathWithQuery); - sessionProtocol = Scheme.parse(uri.getScheme()).sessionProtocol(); - authority = uri.getAuthority(); - } else { - pathAndQuery = PathAndQuery.parse(req.path()); - sessionProtocol = null; - authority = null; - } - if (pathAndQuery == null) { - throw new IllegalArgumentException("invalid path: " + req.path()); - } - - // all validation is complete at this point - super.unsafeUpdateRequest(req); - path(pathAndQuery.path()); - query(pathAndQuery.query()); - if (sessionProtocol != null) { - sessionProtocol(sessionProtocol); - } - if (authority != null) { - updateEndpoint(Endpoint.parse(authority)); - } + return RequestTarget.forClient(headers.path()); } @Override @@ -663,7 +621,7 @@ public Endpoint endpoint() { @Override @Nullable public String fragment() { - return fragment; + return requestTarget().fragment(); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtil.java index 0ac89afe50a..2842c0bc2b6 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtil.java @@ -30,6 +30,8 @@ */ package com.linecorp.armeria.internal.common; +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.netty.util.AsciiString.EMPTY_STRING; import static io.netty.util.ByteProcessor.FIND_COMMA; @@ -47,9 +49,6 @@ import java.util.StringJoiner; import java.util.function.BiConsumer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.github.benmanes.caffeine.cache.Caffeine; import com.github.benmanes.caffeine.cache.LoadingCache; import com.google.common.annotations.VisibleForTesting; @@ -69,6 +68,7 @@ import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestHeadersBuilder; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.ResponseHeadersBuilder; import com.linecorp.armeria.common.annotation.Nullable; @@ -102,8 +102,6 @@ public final class ArmeriaHttpUtil { // Forked from Netty 4.1.34 at 4921f62c8ab8205fd222439dcd1811760b05daf1 - private static final Logger logger = LoggerFactory.getLogger(ArmeriaHttpUtil.class); - /** * The default case-insensitive {@link AsciiString} hasher and comparator for HTTP/2 headers. */ @@ -236,12 +234,6 @@ public boolean equals(AsciiString a, AsciiString b) { HttpHeaderNames.HOST); } - /** - * rfc7540, 8.1.2.3 - * states the path must not be empty, and instead should be {@code /}. - */ - private static final String EMPTY_REQUEST_PATH = "/"; - private static final Splitter COOKIE_SPLITTER = Splitter.on(';').trimResults().omitEmptyStrings(); private static final String COOKIE_SEPARATOR = "; "; private static final Joiner COOKIE_JOINER = Joiner.on(COOKIE_SEPARATOR); @@ -257,51 +249,72 @@ private static LoadingCache buildCache(String spec) { } /** - * Concatenates two path strings. + * Concatenates the specified {@code prefix} and {@code path} into an absolute path. + * + * @throws IllegalArgumentException if {@code prefix} is not an absolute path prefix */ - public static String concatPaths(@Nullable String path1, @Nullable String path2) { - path2 = path2 == null ? "" : path2; - - if (path1 == null || path1.isEmpty() || EMPTY_REQUEST_PATH.equals(path1)) { - if (path2.isEmpty()) { - return EMPTY_REQUEST_PATH; - } + public static String concatPaths(String prefix, @Nullable String path) { + requireNonNull(prefix, "prefix"); + checkArgument(!prefix.isEmpty() && prefix.charAt(0) == '/', + "prefix: %s (expected: an absolute path starting with '/')", prefix); + + path = firstNonNull(path, ""); + if (path.isEmpty()) { + return prefix; + } - if (path2.charAt(0) == '/') { - return path2; // Most requests will land here. + if (prefix.length() == 1) { // means "/".equals(prefix) + if (path.charAt(0) == '/') { + return path; // Most requests will land here. } - - return '/' + path2; + return simpleConcat("/", path); } - // At this point, we are sure path1 is neither empty nor null. - if (path2.isEmpty()) { - // Only path1 is non-empty. No need to concatenate. - return path1; - } + return slowConcatPaths(prefix, path); + } - if (path1.charAt(path1.length() - 1) == '/') { - if (path2.charAt(0) == '/') { - // path1 ends with '/' and path2 starts with '/'. - // Avoid double-slash by stripping the first slash of path2. - return new StringBuilder(path1.length() + path2.length() - 1) - .append(path1).append(path2, 1, path2.length()).toString(); + private static String slowConcatPaths(String prefix, String path) { + if (prefix.charAt(prefix.length() - 1) == '/') { + if (path.charAt(0) == '/') { + // `prefix` ends with '/' and `path` starts with '/'. + // Avoid double-slash by stripping the first slash of `path`. + try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) { + return tmp.stringBuilder() + .append(prefix) + .append(path, 1, path.length()) + .toString(); + } } - // path1 ends with '/' and path2 does not start with '/'. + // `prefix` ends with '/' and `path` does not start with '/'. // Simple concatenation would suffice. - return path1 + path2; + return simpleConcat(prefix, path); } - if (path2.charAt(0) == '/' || path2.charAt(0) == '?') { - // path1 does not end with '/' and path2 starts with '/' or '?' + if (path.charAt(0) == '/' || path.charAt(0) == '?') { + // `prefix` does not end with '/' and `path` starts with '/' or '?' // Simple concatenation would suffice. - return path1 + path2; + return simpleConcat(prefix, path); + } + + // `prefix` does not end with '/' and `path` does not start with '/' or '?'. + // Need to insert '/' in-between. + try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) { + return tmp.stringBuilder() + .append(prefix) + .append('/') + .append(path) + .toString(); } + } - // path1 does not end with '/' and path2 does not start with '/' or '?'. - // Need to insert '/' between path1 and path2. - return path1 + '/' + path2; + private static String simpleConcat(String prefix, String path) { + try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) { + return tmp.stringBuilder() + .append(prefix) + .append(path) + .toString(); + } } /** @@ -557,7 +570,7 @@ public static long parseDirectiveValueAsSeconds(@Nullable String value) { public static RequestHeaders toArmeriaRequestHeaders(ChannelHandlerContext ctx, Http2Headers headers, boolean endOfStream, String scheme, ServerConfig cfg, - @Nullable PathAndQuery pathAndQuery) { + RequestTarget reqTarget) { assert headers instanceof ArmeriaHttp2Headers; final HttpHeadersBuilder builder = ((ArmeriaHttp2Headers) headers).delegate(); builder.endOfStream(endOfStream); @@ -565,15 +578,12 @@ public static RequestHeaders toArmeriaRequestHeaders(ChannelHandlerContext ctx, if (!builder.contains(HttpHeaderNames.SCHEME)) { builder.add(HttpHeaderNames.SCHEME, scheme); } - // if pathAndQuery == null, then either the path is invalid or *, and will be handled later. - if (pathAndQuery != null) { - builder.set(HttpHeaderNames.PATH, pathAndQuery.toString()); - } if (builder.get(HttpHeaderNames.AUTHORITY) == null && builder.get(HttpHeaderNames.HOST) == null) { final String defaultHostname = cfg.defaultVirtualHost().defaultHostname(); final int port = ((InetSocketAddress) ctx.channel().localAddress()).getPort(); builder.add(HttpHeaderNames.AUTHORITY, defaultHostname + ':' + port); } + builder.set(HttpHeaderNames.PATH, reqTarget.toString()); final List cookies = builder.getAll(HttpHeaderNames.COOKIE); if (cookies.size() > 1) { // Cookies must be concatenated into a single octet string. @@ -616,19 +626,14 @@ public static HttpHeaders toArmeria(Http2Headers http2Headers, boolean request, */ public static RequestHeaders toArmeria( ChannelHandlerContext ctx, HttpRequest in, - ServerConfig cfg, String scheme, @Nullable PathAndQuery pathAndQuery) throws URISyntaxException { + ServerConfig cfg, String scheme, RequestTarget reqTarget) throws URISyntaxException { final io.netty.handler.codec.http.HttpHeaders inHeaders = in.headers(); final RequestHeadersBuilder out = RequestHeaders.builder(); out.sizeHint(inHeaders.size()); - out.method(HttpMethod.valueOf(in.method().name())) - .scheme(scheme); - // if pathAndQuery == null, then either the path is invalid or *, and will be handled later. - if (pathAndQuery == null) { - out.path(in.uri()); - } else { - out.path(pathAndQuery.toString()); - } + out.method(firstNonNull(HttpMethod.tryParse(in.method().name()), HttpMethod.UNKNOWN)) + .scheme(scheme) + .path(reqTarget.toString()); // Add the HTTP headers which have not been consumed above toArmeria(inHeaders, out); diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultRequestTarget.java b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultRequestTarget.java new file mode 100644 index 00000000000..2fb0eaa0f08 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultRequestTarget.java @@ -0,0 +1,964 @@ +/* + * Copyright 2017 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.internal.common; + +import static io.netty.util.internal.StringUtil.decodeHexNibble; +import static java.util.Objects.requireNonNull; + +import java.net.URI; +import java.util.BitSet; +import java.util.Objects; + +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.RequestTargetForm; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals; + +import it.unimi.dsi.fastutil.Arrays; +import it.unimi.dsi.fastutil.bytes.ByteArrays; + +public final class DefaultRequestTarget implements RequestTarget { + + private static final String ALLOWED_COMMON_CHARS = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~:/?@!$&'()*,;="; + + /** + * The lookup table for the characters allowed in a path. + */ + private static final BitSet PATH_ALLOWED = toBitSet(ALLOWED_COMMON_CHARS + '+'); + + /** + * The lookup table for the characters allowed in a query. + */ + private static final BitSet QUERY_ALLOWED = toBitSet(ALLOWED_COMMON_CHARS + "[]"); + + /** + * The lookup table for the characters allowed in a fragment. + */ + private static final BitSet FRAGMENT_ALLOWED = PATH_ALLOWED; + + /** + * The lookup table for the characters that whose percent encoding must be preserved + * when used in a path because whether they are percent-encoded or not affects + * their semantics. We do not normalize '%2F' and '%2f' in the path to '/' for compatibility with + * other implementations in the ecosystem, e.g. HTTP/JSON to gRPC transcoding. See + * http.proto. + */ + private static final BitSet PATH_MUST_PRESERVE_ENCODING = toBitSet("/"); + + /** + * The lookup table for the characters that whose percent encoding must be preserved + * when used in a query because whether they are percent-encoded or not affects + * their semantics. For example, 'A%3dB=1' should NOT be normalized into 'A=B=1' because + * 'A=B=1` means 'A' is 'B=1' but 'A%3dB=1' means 'A=B' is '1'. + */ + private static final BitSet QUERY_MUST_PRESERVE_ENCODING = toBitSet(":/?[]@!$&'()*+,;="); + + /** + * The lookup table for the characters that whose percent encoding must be preserved when used + * in a fragment. We currently use the same table with {@link #PATH_MUST_PRESERVE_ENCODING}. + */ + private static final BitSet FRAGMENT_MUST_PRESERVE_ENCODING = PATH_MUST_PRESERVE_ENCODING; + + private static BitSet toBitSet(String chars) { + final BitSet bitSet = new BitSet(); + for (int i = 0; i < chars.length(); i++) { + bitSet.set(chars.charAt(i)); + } + return bitSet; + } + + private enum ComponentType { + CLIENT_PATH(PATH_ALLOWED, PATH_MUST_PRESERVE_ENCODING), + SERVER_PATH(PATH_ALLOWED, PATH_MUST_PRESERVE_ENCODING), + QUERY(QUERY_ALLOWED, QUERY_MUST_PRESERVE_ENCODING), + FRAGMENT(FRAGMENT_ALLOWED, FRAGMENT_MUST_PRESERVE_ENCODING); + + private final BitSet allowed; + private final BitSet mustPreserveEncoding; + + ComponentType(BitSet allowed, BitSet mustPreserveEncoding) { + this.allowed = allowed; + this.mustPreserveEncoding = mustPreserveEncoding; + } + + boolean isAllowed(int cp) { + return allowed.get(cp); + } + + boolean mustPreserveEncoding(int cp) { + return mustPreserveEncoding.get(cp); + } + } + + /** + * The table that converts a byte into a percent-encoded chars, e.g. 'A' -> "%41". + */ + private static final char[][] TO_PERCENT_ENCODED_CHARS = new char[256][]; + + static { + for (int i = 0; i < TO_PERCENT_ENCODED_CHARS.length; i++) { + TO_PERCENT_ENCODED_CHARS[i] = String.format("%%%02X", i).toCharArray(); + } + } + + private static final Bytes EMPTY_BYTES = new Bytes(0); + private static final Bytes SLASH_BYTES = new Bytes(new byte[] { '/' }); + + private static final RequestTarget INSTANCE_ASTERISK = createWithoutValidation( + RequestTargetForm.ASTERISK, + null, + null, + "*", + null, + null); + + /** + * The main implementation of {@link RequestTarget#forServer(String)}. + */ + @Nullable + public static RequestTarget forServer(String reqTarget, boolean allowDoubleDotsInQueryString) { + final RequestTarget cached = RequestTargetCache.getForServer(reqTarget); + if (cached != null) { + return cached; + } + + return slowForServer(reqTarget, allowDoubleDotsInQueryString); + } + + /** + * The main implementation of {@link RequestTarget#forClient(String, String)}. + */ + @Nullable + public static RequestTarget forClient(String reqTarget, @Nullable String prefix) { + requireNonNull(reqTarget, "reqTarget"); + + final int authorityPos = findAuthority(reqTarget); + if (authorityPos >= 0) { + // Note: For an absolute URI, we don't use `prefix` at all, + // so we can just use `reqTarget` in verbatim as a cache key. + final RequestTarget cached = RequestTargetCache.getForClient(reqTarget); + if (cached != null) { + return cached; + } + + // reqTarget is an absolute URI with scheme and authority. + return slowAbsoluteFormForClient(reqTarget, authorityPos); + } + + // Concatenate `prefix` and `reqTarget` if necessary. + final String actualReqTarget; + if (prefix == null || "*".equals(reqTarget)) { + // No prefix was given or request target is `*`. + actualReqTarget = reqTarget; + } else { + actualReqTarget = ArmeriaHttpUtil.concatPaths(prefix, reqTarget); + } + + final RequestTarget cached = RequestTargetCache.getForClient(actualReqTarget); + if (cached != null) { + return cached; + } + + // reqTarget is not an absolute URI; split by the first '?'. + return slowForClient(actualReqTarget, null, 0); + } + + /** + * (Advanced users only) Returns a newly created {@link RequestTarget} filled with the specified + * properties without any validation. + */ + public static RequestTarget createWithoutValidation( + RequestTargetForm form, @Nullable String scheme, @Nullable String authority, + String path, @Nullable String query, @Nullable String fragment) { + return new DefaultRequestTarget(form, scheme, authority, path, query, fragment); + } + + private final RequestTargetForm form; + @Nullable + private final String scheme; + @Nullable + private final String authority; + private final String path; + @Nullable + private final String query; + @Nullable + private final String fragment; + private boolean cached; + + private DefaultRequestTarget(RequestTargetForm form, @Nullable String scheme, @Nullable String authority, + String path, @Nullable String query, @Nullable String fragment) { + + assert (scheme != null && authority != null) || + (scheme == null && authority == null) : "scheme: " + scheme + ", authority: " + authority; + + this.form = form; + this.scheme = scheme; + this.authority = authority; + this.path = path; + this.query = query; + this.fragment = fragment; + } + + @Override + public RequestTargetForm form() { + return form; + } + + @Override + public String scheme() { + return scheme; + } + + @Override + public String authority() { + return authority; + } + + @Override + public String path() { + return path; + } + + @Override + public String query() { + return query; + } + + @Override + public String fragment() { + return fragment; + } + + /** + * Returns {@code true} if this {@link RequestTarget} is already stored in {@link RequestTargetCache}. + */ + public boolean isCached() { + return cached; + } + + /** + * Marks this {@link RequestTarget} as stored in {@link RequestTargetCache} so that it doesn't + * try to store again. + */ + public void setCached() { + cached = true; + } + + /** + * Returns a copy of this {@link RequestTarget} with its {@link #path()} overridden with + * the specified {@code path}. + */ + public RequestTarget withPath(String path) { + if (this.path == path) { + return this; + } + + return new DefaultRequestTarget(form, scheme, authority, path, query, fragment); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof DefaultRequestTarget)) { + return false; + } + + final DefaultRequestTarget that = (DefaultRequestTarget) o; + return path.equals(that.path) && + Objects.equals(query, that.query) && + Objects.equals(fragment, that.fragment) && + Objects.equals(authority, that.authority) && + Objects.equals(scheme, that.scheme); + } + + @Override + public int hashCode() { + return Objects.hash(scheme, authority, path, query, fragment); + } + + @Override + public String toString() { + try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) { + final StringBuilder buf = tmp.stringBuilder(); + if (scheme != null) { + buf.append(scheme).append("://").append(authority); + } + buf.append(path); + if (query != null) { + buf.append('?').append(query); + } + if (fragment != null) { + buf.append('#').append(fragment); + } + return buf.toString(); + } + } + + @Nullable + private static RequestTarget slowForServer(String reqTarget, boolean allowDoubleDotsInQueryString) { + final Bytes path; + final Bytes query; + + // Split by the first '?'. + final int queryPos = reqTarget.indexOf('?'); + if (queryPos >= 0) { + if ((path = decodePercentsAndEncodeToUtf8( + reqTarget, 0, queryPos, + ComponentType.SERVER_PATH, null)) == null) { + return null; + } + if ((query = decodePercentsAndEncodeToUtf8( + reqTarget, queryPos + 1, reqTarget.length(), + ComponentType.QUERY, EMPTY_BYTES)) == null) { + return null; + } + } else { + if ((path = decodePercentsAndEncodeToUtf8( + reqTarget, 0, reqTarget.length(), + ComponentType.SERVER_PATH, null)) == null) { + return null; + } + query = null; + } + + // Reject a relative path and accept an asterisk (e.g. OPTIONS * HTTP/1.1). + if (isRelativePath(path)) { + if (query == null && path.length == 1 && path.data[0] == '*') { + return INSTANCE_ASTERISK; + } else { + // Do not accept a relative path. + return null; + } + } + + // Reject the prohibited patterns. + if (pathContainsDoubleDots(path)) { + return null; + } + if (!allowDoubleDotsInQueryString && queryContainsDoubleDots(query)) { + return null; + } + + return new DefaultRequestTarget(RequestTargetForm.ORIGIN, + null, + null, + encodePathToPercents(path), + encodeQueryToPercents(query), + null); + } + + @Nullable + private static RequestTarget slowAbsoluteFormForClient(String reqTarget, int authorityPos) { + // Extract scheme and authority while looking for path. + final URI schemeAndAuthority; + final String scheme = reqTarget.substring(0, authorityPos - 3); + final int nextPos = findNextComponent(reqTarget, authorityPos); + final String authority; + if (nextPos < 0) { + // Found no other components after authority + authority = reqTarget.substring(authorityPos); + } else { + // Path, query or fragment exists. + authority = reqTarget.substring(authorityPos, nextPos); + } + + // Reject a URI with an empty authority. + if (authority.isEmpty()) { + return null; + } + + // Normalize scheme and authority. + schemeAndAuthority = normalizeSchemeAndAuthority(scheme, authority); + if (schemeAndAuthority == null) { + // Invalid scheme or authority. + return null; + } + + if (nextPos < 0) { + return new DefaultRequestTarget(RequestTargetForm.ABSOLUTE, + schemeAndAuthority.getScheme(), + schemeAndAuthority.getRawAuthority(), + "/", + null, + null); + } + + return slowForClient(reqTarget, schemeAndAuthority, nextPos); + } + + private static int findNextComponent(String reqTarget, int startPos) { + for (int i = startPos; i < reqTarget.length(); i++) { + switch (reqTarget.charAt(i)) { + case '/': + case '?': + case '#': + return i; + } + } + + return -1; + } + + @Nullable + private static RequestTarget slowForClient(String reqTarget, + @Nullable URI schemeAndAuthority, + int pathPos) { + final Bytes fragment; + final Bytes path; + final Bytes query; + // Find where a query string and a fragment starts. + final int queryPos; + final int fragmentPos; + // Note: We don't start from `pathPos + 1` but from `pathPos` just in case path is empty. + final int maybeQueryPos = reqTarget.indexOf('?', pathPos); + final int maybeFragmentPos = reqTarget.indexOf('#', pathPos); + if (maybeQueryPos >= 0) { + // Found '?'. + if (maybeFragmentPos >= 0) { + // Found '#', too. + fragmentPos = maybeFragmentPos; + if (maybeQueryPos < maybeFragmentPos) { + // '#' appeared after '?', e.g. ?foo#bar + queryPos = maybeQueryPos; + } else { + // '#' appeared before '?', e.g. #foo?bar. + // It means the '?' we found is not a part of query string. + queryPos = -1; + } + } else { + // No '#' in reqTarget. + queryPos = maybeQueryPos; + fragmentPos = -1; + } + } else { + // No '?'. + queryPos = -1; + fragmentPos = maybeFragmentPos; + } + + // Split into path, query and fragment. + if (queryPos >= 0) { + if ((path = decodePercentsAndEncodeToUtf8( + reqTarget, pathPos, queryPos, + ComponentType.CLIENT_PATH, SLASH_BYTES)) == null) { + return null; + } + + if (fragmentPos >= 0) { + // path?query#fragment + if ((query = decodePercentsAndEncodeToUtf8( + reqTarget, queryPos + 1, fragmentPos, + ComponentType.QUERY, EMPTY_BYTES)) == null) { + return null; + } + if ((fragment = decodePercentsAndEncodeToUtf8( + reqTarget, fragmentPos + 1, reqTarget.length(), + ComponentType.FRAGMENT, EMPTY_BYTES)) == null) { + return null; + } + } else { + // path?query + if ((query = decodePercentsAndEncodeToUtf8( + reqTarget, queryPos + 1, reqTarget.length(), + ComponentType.QUERY, EMPTY_BYTES)) == null) { + return null; + } + fragment = null; + } + } else { + if (fragmentPos >= 0) { + // path#fragment + if ((path = decodePercentsAndEncodeToUtf8( + reqTarget, pathPos, fragmentPos, + ComponentType.CLIENT_PATH, EMPTY_BYTES)) == null) { + return null; + } + query = null; + if ((fragment = decodePercentsAndEncodeToUtf8( + reqTarget, fragmentPos + 1, reqTarget.length(), + ComponentType.FRAGMENT, EMPTY_BYTES)) == null) { + return null; + } + } else { + // path + if ((path = decodePercentsAndEncodeToUtf8( + reqTarget, pathPos, reqTarget.length(), + ComponentType.CLIENT_PATH, EMPTY_BYTES)) == null) { + return null; + } + query = null; + fragment = null; + } + } + + // Accept an asterisk (e.g. OPTIONS * HTTP/1.1). + if (query == null && path.length == 1 && path.data[0] == '*') { + return INSTANCE_ASTERISK; + } + + final String encodedPath; + if (isRelativePath(path)) { + // Turn a relative path into an absolute one. + encodedPath = '/' + encodePathToPercents(path); + } else { + encodedPath = encodePathToPercents(path); + } + + final String encodedQuery = encodeQueryToPercents(query); + final String encodedFragment = encodeFragmentToPercents(fragment); + + if (schemeAndAuthority != null) { + return new DefaultRequestTarget(RequestTargetForm.ABSOLUTE, + schemeAndAuthority.getScheme(), + schemeAndAuthority.getRawAuthority(), + encodedPath, + encodedQuery, + encodedFragment); + } else { + return new DefaultRequestTarget(RequestTargetForm.ORIGIN, + null, + null, + encodedPath, + encodedQuery, + encodedFragment); + } + } + + /** + * Returns the index of the authority part if the specified {@code reqTarget} is an absolute URI. + * Returns {@code -1} otherwise. + */ + private static int findAuthority(String reqTarget) { + final int firstColonIdx = reqTarget.indexOf(':'); + if (firstColonIdx <= 0 || reqTarget.length() <= firstColonIdx + 3) { + return -1; + } + final int firstSlashIdx = reqTarget.indexOf('/'); + if (firstSlashIdx <= 0 || firstSlashIdx < firstColonIdx) { + return -1; + } + + if (reqTarget.charAt(firstColonIdx + 1) == '/' && reqTarget.charAt(firstColonIdx + 2) == '/') { + return firstColonIdx + 3; + } + + return -1; + } + + @Nullable + private static URI normalizeSchemeAndAuthority(String scheme, String authority) { + try { + return new URI(scheme, authority, null, null, null); + } catch (Exception unused) { + return null; + } + } + + private static boolean isRelativePath(Bytes path) { + return path.length == 0 || path.data[0] != '/' || path.isEncoded(0); + } + + @Nullable + private static Bytes decodePercentsAndEncodeToUtf8(String value, int start, int end, + ComponentType type, @Nullable Bytes whenEmpty) { + final int length = end - start; + if (length == 0) { + return whenEmpty; + } + + final Bytes buf = new Bytes(Math.max(length * 3 / 2, 4)); + boolean wasSlash = false; + for (final CodePointIterator i = new CodePointIterator(value, start, end); + i.hasNextCodePoint();/* noop */) { + final int pos = i.position(); + final int cp = i.nextCodePoint(); + + if (cp == '%') { + final int hexEnd = pos + 3; + if (hexEnd > end) { + // '%' or '%x' (must be followed by two hexadigits) + return null; + } + + final int digit1 = decodeHexNibble(value.charAt(pos + 1)); + final int digit2 = decodeHexNibble(value.charAt(pos + 2)); + if (digit1 < 0 || digit2 < 0) { + // The first or second digit is not hexadecimal. + return null; + } + + final int decoded = (digit1 << 4) | digit2; + if (type.mustPreserveEncoding(decoded)) { + buf.ensure(1); + buf.addEncoded((byte) decoded); + wasSlash = false; + } else if (appendOneByte(buf, decoded, wasSlash, type)) { + wasSlash = decoded == '/'; + } else { + return null; + } + + i.position(hexEnd); + continue; + } + + if (cp == '+' && type == ComponentType.QUERY) { + buf.ensure(1); + buf.addEncoded((byte) ' '); + wasSlash = false; + continue; + } + + if (cp <= 0x7F) { + if (!appendOneByte(buf, cp, wasSlash, type)) { + return null; + } + wasSlash = cp == '/'; + continue; + } + + if (cp <= 0x7ff) { + buf.ensure(2); + buf.addEncoded((byte) ((cp >>> 6) | 0b110_00000)); + buf.addEncoded((byte) (cp & 0b111111 | 0b10_000000)); + } else if (cp <= 0xffff) { + buf.ensure(3); + buf.addEncoded((byte) ((cp >>> 12) | 0b1110_0000)); + buf.addEncoded((byte) (((cp >>> 6) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) ((cp & 0b111111) | 0b10_000000)); + } else if (cp <= 0x1fffff) { + buf.ensure(4); + buf.addEncoded((byte) ((cp >>> 18) | 0b11110_000)); + buf.addEncoded((byte) (((cp >>> 12) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) (((cp >>> 6) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) ((cp & 0b111111) | 0b10_000000)); + } else if (cp <= 0x3ffffff) { + // A valid unicode character will never reach here, but for completeness. + // http://unicode.org/mail-arch/unicode-ml/Archives-Old/UML018/0330.html + buf.ensure(5); + buf.addEncoded((byte) ((cp >>> 24) | 0b111110_00)); + buf.addEncoded((byte) (((cp >>> 18) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) (((cp >>> 12) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) (((cp >>> 6) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) ((cp & 0b111111) | 0b10_000000)); + } else { + // A valid unicode character will never reach here, but for completeness. + // http://unicode.org/mail-arch/unicode-ml/Archives-Old/UML018/0330.html + buf.ensure(6); + buf.addEncoded((byte) ((cp >>> 30) | 0b1111110_0)); + buf.addEncoded((byte) (((cp >>> 24) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) (((cp >>> 18) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) (((cp >>> 12) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) (((cp >>> 6) & 0b111111) | 0b10_000000)); + buf.addEncoded((byte) ((cp & 0b111111) | 0b10_000000)); + } + + wasSlash = false; + } + + return buf; + } + + private static boolean appendOneByte(Bytes buf, int cp, boolean wasSlash, ComponentType type) { + if (cp == 0x7F) { + // Reject the control character: 0x7F + return false; + } + + if (cp >>> 5 == 0) { + // Reject the control characters: 0x00..0x1F + if (type != ComponentType.QUERY) { + return false; + } + + if (cp != 0x0A && cp != 0x0D && cp != 0x09) { + // .. except 0x0A (LF), 0x0D (CR) and 0x09 (TAB) because they are used in a form. + return false; + } + } + + if (cp == '/' && type == ComponentType.SERVER_PATH) { + if (!wasSlash) { + buf.ensure(1); + buf.add((byte) '/'); + } else { + // Remove the consecutive slashes: '/path//with///consecutive////slashes'. + } + } else { + buf.ensure(1); + if (type.isAllowed(cp)) { + buf.add((byte) cp); + } else { + buf.addEncoded((byte) cp); + } + } + + return true; + } + + private static boolean pathContainsDoubleDots(Bytes path) { + final int length = path.length; + byte b0 = 0; + byte b1 = 0; + byte b2 = '/'; + for (int i = 1; i < length; i++) { + final byte b3 = path.data[i]; + // Flag if the last four bytes are `/../`. + if (b1 == '.' && b2 == '.' && isSlash(b0) && isSlash(b3)) { + return true; + } + b0 = b1; + b1 = b2; + b2 = b3; + } + + // Flag if the last three bytes are `/..`. + return b1 == '.' && b2 == '.' && isSlash(b0); + } + + private static boolean queryContainsDoubleDots(@Nullable Bytes query) { + if (query == null) { + return false; + } + + final int length = query.length; + boolean lookingForEquals = true; + byte b0 = 0; + byte b1 = 0; + byte b2 = '/'; + for (int i = 0; i < length; i++) { + byte b3 = query.data[i]; + + // Treat the delimiters as `/` so that we can use isSlash() for matching them. + switch (b3) { + case '=': + // Treat only the first `=` as `/`, e.g. + // - `foo=..` and `foo=../` should be flagged. + // - `foo=..=` shouldn't be flagged because `..=` is not a relative path. + if (lookingForEquals) { + lookingForEquals = false; + b3 = '/'; + } + break; + case '&': + case ';': + b3 = '/'; + lookingForEquals = true; + break; + } + + // Flag if the last four bytes are `/../` or `/..&` + if (b1 == '.' && b2 == '.' && isSlash(b0) && isSlash(b3)) { + return true; + } + + b0 = b1; + b1 = b2; + b2 = b3; + } + + return b1 == '.' && b2 == '.' && isSlash(b0); + } + + private static boolean isSlash(byte b) { + switch (b) { + case '/': + case '\\': + return true; + default: + return false; + } + } + + private static String encodePathToPercents(Bytes value) { + if (!value.hasEncodedBytes()) { + // Deprecated, but it fits perfect for our use case. + // noinspection deprecation + return new String(value.data, 0, 0, value.length); + } + + // Slow path: some percent-encoded chars. + return slowEncodePathToPercents(value); + } + + @Nullable + private static String encodeQueryToPercents(@Nullable Bytes value) { + if (value == null) { + return null; + } + + if (!value.hasEncodedBytes()) { + // Deprecated, but it fits perfect for our use case. + // noinspection deprecation + return new String(value.data, 0, 0, value.length); + } + + // Slow path: some percent-encoded chars. + return slowEncodeQueryToPercents(value); + } + + @Nullable + private static String encodeFragmentToPercents(@Nullable Bytes value) { + if (value == null) { + return null; + } + + if (!value.hasEncodedBytes()) { + // Deprecated, but it fits perfect for our use case. + // noinspection deprecation + return new String(value.data, 0, 0, value.length); + } + + // Slow path: some percent-encoded chars. + return slowEncodePathToPercents(value); + } + + private static String slowEncodePathToPercents(Bytes value) { + final int length = value.length; + final StringBuilder buf = new StringBuilder(length + value.numEncodedBytes() * 2); + for (int i = 0; i < length; i++) { + final int b = value.data[i] & 0xFF; + + if (value.isEncoded(i)) { + buf.append(TO_PERCENT_ENCODED_CHARS[b]); + continue; + } + + buf.append((char) b); + } + + return buf.toString(); + } + + private static String slowEncodeQueryToPercents(Bytes value) { + final int length = value.length; + final StringBuilder buf = new StringBuilder(length + value.numEncodedBytes() * 2); + for (int i = 0; i < length; i++) { + final int b = value.data[i] & 0xFF; + + if (value.isEncoded(i)) { + if (b == ' ') { + buf.append('+'); + } else { + buf.append(TO_PERCENT_ENCODED_CHARS[b]); + } + continue; + } + + buf.append((char) b); + } + + return buf.toString(); + } + + private static final class Bytes { + byte[] data; + int length; + @Nullable + private BitSet encoded; + private int numEncodedBytes; + + Bytes(int initialCapacity) { + data = new byte[initialCapacity]; + } + + Bytes(byte[] data) { + this.data = data; + length = data.length; + } + + void add(byte b) { + data[length++] = b; + } + + void addEncoded(byte b) { + if (encoded == null) { + encoded = new BitSet(); + } + encoded.set(length); + data[length++] = b; + numEncodedBytes++; + } + + boolean isEncoded(int index) { + return encoded != null && encoded.get(index); + } + + boolean hasEncodedBytes() { + return encoded != null; + } + + int numEncodedBytes() { + return numEncodedBytes; + } + + void ensure(int numBytes) { + int newCapacity = length + numBytes; + if (newCapacity <= data.length) { + return; + } + + newCapacity = + (int) Math.max(Math.min((long) data.length + (data.length >> 1), Arrays.MAX_ARRAY_SIZE), + newCapacity); + + data = ByteArrays.forceCapacity(data, newCapacity, length); + } + } + + private static final class CodePointIterator { + private final CharSequence str; + private final int end; + private int pos; + + CodePointIterator(CharSequence str, int start, int end) { + this.str = str; + this.end = end; + pos = start; + } + + int position() { + return pos; + } + + void position(int pos) { + this.pos = pos; + } + + boolean hasNextCodePoint() { + return pos < end; + } + + int nextCodePoint() { + assert pos < end; + + final char c1 = str.charAt(pos++); + if (Character.isHighSurrogate(c1) && pos < end) { + final char c2 = str.charAt(pos); + if (Character.isLowSurrogate(c2)) { + pos++; + return Character.toCodePoint(c1, c2); + } + } + + return c1; + } + } +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java index 183389144f0..ba9ef361f38 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java @@ -16,7 +16,6 @@ package com.linecorp.armeria.internal.common; -import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; import java.net.SocketAddress; @@ -34,6 +33,8 @@ import com.linecorp.armeria.common.RequestContextStorage; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestId; +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.RequestTargetForm; import com.linecorp.armeria.common.RpcRequest; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; @@ -59,14 +60,12 @@ public abstract class NonWrappingRequestContext implements RequestContextExtensi private SessionProtocol sessionProtocol; private final RequestId id; private final HttpMethod method; - private String path; + private RequestTarget reqTarget; private final ExchangeType exchangeType; @Nullable private String decodedPath; @Nullable - private String query; - @Nullable private volatile HttpRequest req; @Nullable private volatile RpcRequest rpcReq; @@ -83,7 +82,7 @@ public abstract class NonWrappingRequestContext implements RequestContextExtensi */ protected NonWrappingRequestContext( MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, HttpMethod method, String path, @Nullable String query, ExchangeType exchangeType, + RequestId id, HttpMethod method, RequestTarget reqTarget, ExchangeType exchangeType, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, @Nullable AttributesGetters rootAttributeMap) { @@ -97,8 +96,7 @@ protected NonWrappingRequestContext( this.sessionProtocol = requireNonNull(sessionProtocol, "sessionProtocol"); this.id = requireNonNull(id, "id"); this.method = requireNonNull(method, "method"); - this.path = requireNonNull(path, "path"); - this.query = query; + this.reqTarget = requireNonNull(reqTarget, "reqTarget"); this.exchangeType = requireNonNull(exchangeType, "exchangeType"); this.req = req; this.rpcReq = rpcReq; @@ -115,10 +113,22 @@ public final RpcRequest rpcRequest() { } @Override - public void updateRequest(HttpRequest req) { + public final void updateRequest(HttpRequest req) { requireNonNull(req, "req"); - validateHeaders(req.headers()); - unsafeUpdateRequest(req); + final RequestHeaders headers = req.headers(); + final RequestTarget reqTarget = validateHeaders(headers); + + if (reqTarget == null) { + throw new IllegalArgumentException("invalid path: " + headers.path()); + } + if (reqTarget.form() == RequestTargetForm.ABSOLUTE) { + throw new IllegalArgumentException("invalid path: " + headers.path() + + " (must not contain scheme or authority)"); + } + + this.req = req; + this.reqTarget = reqTarget; + decodedPath = null; } @Override @@ -128,22 +138,11 @@ public final void updateRpcRequest(RpcRequest rpcReq) { } /** - * Validates the specified {@link RequestHeaders}. By default, this method will raise - * an {@link IllegalArgumentException} if it does not have {@code ":scheme"} or {@code ":authority"} - * header. + * Validates the specified {@link RequestHeaders} and returns the {@link RequestTarget} + * returned by {@link RequestTarget#forClient(String)} or {@link RequestTarget#forServer(String)}. */ - protected void validateHeaders(RequestHeaders headers) { - checkArgument(headers.scheme() != null && headers.authority() != null, - "must set ':scheme' and ':authority' headers"); - } - - /** - * Replaces the {@link HttpRequest} associated with this context with the specified one - * without any validation. Internal use only. Use it at your own risk. - */ - protected void unsafeUpdateRequest(HttpRequest req) { - this.req = req; - } + @Nullable + protected abstract RequestTarget validateHeaders(RequestHeaders headers); @Override public final SessionProtocol sessionProtocol() { @@ -189,11 +188,11 @@ public final HttpMethod method() { @Override public final String path() { - return path; + return reqTarget.path(); } - protected void path(String path) { - this.path = requireNonNull(path, "path"); + protected final RequestTarget requestTarget() { + return reqTarget; } @Override @@ -203,16 +202,12 @@ public final String decodedPath() { return decodedPath; } - return this.decodedPath = ArmeriaHttpUtil.decodePath(path); + return this.decodedPath = ArmeriaHttpUtil.decodePath(path()); } @Override public final String query() { - return query; - } - - protected void query(@Nullable String query) { - this.query = query; + return reqTarget.query(); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/PathAndQuery.java b/core/src/main/java/com/linecorp/armeria/internal/common/PathAndQuery.java deleted file mode 100644 index 7444657e5e0..00000000000 --- a/core/src/main/java/com/linecorp/armeria/internal/common/PathAndQuery.java +++ /dev/null @@ -1,654 +0,0 @@ -/* - * Copyright 2017 LINE Corporation - * - * LINE Corporation licenses this file to you under the Apache License, - * version 2.0 (the "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at: - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ - -package com.linecorp.armeria.internal.common; - -import static io.netty.util.internal.StringUtil.decodeHexNibble; -import static java.util.Objects.requireNonNull; - -import java.util.BitSet; -import java.util.Objects; -import java.util.Set; - -import com.github.benmanes.caffeine.cache.Cache; -import com.github.benmanes.caffeine.cache.Caffeine; -import com.google.common.annotations.VisibleForTesting; - -import com.linecorp.armeria.common.Flags; -import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.common.metric.MeterIdPrefix; -import com.linecorp.armeria.internal.common.metric.CaffeineMetricSupport; - -import io.micrometer.core.instrument.MeterRegistry; -import it.unimi.dsi.fastutil.Arrays; -import it.unimi.dsi.fastutil.bytes.ByteArrays; - -/** - * A parser of the raw path and query components of an HTTP path. Performs validation and allows caching of - * results. - */ -public final class PathAndQuery { - - private static final PathAndQuery ROOT_PATH_QUERY = new PathAndQuery("/", null); - - /** - * The lookup table for the characters allowed in a path. - */ - private static final BitSet ALLOWED_PATH_CHARS = new BitSet(); - - /** - * The lookup table for the characters allowed in a query string. - */ - private static final BitSet ALLOWED_QUERY_CHARS = new BitSet(); - - /** - * The lookup table for the characters that whose percent encoding must be preserved - * when used in a query string because whether they are percent-encoded or not affects - * their semantics. For example, 'A%3dB=1' should NOT be normalized into 'A=B=1' because - * 'A=B=1` means 'A' is 'B=1' but 'A%3dB=1' means 'A=B' is '1'. - */ - private static final BitSet MUST_PRESERVE_ENCODING_IN_QUERY = new BitSet(); - - /** - * The table that converts a byte into a percent-encoded chars, e.g. 'A' -> "%41". - */ - private static final char[][] TO_PERCENT_ENCODED_CHARS = new char[256][]; - - static { - final String commonAllowedChars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~:/?@!$&'()*,;="; - final String allowedPathChars = commonAllowedChars + '+'; - for (int i = 0; i < allowedPathChars.length(); i++) { - ALLOWED_PATH_CHARS.set(allowedPathChars.charAt(i)); - } - - final String allowedQueryChars = commonAllowedChars + "[]"; - for (int i = 0; i < allowedQueryChars.length(); i++) { - ALLOWED_QUERY_CHARS.set(allowedQueryChars.charAt(i)); - } - - final String mustPreserveEncodingInQuery = ":/?[]@!$&'()*+,;="; - for (int i = 0; i < mustPreserveEncodingInQuery.length(); i++) { - MUST_PRESERVE_ENCODING_IN_QUERY.set(mustPreserveEncodingInQuery.charAt(i)); - } - - for (int i = 0; i < TO_PERCENT_ENCODED_CHARS.length; i++) { - TO_PERCENT_ENCODED_CHARS[i] = String.format("%%%02X", i).toCharArray(); - } - } - - private static final Bytes EMPTY_QUERY = new Bytes(0); - private static final Bytes ROOT_PATH = new Bytes(new byte[] { '/' }); - - @Nullable - private static final Cache CACHE = - Flags.parsedPathCacheSpec() != null ? buildCache(Flags.parsedPathCacheSpec()) : null; - - private static Cache buildCache(String spec) { - return Caffeine.from(spec).build(); - } - - public static void registerMetrics(MeterRegistry registry, MeterIdPrefix idPrefix) { - if (CACHE != null) { - CaffeineMetricSupport.setup(registry, idPrefix, CACHE); - } - } - - /** - * Clears the currently cached parsed paths. Only for use in tests. - */ - @VisibleForTesting - public static void clearCachedPaths() { - requireNonNull(CACHE, "CACHE"); - CACHE.asMap().clear(); - } - - /** - * Returns paths that have had their parse result cached. Only for use in tests. - */ - @VisibleForTesting - public static Set cachedPaths() { - requireNonNull(CACHE, "CACHE"); - return CACHE.asMap().keySet(); - } - - /** - * Validates the {@link String} that contains an absolute path and a query, and splits them into - * the path part and the query part. If the path is usable (e.g., can be served a successful response from - * the server and doesn't have variable path parameters), {@link PathAndQuery#storeInCache(String)} should - * be called to cache the parsing result for faster future invocations. - * - * @return a {@link PathAndQuery} with the absolute path and query, or {@code null} if the specified - * {@link String} is not an absolute path or invalid. - */ - @Nullable - public static PathAndQuery parse(@Nullable String rawPath) { - return parse(rawPath, Flags.allowDoubleDotsInQueryString()); - } - - @VisibleForTesting - @Nullable - static PathAndQuery parse(@Nullable String rawPath, boolean allowDoubleDotsInQueryString) { - if (CACHE != null && rawPath != null) { - final PathAndQuery parsed = CACHE.getIfPresent(rawPath); - if (parsed != null) { - return parsed; - } - } - return splitPathAndQuery(rawPath, allowDoubleDotsInQueryString); - } - - /** - * Stores this {@link PathAndQuery} into cache for the given raw path. This should be used by callers when - * the parsed result was valid (e.g., when a server is able to successfully handle the parsed path). - */ - public void storeInCache(@Nullable String rawPath) { - if (CACHE != null && !cached && rawPath != null) { - cached = true; - CACHE.put(rawPath, this); - } - } - - private final String path; - @Nullable - private final String query; - - private boolean cached; - - private PathAndQuery(String path, @Nullable String query) { - this.path = path; - this.query = query; - } - - public String path() { - return path; - } - - @Nullable - public String query() { - return query; - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - - if (!(o instanceof PathAndQuery)) { - return false; - } - - final PathAndQuery that = (PathAndQuery) o; - return Objects.equals(path, that.path) && - Objects.equals(query, that.query); - } - - @Override - public int hashCode() { - return Objects.hash(path, query); - } - - @Override - public String toString() { - if (query == null) { - return path; - } - return path + '?' + query; - } - - @Nullable - private static PathAndQuery splitPathAndQuery(@Nullable String pathAndQuery, - boolean allowDoubleDotsInQueryString) { - final Bytes path; - final Bytes query; - - if (pathAndQuery == null) { - return ROOT_PATH_QUERY; - } - - // Split by the first '?'. - final int queryPos = pathAndQuery.indexOf('?'); - if (queryPos >= 0) { - if ((path = decodePercentsAndEncodeToUtf8( - pathAndQuery, 0, queryPos, true)) == null) { - return null; - } - if ((query = decodePercentsAndEncodeToUtf8( - pathAndQuery, queryPos + 1, pathAndQuery.length(), false)) == null) { - return null; - } - } else { - if ((path = decodePercentsAndEncodeToUtf8( - pathAndQuery, 0, pathAndQuery.length(), true)) == null) { - return null; - } - query = null; - } - - if (path.data[0] != '/' || path.isEncoded(0)) { - // Do not accept a relative path. - return null; - } - - // Reject the prohibited patterns. - if (pathContainsDoubleDots(path)) { - return null; - } - if (!allowDoubleDotsInQueryString && queryContainsDoubleDots(query)) { - return null; - } - - return new PathAndQuery(encodePathToPercents(path), encodeQueryToPercents(query)); - } - - /** - * Decodes a percent-encoded query string. This method is only used for {@code PathAndQueryTest}. - */ - @Nullable - @VisibleForTesting - static String decodePercentEncodedQuery(String query) { - final Bytes bytes = decodePercentsAndEncodeToUtf8(query, 0, query.length(), false); - return encodeQueryToPercents(bytes); - } - - @Nullable - private static Bytes decodePercentsAndEncodeToUtf8(String value, int start, int end, boolean isPath) { - final int length = end - start; - if (length == 0) { - return isPath ? ROOT_PATH : EMPTY_QUERY; - } - - final Bytes buf = new Bytes(Math.max(length * 3 / 2, 4)); - boolean wasSlash = false; - for (final CodePointIterator i = new CodePointIterator(value, start, end); - i.hasNextCodePoint();/* noop */) { - final int pos = i.position(); - final int cp = i.nextCodePoint(); - - if (cp == '%') { - final int hexEnd = pos + 3; - if (hexEnd > end) { - // '%' or '%x' (must be followed by two hexadigits) - return null; - } - - final int digit1 = decodeHexNibble(value.charAt(pos + 1)); - final int digit2 = decodeHexNibble(value.charAt(pos + 2)); - if (digit1 < 0 || digit2 < 0) { - // The first or second digit is not hexadecimal. - return null; - } - - final int decoded = (digit1 << 4) | digit2; - if (isPath) { - if (decoded == '/') { - // Do not decode '%2F' and '%2f' in the path to '/' for compatibility with - // other implementations in the ecosystem, e.g. HTTP/JSON to gRPC transcoding. - // https://github.com/googleapis/googleapis/blob/02710fa0ea5312d79d7fb986c9c9823fb41049a9/google/api/http.proto#L257-L258 - buf.ensure(1); - buf.addEncoded((byte) '/'); - wasSlash = false; - } else { - if (appendOneByte(buf, decoded, wasSlash, isPath)) { - wasSlash = false; - } else { - return null; - } - } - } else { - // If query: - if (MUST_PRESERVE_ENCODING_IN_QUERY.get(decoded)) { - buf.ensure(1); - buf.addEncoded((byte) decoded); - wasSlash = false; - } else if (appendOneByte(buf, decoded, wasSlash, isPath)) { - wasSlash = decoded == '/'; - } else { - return null; - } - } - - i.position(hexEnd); - continue; - } - - if (cp == '+' && !isPath) { - buf.ensure(1); - buf.addEncoded((byte) ' '); - wasSlash = false; - continue; - } - - if (cp <= 0x7F) { - if (!appendOneByte(buf, cp, wasSlash, isPath)) { - return null; - } - wasSlash = cp == '/'; - continue; - } - - if (cp <= 0x7ff) { - buf.ensure(2); - buf.addEncoded((byte) ((cp >>> 6) | 0b110_00000)); - buf.addEncoded((byte) (cp & 0b111111 | 0b10_000000)); - } else if (cp <= 0xffff) { - buf.ensure(3); - buf.addEncoded((byte) ((cp >>> 12) | 0b1110_0000)); - buf.addEncoded((byte) (((cp >>> 6) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) ((cp & 0b111111) | 0b10_000000)); - } else if (cp <= 0x1fffff) { - buf.ensure(4); - buf.addEncoded((byte) ((cp >>> 18) | 0b11110_000)); - buf.addEncoded((byte) (((cp >>> 12) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) (((cp >>> 6) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) ((cp & 0b111111) | 0b10_000000)); - } else if (cp <= 0x3ffffff) { - // A valid unicode character will never reach here, but for completeness. - // http://unicode.org/mail-arch/unicode-ml/Archives-Old/UML018/0330.html - buf.ensure(5); - buf.addEncoded((byte) ((cp >>> 24) | 0b111110_00)); - buf.addEncoded((byte) (((cp >>> 18) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) (((cp >>> 12) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) (((cp >>> 6) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) ((cp & 0b111111) | 0b10_000000)); - } else { - // A valid unicode character will never reach here, but for completeness. - // http://unicode.org/mail-arch/unicode-ml/Archives-Old/UML018/0330.html - buf.ensure(6); - buf.addEncoded((byte) ((cp >>> 30) | 0b1111110_0)); - buf.addEncoded((byte) (((cp >>> 24) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) (((cp >>> 18) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) (((cp >>> 12) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) (((cp >>> 6) & 0b111111) | 0b10_000000)); - buf.addEncoded((byte) ((cp & 0b111111) | 0b10_000000)); - } - - wasSlash = false; - } - - return buf; - } - - private static boolean appendOneByte(Bytes buf, int cp, boolean wasSlash, boolean isPath) { - if (cp == 0x7F) { - // Reject the control character: 0x7F - return false; - } - - if (cp >>> 5 == 0) { - // Reject the control characters: 0x00..0x1F - if (isPath) { - return false; - } else if (cp != 0x0A && cp != 0x0D && cp != 0x09) { - // .. except 0x0A (LF), 0x0D (CR) and 0x09 (TAB) because they are used in a form. - return false; - } - } - - if (cp == '/' && isPath) { - if (!wasSlash) { - buf.ensure(1); - buf.add((byte) '/'); - } else { - // Remove the consecutive slashes: '/path//with///consecutive////slashes'. - } - } else { - final BitSet allowedChars = isPath ? ALLOWED_PATH_CHARS : ALLOWED_QUERY_CHARS; - buf.ensure(1); - if (allowedChars.get(cp)) { - buf.add((byte) cp); - } else { - buf.addEncoded((byte) cp); - } - } - - return true; - } - - private static boolean pathContainsDoubleDots(Bytes path) { - final int length = path.length; - byte b0 = 0; - byte b1 = 0; - byte b2 = '/'; - for (int i = 1; i < length; i++) { - final byte b3 = path.data[i]; - // Flag if the last four bytes are `/../`. - if (b1 == '.' && b2 == '.' && isSlash(b0) && isSlash(b3)) { - return true; - } - b0 = b1; - b1 = b2; - b2 = b3; - } - - // Flag if the last three bytes are `/..`. - return b1 == '.' && b2 == '.' && isSlash(b0); - } - - private static boolean queryContainsDoubleDots(@Nullable Bytes query) { - if (query == null) { - return false; - } - - final int length = query.length; - boolean lookingForEquals = true; - byte b0 = 0; - byte b1 = 0; - byte b2 = '/'; - for (int i = 0; i < length; i++) { - byte b3 = query.data[i]; - - // Treat the delimiters as `/` so that we can use isSlash() for matching them. - switch (b3) { - case '=': - // Treat only the first `=` as `/`, e.g. - // - `foo=..` and `foo=../` should be flagged. - // - `foo=..=` shouldn't be flagged because `..=` is not a relative path. - if (lookingForEquals) { - lookingForEquals = false; - b3 = '/'; - } - break; - case '&': - case ';': - b3 = '/'; - lookingForEquals = true; - break; - } - - // Flag if the last four bytes are `/../` or `/..&` - if (b1 == '.' && b2 == '.' && isSlash(b0) && isSlash(b3)) { - return true; - } - - b0 = b1; - b1 = b2; - b2 = b3; - } - - return b1 == '.' && b2 == '.' && isSlash(b0); - } - - private static boolean isSlash(byte b) { - switch (b) { - case '/': - case '\\': - return true; - default: - return false; - } - } - - private static String encodePathToPercents(Bytes value) { - if (!value.hasEncodedBytes()) { - // Deprecated, but it fits perfect for our use case. - // noinspection deprecation - return new String(value.data, 0, 0, value.length); - } - - // Slow path: some percent-encoded chars. - return slowEncodePathToPercents(value); - } - - @Nullable - private static String encodeQueryToPercents(@Nullable Bytes value) { - if (value == null) { - return null; - } - - if (!value.hasEncodedBytes()) { - // Deprecated, but it fits perfect for our use case. - // noinspection deprecation - return new String(value.data, 0, 0, value.length); - } - - // Slow path: some percent-encoded chars. - return slowEncodeQueryToPercents(value); - } - - private static String slowEncodePathToPercents(Bytes value) { - final int length = value.length; - final StringBuilder buf = new StringBuilder(length + value.numEncodedBytes() * 2); - for (int i = 0; i < length; i++) { - final int b = value.data[i] & 0xFF; - - if (value.isEncoded(i)) { - buf.append(TO_PERCENT_ENCODED_CHARS[b]); - continue; - } - - buf.append((char) b); - } - - return buf.toString(); - } - - private static String slowEncodeQueryToPercents(Bytes value) { - final int length = value.length; - final StringBuilder buf = new StringBuilder(length + value.numEncodedBytes() * 2); - for (int i = 0; i < length; i++) { - final int b = value.data[i] & 0xFF; - - if (value.isEncoded(i)) { - if (b == ' ') { - buf.append('+'); - } else { - buf.append(TO_PERCENT_ENCODED_CHARS[b]); - } - continue; - } - - buf.append((char) b); - } - - return buf.toString(); - } - - private static final class Bytes { - byte[] data; - int length; - @Nullable - private BitSet encoded; - private int numEncodedBytes; - - Bytes(int initialCapacity) { - data = new byte[initialCapacity]; - } - - Bytes(byte[] data) { - this.data = data; - length = data.length; - } - - void add(byte b) { - data[length++] = b; - } - - void addEncoded(byte b) { - if (encoded == null) { - encoded = new BitSet(); - } - encoded.set(length); - data[length++] = b; - numEncodedBytes++; - } - - boolean isEncoded(int index) { - return encoded != null && encoded.get(index); - } - - boolean hasEncodedBytes() { - return encoded != null; - } - - int numEncodedBytes() { - return numEncodedBytes; - } - - void ensure(int numBytes) { - int newCapacity = length + numBytes; - if (newCapacity <= data.length) { - return; - } - - newCapacity = - (int) Math.max(Math.min((long) data.length + (data.length >> 1), Arrays.MAX_ARRAY_SIZE), - newCapacity); - - data = ByteArrays.forceCapacity(data, newCapacity, length); - } - } - - private static final class CodePointIterator { - private final CharSequence str; - private final int end; - private int pos; - - CodePointIterator(CharSequence str, int start, int end) { - this.str = str; - this.end = end; - pos = start; - } - - int position() { - return pos; - } - - void position(int pos) { - this.pos = pos; - } - - boolean hasNextCodePoint() { - return pos < end; - } - - int nextCodePoint() { - assert pos < end; - - final char c1 = str.charAt(pos++); - if (Character.isHighSurrogate(c1) && pos < end) { - final char c2 = str.charAt(pos); - if (Character.isLowSurrogate(c2)) { - pos++; - return Character.toCodePoint(c1, c2); - } - } - - return c1; - } - } -} diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/RequestTargetCache.java b/core/src/main/java/com/linecorp/armeria/internal/common/RequestTargetCache.java new file mode 100644 index 00000000000..62fb40ecbf4 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/common/RequestTargetCache.java @@ -0,0 +1,131 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.internal.common; + +import java.util.Set; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.google.common.annotations.VisibleForTesting; + +import com.linecorp.armeria.common.Flags; +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.metric.MeterIdPrefix; +import com.linecorp.armeria.internal.common.metric.CaffeineMetricSupport; + +import io.micrometer.core.instrument.MeterRegistry; + +public final class RequestTargetCache { + + private static final MeterIdPrefix METER_ID_PREFIX = new MeterIdPrefix("armeria.path.cache"); + + @Nullable + private static final Cache SERVER_CACHE = + Flags.parsedPathCacheSpec() != null ? buildCache(Flags.parsedPathCacheSpec()) : null; + + @Nullable + private static final Cache CLIENT_CACHE = + Flags.parsedPathCacheSpec() != null ? buildCache(Flags.parsedPathCacheSpec()) : null; + + private static Cache buildCache(String spec) { + return Caffeine.from(spec).build(); + } + + public static void registerServerMetrics(MeterRegistry registry) { + if (SERVER_CACHE != null) { + CaffeineMetricSupport.setup(registry, METER_ID_PREFIX.withTags("type", "server"), SERVER_CACHE); + } + } + + public static void registerClientMetrics(MeterRegistry registry) { + if (CLIENT_CACHE != null) { + CaffeineMetricSupport.setup(registry, METER_ID_PREFIX.withTags("type", "client"), CLIENT_CACHE); + } + } + + @Nullable + public static RequestTarget getForServer(String reqTarget) { + return get(reqTarget, SERVER_CACHE); + } + + @Nullable + public static RequestTarget getForClient(String reqTarget) { + return get(reqTarget, CLIENT_CACHE); + } + + @Nullable + private static RequestTarget get(String reqTarget, @Nullable Cache cache) { + if (cache != null) { + return cache.getIfPresent(reqTarget); + } else { + return null; + } + } + + public static void putForServer(String reqTarget, RequestTarget normalized) { + put(reqTarget, normalized, SERVER_CACHE); + } + + public static void putForClient(String reqTarget, RequestTarget normalized) { + put(reqTarget, normalized, CLIENT_CACHE); + } + + private static void put(String reqTarget, RequestTarget normalized, + @Nullable Cache cache) { + assert reqTarget != null; + assert normalized != null; + + if (cache != null && normalized instanceof DefaultRequestTarget) { + final DefaultRequestTarget value = (DefaultRequestTarget) normalized; + if (!value.isCached()) { + value.setCached(); + cache.put(reqTarget, normalized); + } + } + } + + /** + * Clears the currently cached parsed paths. Only for use in tests. + */ + @VisibleForTesting + public static void clearCachedPaths() { + assert CLIENT_CACHE != null : "CLIENT_CACHE"; + assert SERVER_CACHE != null : "SERVER_CACHE"; + CLIENT_CACHE.asMap().clear(); + SERVER_CACHE.asMap().clear(); + } + + /** + * Returns server-side paths that have had their parse result cached. Only for use in tests. + */ + @VisibleForTesting + public static Set cachedServerPaths() { + assert SERVER_CACHE != null : "SERVER_CACHE"; + return SERVER_CACHE.asMap().keySet(); + } + + /** + * Returns client-side paths that have had their parse result cached. Only for use in tests. + */ + @VisibleForTesting + public static Set cachedClientPaths() { + assert CLIENT_CACHE != null : "CLIENT_CACHE"; + return CLIENT_CACHE.asMap().keySet(); + } + + private RequestTargetCache() {} +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java index 9cb92aff722..bdd484bfd02 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java @@ -46,7 +46,9 @@ import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.QueryParams; import com.linecorp.armeria.common.Request; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestId; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.Response; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; @@ -160,8 +162,8 @@ public DefaultServiceRequestContext( HttpHeaders additionalResponseHeaders, HttpHeaders additionalResponseTrailers) { super(meterRegistry, sessionProtocol, id, - requireNonNull(routingContext, "routingContext").method(), routingContext.path(), - requireNonNull(routingResult, "routingResult").query(), exchangeType, + requireNonNull(routingContext, "routingContext").method(), + routingContext.requestTarget(), exchangeType, requireNonNull(req, "req"), null, null); this.ch = requireNonNull(ch, "ch"); @@ -193,6 +195,13 @@ public DefaultServiceRequestContext( this.additionalResponseTrailers = additionalResponseTrailers; } + @Override + protected RequestTarget validateHeaders(RequestHeaders headers) { + checkArgument(headers.scheme() != null && headers.authority() != null, + "must set ':scheme' and ':authority' headers"); + return RequestTarget.forServer(headers.path()); + } + @Nullable @Override public V attr(AttributeKey key) { diff --git a/core/src/main/java/com/linecorp/armeria/server/DefaultRoutingContext.java b/core/src/main/java/com/linecorp/armeria/server/DefaultRoutingContext.java index d7415a038a8..f59503ad891 100644 --- a/core/src/main/java/com/linecorp/armeria/server/DefaultRoutingContext.java +++ b/core/src/main/java/com/linecorp/armeria/server/DefaultRoutingContext.java @@ -29,8 +29,8 @@ import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.QueryParams; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.internal.common.PathAndQuery; /** * Holds the parameters which are required to find a service available to handle the request. @@ -40,33 +40,16 @@ final class DefaultRoutingContext implements RoutingContext { /** * Returns a new {@link RoutingContext} instance. */ - static RoutingContext of(VirtualHost virtualHost, String hostname, - String path, @Nullable String query, + static RoutingContext of(VirtualHost virtualHost, String hostname, RequestTarget reqTarget, RequestHeaders headers, RoutingStatus routingStatus) { - return new DefaultRoutingContext(virtualHost, hostname, headers, path, query, null, - routingStatus); - } - - /** - * Returns a new {@link RoutingContext} instance. - */ - static RoutingContext of(VirtualHost virtualHost, String hostname, - PathAndQuery pathAndQuery, - RequestHeaders headers, RoutingStatus routingStatus) { - requireNonNull(pathAndQuery, "pathAndQuery"); - return new DefaultRoutingContext(virtualHost, hostname, headers, pathAndQuery.path(), - pathAndQuery.query(), pathAndQuery, routingStatus); + return new DefaultRoutingContext(virtualHost, hostname, headers, reqTarget, routingStatus); } private final VirtualHost virtualHost; private final String hostname; private final HttpMethod method; private final RequestHeaders headers; - private final String path; - @Nullable - private final String query; - @Nullable - private final PathAndQuery pathAndQuery; + private final RequestTarget reqTarget; @Nullable private final MediaType contentType; private final List acceptTypes; @@ -81,14 +64,11 @@ static RoutingContext of(VirtualHost virtualHost, String hostname, private final int hashCode; DefaultRoutingContext(VirtualHost virtualHost, String hostname, RequestHeaders headers, - String path, @Nullable String query, @Nullable PathAndQuery pathAndQuery, - RoutingStatus routingStatus) { + RequestTarget reqTarget, RoutingStatus routingStatus) { this.virtualHost = requireNonNull(virtualHost, "virtualHost"); this.hostname = requireNonNull(hostname, "hostname"); this.headers = requireNonNull(headers, "headers"); - this.path = requireNonNull(path, "path"); - this.query = query; - this.pathAndQuery = pathAndQuery; + this.reqTarget = requireNonNull(reqTarget, "reqTarget"); this.routingStatus = routingStatus; method = headers.method(); contentType = headers.contentType(); @@ -112,25 +92,26 @@ public HttpMethod method() { } @Override - public String path() { - return path; + public RequestTarget requestTarget() { + return reqTarget; } - @Nullable @Override - public String query() { - return query; + public String path() { + return reqTarget.path(); } @Nullable - PathAndQuery pathAndQuery() { - return pathAndQuery; + @Override + public String query() { + return reqTarget.query(); } @Override public QueryParams params() { QueryParams queryParams = this.queryParams; if (queryParams == null) { + final String query = reqTarget.query(); if (query == null) { queryParams = QueryParams.of(); } else { @@ -269,8 +250,7 @@ static String toString(RoutingContext routingCtx) { if (!routingCtx.acceptTypes().isEmpty()) { helper.add("acceptTypes", routingCtx.acceptTypes()); } - helper.add("isCorsPreflight", routingCtx.isCorsPreflight()) - .add("requiresMatchingParamsPredicates", routingCtx.requiresMatchingParamsPredicates()) + helper.add("requiresMatchingParamsPredicates", routingCtx.requiresMatchingParamsPredicates()) .add("requiresMatchingHeadersPredicates", routingCtx.requiresMatchingHeadersPredicates()); return helper.toString(); } diff --git a/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java b/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java index 30f5fa38f6f..632c14f4d4f 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java @@ -33,6 +33,7 @@ import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.ProtocolViolationException; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.common.ArmeriaHttpUtil; @@ -40,7 +41,6 @@ import com.linecorp.armeria.internal.common.InitiateConnectionShutdown; import com.linecorp.armeria.internal.common.KeepAliveHandler; import com.linecorp.armeria.internal.common.NoopKeepAliveHandler; -import com.linecorp.armeria.internal.common.PathAndQuery; import com.linecorp.armeria.server.HttpServerUpgradeHandler.UpgradeEvent; import io.netty.buffer.ByteBuf; @@ -138,26 +138,16 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception if (!nettyReq.decoderResult().isSuccess()) { final Throwable cause = nettyReq.decoderResult().cause(); if (cause instanceof TooLongHttpLineException) { - fail(id, null, HttpStatus.REQUEST_URI_TOO_LONG, Http2Error.FRAME_SIZE_ERROR, - "Too Long URI", cause); + fail(id, null, HttpStatus.REQUEST_URI_TOO_LONG, "Too Long URI", cause); } else if (cause instanceof TooLongHttpHeaderException) { fail(id, null, HttpStatus.REQUEST_HEADER_FIELDS_TOO_LARGE, - Http2Error.FRAME_SIZE_ERROR, "Request header fields too large", cause); + "Request header fields too large", cause); } else { - fail(id, null, HttpStatus.BAD_REQUEST, Http2Error.PROTOCOL_ERROR, - "Decoder failure", cause); + fail(id, null, HttpStatus.BAD_REQUEST, "Decoder failure", cause); } return; } - // Do not accept unsupported methods. - final io.netty.handler.codec.http.HttpMethod nettyMethod = nettyReq.method(); - if (!HttpMethod.isSupported(nettyMethod.name())) { - fail(id, null, HttpStatus.METHOD_NOT_ALLOWED, Http2Error.PROTOCOL_ERROR, - "Unsupported method", null); - return; - } - // Handle `expect: 100-continue` first to give `handle100Continue()` a chance to remove // the `expect` header before converting the Netty HttpHeaders into Armeria RequestHeaders. // This is because removing a header from RequestHeaders is more expensive due to its @@ -166,16 +156,29 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception final String path = HttpHeaderUtil .maybeTransformAbsoluteUri(nettyReq.uri(), cfg.absoluteUriTransformer()); - final PathAndQuery pathAndQuery = PathAndQuery.parse(path); + + // Parse and normalize the request path. + final RequestTarget reqTarget = RequestTarget.forServer(path); + if (reqTarget == null) { + failWithInvalidRequestPath(id, null); + return; + } // Convert the Netty HttpHeaders into Armeria RequestHeaders. final RequestHeaders headers = - ArmeriaHttpUtil.toArmeria(ctx, nettyReq, cfg, scheme.toString(), pathAndQuery); + ArmeriaHttpUtil.toArmeria(ctx, nettyReq, cfg, scheme.toString(), reqTarget); + // Do not accept unsupported methods. + final HttpMethod method = headers.method(); + switch (method) { + case CONNECT: + case UNKNOWN: + fail(id, headers, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null); + return; + } - // Do not accept a CONNECT request. - if (headers.method() == HttpMethod.CONNECT) { - fail(id, headers, HttpStatus.METHOD_NOT_ALLOWED, Http2Error.CONNECT_ERROR, - "Unsupported method", null); + // Do not accept the request path '*' for a non-OPTIONS request. + if (method != HttpMethod.OPTIONS && "*".equals(path)) { + failWithInvalidRequestPath(id, headers); return; } @@ -190,8 +193,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception contentLength = -1; } if (contentLength < 0) { - fail(id, headers, HttpStatus.BAD_REQUEST, Http2Error.FRAME_SIZE_ERROR, - "Invalid content length", null); + fail(id, headers, HttpStatus.BAD_REQUEST, "Invalid content length", null); return; } @@ -203,13 +205,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception // Reject the requests with an `expect` header whose value is not `100-continue`. if (hasInvalidExpectHeader) { ctx.pipeline().fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); - fail(id, headers, HttpStatus.EXPECTATION_FAILED, Http2Error.PROTOCOL_ERROR, null, null); + fail(id, headers, HttpStatus.EXPECTATION_FAILED, null, null); return; } // Close the request early when it is certain there will be neither content nor trailers. final RoutingContext routingCtx = newRoutingContext(cfg, ctx.channel(), - headers, pathAndQuery); + headers, reqTarget); if (routingCtx.status().routeMustExist()) { try { // Find the service that matches the path. @@ -218,8 +220,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception assert routed.isPresent(); } catch (Throwable cause) { logger.warn("{} Unexpected exception: {}", ctx.channel(), headers, cause); - fail(id, headers, HttpStatus.INTERNAL_SERVER_ERROR, Http2Error.INTERNAL_ERROR, - null, cause); + fail(id, headers, HttpStatus.INTERNAL_SERVER_ERROR, null, cause); return; } } @@ -235,8 +236,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception ctx.fireChannelRead(req); } } else { - fail(id, null, HttpStatus.BAD_REQUEST, Http2Error.PROTOCOL_ERROR, - "Invalid decoder state", null); + fail(id, null, HttpStatus.BAD_REQUEST, "Invalid decoder state", null); return; } } @@ -254,8 +254,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception final HttpContent content = (HttpContent) msg; final DecoderResult decoderResult = content.decoderResult(); if (!decoderResult.isSuccess()) { - fail(id, decodedReq.headers(), HttpStatus.BAD_REQUEST, Http2Error.PROTOCOL_ERROR, - "Decoder failure", null); + fail(id, decodedReq.headers(), HttpStatus.BAD_REQUEST, "Decoder failure", null); final ProtocolViolationException cause = new ProtocolViolationException(decoderResult.cause()); decodedReq.close(HttpStatusException.of(HttpStatus.BAD_REQUEST, cause)); @@ -275,8 +274,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception .contentLength(req.headers()) .transferred(transferredLength) .build(); - fail(id, decodedReq.headers(), HttpStatus.REQUEST_ENTITY_TOO_LARGE, - Http2Error.CANCEL, null, cause); + fail(id, decodedReq.headers(), HttpStatus.REQUEST_ENTITY_TOO_LARGE, null, cause); // Wrap the cause with the returned status to let LoggingService correctly log the // status. decodedReq.abortResponse( @@ -306,17 +304,17 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } } catch (URISyntaxException e) { if (req != null) { - fail(id, req.headers(), HttpStatus.BAD_REQUEST, Http2Error.CANCEL, "Invalid request path", e); + fail(id, req.headers(), HttpStatus.BAD_REQUEST, "Invalid request path", e); req.close(HttpStatusException.of(HttpStatus.BAD_REQUEST, e)); } else { - fail(id, null, HttpStatus.BAD_REQUEST, Http2Error.CANCEL, "Invalid request path", e); + fail(id, null, HttpStatus.BAD_REQUEST, "Invalid request path", e); } } catch (Throwable t) { if (req != null) { - fail(id, req.headers(), HttpStatus.INTERNAL_SERVER_ERROR, Http2Error.INTERNAL_ERROR, null, t); + fail(id, req.headers(), HttpStatus.INTERNAL_SERVER_ERROR, null, t); req.close(HttpStatusException.of(HttpStatus.INTERNAL_SERVER_ERROR, t)); } else { - fail(id, null, HttpStatus.INTERNAL_SERVER_ERROR, Http2Error.INTERNAL_ERROR, null, t); + fail(id, null, HttpStatus.INTERNAL_SERVER_ERROR, null, t); logger.warn("Unexpected exception:", t); } } finally { @@ -350,12 +348,16 @@ private boolean handle100Continue(int id, HttpRequest nettyReq) { return true; } - private void fail(int id, @Nullable RequestHeaders headers, HttpStatus status, Http2Error error, + private void failWithInvalidRequestPath(int id, @Nullable RequestHeaders headers) { + fail(id, headers, HttpStatus.BAD_REQUEST, "Invalid request path", null); + } + + private void fail(int id, @Nullable RequestHeaders headers, HttpStatus status, @Nullable String message, @Nullable Throwable cause) { if (encoder.isResponseHeadersSent(id, 1)) { - // The response is sent or being sent by HttpResponseSubscriber so we cannot send + // The response is sent or being sent by HttpResponseSubscriber, so we cannot send // the error response. - encoder.writeReset(id, 1, error); + encoder.writeReset(id, 1, Http2Error.PROTOCOL_ERROR); } else { discarding = true; req = null; diff --git a/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java b/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java index ea02e6ec443..e70f8acbc70 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java @@ -31,6 +31,7 @@ import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.stream.ClosedStreamException; @@ -38,7 +39,6 @@ import com.linecorp.armeria.internal.common.Http2GoAwayHandler; import com.linecorp.armeria.internal.common.InboundTrafficController; import com.linecorp.armeria.internal.common.KeepAliveHandler; -import com.linecorp.armeria.internal.common.PathAndQuery; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; @@ -119,25 +119,38 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers return; } - // Reject a request with an unsupported method. - final HttpMethod method = HttpMethod.tryParse(methodText.toString()); - if (method == null) { - writeErrorResponse(streamId, null, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null); + // Parse and normalize the request path. + final String path = nettyHeaders.path().toString(); + final RequestTarget reqTarget = RequestTarget.forServer(path); + if (reqTarget == null) { + writeInvalidRequestPathResponse(streamId, null); return; } - final PathAndQuery pathAndQuery = PathAndQuery.parse(nettyHeaders.path().toString()); - // Convert the Netty Http2Headers into Armeria RequestHeaders. final RequestHeaders headers = ArmeriaHttpUtil.toArmeriaRequestHeaders(ctx, nettyHeaders, endOfStream, - scheme, cfg, pathAndQuery); + scheme, cfg, reqTarget); - // Accept a CONNECT request only when it has a :protocol header, as defined in: - // https://datatracker.ietf.org/doc/html/rfc8441#section-4 - if (method == HttpMethod.CONNECT && !nettyHeaders.contains(HttpHeaderNames.PROTOCOL)) { - writeErrorResponse(streamId, headers, HttpStatus.METHOD_NOT_ALLOWED, - "Unsupported method", null); + // Reject a request with an unsupported method. + final HttpMethod method = headers.method(); + switch (method) { + case CONNECT: + // Accept a CONNECT request only when it has a :protocol header, as defined in: + // https://datatracker.ietf.org/doc/html/rfc8441#section-4 + if (!nettyHeaders.contains(HttpHeaderNames.PROTOCOL)) { + writeUnsupportedMethodResponse(streamId, headers); + return; + } + break; + case UNKNOWN: + writeUnsupportedMethodResponse(streamId, headers); + return; + } + + // Do not accept the request path '*' for a non-OPTIONS request. + if (method != HttpMethod.OPTIONS && "*".equals(path)) { + writeInvalidRequestPathResponse(streamId, headers); return; } @@ -162,7 +175,7 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers return; } - final RoutingContext routingCtx = newRoutingContext(cfg, ctx.channel(), headers, pathAndQuery); + final RoutingContext routingCtx = newRoutingContext(cfg, ctx.channel(), headers, reqTarget); if (routingCtx.status().routeMustExist()) { try { // Find the service that matches the path. @@ -188,9 +201,9 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers } else { if (!(req instanceof DecodedHttpRequestWriter)) { // Silently ignore the following HEADERS Frames of non-DecodedHttpRequestWriter. The request - // stream is closed when receiving the first HEADERS Frame and some responses might be sent - // already. - logger.debug("{} received a HEADERS Frame for an invalid stream: {}", ctx.channel(), streamId); + // stream is closed when receiving the first HEADERS frame, but the client might send + // more frames before realizing it. + logger.debug("{} Received a HEADERS frame for a finished stream: {}", ctx.channel(), streamId); return; } final HttpHeaders trailers = ArmeriaHttpUtil.toArmeria(nettyHeaders, true, endOfStream); @@ -255,18 +268,27 @@ public int onDataRead( int padding, boolean endOfStream) throws Http2Exception { keepAliveChannelRead(false); + final int dataLength = data.readableBytes(); final DecodedHttpRequest req = requests.get(streamId); + final boolean logInvalidStream; if (req == null) { - throw connectionError(PROTOCOL_ERROR, "received a DATA Frame for an unknown stream: %d", - streamId); + if (encoder == null || encoder.findStream(streamId) == null) { + throw connectionError(PROTOCOL_ERROR, "received a DATA frame for an unknown stream: %d", + streamId); + } else { + // Received a frame for the stream we rejected. + logInvalidStream = true; + } + } else { + // Silently ignore the following DATA Frames of non-DecodedHttpRequestWriter. + // The request stream is closed when receiving the HEADERS frame, but the client might send + // more frames before realizing it. + logInvalidStream = !(req instanceof DecodedHttpRequestWriter); } - final int dataLength = data.readableBytes(); - if (!(req instanceof DecodedHttpRequestWriter)) { - // Silently ignore the following DATA Frames of non-DecodedHttpRequestWriter. The request stream is - // closed when receiving the HEADERS Frame and some responses might be sent already. - logger.debug("{} received a DATA Frame for an invalid stream: {}. headers: {}", - ctx.channel(), streamId, req.headers()); + if (logInvalidStream) { + logger.debug("{} Received a DATA frame for a finished stream: {} / headers: {}", + ctx.channel(), streamId, req != null ? req.headers() : ""); return dataLength + padding; } @@ -335,6 +357,16 @@ private static boolean isWritable(@Nullable Http2Stream stream) { } } + private void writeInvalidRequestPathResponse(int streamId, @Nullable RequestHeaders headers) { + writeErrorResponse(streamId, headers, HttpStatus.BAD_REQUEST, + "Invalid request path", null); + } + + private void writeUnsupportedMethodResponse(int streamId, RequestHeaders headers) { + writeErrorResponse(streamId, headers, HttpStatus.METHOD_NOT_ALLOWED, + "Unsupported method", null); + } + private void writeErrorResponse(int streamId, @Nullable RequestHeaders headers, HttpStatus status, @Nullable String message, @Nullable Throwable cause) { @@ -349,8 +381,15 @@ public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorC keepAliveChannelRead(false); final DecodedHttpRequest req = requests.get(streamId); if (req == null) { - throw connectionError(PROTOCOL_ERROR, - "received a RST_STREAM frame for an unknown stream: %d", streamId); + if (encoder == null || encoder.findStream(streamId) == null) { + throw connectionError(PROTOCOL_ERROR, + "received a RST_STREAM frame for an unknown stream: %d", streamId); + } else { + // Received a frame for the stream we rejected. + logger.debug("{} Received a RST_STREAM frame for a finished stream: {}", + ctx.channel(), streamId); + return; + } } final ClosedStreamException cause = diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java index c30a3ce5637..6dbe41e631b 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java @@ -63,8 +63,8 @@ import com.linecorp.armeria.common.util.SystemInfo; import com.linecorp.armeria.internal.common.AbstractHttp2ConnectionHandler; import com.linecorp.armeria.internal.common.Http1ObjectEncoder; -import com.linecorp.armeria.internal.common.PathAndQuery; import com.linecorp.armeria.internal.common.RequestContextUtil; +import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.internal.server.DefaultServiceRequestContext; import io.netty.buffer.Unpooled; @@ -328,17 +328,14 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th if (!routingStatus.routeMustExist()) { final ServiceRequestContext reqCtx = newEarlyRespondingRequestContext(channel, req, proxiedAddresses, clientAddress, routingCtx); - switch (routingStatus) { - case OPTIONS: - // Handle 'OPTIONS * HTTP/1.1'. - handleOptions(ctx, reqCtx); - return; - case INVALID_PATH: - rejectInvalidPath(ctx, reqCtx); - return; - default: - throw new Error(); // Should never reach here. + + // Handle 'OPTIONS * HTTP/1.1'. + if (routingStatus == RoutingStatus.OPTIONS) { + handleOptions(ctx, reqCtx); + return; } + + throw new Error(); // Should never reach here. } // Find the service that matches the path. @@ -393,10 +390,8 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th if (service.shouldCachePath(routingCtx.path(), routingCtx.query(), routed.route())) { reqCtx.log().whenComplete().thenAccept(log -> { final int statusCode = log.responseHeaders().status().code(); - if (statusCode >= 200 && statusCode < 400 && routingCtx instanceof DefaultRoutingContext) { - final PathAndQuery pathAndQuery = ((DefaultRoutingContext) routingCtx).pathAndQuery(); - assert pathAndQuery != null; - pathAndQuery.storeInCache(req.path()); + if (statusCode >= 200 && statusCode < 400) { + RequestTargetCache.putForServer(req.path(), routingCtx.requestTarget()); } }); } diff --git a/core/src/main/java/com/linecorp/armeria/server/RoutingContext.java b/core/src/main/java/com/linecorp/armeria/server/RoutingContext.java index 084192e7e30..eb08575a197 100644 --- a/core/src/main/java/com/linecorp/armeria/server/RoutingContext.java +++ b/core/src/main/java/com/linecorp/armeria/server/RoutingContext.java @@ -25,8 +25,11 @@ import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.QueryParams; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.RequestTargetForm; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.internal.common.DefaultRequestTarget; /** * Holds the parameters which are required to find a service available to handle the request. @@ -65,18 +68,29 @@ public HttpMethod method() { }; } + /** + * Returns the {@link RequestTarget} of the request. The form of the returned {@link RequestTarget} is + * never {@link RequestTargetForm#ABSOLUTE}, which means it is always {@link RequestTargetForm#ORIGIN} or + * {@link RequestTargetForm#ASTERISK}. + */ + RequestTarget requestTarget(); + /** * Returns the absolute path retrieved from the request, * as defined in RFC3986. */ - String path(); + default String path() { + return requestTarget().path(); + } /** * Returns the query retrieved from the request, * as defined in RFC3986. */ @Nullable - String query(); + default String query() { + return requestTarget().query(); + } /** * Returns the query parameters retrieved from the request path. @@ -120,18 +134,26 @@ public HttpMethod method() { HttpStatusException deferredStatusException(); /** - * Returns a wrapped {@link RoutingContext} which holds the specified {@code path}. + * (Advanced users only) Returns a wrapped {@link RoutingContext} which holds the specified {@code path}. * It is usually used to find an {@link HttpService} with a prefix-stripped path. + * Note that specifying a malformed or relative path will lead to unspecified behavior. */ default RoutingContext withPath(String path) { requireNonNull(path, "path"); - if (path.equals(path())) { - return this; - } + final RequestTarget oldReqTarget = requestTarget(); + final RequestTarget newReqTarget = + DefaultRequestTarget.createWithoutValidation( + oldReqTarget.form(), + oldReqTarget.scheme(), + oldReqTarget.authority(), + path, + oldReqTarget.query(), + oldReqTarget.fragment()); + return new RoutingContextWrapper(this) { @Override - public String path() { - return path; + public RequestTarget requestTarget() { + return newReqTarget; } }; } @@ -140,7 +162,7 @@ public String path() { * Returns a wrapped {@link RoutingContext} which holds the specified {@code path}. * It is usually used to find an {@link HttpService} with a prefix-stripped path. * - * @deprecated Use {@link #withPath}. + * @deprecated Use {@link #withPath(String)}. */ @Deprecated default RoutingContext overridePath(String path) { diff --git a/core/src/main/java/com/linecorp/armeria/server/RoutingContextWrapper.java b/core/src/main/java/com/linecorp/armeria/server/RoutingContextWrapper.java index 10a01d68fee..f8a32fc06e6 100644 --- a/core/src/main/java/com/linecorp/armeria/server/RoutingContextWrapper.java +++ b/core/src/main/java/com/linecorp/armeria/server/RoutingContextWrapper.java @@ -22,6 +22,7 @@ import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.QueryParams; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.annotation.Nullable; class RoutingContextWrapper implements RoutingContext { @@ -51,14 +52,19 @@ public HttpMethod method() { } @Override - public String path() { - return delegate.path(); + public RequestTarget requestTarget() { + return delegate.requestTarget(); + } + + @Override + public final String path() { + return RoutingContext.super.path(); } @Nullable @Override - public String query() { - return delegate.query(); + public final String query() { + return RoutingContext.super.query(); } @Override @@ -103,6 +109,7 @@ public RoutingContext withPath(String path) { } @Override + @Deprecated public boolean isCorsPreflight() { return delegate.isCorsPreflight(); } diff --git a/core/src/main/java/com/linecorp/armeria/server/RoutingStatus.java b/core/src/main/java/com/linecorp/armeria/server/RoutingStatus.java index 41b2c776884..08e49d07c7c 100644 --- a/core/src/main/java/com/linecorp/armeria/server/RoutingStatus.java +++ b/core/src/main/java/com/linecorp/armeria/server/RoutingStatus.java @@ -36,12 +36,7 @@ public enum RoutingStatus { /** * An {@code "OPTIONS * HTTP/1.1"} request. */ - OPTIONS(false), - - /** - * The request specified an invalid path. - */ - INVALID_PATH(false); + OPTIONS(false); private final boolean routeMustExist; diff --git a/core/src/main/java/com/linecorp/armeria/server/Server.java b/core/src/main/java/com/linecorp/armeria/server/Server.java index 2698b030eb9..0940239a067 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Server.java +++ b/core/src/main/java/com/linecorp/armeria/server/Server.java @@ -69,14 +69,13 @@ import com.linecorp.armeria.common.Flags; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.common.metric.MeterIdPrefix; import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.common.util.Exceptions; import com.linecorp.armeria.common.util.ListenableAsyncCloseable; import com.linecorp.armeria.common.util.ShutdownHooks; import com.linecorp.armeria.common.util.StartStopSupport; import com.linecorp.armeria.common.util.Version; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.internal.common.util.ChannelUtil; import io.micrometer.core.instrument.Gauge; @@ -130,10 +129,8 @@ public static ServerBuilder builder() { startStop = new ServerStartStopSupport(config.startStopExecutor()); connectionLimitingHandler = new ConnectionLimitingHandler(config.maxNumConnections()); - // Server-wide cache metrics. - final MeterIdPrefix idPrefix = new MeterIdPrefix("armeria.server.parsed.path.cache"); - PathAndQuery.registerMetrics(config.meterRegistry(), idPrefix); - + // Server-wide metrics. + RequestTargetCache.registerServerMetrics(config.meterRegistry()); setupVersionMetrics(); for (VirtualHost virtualHost : config().virtualHosts()) { diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java index 14683bbfbc7..0e76560bc97 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java @@ -210,7 +210,7 @@ public ServiceRequestContext build() { if (route != null) { serviceBindingBuilder = serverBuilder.route().addRoute(route); } else { - serviceBindingBuilder = serverBuilder.route().path(path()); + serviceBindingBuilder = serverBuilder.route().path(requestTarget().path()); } if (defaultServiceNaming != null) { @@ -236,15 +236,17 @@ public ServiceRequestContext build() { final RoutingContext routingCtx = DefaultRoutingContext.of( server.config().defaultVirtualHost(), ((InetSocketAddress) localAddress()).getHostString(), - path(), - query(), + requestTarget(), req.headers(), RoutingStatus.OK); final RoutingResult routingResult = this.routingResult != null ? this.routingResult - : RoutingResult.builder().path(path()).query(query()).build(); - final Route route = Route.builder().path(path()).build(); + : RoutingResult.builder() + .path(requestTarget().path()) + .query(requestTarget().query()) + .build(); + final Route route = Route.builder().path(requestTarget().path()).build(); final Routed routed = Routed.of(route, routingResult, serviceCfg); routingCtx.setResult(routed); final ExchangeType exchangeType = service.exchangeType(routingCtx); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceRouteUtil.java b/core/src/main/java/com/linecorp/armeria/server/ServiceRouteUtil.java index c1e03733155..fca51477580 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceRouteUtil.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceRouteUtil.java @@ -22,43 +22,34 @@ import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.RequestHeaders; -import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.common.RequestTarget; import io.netty.channel.Channel; final class ServiceRouteUtil { static RoutingContext newRoutingContext(ServerConfig serverConfig, Channel channel, - RequestHeaders headers, - @Nullable PathAndQuery pathAndQuery) { + RequestHeaders headers, RequestTarget reqTarget) { final String hostname = hostname(headers); final int port = ((InetSocketAddress) channel.localAddress()).getPort(); final String originalPath = headers.path(); final RoutingStatus routingStatus; - if (pathAndQuery != null) { + if (headers.method() == HttpMethod.OPTIONS) { if (isCorsPreflightRequest(headers)) { routingStatus = RoutingStatus.CORS_PREFLIGHT; + } else if ("*".equals(originalPath)) { + routingStatus = RoutingStatus.OPTIONS; } else { routingStatus = RoutingStatus.OK; } } else { - if (headers.method() == HttpMethod.OPTIONS && "*".equals(originalPath)) { - routingStatus = RoutingStatus.OPTIONS; - } else { - routingStatus = RoutingStatus.INVALID_PATH; - } + routingStatus = RoutingStatus.OK; } - final VirtualHost virtualHost = serverConfig.findVirtualHost(hostname, port); - if (pathAndQuery == null) { - return DefaultRoutingContext.of(virtualHost, hostname, headers.path(), /* query */ null, headers, - routingStatus); - } else { - return DefaultRoutingContext.of(virtualHost, hostname, pathAndQuery, headers, routingStatus); - } + return DefaultRoutingContext.of(serverConfig.findVirtualHost(hostname, port), + hostname, reqTarget, headers, routingStatus); } private static String hostname(RequestHeaders headers) { diff --git a/core/src/main/java/com/linecorp/armeria/server/docs/MethodInfo.java b/core/src/main/java/com/linecorp/armeria/server/docs/MethodInfo.java index 3bcc7be30f4..b1a5cb5628f 100644 --- a/core/src/main/java/com/linecorp/armeria/server/docs/MethodInfo.java +++ b/core/src/main/java/com/linecorp/armeria/server/docs/MethodInfo.java @@ -33,9 +33,9 @@ import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; -import com.linecorp.armeria.internal.common.PathAndQuery; import com.linecorp.armeria.server.Service; /** @@ -147,9 +147,9 @@ public MethodInfo(String serviceName, String name, final ImmutableList.Builder examplePathsBuilder = ImmutableList.builderWithExpectedSize(Iterables.size(examplePaths)); for (String path : examplePaths) { - final PathAndQuery pathAndQuery = PathAndQuery.parse(path); - checkArgument(pathAndQuery != null, "examplePaths contains an invalid path: %s", path); - examplePathsBuilder.add(pathAndQuery.path()); + final RequestTarget reqTarget = RequestTarget.forServer(path); + checkArgument(reqTarget != null, "examplePaths contains an invalid path: %s", path); + examplePathsBuilder.add(reqTarget.path()); } this.examplePaths = examplePathsBuilder.build(); @@ -157,9 +157,9 @@ public MethodInfo(String serviceName, String name, final ImmutableList.Builder exampleQueriesBuilder = ImmutableList.builderWithExpectedSize(Iterables.size(exampleQueries)); for (String query : exampleQueries) { - final PathAndQuery pathAndQuery = PathAndQuery.parse('?' + query); - checkArgument(pathAndQuery != null, "exampleQueries contains an invalid query string: %s", query); - exampleQueriesBuilder.add(pathAndQuery.query()); + final RequestTarget reqTarget = RequestTarget.forServer("/?" + query); + checkArgument(reqTarget != null, "exampleQueries contains an invalid query string: %s", query); + exampleQueriesBuilder.add(reqTarget.query()); } this.exampleQueries = exampleQueriesBuilder.build(); diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java index d89365e6478..be006e0fea2 100644 --- a/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java @@ -18,7 +18,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import java.net.URI; import java.util.function.Function; import org.junit.jupiter.api.Test; @@ -265,29 +264,7 @@ void hasOwnAttr() { } @ParameterizedTest - @ValueSource(strings = {"https://host.com/path?a=b", "http://host.com/path?a=b", - "http://1.2.3.4:8080/path?a=b"}) - void updateRequestWithAbsolutePath(String path) { - final ClientRequestContext clientRequestContext = clientRequestContext(); - assertThat(clientRequestContext.path()).isEqualTo("/"); - final HttpRequest request = - HttpRequest.of(RequestHeaders.of(HttpMethod.GET, path)); - - final URI uri = URI.create(path); - - clientRequestContext.updateRequest(request); - - // absolute path updates the authority, session protocol - assertThat(clientRequestContext.authority()).isEqualTo(uri.getAuthority()); - assertThat(clientRequestContext.sessionProtocol().toString()).isEqualTo(uri.getScheme()); - assertThat(clientRequestContext.path()).isEqualTo("/path"); - assertThat(clientRequestContext.query()).isEqualTo("a=b"); - assertThat(clientRequestContext.uri().toString()).isEqualTo(path); - assertThat(clientRequestContext.endpoint().authority()).isEqualTo(uri.getAuthority()); - } - - @ParameterizedTest - @ValueSource(strings = {"https:/path?a=b", "http:///"}) + @ValueSource(strings = {"%", "http:///", "http://foo.com/bar"}) void updateRequestWithInvalidPath(String path) { final ClientRequestContext clientRequestContext = clientRequestContext(); assertThat(clientRequestContext.path()).isEqualTo("/"); @@ -295,7 +272,8 @@ void updateRequestWithInvalidPath(String path) { HttpRequest.of(RequestHeaders.of(HttpMethod.GET, path)); assertThatThrownBy(() -> clientRequestContext.updateRequest(request)) - .isInstanceOf(IllegalArgumentException.class); + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("invalid path"); } private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) { diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpClientContextCaptorTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpClientContextCaptorTest.java index 4103b2c0e43..db6d008a46f 100644 --- a/core/src/test/java/com/linecorp/armeria/client/HttpClientContextCaptorTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/HttpClientContextCaptorTest.java @@ -62,7 +62,7 @@ void connectionRefused() { void badPath() { try (ClientRequestContextCaptor ctxCaptor = Clients.newContextCaptor()) { // Send a request with a bad path. - final HttpResponse res = WebClient.of().get("http://127.0.0.1:1/|"); + final HttpResponse res = WebClient.of().get("http://127.0.0.1:1/%"); assertThatThrownBy(ctxCaptor::get).isInstanceOf(NoSuchElementException.class) .hasMessageContaining("no request was made"); res.aggregate(); diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpClientWithRequestLogTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpClientWithRequestLogTest.java index 67d25055aff..b5f2005e989 100644 --- a/core/src/test/java/com/linecorp/armeria/client/HttpClientWithRequestLogTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/HttpClientWithRequestLogTest.java @@ -76,6 +76,7 @@ void invalidPath() { WebClient.builder(LOCAL_HOST) .decorator((delegate, ctx, req) -> { final HttpRequest badReq = req.withHeaders(req.headers().toBuilder().path("/%")); + ctx.updateRequest(badReq); return delegate.execute(ctx, badReq); }) .decorator(new ExceptionHoldingDecorator()) @@ -83,14 +84,13 @@ void invalidPath() { final HttpRequest req = HttpRequest.of(HttpMethod.GET, "/"); assertThatThrownBy(() -> client.execute(req).aggregate().get()) - .hasCauseInstanceOf(UnprocessedRequestException.class) - .hasRootCauseExactlyInstanceOf(IllegalArgumentException.class) + .hasCauseInstanceOf(IllegalArgumentException.class) .hasMessageContaining("invalid path"); await().untilAsserted(() -> assertThat( - requestCauseHolder.get()).hasRootCauseExactlyInstanceOf(IllegalArgumentException.class)); + requestCauseHolder.get()).isExactlyInstanceOf(IllegalArgumentException.class)); await().untilAsserted(() -> assertThat( - responseCauseHolder.get()).hasRootCauseExactlyInstanceOf(IllegalArgumentException.class)); + responseCauseHolder.get()).isExactlyInstanceOf(IllegalArgumentException.class)); await().untilAsserted(() -> assertThat(req.isComplete()).isTrue()); } diff --git a/core/src/test/java/com/linecorp/armeria/internal/client/DefaultClientRequestContextTest.java b/core/src/test/java/com/linecorp/armeria/internal/client/DefaultClientRequestContextTest.java index b98989c2fe2..ad0caae1917 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/client/DefaultClientRequestContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/client/DefaultClientRequestContextTest.java @@ -43,8 +43,8 @@ import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestId; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.SessionProtocol; -import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.metric.NoopMeterRegistry; import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.common.util.SystemInfo; @@ -263,26 +263,13 @@ void requestUpdateAllComponents() { assertThat(ctx.uri().toString()).isEqualTo("http://example.com:8080/foo"); final HttpRequest request = HttpRequest.of(RequestHeaders.of( - HttpMethod.POST, "https://path.com/a/b/c", + HttpMethod.POST, "/a/b/c?q1=p1&q2=p2#fragment1", HttpHeaderNames.SCHEME, "http", HttpHeaderNames.AUTHORITY, "request.com")); ctx.updateRequest(request); - assertThat(ctx.sessionProtocol()).isEqualTo(SessionProtocol.HTTPS); assertThat(ctx.authority()).isEqualTo("request.com"); - assertThat(ctx.uri().toString()).isEqualTo("https://request.com/a/b/c"); - assertThat(ctx.endpoint().authority()).isEqualTo("path.com"); - } - - @Test - void uriIncludesAllComponents() { - final HttpRequest request = HttpRequest.of(RequestHeaders.of( - HttpMethod.POST, "https://path.com/a/b/c?q1=p1&q2=p2", - HttpHeaderNames.SCHEME, "http", - HttpHeaderNames.AUTHORITY, "request.com")); - final DefaultClientRequestContext ctx = newContext(ClientOptions.of(), request, "fragment1"); - ctx.updateRequest(request); - assertThat(ctx.uri().toString()).isEqualTo("https://request.com/a/b/c?q1=p1&q2=p2#fragment1"); - assertThat(ctx.endpoint().authority()).isEqualTo("path.com"); + assertThat(ctx.uri().toString()).isEqualTo("http://request.com/a/b/c?q1=p1&q2=p2#fragment1"); + assertThat(ctx.endpoint().authority()).isEqualTo("example.com:8080"); } @Test @@ -307,16 +294,12 @@ private static DefaultClientRequestContext newContext() { private static DefaultClientRequestContext newContext(ClientOptions clientOptions, HttpRequest httpRequest) { - return newContext(clientOptions, httpRequest, null); - } + final RequestTarget reqTarget = RequestTarget.forClient(httpRequest.path()); + assertThat(reqTarget).isNotNull(); - private static DefaultClientRequestContext newContext(ClientOptions clientOptions, - HttpRequest httpRequest, - @Nullable String fragment) { return new DefaultClientRequestContext( mock(EventLoop.class), NoopMeterRegistry.get(), SessionProtocol.H2C, - RequestId.random(), HttpMethod.POST, "/foo", null, fragment, - clientOptions, httpRequest, + RequestId.random(), HttpMethod.POST, reqTarget, clientOptions, httpRequest, null, RequestOptions.of(), new CancellationScheduler(0), System.nanoTime(), SystemInfo.currentTimeMicros()); } diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtilTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtilTest.java index cd7f2736d8f..64d523f2146 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtilTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/ArmeriaHttpUtilTest.java @@ -23,6 +23,7 @@ import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.toNettyHttp1ServerHeaders; import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.toNettyHttp2ClientHeaders; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -49,6 +50,7 @@ import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.ResponseHeadersBuilder; import com.linecorp.armeria.server.Server; @@ -70,12 +72,6 @@ class ArmeriaHttpUtilTest { @Test void testConcatPaths() throws Exception { - assertThat(concatPaths(null, "a")).isEqualTo("/a"); - assertThat(concatPaths(null, "/a")).isEqualTo("/a"); - - assertThat(concatPaths("", "a")).isEqualTo("/a"); - assertThat(concatPaths("", "/a")).isEqualTo("/a"); - assertThat(concatPaths("/", "a")).isEqualTo("/a"); assertThat(concatPaths("/", "/a")).isEqualTo("/a"); assertThat(concatPaths("/", "/")).isEqualTo("/"); @@ -88,6 +84,11 @@ void testConcatPaths() throws Exception { assertThat(concatPaths("/a/", "")).isEqualTo("/a/"); assertThat(concatPaths("/a", "?foo=bar")).isEqualTo("/a?foo=bar"); assertThat(concatPaths("/a/", "?foo=bar")).isEqualTo("/a/?foo=bar"); + + // Bad prefixes + assertThatThrownBy(() -> concatPaths(null, "a")).isInstanceOf(NullPointerException.class); + assertThatThrownBy(() -> concatPaths("", "b")).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> concatPaths("relative", "c")).isInstanceOf(IllegalArgumentException.class); } @ParameterizedTest @@ -260,9 +261,9 @@ void inboundCookiesMustBeMergedForHttp2() { in.add(HttpHeaderNames.COOKIE, "i=j"); in.add(HttpHeaderNames.COOKIE, "k=l;"); - final PathAndQuery pathAndQuery = PathAndQuery.parse(in.path().toString()); + final RequestTarget reqTarget = RequestTarget.forServer(in.path().toString()); final RequestHeaders out = ArmeriaHttpUtil.toArmeriaRequestHeaders( - null, in, false, "http", null, pathAndQuery); + null, in, false, "http", null, reqTarget); assertThat(out.getAll(HttpHeaderNames.COOKIE)) .containsExactly("a=b; c=d; e=f;g=h; i=j; k=l;"); @@ -279,7 +280,7 @@ void addHostHeaderIfMissing() throws URISyntaxException { final ChannelHandlerContext ctx = mockChannelHandlerContext(); RequestHeaders armeriaHeaders = toArmeria(ctx, originReq, serverConfig(), "http", - PathAndQuery.parse(originReq.uri())); + RequestTarget.forServer(originReq.uri())); assertThat(armeriaHeaders.get(HttpHeaderNames.HOST)).isEqualTo("bar"); assertThat(armeriaHeaders.authority()).isEqualTo("bar"); assertThat(armeriaHeaders.scheme()).isEqualTo("http"); @@ -288,7 +289,7 @@ void addHostHeaderIfMissing() throws URISyntaxException { // Remove Host header. headers.remove(HttpHeaderNames.HOST); armeriaHeaders = toArmeria(ctx, originReq, serverConfig(), "https", - PathAndQuery.parse(originReq.uri())); + RequestTarget.forServer(originReq.uri())); assertThat(armeriaHeaders.get(HttpHeaderNames.HOST)).isEqualTo("foo:36462"); // The default hostname. assertThat(armeriaHeaders.authority()).isEqualTo("foo:36462"); assertThat(armeriaHeaders.scheme()).isEqualTo("https"); @@ -523,10 +524,10 @@ void toArmeriaRequestHeaders() { in.set(HttpHeaderNames.METHOD, "GET") .set(HttpHeaderNames.PATH, "/"); // Request headers without pseudo headers. - final PathAndQuery pathAndQuery = PathAndQuery.parse(in.path().toString()); + final RequestTarget reqTarget = RequestTarget.forServer(in.path().toString()); final RequestHeaders headers = ArmeriaHttpUtil.toArmeriaRequestHeaders(ctx, in, false, "https", - serverConfig(), pathAndQuery); + serverConfig(), reqTarget); assertThat(headers.scheme()).isEqualTo("https"); assertThat(headers.authority()).isEqualTo("foo:36462"); } diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/DefaultRequestTargetTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/DefaultRequestTargetTest.java new file mode 100644 index 00000000000..9bf83e68fd5 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/internal/common/DefaultRequestTargetTest.java @@ -0,0 +1,601 @@ +/* + * Copyright 2018 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.internal.common; + +import static com.google.common.base.Strings.emptyToNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Set; +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Ascii; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; + +import com.linecorp.armeria.common.QueryParams; +import com.linecorp.armeria.common.RequestTarget; +import com.linecorp.armeria.common.annotation.Nullable; + +class DefaultRequestTargetTest { + + private static final Logger logger = LoggerFactory.getLogger(DefaultRequestTargetTest.class); + + private static final Set QUERY_SEPARATORS = ImmutableSet.of("&", ";"); + + @Test + @SuppressWarnings("DataFlowIssue") + void shouldThrowNpeOnNull() { + assertThatThrownBy(() -> RequestTarget.forServer(null)) + .isInstanceOf(NullPointerException.class); + assertThatThrownBy(() -> RequestTarget.forClient(null)) + .isInstanceOf(NullPointerException.class); + } + + @Test + void serverShouldRejectEmptyPath() { + assertRejected(forServer("")); + } + + @Test + void serverShouldRejectRelativePath() { + assertRejected(forServer("foo")); + assertRejected(forServer("?")); + assertRejected(forServer("#")); + assertRejected(forServer("%2f")); // percent-encoded slash + assertRejected(forServer("%2F")); // percent-encoded slash + } + + @Test + void clientShouldAcceptRelativePath() { + assertAccepted(forClient(""), "/"); + assertAccepted(forClient("foo"), "/foo"); + assertAccepted(forClient("?foo"), "/", "foo"); + assertAccepted(forClient("#foo"), "/", null, "foo"); + assertAccepted(forClient("%2f"), "/%2F"); // percent-encoded slash + assertAccepted(forClient("%2F"), "/%2F"); // percent-encoded slash + } + + @Test + void clientShouldPrependPrefix() { + assertAccepted(forClient("", "/"), "/"); + assertAccepted(forClient("foo", "/"), "/foo"); + assertAccepted(forClient("foo", "/bar"), "/bar/foo"); + assertAccepted(forClient("foo", "/bar/"), "/bar/foo"); + assertAccepted(forClient("/foo", "/bar"), "/bar/foo"); + assertAccepted(forClient("/foo", "/bar/"), "/bar/foo"); + } + + @Test + void clientThrowsOnNonAbsolutePrefix() { + assertThatThrownBy(() -> forClient("/", "")) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> forClient("/", "relative-prefix")) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> forClient("/", "relative-prefix/")) + .isInstanceOf(IllegalArgumentException.class); + } + + @ParameterizedTest + @MethodSource("badDoubleDotPatterns") + void serverShouldRejectBadDoubleDotPatterns(String pattern) { + assertRejected(forServer(toAbsolutePath(pattern))); // in a path + assertRejected(forServer("/?" + pattern)); // in a free-form query + assertRejected(forServer("/?" + pattern + "=foo")); // in a query name + assertRejected(forServer("/?foo=" + pattern)); // in a query value + + QUERY_SEPARATORS.forEach(qs -> { + // Query names and values that appear in the middle: + assertRejected(forServer("/?a=b" + qs + pattern + "=c" + qs + "d=e")); + assertRejected(forServer("/?a=b" + qs + "c=" + pattern + qs + "d=e")); + // Query names and values appear lastly: + assertRejected(forServer("/?a=b" + qs + pattern + "=c")); + assertRejected(forServer("/?a=b" + qs + "c=" + pattern)); + }); + } + + @ParameterizedTest + @MethodSource("goodDoubleDotPatterns") + void serverShouldAcceptGoodDoubleDotPatterns(String pattern) { + assertThat(forServer(toAbsolutePath(pattern))).isNotNull(); // in a path + assertThat(forServer("/?" + pattern)).isNotNull(); // in a free-form query + assertThat(forServer("/?" + pattern + "=foo")).isNotNull(); // in a query name + assertThat(forServer("/?foo=" + pattern)).isNotNull(); // in a query value + + QUERY_SEPARATORS.forEach(qs -> { + // Query names and values that appear in the middle: + assertThat(forServer("/?a=b" + qs + pattern + "=c" + qs + "d=e")).isNotNull(); + assertThat(forServer("/?a=b" + qs + "c=" + pattern + qs + "d=e")).isNotNull(); + // Query names and values appear lastly: + assertThat(forServer("/?a=b" + qs + pattern + "=c")).isNotNull(); + assertThat(forServer("/?a=b" + qs + "c=" + pattern)).isNotNull(); + }); + } + + /** + * {@link RequestTarget} treats the first `=` in a query parameter as `/` internally to simplify + * the detection the logic. This test makes sure the `=` appeared later is not treated as `/`. + */ + @Test + void dotsAndEqualsInNameValueQuery() { + QUERY_SEPARATORS.forEach(qs -> { + assertThat(forServer("/?a=..=" + qs + "b=..=")).satisfies(res -> { + assertThat(res).isNotNull(); + assertThat(res.query()).isEqualTo("a=..=" + qs + "b=..="); + assertThat(QueryParams.fromQueryString(res.query(), true)).containsExactly( + Maps.immutableEntry("a", "..="), + Maps.immutableEntry("b", "..=") + ); + }); + + assertThat(forServer("/?a==.." + qs + "b==..")).satisfies(res -> { + assertThat(res).isNotNull(); + assertThat(res.query()).isEqualTo("a==.." + qs + "b==.."); + assertThat(QueryParams.fromQueryString(res.query(), true)).containsExactly( + Maps.immutableEntry("a", "=.."), + Maps.immutableEntry("b", "=..") + ); + }); + + assertThat(forServer("/?a==..=" + qs + "b==..=")).satisfies(res -> { + assertThat(res).isNotNull(); + assertThat(res.query()).isEqualTo("a==..=" + qs + "b==..="); + assertThat(QueryParams.fromQueryString(res.query(), true)).containsExactly( + Maps.immutableEntry("a", "=..="), + Maps.immutableEntry("b", "=..=") + ); + }); + }); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldRejectInvalidPercentEncoding(Mode mode) { + assertRejected(parse(mode, "/%")); + assertRejected(parse(mode, "/%0")); + assertRejected(parse(mode, "/%0X")); + assertRejected(parse(mode, "/%X0")); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldRejectControlChars(Mode mode) { + assertRejected(parse(mode, "/\0")); + assertRejected(parse(mode, "/a\nb")); + assertRejected(parse(mode, "/a\u007fb")); + + // Escaped + assertRejected(parse(mode, "/%00")); + assertRejected(parse(mode, "/a%09b")); + assertRejected(parse(mode, "/a%0ab")); + assertRejected(parse(mode, "/a%0db")); + assertRejected(parse(mode, "/a%7fb")); + + // With query string + assertRejected(parse(mode, "/\0?c")); + assertRejected(parse(mode, "/a\tb?c")); + assertRejected(parse(mode, "/a\nb?c")); + assertRejected(parse(mode, "/a\rb?c")); + assertRejected(parse(mode, "/a\u007fb?c")); + + // With query string with control chars + assertRejected(parse(mode, "/?\0")); + assertRejected(parse(mode, "/?%00")); + assertRejected(parse(mode, "/?a\u007fb")); + assertRejected(parse(mode, "/?a%7Fb")); + + // However, 0x0A, 0x0D, 0x09 should be accepted in a query string. + assertAccepted(parse(mode, "/?a\tb"), "/", "a%09b"); + assertAccepted(parse(mode, "/?a\nb"), "/", "a%0Ab"); + assertAccepted(parse(mode, "/?a\rb"), "/", "a%0Db"); + assertAccepted(parse(mode, "/?a%09b"), "/", "a%09b"); + assertAccepted(parse(mode, "/?a%0Ab"), "/", "a%0Ab"); + assertAccepted(parse(mode, "/?a%0Db"), "/", "a%0Db"); + + if (mode == Mode.CLIENT) { + // All sort of control characters should be rejected in a fragment. + assertRejected(forClient("/#\0")); + assertRejected(forClient("/#%00")); + assertRejected(forClient("/#a\u007fb")); + assertRejected(forClient("/#a%7Fb")); + assertRejected(forClient("/#a\tb")); + assertRejected(forClient("/#a\nb")); + assertRejected(forClient("/#a\rb")); + assertRejected(forClient("/#a%09b")); + assertRejected(forClient("/#a%0Ab")); + assertRejected(forClient("/#a%0Db")); + } + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldAcceptPercentEncodedPercent(Mode mode) { + assertAccepted(parse(mode, "/%25"), "/%25"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldNotDecodeSlash(Mode mode) { + assertAccepted(parse(mode, "/%2F?%2F"), "/%2F", "%2F"); // path & query + assertAccepted(parse(mode, "/foo%2F"), "/foo%2F"); // path only + assertAccepted(parse(mode, "/?%2f=%2F"), "/", "%2F=%2F"); // query only + } + + @Test + void serverShouldCleanUpConsecutiveSlashes() { + assertAccepted( + forServer("/path//with///consecutive////slashes" + + "?/query//with///consecutive////slashes"), + "/path/with/consecutive/slashes", + "/query//with///consecutive////slashes"); + + // Encoded slashes should be retained. + assertAccepted( + forServer("/path%2F/with/%2F/consecutive//%2F%2Fslashes" + + "?/query%2F/with/%2F/consecutive//%2F%2Fslashes"), + "/path%2F/with/%2F/consecutive/%2F%2Fslashes", + "/query%2F/with/%2F/consecutive//%2F%2Fslashes"); + } + + @Test + void clientShouldNotCleanUpConsecutiveSlashes() { + assertAccepted( + forClient("/path//with///consecutive////slashes" + + "?/query//with///consecutive////slashes" + + "#/fragment//with///consecutive////slashes"), + "/path//with///consecutive////slashes", + "/query//with///consecutive////slashes", + "/fragment//with///consecutive////slashes"); + + // Encoded slashes should be retained. + assertAccepted( + forClient("/path%2F/with/%2F/consecutive//%2F%2Fslashes" + + "?/query%2F/with/%2F/consecutive//%2F%2Fslashes" + + "#/fragment%2F/with/%2F/consecutive//%2F%2Fslashes"), + "/path%2F/with/%2F/consecutive//%2F%2Fslashes", + "/query%2F/with/%2F/consecutive//%2F%2Fslashes", + "/fragment%2F/with/%2F/consecutive//%2F%2Fslashes"); + } + + @Test + void clientShouldRetainConsecutiveSlashesInFragment() { + assertAccepted(forClient("/#/////"), "/", null, "/////"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldAcceptColon(Mode mode) { + assertThat(parse(mode, "/:")).isNotNull(); + assertThat(parse(mode, "/:/")).isNotNull(); + assertThat(parse(mode, "/a/:")).isNotNull(); + assertThat(parse(mode, "/a/:/")).isNotNull(); + } + + @ParameterizedTest + @EnumSource(Mode.class) + @SuppressWarnings("checkstyle:AvoidEscapedUnicodeCharacters") + void shouldNormalizeUnicode(Mode mode) { + // 2- and 3-byte UTF-8 + assertAccepted(parse(mode, "/\u00A2?\u20AC"), "/%C2%A2", "%E2%82%AC"); + + // 4-byte UTF-8 + assertAccepted(parse(mode, "/\uD800\uDF48"), "/%F0%90%8D%88"); + + // 5- and 6-byte forms are only theoretically possible, so we won't test them here. + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldNormalizeEncodedUnicode(Mode mode) { + final String encodedPath = "/%ec%95%88"; + final String encodedQuery = "%eb%85%95"; + assertAccepted(parse(mode, encodedPath + '?' + encodedQuery), + Ascii.toUpperCase(encodedPath), + Ascii.toUpperCase(encodedQuery)); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldNotEncodeWhenUnnecessary(Mode mode) { + assertAccepted(parse(mode, "/a?b=c"), "/a", "b=c"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldNormalizeSpace(Mode mode) { + assertAccepted(parse(mode, "/ ? "), "/%20", "+"); + assertAccepted(parse(mode, "/%20?%20"), "/%20", "+"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldNormalizePlusSign(Mode mode) { + assertAccepted(parse(mode, "/+?a+b=c+d"), "/+", "a+b=c+d"); + assertAccepted(parse(mode, "/%2b?a%2bb=c%2bd"), "/+", "a%2Bb=c%2Bd"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldNormalizeAmpersand(Mode mode) { + assertAccepted(parse(mode, "/&?a=1&a=2&b=3"), "/&", "a=1&a=2&b=3"); + assertAccepted(parse(mode, "/%26?a=1%26a=2&b=3"), "/&", "a=1%26a=2&b=3"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldNormalizeSemicolon(Mode mode) { + assertAccepted(parse(mode, "/;?a=b;c=d"), "/;", "a=b;c=d"); + // '%3B' in a query string should never be decoded into ';'. + assertAccepted(parse(mode, "/%3b?a=b%3Bc=d"), "/;", "a=b%3Bc=d"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldNormalizeEqualSign(Mode mode) { + assertAccepted(parse(mode, "/=?a=b=1"), "/=", "a=b=1"); + // '%3D' in a query string should never be decoded into '='. + assertAccepted(parse(mode, "/%3D?a%3db=1"), "/=", "a%3Db=1"); + } + + @Test + void serverShouldNormalizePoundSign() { + // '#' must be encoded into '%23'. + assertAccepted(forServer("/#?a=b#1"), "/%23", "a=b%231"); + + // '%23' should never be decoded into '#'. + assertAccepted(forServer("/%23?a=b%231"), "/%23", "a=b%231"); + } + + @Test + void clientShouldTreatPoundSignAsFragment() { + // '#' must be treated as a fragment marker. + assertAccepted(forClient("/?a=b#1"), "/", "a=b", "1"); + assertAccepted(forClient("/#?a=b#1"), "/", null, "?a=b%231"); + + // '%23' should never be treated as a fragment marker. + assertAccepted(forClient("/%23?a=b%231"), "/%23", "a=b%231"); + } + + @Test + void serverShouldHandleReservedCharacters() { + assertAccepted(forServer("/#/:@!$&'()*+,;=?a=/#/:[]@!$&'()*+,;="), + "/%23/:@!$&'()*+,;=", + "a=/%23/:[]@!$&'()*+,;="); + assertAccepted(forServer("/%23%2F%3A%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F" + + "?a=%23%2F%3A%5B%5D%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F"), + "/%23%2F:@!$&'()*+,;=?", + "a=%23%2F%3A%5B%5D%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F"); + } + + @Test + void clientShouldHandleReservedCharacters() { + assertAccepted(forClient("/:@!$&'()*+,;=?a=/:[]@!$&'()*+,;=#/:@!$&'()*+,;="), + "/:@!$&'()*+,;=", + "a=/:[]@!$&'()*+,;=", + "/:@!$&'()*+,;="); + assertAccepted(forClient("/%23%2F%3A%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F" + + "?a=%23%2F%3A%5B%5D%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F" + + "#%23%2F%3A%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F"), + "/%23%2F:@!$&'()*+,;=?", + "a=%23%2F%3A%5B%5D%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F", + "%23%2F:@!$&'()*+,;=?"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldHandleDoubleQuote(Mode mode) { + assertAccepted(parse(mode, "/\"?\""), "/%22", "%22"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldHandleSquareBracketsInPath(Mode mode) { + assertAccepted(parse(mode, "/@/:[]!$&'()*+,;="), "/@/:%5B%5D!$&'()*+,;="); + assertAccepted(parse(mode, "/%40%2F%3A%5B%5D%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F"), + "/@%2F:%5B%5D!$&'()*+,;=?"); + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldAcceptAsteriskPath(Mode mode) { + assertAccepted(parse(mode, "*"), "*"); + } + + @ParameterizedTest + @CsvSource({ + "a://b/, a, b, /,,", + "a://b/c, a, b, /c,,", + "a://b/c?d, a, b, /c, d,", + "a://b/c#d, a, b, /c,, d", + "a://b/c?d#e, a, b, /c, d, e", + "a://b/c#d?e, a, b, /c,, d?e", + // Empty path + "a://b, a, b, /,,", + "a://b?c, a, b, /, c,", + "a://b#c, a, b, /,, c", + "a://b?c#d, a, b, /, c, d", + "a://b#c?d, a, b, /,, c?d", + // Userinfo and port in authority + "a://b@c:80, a, b@c:80, /,,", + // IP addresses + "a://127.0.0.1/, a, 127.0.0.1, /,,", + "a://[::1]:80/, a, [::1]:80, /,,", + }) + void clientShouldAcceptAbsoluteUri(String uri, + String expectedScheme, String expectedAuthority, String expectedPath, + @Nullable String expectedQuery, @Nullable String expectedFragment) { + + final RequestTarget res = forClient(uri); + assertThat(res.scheme()).isEqualTo(expectedScheme); + assertThat(res.authority()).isEqualTo(expectedAuthority); + assertAccepted(res, expectedPath, emptyToNull(expectedQuery), emptyToNull(expectedFragment)); + } + + @Test + void serverShouldRejectAbsoluteUri() { + assertRejected(forServer("http://foo/bar")); + } + + @Test + void clientShouldRejectInvalidSchemeOrAuthority() { + assertRejected(forClient("ht%tp://acme")); // bad scheme + assertRejected(forClient("http://[acme")); // bad authority + assertRejected(forClient("http:///")); // empty authority + } + + @ParameterizedTest + @EnumSource(Mode.class) + void shouldYieldEmptyStringForEmptyQueryAndFragment(Mode mode) { + assertAccepted(parse(mode, "/?"), "/", ""); + if (mode == Mode.CLIENT) { + assertAccepted(forClient("/#"), "/", null, ""); + assertAccepted(forClient("/?#"), "/", "", ""); + } + } + + @ParameterizedTest + @EnumSource(Mode.class) + void testToString(Mode mode) { + assertThat(parse(mode, "/")).asString().isEqualTo("/"); + assertThat(parse(mode, "/?")).asString().isEqualTo("/?"); + assertThat(parse(mode, "/?a=b")).asString().isEqualTo("/?a=b"); + + if (mode == Mode.CLIENT) { + assertThat(forClient("/#")).asString().isEqualTo("/#"); + assertThat(forClient("/?#")).asString().isEqualTo("/?#"); + assertThat(forClient("/?a=b#c=d")).asString().isEqualTo("/?a=b#c=d"); + assertThat(forClient("http://foo/bar?a=b#c=d")).asString().isEqualTo("http://foo/bar?a=b#c=d"); + } + } + + private static void assertAccepted(@Nullable RequestTarget res, String expectedPath) { + assertAccepted(res, expectedPath, null, null); + } + + private static void assertAccepted(@Nullable RequestTarget res, + String expectedPath, + @Nullable String expectedQuery) { + assertAccepted(res, expectedPath, expectedQuery, null); + } + + private static void assertAccepted(@Nullable RequestTarget res, + String expectedPath, + @Nullable String expectedQuery, + @Nullable String expectedFragment) { + assertThat(res).isNotNull(); + assertThat(res.path()).isEqualTo(expectedPath); + assertThat(res.query()).isEqualTo(expectedQuery); + assertThat(res.fragment()).isEqualTo(expectedFragment); + } + + private static void assertRejected(@Nullable RequestTarget res) { + assertThat(res).isNull(); + } + + @Nullable + private static RequestTarget parse(Mode mode, String rawPath) { + switch (mode) { + case SERVER: + return forServer(rawPath); + case CLIENT: + return forClient(rawPath); + default: + throw new Error(); + } + } + + @Nullable + private static RequestTarget forServer(String rawPath) { + return forServer(rawPath, false); + } + + @Nullable + private static RequestTarget forServer(String rawPath, boolean allowDoubleDotsInQueryString) { + final RequestTarget res = DefaultRequestTarget.forServer(rawPath, allowDoubleDotsInQueryString); + if (res != null) { + logger.info("forServer({}) => path: {}, query: {}", rawPath, res.path(), res.query()); + } else { + logger.info("forServer({}) => null", rawPath); + } + return res; + } + + @Nullable + private static RequestTarget forClient(String rawPath) { + return forClient(rawPath, null); + } + + @Nullable + private static RequestTarget forClient(String rawPath, @Nullable String prefix) { + final RequestTarget res = DefaultRequestTarget.forClient(rawPath, prefix); + if (res != null) { + logger.info("forClient({}, {}) => path: {}, query: {}, fragment: {}", rawPath, prefix, res.path(), + res.query(), res.fragment()); + } else { + logger.info("forClient({}, {}) => null", rawPath, prefix); + } + return res; + } + + private static String toAbsolutePath(String pattern) { + return pattern.startsWith("/") ? pattern : '/' + pattern; + } + + private enum Mode { + SERVER, + CLIENT + } + + private static Stream badDoubleDotPatterns() { + return Stream.of( + "..", "/..", "../", "/../", + "../foo", "/../foo", + "foo/..", "/foo/..", + "foo/../", "/foo/../", + "foo/../bar", "/foo/../bar", + + // Dots escaped + ".%2e", "/.%2e", "%2E./", "/%2E./", ".%2E/", "/.%2E/", + "foo/.%2e", "/foo/.%2e", + "foo/%2E./", "/foo/%2E./", + "foo/%2E./bar", "/foo/%2E./bar", + + // Slashes escaped + "%2f..", "..%2F", "/..%2F", "%2F../", "%2f..%2f", + "/foo%2f..", "/foo%2f../", "/foo/..%2f", "/foo%2F..%2F", + + // Dots and slashes escaped + ".%2E%2F" + ); + } + + private static Stream goodDoubleDotPatterns() { + return Stream.of( + "..a", "a..", "a..b", + "/..a", "/a..", "/a..b", + "..a/", "a../", "a..b/", + "/..a/", "/a../", "/a..b/" + ); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/PathAndQueryTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/PathAndQueryTest.java deleted file mode 100644 index f0f8b1647ce..00000000000 --- a/core/src/test/java/com/linecorp/armeria/internal/common/PathAndQueryTest.java +++ /dev/null @@ -1,572 +0,0 @@ -/* - * Copyright 2018 LINE Corporation - * - * LINE Corporation licenses this file to you under the Apache License, - * version 2.0 (the "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at: - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ -package com.linecorp.armeria.internal.common; - -import static com.linecorp.armeria.internal.common.PathAndQuery.decodePercentEncodedQuery; -import static org.assertj.core.api.Assertions.assertThat; - -import java.util.Set; - -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.google.common.base.Ascii; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Maps; - -import com.linecorp.armeria.common.QueryParams; -import com.linecorp.armeria.common.annotation.Nullable; - -class PathAndQueryTest { - - private static final Logger logger = LoggerFactory.getLogger(PathAndQueryTest.class); - - private static final Set QUERY_SEPARATORS = ImmutableSet.of("&", ";"); - - private static final Set BAD_DOUBLE_DOT_PATTERNS = ImmutableSet.of( - "..", "/..", "../", "/../", - "../foo", "/../foo", - "foo/..", "/foo/..", - "foo/../", "/foo/../", - "foo/../bar", "/foo/../bar", - - // Dots escaped - ".%2e", "/.%2e", "%2E./", "/%2E./", ".%2E/", "/.%2E/", - "foo/.%2e", "/foo/.%2e", - "foo/%2E./", "/foo/%2E./", - "foo/%2E./bar", "/foo/%2E./bar", - - // Slashes escaped - "%2f..", "..%2F", "/..%2F", "%2F../", "%2f..%2f", - "/foo%2f..", "/foo%2f../", "/foo/..%2f", "/foo%2F..%2F", - - // Dots and slashes escaped - ".%2E%2F" - ); - - private static final Set GOOD_DOUBLE_DOT_PATTERNS = ImmutableSet.of( - "..a", "a..", "a..b", - "/..a", "/a..", "/a..b", - "..a/", "a../", "a..b/", - "/..a/", "/a../", "/a..b/" - ); - - @Test - void empty() { - final PathAndQuery res = parse(null); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/"); - assertThat(res.query()).isNull(); - - final PathAndQuery res2 = parse(""); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/"); - assertThat(res2.query()).isNull(); - - final PathAndQuery res3 = parse("?"); - assertThat(res3).isNotNull(); - assertThat(res3.path()).isEqualTo("/"); - assertThat(res3.query()).isEqualTo(""); - } - - @Test - void relative() { - assertThat(parse("foo")).isNull(); - } - - @Test - void doubleDotsInPath() { - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> assertProhibited(pattern)); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - final String path = pattern.startsWith("/") ? pattern : '/' + pattern; - final PathAndQuery res = parse(path); - assertThat(res).as("Ensure %s is allowed.", path).isNotNull(); - assertThat(res.path()).as("Ensure %s is parsed as-is.", path).isEqualTo(path); - }); - } - - @Test - void doubleDotsInFreeFormQuery() { - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertProhibited("/?" + pattern); - }); - - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?" + pattern, pattern); - }); - } - - @Test - void prohibitDoubleDotsInNameValueQuery() { - // Dots in a query param name. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertProhibited("/?" + pattern + "=foo"); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?" + pattern + "=foo"); - }); - - // Dots in a query param value. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertProhibited("/?foo=" + pattern); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?foo=" + pattern); - }); - - QUERY_SEPARATORS.forEach(qs -> { - // Dots in the second query param name. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertProhibited("/?a=b" + qs + pattern + "=c"); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?a=b" + qs + pattern + "=c"); - }); - - // Dots in the second query param value. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertProhibited("/?a=b" + qs + "c=" + pattern); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?a=b" + qs + "c=" + pattern); - }); - - // Dots in the name of the query param in the middle. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertProhibited("/?a=b" + qs + pattern + "=c" + qs + "d=e"); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?a=b" + qs + pattern + "=c" + qs + "d=e"); - }); - - // Dots in the value of the query param in the middle. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertProhibited("/?a=b" + qs + "c=" + pattern + qs + "d=e"); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?a=b" + qs + "c=" + pattern + qs + "d=e"); - }); - }); - } - - @Test - void allowDoubleDotsInNameValueQuery() { - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?" + pattern, decodePercentEncodedQuery(pattern), true); - }); - - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?" + pattern, pattern, true); - }); - - // Dots in a query param name. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - final String query = pattern + "=foo"; - assertQueryStringAllowed("/?" + query, decodePercentEncodedQuery(query), true); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?" + pattern + "=foo", true); - }); - - // Dots in a query param value. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - final String query = "foo=" + pattern; - assertQueryStringAllowed("/?" + query, decodePercentEncodedQuery(query), true); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?foo=" + pattern, true); - }); - - QUERY_SEPARATORS.forEach(qs -> { - // Dots in the second query param name. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - final String query = "a=b" + qs + pattern + "=c"; - assertQueryStringAllowed("/?" + query, decodePercentEncodedQuery(query), true); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?a=b" + qs + pattern + "=c", true); - }); - - // Dots in the second query param value. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - final String query = "a=b" + qs + "c=" + pattern; - assertQueryStringAllowed("/?" + query, decodePercentEncodedQuery(query), true); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?a=b" + qs + "c=" + pattern, true); - }); - - // Dots in the name of the query param in the middle. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - final String query = "a=b" + qs + pattern + "=c" + qs + "d=e"; - assertQueryStringAllowed("/?" + query, decodePercentEncodedQuery(query), true); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?a=b" + qs + pattern + "=c" + qs + "d=e"); - }); - - // Dots in the value of the query param in the middle. - BAD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - final String query = "a=b" + qs + "c=" + pattern + qs + "d=e"; - assertQueryStringAllowed("/?" + query, decodePercentEncodedQuery(query), true); - }); - GOOD_DOUBLE_DOT_PATTERNS.forEach(pattern -> { - assertQueryStringAllowed("/?a=b" + qs + "c=" + pattern + qs + "d=e", true); - }); - }); - } - - /** - * {@link PathAndQuery} treats the first `=` in a query parameter as `/` internally to simplify - * the detection the logic. This test makes sure the `=` appeared later is not treated as `/`. - */ - @Test - void dotsAndEqualsInNameValueQuery() { - QUERY_SEPARATORS.forEach(qs -> { - final PathAndQuery res = parse("/?a=..=" + qs + "b=..="); - assertThat(res).isNotNull(); - assertThat(res.query()).isEqualTo("a=..=" + qs + "b=..="); - assertThat(QueryParams.fromQueryString(res.query(), true)).containsExactly( - Maps.immutableEntry("a", "..="), - Maps.immutableEntry("b", "..=") - ); - - final PathAndQuery res2 = parse("/?a==.." + qs + "b==.."); - assertThat(res2).isNotNull(); - assertThat(res2.query()).isEqualTo("a==.." + qs + "b==.."); - assertThat(QueryParams.fromQueryString(res2.query(), true)).containsExactly( - Maps.immutableEntry("a", "=.."), - Maps.immutableEntry("b", "=..") - ); - - final PathAndQuery res3 = parse("/?a==..=" + qs + "b==..="); - assertThat(res3).isNotNull(); - assertThat(res3.query()).isEqualTo("a==..=" + qs + "b==..="); - assertThat(QueryParams.fromQueryString(res3.query(), true)).containsExactly( - Maps.immutableEntry("a", "=..="), - Maps.immutableEntry("b", "=..=") - ); - }); - } - - @Test - void hexadecimal() { - assertThat(parse("/%")).isNull(); - assertThat(parse("/%0")).isNull(); - assertThat(parse("/%0X")).isNull(); - assertThat(parse("/%X0")).isNull(); - } - - @Test - void controlChars() { - assertThat(parse("/\0")).isNull(); - assertThat(parse("/a\nb")).isNull(); - assertThat(parse("/a\u007fb")).isNull(); - - // Escaped - assertThat(parse("/%00")).isNull(); - assertThat(parse("/a%09b")).isNull(); - assertThat(parse("/a%0ab")).isNull(); - assertThat(parse("/a%0db")).isNull(); - assertThat(parse("/a%7fb")).isNull(); - - // With query string - assertThat(parse("/\0?c")).isNull(); - assertThat(parse("/a\tb?c")).isNull(); - assertThat(parse("/a\nb?c")).isNull(); - assertThat(parse("/a\rb?c")).isNull(); - assertThat(parse("/a\u007fb?c")).isNull(); - - // With query string with control chars - assertThat(parse("/?\0")).isNull(); - assertThat(parse("/?%00")).isNull(); - assertThat(parse("/?a\u007fb")).isNull(); - assertThat(parse("/?a%7Fb")).isNull(); - // However, 0x0A, 0x0D, 0x09 should be accepted in a query string. - assertThat(parse("/?a\tb").query()).isEqualTo("a%09b"); - assertThat(parse("/?a\nb").query()).isEqualTo("a%0Ab"); - assertThat(parse("/?a\rb").query()).isEqualTo("a%0Db"); - assertThat(parse("/?a%09b").query()).isEqualTo("a%09b"); - assertThat(parse("/?a%0Ab").query()).isEqualTo("a%0Ab"); - assertThat(parse("/?a%0Db").query()).isEqualTo("a%0Db"); - } - - @Test - void percent() { - final PathAndQuery res = parse("/%25"); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/%25"); - assertThat(res.query()).isNull(); - } - - @Test - void shouldNotDecodeSlash() { - final PathAndQuery res = parse("%2F?%2F"); - // Do not accept a relative path. - assertThat(res).isNull(); - final PathAndQuery res1 = parse("/%2F?%2F"); - assertThat(res1).isNotNull(); - assertThat(res1.path()).isEqualTo("/%2F"); - assertThat(res1.query()).isEqualTo("%2F"); - - final PathAndQuery pathOnly = parse("/foo%2F"); - assertThat(pathOnly).isNotNull(); - assertThat(pathOnly.path()).isEqualTo("/foo%2F"); - assertThat(pathOnly.query()).isNull(); - - final PathAndQuery queryOnly = parse("/?%2f=%2F"); - assertThat(queryOnly).isNotNull(); - assertThat(queryOnly.path()).isEqualTo("/"); - assertThat(queryOnly.query()).isEqualTo("%2F=%2F"); - } - - @Test - void consecutiveSlashes() { - final PathAndQuery res = parse( - "/path//with///consecutive////slashes?/query//with///consecutive////slashes"); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/path/with/consecutive/slashes"); - assertThat(res.query()).isEqualTo("/query//with///consecutive////slashes"); - - // Encoded slashes - final PathAndQuery res2 = parse( - "/path%2F/with/%2F/consecutive//%2F%2Fslashes?/query%2F/with/%2F/consecutive//%2F%2Fslashes"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/path%2F/with/%2F/consecutive/%2F%2Fslashes"); - assertThat(res2.query()).isEqualTo("/query%2F/with/%2F/consecutive//%2F%2Fslashes"); - } - - @Test - void colon() { - assertThat(parse("/:")).isNotNull(); - assertThat(parse("/:/")).isNotNull(); - assertThat(parse("/a/:")).isNotNull(); - assertThat(parse("/a/:/")).isNotNull(); - } - - @Test - void rawUnicode() { - // 2- and 3-byte UTF-8 - final PathAndQuery res1 = parse("/\u00A2?\u20AC"); // ¢ and € - assertThat(res1).isNotNull(); - assertThat(res1.path()).isEqualTo("/%C2%A2"); - assertThat(res1.query()).isEqualTo("%E2%82%AC"); - - // 4-byte UTF-8 - final PathAndQuery res2 = parse("/\uD800\uDF48"); // 𐍈 - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/%F0%90%8D%88"); - assertThat(res2.query()).isNull(); - - // 5- and 6-byte forms are only theoretically possible, so we won't test them here. - } - - @Test - void encodedUnicode() { - final String encodedPath = "/%ec%95%88"; - final String encodedQuery = "%eb%85%95"; - final PathAndQuery res = parse(encodedPath + '?' + encodedQuery); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo(Ascii.toUpperCase(encodedPath)); - assertThat(res.query()).isEqualTo(Ascii.toUpperCase(encodedQuery)); - } - - @Test - void noEncoding() { - final PathAndQuery res = parse("/a?b=c"); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/a"); - assertThat(res.query()).isEqualTo("b=c"); - } - - @Test - void space() { - final PathAndQuery res = parse("/ ? "); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/%20"); - assertThat(res.query()).isEqualTo("+"); - - final PathAndQuery res2 = parse("/%20?%20"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/%20"); - assertThat(res2.query()).isEqualTo("+"); - } - - @Test - void plus() { - final PathAndQuery res = parse("/+?a+b=c+d"); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/+"); - assertThat(res.query()).isEqualTo("a+b=c+d"); - - final PathAndQuery res2 = parse("/%2b?a%2bb=c%2bd"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/+"); - assertThat(res2.query()).isEqualTo("a%2Bb=c%2Bd"); - } - - @Test - void ampersand() { - final PathAndQuery res = parse("/&?a=1&a=2&b=3"); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/&"); - assertThat(res.query()).isEqualTo("a=1&a=2&b=3"); - - // '%26' in a query string should never be decoded into '&'. - final PathAndQuery res2 = parse("/%26?a=1%26a=2&b=3"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/&"); - assertThat(res2.query()).isEqualTo("a=1%26a=2&b=3"); - } - - @Test - void semicolon() { - final PathAndQuery res = parse("/;?a=b;c=d"); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/;"); - assertThat(res.query()).isEqualTo("a=b;c=d"); - - // '%3B' in a query string should never be decoded into ';'. - final PathAndQuery res2 = parse("/%3b?a=b%3Bc=d"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/;"); - assertThat(res2.query()).isEqualTo("a=b%3Bc=d"); - } - - @Test - void equal() { - final PathAndQuery res = parse("/=?a=b=1"); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/="); - assertThat(res.query()).isEqualTo("a=b=1"); - - // '%3D' in a query string should never be decoded into '='. - final PathAndQuery res2 = parse("/%3D?a%3db=1"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/="); - assertThat(res2.query()).isEqualTo("a%3Db=1"); - } - - @Test - void sharp() { - // '#' must be encoded into '%23'. - final PathAndQuery res = parse("/#?a=b#1"); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/%23"); - assertThat(res.query()).isEqualTo("a=b%231"); - - // '%23' should never be decoded into '#'. - final PathAndQuery res2 = parse("/%23?a=b%231"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/%23"); - assertThat(res2.query()).isEqualTo("a=b%231"); - } - - @Test - void allReservedCharacters() { - final PathAndQuery res = parse("/#/:@!$&'()*+,;=?a=/#/:[]@!$&'()*+,;="); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/%23/:@!$&'()*+,;="); - assertThat(res.query()).isEqualTo("a=/%23/:[]@!$&'()*+,;="); - - final PathAndQuery res2 = - parse("/%23%2F%3A%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F" + - "?a=%23%2F%3A%5B%5D%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/%23%2F:@!$&'()*+,;=?"); - assertThat(res2.query()).isEqualTo("a=%23%2F%3A%5B%5D%40%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F"); - } - - @Test - void doubleQuote() { - final PathAndQuery res = parse("/\"?\""); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/%22"); - assertThat(res.query()).isEqualTo("%22"); - } - - private static void assertProhibited(String rawPath) { - assertThat(parse(rawPath)) - .as("Ensure parse(\"%s\") returns null.", rawPath) - .isNull(); - } - - private static void assertQueryStringAllowed(String rawPath) { - assertQueryStringAllowed(rawPath, false); - } - - private static void assertQueryStringAllowed(String rawPath, boolean allowDoubleDotsInQueryString) { - assertThat(rawPath).startsWith("/?"); - final String expectedQuery = rawPath.substring(2); - assertQueryStringAllowed(rawPath, expectedQuery, allowDoubleDotsInQueryString); - } - - private static void assertQueryStringAllowed(String rawPath, String expectedQuery) { - assertQueryStringAllowed(rawPath, expectedQuery, false); - } - - private static void assertQueryStringAllowed(String rawPath, String expectedQuery, - boolean allowDoubleDotsInQueryString) { - final PathAndQuery res = parse(rawPath, allowDoubleDotsInQueryString); - assertThat(res) - .as("parse(\"%s\") must return non-null.", rawPath) - .isNotNull(); - assertThat(res.query()) - .as("parse(\"%s\").query() must return \"%s\".", rawPath, expectedQuery) - .isEqualTo(expectedQuery); - } - - @Nullable - private static PathAndQuery parse(@Nullable String rawPath) { - return parse(rawPath, false); - } - - @Nullable - private static PathAndQuery parse(@Nullable String rawPath, boolean allowDoubleDotsInQueryString) { - final PathAndQuery res = PathAndQuery.parse(rawPath, allowDoubleDotsInQueryString); - if (res != null) { - logger.info("parse({}) => path: {}, query: {}", rawPath, res.path(), res.query()); - } else { - logger.info("parse({}) => null", rawPath); - } - return res; - } - - @Test - void assertSquareBracketsInPath() { - final PathAndQuery res = parse("/@/:[]!$&'()*+,;="); - assertThat(res).isNotNull(); - assertThat(res.path()).isEqualTo("/@/:%5B%5D!$&'()*+,;="); - - final PathAndQuery res2 = - parse("/%40%2F%3A%5B%5D%21%24%26%27%28%29%2A%2B%2C%3B%3D%3F"); - assertThat(res2).isNotNull(); - assertThat(res2.path()).isEqualTo("/@%2F:%5B%5D!$&'()*+,;=?"); - } - - @Test - void toStringAppendsQueryCorrectly() { - PathAndQuery res = parse("/"); - assertThat(res.toString()).isEqualTo("/"); - - res = parse("/?"); - assertThat(res.toString()).isEqualTo("/?"); - - res = parse("/?a=b"); - assertThat(res.toString()).isEqualTo("/?a=b"); - } -} diff --git a/core/src/test/java/com/linecorp/armeria/server/CachingRoutingContextTest.java b/core/src/test/java/com/linecorp/armeria/server/CachingRoutingContextTest.java index f59b8410617..02c23c1d08e 100644 --- a/core/src/test/java/com/linecorp/armeria/server/CachingRoutingContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/CachingRoutingContextTest.java @@ -18,14 +18,13 @@ import static com.linecorp.armeria.server.RoutingContextTest.virtualHost; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import org.junit.jupiter.api.Test; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.QueryParams; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.server.RouteCache.CachingRoutingContext; class CachingRoutingContextTest { @@ -39,13 +38,22 @@ void disableMatchingQueryParamsByCachingRoutingContext() { .matchesParams("foo=bar") .build(); - final RoutingContext context = mock(RoutingContext.class); - when(context.path()).thenReturn("/test"); - when(context.method()).thenReturn(HttpMethod.GET); - when(context.params()).thenReturn(QueryParams.of("foo", "qux")); - when(context.requiresMatchingParamsPredicates()).thenReturn(true); - when(context.virtualHost()).thenReturn(virtualHost); + final RequestTarget reqTarget = RequestTarget.forServer("/test?foo=qux"); + assertThat(reqTarget).isNotNull(); + final RoutingContext context = + new RoutingContextWrapper(DefaultRoutingContext.of( + virtualHost, virtualHost.defaultHostname(), + reqTarget, + RequestHeaders.of(HttpMethod.GET, reqTarget.pathAndQuery()), + RoutingStatus.OK)) { + @Override + public boolean requiresMatchingParamsPredicates() { + return true; + } + }; + + assertThat(context.params()).isEqualTo(QueryParams.of("foo", "qux")); assertThat(route.apply(context, false).isPresent()).isFalse(); // Because of the query parameters. final CachingRoutingContext cachingContext = new CachingRoutingContext(context); @@ -61,13 +69,23 @@ void disableMatchingHeadersByCachingRoutingContext() { .matchesHeaders("foo=bar") .build(); - final RoutingContext context = mock(RoutingContext.class); - when(context.path()).thenReturn("/test"); - when(context.method()).thenReturn(HttpMethod.GET); - when(context.headers()).thenReturn(RequestHeaders.of(HttpMethod.GET, "/test", "foo", "qux")); - when(context.requiresMatchingHeadersPredicates()).thenReturn(true); - when(context.virtualHost()).thenReturn(virtualHost); + final RequestTarget reqTarget = RequestTarget.forServer("/test"); + assertThat(reqTarget).isNotNull(); + + final RoutingContext context = + new RoutingContextWrapper(DefaultRoutingContext.of( + virtualHost, virtualHost.defaultHostname(), + reqTarget, + RequestHeaders.of(HttpMethod.GET, reqTarget.pathAndQuery(), + "foo", "qux"), + RoutingStatus.OK)) { + @Override + public boolean requiresMatchingHeadersPredicates() { + return true; + } + }; + assertThat(context.headers().contains("foo", "qux")).isTrue(); assertThat(route.apply(context, false).isPresent()).isFalse(); // Because of HTTP headers. final CachingRoutingContext cachingContext = new CachingRoutingContext(context); diff --git a/core/src/test/java/com/linecorp/armeria/server/GlobPathMappingTest.java b/core/src/test/java/com/linecorp/armeria/server/GlobPathMappingTest.java index bca7ddc7366..ce36f50678e 100644 --- a/core/src/test/java/com/linecorp/armeria/server/GlobPathMappingTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/GlobPathMappingTest.java @@ -18,7 +18,6 @@ import static com.linecorp.armeria.server.RoutingContextTest.create; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.fail; import org.junit.jupiter.api.Test; @@ -75,13 +74,6 @@ void testRelativePattern() { mustFail("bar/baz", "/bar/baz/", "/foo/bar/baz/", "/foo/bar/baz/quo"); } - @Test - void testPathValidation() { - final Route route = glob("**"); - assertThatThrownBy(() -> route.apply(create("not/an/absolute/path"), false)) - .isInstanceOf(IllegalArgumentException.class); - } - @Test void params() throws Exception { Route route = glob("baz"); diff --git a/core/src/test/java/com/linecorp/armeria/server/HttpServerTest.java b/core/src/test/java/com/linecorp/armeria/server/HttpServerTest.java index e432539559a..1b37c4762c2 100644 --- a/core/src/test/java/com/linecorp/armeria/server/HttpServerTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/HttpServerTest.java @@ -86,7 +86,7 @@ import com.linecorp.armeria.common.stream.ClosedStreamException; import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.common.util.TimeoutMode; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.server.encoding.EncodingService; import com.linecorp.armeria.testing.junit5.server.ServerExtension; @@ -457,7 +457,7 @@ void resetOptions() { serverMaxRequestLength = MAX_CONTENT_LENGTH; clientMaxResponseLength = MAX_CONTENT_LENGTH; - PathAndQuery.clearCachedPaths(); + RequestTargetCache.clearCachedPaths(); } @AfterEach @@ -900,7 +900,7 @@ void testTrailers(WebClient client) throws Exception { void testExactPathCached(WebClient client) throws Exception { assertThat(client.get("/cached-exact-path") .aggregate().get().status()).isEqualTo(HttpStatus.OK); - assertThat(PathAndQuery.cachedPaths()).contains("/cached-exact-path"); + assertThat(RequestTargetCache.cachedServerPaths()).contains("/cached-exact-path"); } @ParameterizedTest @@ -908,7 +908,7 @@ void testExactPathCached(WebClient client) throws Exception { void testPrefixPathNotCached(WebClient client) throws Exception { assertThat(client.get("/not-cached-paths/hoge") .aggregate().get().status()).isEqualTo(HttpStatus.OK); - assertThat(PathAndQuery.cachedPaths()).doesNotContain("/not-cached-paths/hoge"); + assertThat(RequestTargetCache.cachedServerPaths()).doesNotContain("/not-cached-paths/hoge"); } @ParameterizedTest @@ -916,7 +916,7 @@ void testPrefixPathNotCached(WebClient client) throws Exception { void testPrefixPath_cacheForced(WebClient client) throws Exception { assertThat(client.get("/cached-paths/hoge") .aggregate().get().status()).isEqualTo(HttpStatus.OK); - assertThat(PathAndQuery.cachedPaths()).contains("/cached-paths/hoge"); + assertThat(RequestTargetCache.cachedServerPaths()).contains("/cached-paths/hoge"); } @ParameterizedTest diff --git a/core/src/test/java/com/linecorp/armeria/server/RouteTest.java b/core/src/test/java/com/linecorp/armeria/server/RouteTest.java index 3b08326acc4..1798e03cd0e 100644 --- a/core/src/test/java/com/linecorp/armeria/server/RouteTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/RouteTest.java @@ -30,11 +30,14 @@ import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; class RouteTest { private static final String PATH = "/test"; + private static final RequestTarget REQ_TARGET = RequestTarget.forServer(PATH); + @Test void route() { Route route; @@ -407,30 +410,36 @@ void testRouteExclusionIsEvaluatedAtLast() { private static RoutingContext withMethod(HttpMethod method) { return DefaultRoutingContext.of(virtualHost(), "example.com", - PATH, null, RequestHeaders.of(method, PATH), RoutingStatus.OK); + REQ_TARGET, RequestHeaders.of(method, PATH), RoutingStatus.OK); } private static RoutingContext withConsumeType(HttpMethod method, MediaType contentType) { final RequestHeaders headers = RequestHeaders.of(method, PATH, HttpHeaderNames.CONTENT_TYPE, contentType); - return DefaultRoutingContext.of(virtualHost(), "example.com", PATH, null, headers, RoutingStatus.OK); + return DefaultRoutingContext.of(virtualHost(), "example.com", + REQ_TARGET, headers, RoutingStatus.OK); } private static RoutingContext withAcceptHeader(HttpMethod method, String acceptHeader) { final RequestHeaders headers = RequestHeaders.of(method, PATH, HttpHeaderNames.ACCEPT, acceptHeader); - return DefaultRoutingContext.of(virtualHost(), "example.com", PATH, null, headers, RoutingStatus.OK); + return DefaultRoutingContext.of(virtualHost(), "example.com", + REQ_TARGET, headers, RoutingStatus.OK); } private static RoutingContext withPath(String path) { + final RequestTarget reqTarget = RequestTarget.forServer(path); + assertThat(reqTarget).isNotNull(); + return DefaultRoutingContext.of(virtualHost(), "example.com", - path, null, RequestHeaders.of(HttpMethod.GET, path), RoutingStatus.OK); + reqTarget, RequestHeaders.of(HttpMethod.GET, path), + RoutingStatus.OK); } private static RoutingContext withRequestHeaders(RequestHeaders headers) { - final String[] pathAndQuery = headers.path().split("\\?", 2); + final RequestTarget reqTarget = RequestTarget.forServer(headers.path()); + assertThat(reqTarget).isNotNull(); return DefaultRoutingContext.of(virtualHost(), "example.com", - pathAndQuery[0], pathAndQuery.length == 2 ? pathAndQuery[1] : null, - headers, RoutingStatus.OK); + reqTarget, headers, RoutingStatus.OK); } } diff --git a/core/src/test/java/com/linecorp/armeria/server/RouterTest.java b/core/src/test/java/com/linecorp/armeria/server/RouterTest.java index 01472989b8a..772104878f4 100644 --- a/core/src/test/java/com/linecorp/armeria/server/RouterTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/RouterTest.java @@ -44,6 +44,7 @@ import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; class RouterTest { private static final Logger logger = LoggerFactory.getLogger(RouterTest.class); @@ -118,7 +119,7 @@ void testFindAllMatchedRouters(String path, int expectForFind, List exp private static DefaultRoutingContext routingCtx(String path) { return new DefaultRoutingContext(virtualHost(), "example.com", RequestHeaders.of(HttpMethod.GET, path), - path, null, null, RoutingStatus.OK); + RequestTarget.forServer(path), RoutingStatus.OK); } static Stream generateRouteMatchData() { diff --git a/core/src/test/java/com/linecorp/armeria/server/RoutingContextTest.java b/core/src/test/java/com/linecorp/armeria/server/RoutingContextTest.java index 2994a390522..a8b81583fdf 100644 --- a/core/src/test/java/com/linecorp/armeria/server/RoutingContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/RoutingContextTest.java @@ -27,6 +27,7 @@ import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.annotation.Nullable; class RoutingContextTest { @@ -50,6 +51,8 @@ void testAcceptTypes() { @Test void testEquals() { final VirtualHost virtualHost = virtualHost(); + final RequestTarget reqTarget = RequestTarget.forServer("/hello"); + assertThat(reqTarget).isNotNull(); final RoutingContext ctx1 = new DefaultRoutingContext(virtualHost, "example.com", @@ -58,7 +61,7 @@ void testEquals() { HttpHeaderNames.ACCEPT, MediaType.JSON_UTF_8 + ", " + MediaType.XML_UTF_8 + "; q=0.8"), - "/hello", null, null, RoutingStatus.OK); + reqTarget, RoutingStatus.OK); final RoutingContext ctx2 = new DefaultRoutingContext(virtualHost, "example.com", RequestHeaders.of(HttpMethod.GET, "/hello", @@ -66,7 +69,7 @@ void testEquals() { HttpHeaderNames.ACCEPT, MediaType.JSON_UTF_8 + ", " + MediaType.XML_UTF_8 + "; q=0.8"), - "/hello", null, null, RoutingStatus.OK); + reqTarget, RoutingStatus.OK); final RoutingContext ctx3 = new DefaultRoutingContext(virtualHost, "example.com", RequestHeaders.of(HttpMethod.GET, "/hello", @@ -74,7 +77,7 @@ void testEquals() { HttpHeaderNames.ACCEPT, MediaType.XML_UTF_8 + ", " + MediaType.JSON_UTF_8 + "; q=0.8"), - "/hello", null, null, RoutingStatus.OK); + reqTarget, RoutingStatus.OK); assertThat(ctx1.hashCode()).isEqualTo(ctx2.hashCode()); assertThat(ctx1).isEqualTo(ctx2); @@ -93,6 +96,9 @@ void queryDoesNotMatterWhenComparing() { @Test void hashcodeRecalculateWhenMethodChange() { final VirtualHost virtualHost = virtualHost(); + final RequestTarget reqTarget = RequestTarget.forServer("/hello"); + assertThat(reqTarget).isNotNull(); + final RoutingContext ctx1 = new DefaultRoutingContext(virtualHost, "example.com", RequestHeaders.of(HttpMethod.GET, "/hello", @@ -100,7 +106,7 @@ void hashcodeRecalculateWhenMethodChange() { HttpHeaderNames.ACCEPT, MediaType.JSON_UTF_8 + ", " + MediaType.XML_UTF_8 + "; q=0.8"), - "/hello", null, null, RoutingStatus.OK); + reqTarget, RoutingStatus.OK); final RoutingContext ctx2 = new DefaultRoutingContext(virtualHost, "example.com", RequestHeaders.of(HttpMethod.POST, "/hello", @@ -108,7 +114,7 @@ void hashcodeRecalculateWhenMethodChange() { HttpHeaderNames.ACCEPT, MediaType.JSON_UTF_8 + ", " + MediaType.XML_UTF_8 + "; q=0.8"), - "/hello", null, null, RoutingStatus.OK); + reqTarget, RoutingStatus.OK); final RoutingContext ctx3 = ctx1.withMethod(HttpMethod.POST); assertThat(ctx1.hashCode()).isNotEqualTo(ctx3.hashCode()); assertThat(ctx2.hashCode()).isEqualTo(ctx3.hashCode()); @@ -124,9 +130,12 @@ static RoutingContext create(String path, @Nullable String query) { static RoutingContext create(VirtualHost virtualHost, String path, @Nullable String query) { final String requestPath = query != null ? path + '?' + query : path; + final RequestTarget reqTarget = RequestTarget.forServer(requestPath); final RequestHeaders headers = RequestHeaders.of(HttpMethod.GET, requestPath); + assertThat(reqTarget).isNotNull(); + return DefaultRoutingContext.of(virtualHost, "example.com", - path, query, headers, RoutingStatus.OK); + reqTarget, headers, RoutingStatus.OK); } static VirtualHost virtualHost() { diff --git a/core/src/test/java/com/linecorp/armeria/server/ServiceRouteUtilTest.java b/core/src/test/java/com/linecorp/armeria/server/ServiceRouteUtilTest.java index 1145e7c3f9c..16e6e9268c3 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServiceRouteUtilTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServiceRouteUtilTest.java @@ -24,7 +24,7 @@ import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.RequestHeaders; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.common.RequestTarget; import io.netty.channel.Channel; @@ -39,7 +39,7 @@ void optionRequest() { .authority("foo.com") .build(); final RoutingContext routingContext = ServiceRouteUtil.newRoutingContext( - config, channel, headers, PathAndQuery.parse(headers.path())); + config, channel, headers, RequestTarget.forServer(headers.path())); assertThat(routingContext.status()).isEqualTo(RoutingStatus.OPTIONS); } @@ -49,20 +49,10 @@ void normalRequest() { .authority("foo.com") .build(); final RoutingContext routingContext = ServiceRouteUtil.newRoutingContext( - config, channel, headers, PathAndQuery.parse(headers.path())); + config, channel, headers, RequestTarget.forServer(headers.path())); assertThat(routingContext.status()).isEqualTo(RoutingStatus.OK); } - @Test - void invalidPath() { - final RequestHeaders headers = RequestHeaders.builder(HttpMethod.GET, "abc/def") - .authority("foo.com") - .build(); - final RoutingContext routingContext = ServiceRouteUtil.newRoutingContext( - config, channel, headers, PathAndQuery.parse(headers.path())); - assertThat(routingContext.status()).isEqualTo(RoutingStatus.INVALID_PATH); - } - @Test void cors() { final RequestHeaders headers = @@ -74,7 +64,7 @@ void cors() { "X-PINGOTHER, Content-Type") .build(); final RoutingContext routingContext = ServiceRouteUtil.newRoutingContext( - config, channel, headers, PathAndQuery.parse(headers.path())); + config, channel, headers, RequestTarget.forServer(headers.path())); assertThat(routingContext.status()).isEqualTo(RoutingStatus.CORS_PREFLIGHT); } } diff --git a/core/src/test/java/com/linecorp/armeria/server/VirtualHostBuilderTest.java b/core/src/test/java/com/linecorp/armeria/server/VirtualHostBuilderTest.java index 96ab080d373..7bdfebf29ef 100644 --- a/core/src/test/java/com/linecorp/armeria/server/VirtualHostBuilderTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/VirtualHostBuilderTest.java @@ -37,6 +37,7 @@ import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestTarget; import io.netty.handler.ssl.SslContextBuilder; @@ -318,6 +319,9 @@ void virtualHostWithMismatch2() { void precedenceOfDuplicateRoute() { final Route routeA = Route.builder().path("/").build(); final Route routeB = Route.builder().path("/").build(); + final RequestTarget reqTarget = RequestTarget.forServer("/"); + assertThat(reqTarget).isNotNull(); + final VirtualHost virtualHost = new VirtualHostBuilder(Server.builder(), true) .service(routeA, (ctx, req) -> HttpResponse.of(OK)) .service(routeB, (ctx, req) -> HttpResponse.of(OK)) @@ -325,8 +329,7 @@ void precedenceOfDuplicateRoute() { assertThat(virtualHost.serviceConfigs().size()).isEqualTo(2); final RoutingContext routingContext = new DefaultRoutingContext(virtualHost(), "example.com", RequestHeaders.of(HttpMethod.GET, "/"), - "/", null, null, - RoutingStatus.OK); + reqTarget, RoutingStatus.OK); final Routed serviceConfig = virtualHost.findServiceConfig(routingContext); final Route route = serviceConfig.route(); assertThat(route).isSameAs(routeA); diff --git a/core/src/test/java/com/linecorp/armeria/server/file/FileServiceTest.java b/core/src/test/java/com/linecorp/armeria/server/file/FileServiceTest.java index 89401e5d554..b94681ea489 100644 --- a/core/src/test/java/com/linecorp/armeria/server/file/FileServiceTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/file/FileServiceTest.java @@ -60,7 +60,7 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.OsType; import com.linecorp.armeria.common.util.SystemInfo; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.server.logging.LoggingService; import com.linecorp.armeria.testing.junit5.server.ServerExtension; @@ -188,7 +188,7 @@ static void stopSynchronously() { @BeforeEach void setUp() { - PathAndQuery.clearCachedPaths(); + RequestTargetCache.clearCachedPaths(); } @ParameterizedTest @@ -207,7 +207,7 @@ void testClassPathGet(String baseUri) throws Exception { // Confirm file service paths are cached when cache is enabled. if (baseUri.contains("/cached")) { - assertThat(PathAndQuery.cachedPaths()).contains("/cached/foo.txt"); + assertThat(RequestTargetCache.cachedServerPaths()).contains("/cached/foo.txt"); } } } @@ -391,8 +391,7 @@ void testGetPreCompressedSupportsNone(String baseUri) throws Exception { assertThat(new String(content, StandardCharsets.UTF_8)).isEqualTo("foo"); // Confirm path not cached when cache disabled. - assertThat(PathAndQuery.cachedPaths()) - .doesNotContain("/compressed/foo.txt"); + assertThat(RequestTargetCache.cachedServerPaths()).doesNotContain("/compressed/foo.txt"); } } } @@ -411,8 +410,7 @@ void testGetWithoutPreCompression(String baseUri) throws Exception { assertThat(new String(content, StandardCharsets.UTF_8)).isEqualTo("foo_alone"); // Confirm path not cached when cache disabled. - assertThat(PathAndQuery.cachedPaths()) - .doesNotContain("/compressed/foo_alone.txt"); + assertThat(RequestTargetCache.cachedServerPaths()).doesNotContain("/compressed/foo_alone.txt"); } } } diff --git a/core/src/test/java12/com/linecorp/armeria/server/InvalidPathWithDataTest.java b/core/src/test/java12/com/linecorp/armeria/server/InvalidPathWithDataTest.java index 5360b1f4fe6..5642ba543b6 100644 --- a/core/src/test/java12/com/linecorp/armeria/server/InvalidPathWithDataTest.java +++ b/core/src/test/java12/com/linecorp/armeria/server/InvalidPathWithDataTest.java @@ -28,7 +28,7 @@ import org.slf4j.LoggerFactory; import com.linecorp.armeria.common.HttpResponse; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.server.logging.LoggingService; import com.linecorp.armeria.testing.junit5.server.ServerExtension; @@ -54,8 +54,8 @@ protected void configure(ServerBuilder sb) { @Test void invalidPath() throws Exception { final String invalidPath = "/foo?download=../../secret.txt"; - final PathAndQuery pathAndQuery = PathAndQuery.parse(invalidPath); - assertThat(pathAndQuery).isNull(); + final RequestTarget reqTarget = RequestTarget.forServer(invalidPath); + assertThat(reqTarget).isNull(); final HttpClient client = HttpClient.newHttpClient(); @@ -94,8 +94,7 @@ void invalidPath() throws Exception { .anyMatch(event -> { final String logMessage = event.getFormattedMessage(); return event.getLevel().equals(Level.DEBUG) && - logMessage.contains("received a DATA Frame for an invalid stream") && - logMessage.contains(invalidPath); + logMessage.contains("Received a DATA frame for a finished stream"); }); }); } diff --git a/grpc-protocol/src/main/java/com/linecorp/armeria/client/grpc/protocol/UnaryGrpcClient.java b/grpc-protocol/src/main/java/com/linecorp/armeria/client/grpc/protocol/UnaryGrpcClient.java index ce252e5e1c2..f921b75fd7a 100644 --- a/grpc-protocol/src/main/java/com/linecorp/armeria/client/grpc/protocol/UnaryGrpcClient.java +++ b/grpc-protocol/src/main/java/com/linecorp/armeria/client/grpc/protocol/UnaryGrpcClient.java @@ -193,7 +193,9 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) { } try { - return unwrap().execute(ctx, HttpRequest.of(req.headers(), framed)) + final HttpRequest framedReq = HttpRequest.of(req.headers(), framed); + ctx.updateRequest(framedReq); + return unwrap().execute(ctx, framedReq) .aggregate(aggregationOptions); } catch (Exception e) { throw new ArmeriaStatusException(StatusCodes.INTERNAL, diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaChannel.java b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaChannel.java index 7adb741c1c0..fe4ca2e8676 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaChannel.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaChannel.java @@ -38,6 +38,7 @@ import com.linecorp.armeria.common.HttpRequestWriter; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestHeadersBuilder; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; @@ -47,6 +48,7 @@ import com.linecorp.armeria.common.util.SystemInfo; import com.linecorp.armeria.common.util.Unwrappable; import com.linecorp.armeria.internal.client.DefaultClientRequestContext; +import com.linecorp.armeria.internal.common.RequestTargetCache; import io.grpc.CallCredentials; import io.grpc.CallOptions; @@ -221,14 +223,17 @@ public T as(Class type) { private DefaultClientRequestContext newContext(HttpMethod method, HttpRequest req, MethodDescriptor methodDescriptor) { + final String path = req.path(); + final RequestTarget reqTarget = RequestTarget.forClient(path); + assert reqTarget != null : path; + RequestTargetCache.putForClient(path, reqTarget); + return new DefaultClientRequestContext( meterRegistry, sessionProtocol, options().requestIdGenerator().get(), method, - req.path(), - null, - null, + reqTarget, options(), req, null, diff --git a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java index 5ccb0a82b78..3c01d6a1bb7 100644 --- a/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/client/grpc/GrpcClientTest.java @@ -94,6 +94,7 @@ import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceBlockingStub; import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceImplBase; import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceStub; +import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.internal.common.grpc.GrpcLogUtil; import com.linecorp.armeria.internal.common.grpc.GrpcStatus; import com.linecorp.armeria.internal.common.grpc.MetadataUtil; @@ -236,6 +237,7 @@ public void close(Status status, Metadata trailers) { @BeforeEach void setUp() { + RequestTargetCache.clearCachedPaths(); requestLogQueue.clear(); final DecoratingHttpClientFunction requestLogRecorder = (delegate, ctx, req) -> { ctx.log().whenComplete().thenAccept(requestLogQueue::add); @@ -272,6 +274,8 @@ void emptyUnary() throws Exception { assertThat(rpcReq.params()).containsExactly(EMPTY); assertThat(rpcRes.get()).isEqualTo(EMPTY); }); + await().untilAsserted(() -> assertThat(RequestTargetCache.cachedClientPaths()) + .contains("/armeria.grpc.testing.TestService/EmptyCall")); } @Test @@ -322,6 +326,8 @@ void largeUnary() throws Exception { assertThat(rpcReq.params()).containsExactly(request); assertThat(rpcRes.get()).isEqualTo(goldenResponse); }); + await().untilAsserted(() -> assertThat(RequestTargetCache.cachedClientPaths()) + .contains("/armeria.grpc.testing.TestService/UnaryCall")); } @Test @@ -473,6 +479,8 @@ void serverStreaming() throws Exception { assertThat(rpcReq.params()).containsExactly(request); assertThat(rpcRes.get()).isEqualTo(goldenResponses.get(0)); }); + await().untilAsserted(() -> assertThat(RequestTargetCache.cachedClientPaths()) + .contains("/armeria.grpc.testing.TestService/StreamingOutputCall")); } @Test @@ -552,6 +560,8 @@ void clientStreaming() throws Exception { assertThat(rpcReq.params()).containsExactly(requests.get(0)); assertThat(rpcRes.get()).isEqualTo(goldenResponse); }); + await().untilAsserted(() -> assertThat(RequestTargetCache.cachedClientPaths()) + .contains("/armeria.grpc.testing.TestService/StreamingInputCall")); } @Test @@ -634,6 +644,8 @@ public void onCompleted() { assertThat(rpcReq.params()).containsExactly(requests.get(0)); assertThat(rpcRes.get()).isEqualTo(goldenResponses.get(0)); }); + await().untilAsserted(() -> assertThat(RequestTargetCache.cachedClientPaths()) + .contains("/armeria.grpc.testing.TestService/FullDuplexCall")); } @Test diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcServiceServerTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcServiceServerTest.java index 511298b8ef8..cb8e99b1131 100644 --- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcServiceServerTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/GrpcServiceServerTest.java @@ -89,7 +89,7 @@ import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceBlockingStub; import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceImplBase; import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceStub; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.internal.common.grpc.GrpcLogUtil; import com.linecorp.armeria.internal.common.grpc.GrpcTestUtil; import com.linecorp.armeria.internal.common.grpc.StreamRecorder; @@ -545,7 +545,7 @@ void setUp() { COMPLETED.set(false); CLIENT_CLOSED.set(false); - PathAndQuery.clearCachedPaths(); + RequestTargetCache.clearCachedPaths(); } @AfterEach @@ -564,7 +564,7 @@ void unary_normal(UnitTestServiceBlockingStub blockingClient) throws Exception { assertThat(blockingClient.staticUnaryCall(REQUEST_MESSAGE)).isEqualTo(RESPONSE_MESSAGE); // Confirm gRPC paths are cached despite using serviceUnder - await().untilAsserted(() -> assertThat(PathAndQuery.cachedPaths()) + await().untilAsserted(() -> assertThat(RequestTargetCache.cachedServerPaths()) .contains("/armeria.grpc.testing.UnitTestService/StaticUnaryCall")); checkRequestLog((rpcReq, rpcRes, grpcStatus) -> { diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java index 4bf823db376..7149ce3c327 100644 --- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java @@ -17,7 +17,6 @@ package com.linecorp.armeria.internal.client.thrift; import static com.linecorp.armeria.internal.client.thrift.THttpClientDelegate.decodeException; -import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.concatPaths; import static java.util.Objects.requireNonNull; import org.apache.thrift.transport.TTransportException; @@ -29,10 +28,11 @@ import com.linecorp.armeria.client.thrift.THttpClient; import com.linecorp.armeria.common.ExchangeType; import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.RpcRequest; import com.linecorp.armeria.common.RpcResponse; import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.internal.common.PathAndQuery; +import com.linecorp.armeria.internal.common.RequestTargetCache; import io.micrometer.core.instrument.MeterRegistry; @@ -63,19 +63,22 @@ public RpcResponse executeMultiplexed( private RpcResponse execute0( String path, Class serviceType, @Nullable String serviceName, String method, Object[] args) { - path = concatPaths(uri().getRawPath(), path); - final PathAndQuery pathAndQuery = PathAndQuery.parse(path); - if (pathAndQuery == null) { + if (serviceName != null) { + path = path + '#' + serviceName; + } + + final RequestTarget reqTarget = RequestTarget.forClient(path, uri().getRawPath()); + if (reqTarget == null) { return RpcResponse.ofFailure(new TTransportException( new IllegalArgumentException("invalid path: " + path))); } // A thrift path is always good to cache as it cannot have non-fixed parameters. - pathAndQuery.storeInCache(path); + RequestTargetCache.putForClient(path, reqTarget); final RpcRequest call = RpcRequest.of(serviceType, method, args); return execute(scheme().sessionProtocol(), HttpMethod.POST, - pathAndQuery.path(), null, serviceName, call, UNARY_REQUEST_OPTIONS); + reqTarget, call, UNARY_REQUEST_OPTIONS); } @Override