diff --git a/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java b/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java index 48db5d5c5e..b43f3b0a05 100644 --- a/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java +++ b/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java @@ -1,3 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + package org.opensearch.security; import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; diff --git a/src/integrationTest/java/org/opensearch/security/rest/CompressionTests.java b/src/integrationTest/java/org/opensearch/security/rest/CompressionTests.java new file mode 100644 index 0000000000..0898638414 --- /dev/null +++ b/src/integrationTest/java/org/opensearch/security/rest/CompressionTests.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.security.rest; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.apache.http.Header; +import org.apache.http.HttpStatus; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.entity.ContentType; +import org.apache.http.message.BasicHeader; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.anyOf; +import static org.hamcrest.MatcherAssert.assertThat; +import org.opensearch.test.framework.TestSecurityConfig; +import org.opensearch.test.framework.cluster.ClusterManager; +import org.opensearch.test.framework.cluster.LocalCluster; +import org.opensearch.test.framework.cluster.TestRestClient; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.zip.GZIPOutputStream; +import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse; + +import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; +import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS; +import static org.opensearch.test.framework.cluster.TestRestClientConfiguration.getBasicAuthHeader; + +@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class CompressionTests { + private static final TestSecurityConfig.User ADMIN_USER = new TestSecurityConfig.User("admin").roles(ALL_ACCESS); + + @ClassRule + public static LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.THREE_CLUSTER_MANAGERS) + .authc(AUTHC_HTTPBASIC_INTERNAL) + .users(ADMIN_USER) + .anonymousAuth(false) + .build(); + + @Test + public void testAuthenticatedGzippedRequests() throws Exception { + final String requestPath = "/*/_search"; + final int parallelism = 10; + final int totalNumberOfRequests = 100; + + final String rawBody = "{ \"query\": { \"match\": { \"foo\": \"bar\" }}}"; + + final byte[] compressedRequestBody = createCompressedRequestBody(rawBody); + try (final TestRestClient client = cluster.getRestClient(ADMIN_USER, new BasicHeader("Content-Encoding", "gzip"))) { + + final ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism); + + final List> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests) + .boxed() + .map(i -> CompletableFuture.supplyAsync(() -> { + final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); + post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); + return client.executeRequest(post); + }, forkJoinPool)) + .collect(Collectors.toList()); + + final CompletableFuture allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0])); + + allOfThem.get(30, TimeUnit.SECONDS); + + waitingOn.stream().forEach(future -> { + try { + final HttpResponse response = future.get(); + response.assertStatusCode(HttpStatus.SC_OK); + } catch (final Exception ex) { + throw new RuntimeException(ex); + } + }); + ; + } + } + + @Test + public void testMixOfAuthenticatedAndUnauthenticatedGzippedRequests() throws Exception { + final String requestPath = "/*/_search"; + final int parallelism = 10; + final int totalNumberOfRequests = 100; + + final String rawBody = "{ \"query\": { \"match\": { \"foo\": \"bar\" }}}"; + + final byte[] compressedRequestBody = createCompressedRequestBody(rawBody); + try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) { + + final ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism); + + final Header basicAuthHeader = getBasicAuthHeader(ADMIN_USER.getName(), ADMIN_USER.getPassword()); + + final List> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests) + .boxed() + .map(i -> CompletableFuture.supplyAsync(() -> { + final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); + post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); + return i % 2 == 0 ? client.executeRequest(post) : client.executeRequest(post, basicAuthHeader); + }, forkJoinPool)) + .collect(Collectors.toList()); + + final CompletableFuture allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0])); + + allOfThem.get(30, TimeUnit.SECONDS); + + waitingOn.stream().forEach(future -> { + try { + final HttpResponse response = future.get(); + assertThat(response.getBody(), not(containsString("json_parse_exception"))); + assertThat(response.getStatusCode(), anyOf(equalTo(HttpStatus.SC_UNAUTHORIZED), equalTo(HttpStatus.SC_OK))); + } catch (final Exception ex) { + throw new RuntimeException(ex); + } + }); + ; + } + } + + static byte[] createCompressedRequestBody(final String rawBody) { + try ( + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + final GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream) + ) { + gzipOutputStream.write(rawBody.getBytes(StandardCharsets.UTF_8)); + gzipOutputStream.finish(); + + final byte[] compressedRequestBody = byteArrayOutputStream.toByteArray(); + return compressedRequestBody; + } catch (final IOException ioe) { + throw new RuntimeException(ioe); + } + } +} diff --git a/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java b/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java index 71586a2dff..cca1df9b46 100644 --- a/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java @@ -47,7 +47,6 @@ public class SecurityNonSslHttpServerTransport extends Netty4HttpServerTransport { private final ChannelInboundHandlerAdapter headerVerifier; - private final ChannelInboundHandlerAdapter conditionalDecompressor; public SecurityNonSslHttpServerTransport( final Settings settings, @@ -73,7 +72,6 @@ public SecurityNonSslHttpServerTransport( tracer ); headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); - conditionalDecompressor = new Netty4ConditionalDecompressor(); } @Override @@ -100,6 +98,6 @@ protected ChannelInboundHandlerAdapter createHeaderVerifier() { @Override protected ChannelInboundHandlerAdapter createDecompressor() { - return conditionalDecompressor; + return new Netty4ConditionalDecompressor(); } } diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java index c8059fad5d..1eec49add0 100644 --- a/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java +++ b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java @@ -8,7 +8,6 @@ package org.opensearch.security.ssl.http.netty; -import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.HttpContentDecompressor; @@ -17,7 +16,6 @@ import org.opensearch.security.filter.NettyAttribute; -@Sharable public class Netty4ConditionalDecompressor extends HttpContentDecompressor { @Override diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java b/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java index 3eee278083..0d218acd09 100644 --- a/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java @@ -46,7 +46,6 @@ public class SecuritySSLNettyHttpServerTransport extends Netty4HttpServerTranspo private final SecurityKeyStore sks; private final SslExceptionHandler errorHandler; private final ChannelInboundHandlerAdapter headerVerifier; - private final ChannelInboundHandlerAdapter conditionalDecompressor; public SecuritySSLNettyHttpServerTransport( final Settings settings, @@ -76,7 +75,6 @@ public SecuritySSLNettyHttpServerTransport( this.sks = sks; this.errorHandler = errorHandler; headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); - conditionalDecompressor = new Netty4ConditionalDecompressor(); } @Override @@ -123,6 +121,6 @@ protected ChannelInboundHandlerAdapter createHeaderVerifier() { @Override protected ChannelInboundHandlerAdapter createDecompressor() { - return conditionalDecompressor; + return new Netty4ConditionalDecompressor(); } }