diff --git a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java index 5191a8d85cef6..098a01410897c 100644 --- a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java +++ b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java @@ -168,12 +168,13 @@ private void addCookies(HttpResponse response) { // Determine if the request connection should be closed on completion. private boolean isCloseConnection() { - final boolean http10 = isHttp10(); - return CLOSE.equalsIgnoreCase(request.header(CONNECTION)) || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION))); - } - - // Determine if the request protocol version is HTTP 1.0 - private boolean isHttp10() { - return request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0; + try { + final boolean http10 = request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0; + return CLOSE.equalsIgnoreCase(request.header(CONNECTION)) + || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION))); + } catch (Exception e) { + // In case we fail to parse the http protocol version out of the request we always close the connection + return true; + } } } diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java index c58ec6a4becbb..3bfc649820fee 100644 --- a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -57,6 +57,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; @@ -272,8 +273,13 @@ public void testReleaseInListener() throws IOException { public void testConnectionClose() throws Exception { final Settings settings = Settings.builder().build(); final HttpRequest httpRequest; - final boolean close = randomBoolean(); - if (randomBoolean()) { + final boolean brokenRequest = randomBoolean(); + final boolean close = brokenRequest || randomBoolean(); + if (brokenRequest) { + httpRequest = new TestRequest(() -> { + throw new IllegalArgumentException("Can't parse HTTP version"); + }, RestRequest.Method.GET, "/"); + } else if (randomBoolean()) { httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); if (close) { httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE)); @@ -399,18 +405,21 @@ private TestResponse executeRequest(final Settings settings, final String origin private static class TestRequest implements HttpRequest { - private final HttpVersion version; + private final Supplier version; private final RestRequest.Method method; private final String uri; private HashMap> headers = new HashMap<>(); - private TestRequest(HttpVersion version, RestRequest.Method method, String uri) { - - this.version = version; + private TestRequest(Supplier versionSupplier, RestRequest.Method method, String uri) { + this.version = versionSupplier; this.method = method; this.uri = uri; } + private TestRequest(HttpVersion version, RestRequest.Method method, String uri) { + this(() -> version, method, uri); + } + @Override public RestRequest.Method method() { return method; @@ -438,7 +447,7 @@ public List strictCookies() { @Override public HttpVersion protocolVersion() { - return version; + return version.get(); } @Override