diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e2439da2ad..11bf02a396 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,6 +10,7 @@ on: env: GRADLE_OPTS: -Dhttp.keepAlive=false + CI_ENVIRONMENT: normal jobs: generate-test-list: @@ -107,6 +108,51 @@ jobs: arguments: | integrationTest -Dbuild.snapshot=false + resource-tests: + env: + CI_ENVIRONMENT: resource-test + strategy: + fail-fast: false + matrix: + jdk: [17] + platform: [ubuntu-latest] + runs-on: ${{ matrix.platform }} + + steps: + - name: Set up JDK for build and test + uses: actions/setup-java@v3 + with: + distribution: temurin # Temurin is a distribution of adoptium + java-version: ${{ matrix.jdk }} + + - name: Checkout security + uses: actions/checkout@v4 + + - name: Build and Test + uses: gradle/gradle-build-action@v2 + with: + cache-disabled: true + arguments: | + integrationTest -Dbuild.snapshot=false --tests org.opensearch.security.ResourceFocusedTests + + backward-compatibility-build: + runs-on: ubuntu-latest + steps: + - uses: actions/setup-java@v3 + with: + distribution: temurin # Temurin is a distribution of adoptium + java-version: 17 + + - name: Checkout Security Repo + uses: actions/checkout@v4 + + - name: Build BWC tests + uses: gradle/gradle-build-action@v2 + with: + cache-disabled: true + arguments: | + -p bwc-test build -x test -x integTest + backward-compatibility: strategy: fail-fast: false diff --git a/build.gradle b/build.gradle index c4d74d66a0..0451684c56 100644 --- a/build.gradle +++ b/build.gradle @@ -470,17 +470,27 @@ sourceSets { //add new task that runs integration tests task integrationTest(type: Test) { + doFirst { + // Only run resources tests on resource-test CI environments or locally + if (System.getenv('CI_ENVIRONMENT') == 'resource-test' || System.getenv('CI_ENVIRONMENT') == null) { + include '**/ResourceFocusedTests.class' + } else { + exclude '**/ResourceFocusedTests.class' + } + // Only run with retries while in CI systems + if (System.getenv('CI_ENVIRONMENT') == 'normal') { + retry { + failOnPassedAfterRetry = false + maxRetries = 2 + maxFailures = 10 + } + } + } description = 'Run integration tests.' group = 'verification' systemProperty "java.util.logging.manager", "org.apache.logging.log4j.jul.LogManager" testClassesDirs = sourceSets.integrationTest.output.classesDirs classpath = sourceSets.integrationTest.runtimeClasspath - jvmArgs += "-Djdk.internal.httpclient.disableHostnameVerification" - retry { - failOnPassedAfterRetry = false - maxRetries = 2 - maxFailures = 10 - } //run the integrationTest task after the test task shouldRunAfter test } diff --git a/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java b/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java deleted file mode 100644 index 8b13789179..0000000000 --- a/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java b/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java new file mode 100644 index 0000000000..48db5d5c5e --- /dev/null +++ b/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java @@ -0,0 +1,257 @@ +package org.opensearch.security; + +import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; +import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.management.GarbageCollectorMXBean; +import java.lang.management.ManagementFactory; +import java.lang.management.MemoryPoolMXBean; +import java.lang.management.MemoryUsage; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.zip.GZIPOutputStream; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.entity.ContentType; +import org.apache.http.message.BasicHeader; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.client.Client; +import org.opensearch.test.framework.TestSecurityConfig; +import org.opensearch.test.framework.TestSecurityConfig.User; +import org.opensearch.test.framework.cluster.ClusterManager; +import org.opensearch.test.framework.cluster.LocalCluster; +import org.opensearch.test.framework.cluster.TestRestClient; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class ResourceFocusedTests { + private static final User ADMIN_USER = new User("admin").roles(ALL_ACCESS); + private static final User LIMITED_USER = new User("limited_user").roles( + new TestSecurityConfig.Role("limited-role").clusterPermissions( + "indices:data/read/mget", + "indices:data/read/msearch", + "indices:data/read/scroll", + "cluster:monitor/state", + "cluster:monitor/health" + ) + .indexPermissions( + "indices:data/read/search", + "indices:data/read/mget*", + "indices:monitor/settings/get", + "indices:monitor/stats" + ) + .on("*") + ); + + @ClassRule + public static LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.THREE_CLUSTER_MANAGERS) + .authc(AUTHC_HTTPBASIC_INTERNAL) + .users(ADMIN_USER, LIMITED_USER) + .anonymousAuth(false) + .doNotFailOnForbidden(true) + .build(); + + @BeforeClass + public static void createTestData() { + try (Client client = cluster.getInternalNodeClient()) { + client.index(new IndexRequest().setRefreshPolicy(IMMEDIATE).index("document").source(Map.of("foo", "bar", "abc", "xyz"))) + .actionGet(); + } + } + + @Test + public void testUnauthenticatedFewBig() { + // Tweaks: + final RequestBodySize size = RequestBodySize.XLarge; + final String requestPath = "/*/_search"; + final int parrallelism = 5; + final int totalNumberOfRequests = 100; + final boolean statsPrinter = false; + + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + } + + @Test + public void testUnauthenticatedManyMedium() { + // Tweaks: + final RequestBodySize size = RequestBodySize.Medium; + final String requestPath = "/*/_search"; + final int parrallelism = 20; + final int totalNumberOfRequests = 10_000; + final boolean statsPrinter = false; + + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + } + + @Test + public void testUnauthenticatedTonsSmall() { + // Tweaks: + final RequestBodySize size = RequestBodySize.Small; + final String requestPath = "/*/_search"; + final int parrallelism = 100; + final int totalNumberOfRequests = 1_000_000; + final boolean statsPrinter = false; + + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + } + + private Long runResourceTest( + final RequestBodySize size, + final String requestPath, + final int parrallelism, + final int totalNumberOfRequests, + final boolean statsPrinter + ) { + final byte[] compressedRequestBody = createCompressedRequestBody(size); + try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) { + + if (statsPrinter) { + printStats(); + } + final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); + post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); + + final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism); + + final List> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests) + .boxed() + .map(i -> CompletableFuture.runAsync(() -> client.executeRequest(post), forkJoinPool)) + .collect(Collectors.toList()); + Supplier getCount = () -> waitingOn.stream().filter(cf -> cf.isDone() && !cf.isCompletedExceptionally()).count(); + + CompletableFuture statPrinter = statsPrinter ? CompletableFuture.runAsync(() -> { + while (true) { + printStats(); + System.err.println(" & Succesful completions: " + getCount.get()); + try { + Thread.sleep(500); + } catch (Exception e) { + break; + } + } + }, forkJoinPool) : CompletableFuture.completedFuture(null); + + final CompletableFuture allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0])); + + try { + allOfThem.get(30, TimeUnit.SECONDS); + statPrinter.cancel(true); + } catch (final Exception e) { + // Ignored + } + + if (statsPrinter) { + printStats(); + System.err.println(" & Succesful completions: " + getCount.get()); + } + return getCount.get(); + } + } + + static enum RequestBodySize { + Small(1), + Medium(1_000), + XLarge(1_000_000); + + public final int elementCount; + + private RequestBodySize(final int elementCount) { + this.elementCount = elementCount; + } + } + + private byte[] createCompressedRequestBody(final RequestBodySize size) { + final int repeatCount = size.elementCount; + final String prefix = "{ \"items\": ["; + final String repeatedElement = IntStream.range(0, 20) + .mapToObj(n -> ('a' + n) + "") + .map(n -> '"' + n + '"' + ": 123") + .collect(Collectors.joining(",", "{", "}")); + final String postfix = "]}"; + long uncompressedBytesSize = 0; + + try ( + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + final GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream) + ) { + + final byte[] prefixBytes = prefix.getBytes(StandardCharsets.UTF_8); + final byte[] repeatedElementBytes = repeatedElement.getBytes(StandardCharsets.UTF_8); + final byte[] postfixBytes = postfix.getBytes(StandardCharsets.UTF_8); + + gzipOutputStream.write(prefixBytes); + uncompressedBytesSize = uncompressedBytesSize + prefixBytes.length; + for (int i = 0; i < repeatCount; i++) { + gzipOutputStream.write(repeatedElementBytes); + uncompressedBytesSize = uncompressedBytesSize + repeatedElementBytes.length; + } + gzipOutputStream.write(postfixBytes); + uncompressedBytesSize = uncompressedBytesSize + postfixBytes.length; + gzipOutputStream.finish(); + + final byte[] compressedRequestBody = byteArrayOutputStream.toByteArray(); + System.err.println( + "^^^" + + String.format( + "Original size was %,d bytes, compressed to %,d bytes, ratio %,.2f", + uncompressedBytesSize, + compressedRequestBody.length, + ((double) uncompressedBytesSize / compressedRequestBody.length) + ) + ); + return compressedRequestBody; + } catch (final IOException ioe) { + throw new RuntimeException(ioe); + } + } + + private void printStats() { + System.err.println("** Stats "); + printMemory(); + printMemoryPools(); + printGCPools(); + } + + private void printMemory() { + final Runtime runtime = Runtime.getRuntime(); + + final long totalMemory = runtime.totalMemory(); // Total allocated memory + final long freeMemory = runtime.freeMemory(); // Amount of free memory + final long usedMemory = totalMemory - freeMemory; // Amount of used memory + + System.err.println(" Memory Total: " + totalMemory + " Free:" + freeMemory + " Used:" + usedMemory); + } + + private void printMemoryPools() { + List memoryPools = ManagementFactory.getMemoryPoolMXBeans(); + for (MemoryPoolMXBean memoryPool : memoryPools) { + MemoryUsage usage = memoryPool.getUsage(); + System.err.println(" " + memoryPool.getName() + " USED: " + usage.getUsed() + " MAX: " + usage.getMax()); + } + } + + private void printGCPools() { + List garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans(); + for (GarbageCollectorMXBean garbageCollector : garbageCollectors) { + System.err.println(" " + garbageCollector.getName() + " COLLECTION TIME: " + garbageCollector.getCollectionTime()); + } + } + +} diff --git a/src/integrationTest/java/org/opensearch/security/http/LdapTlsAuthenticationTest.java b/src/integrationTest/java/org/opensearch/security/http/LdapTlsAuthenticationTest.java index f00007e4fc..1f68a415b8 100644 --- a/src/integrationTest/java/org/opensearch/security/http/LdapTlsAuthenticationTest.java +++ b/src/integrationTest/java/org/opensearch/security/http/LdapTlsAuthenticationTest.java @@ -372,6 +372,7 @@ public void shouldImpersonateUser_negativeJean() { response.assertStatusCode(403); String expectedMessage = String.format("'%s' is not allowed to impersonate as '%s'", USER_KIRK, USER_JEAN); + System.out.println("&&&& " + response.getBody()); assertThat(response.getTextFromJsonBody(POINTER_ERROR_REASON), equalTo(expectedMessage)); } } diff --git a/src/integrationTest/java/org/opensearch/test/framework/cluster/TestRestClient.java b/src/integrationTest/java/org/opensearch/test/framework/cluster/TestRestClient.java index 59d2f57f4a..21cfd7fdaf 100644 --- a/src/integrationTest/java/org/opensearch/test/framework/cluster/TestRestClient.java +++ b/src/integrationTest/java/org/opensearch/test/framework/cluster/TestRestClient.java @@ -275,7 +275,7 @@ public void createRoleMapping(String backendRoleName, String roleName) { response.assertStatusCode(201); } - protected final String getHttpServerUri() { + public final String getHttpServerUri() { return "http" + (enableHTTPClientSSL ? "s" : "") + "://" + nodeHttpAddress.getHostString() + ":" + nodeHttpAddress.getPort(); } @@ -403,7 +403,7 @@ private JsonNode getJsonNodeAt(String jsonPointer) { try { return toJsonNode().at(jsonPointer); } catch (IOException e) { - throw new IllegalArgumentException("Cound not convert response body to JSON node ", e); + throw new IllegalArgumentException("Cound not convert response body to JSON node '" + getBody() + "'", e); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index 8bfc801217..0a6f1649b7 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -239,7 +239,7 @@ private Optional handleLowLevel(RestRequest restRequest) throw return Optional.of(new SecurityResponse(HttpStatus.SC_OK, SecurityResponse.CONTENT_TYPE_APP_JSON, responseBodyString)); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, null, "JSON could not be parsed")); + return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, new Exception("JSON could not be parsed"))); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index a3f37ba46e..846573289f 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -79,7 +79,7 @@ public class HTTPSamlAuthenticator implements HTTPAuthenticator, Destroyable { public static final String IDP_METADATA_FILE = "idp.metadata_file"; public static final String IDP_METADATA_CONTENT = "idp.metadata_content"; - private static final String API_AUTHTOKEN_SUFFIX = "api/authtoken"; + public static final String API_AUTHTOKEN_SUFFIX = "api/authtoken"; private static final String AUTHINFO_SUFFIX = "authinfo"; private static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)"; private static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX); diff --git a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java index ac68088387..667f7f9a28 100644 --- a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java +++ b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java @@ -208,7 +208,6 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin public static final String PLUGINS_PREFIX = "_plugins/_security"; private boolean sslCertReloadEnabled; - private volatile SecurityRestFilter securityRestHandler; private volatile SecurityInterceptor si; private volatile PrivilegesEvaluator evaluator; private volatile UserService userService; @@ -898,7 +897,8 @@ public Map> getHttpTransports( validatingDispatcher, clusterSettings, sharedGroupFactory, - tracer + tracer, + securityRestHandler ); return Collections.singletonMap("org.opensearch.security.http.SecurityHttpServerTransport", () -> odshst); @@ -914,7 +914,8 @@ public Map> getHttpTransports( dispatcher, clusterSettings, sharedGroupFactory, - tracer + tracer, + securityRestHandler ) ); } diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index 6e73476bdb..980dc64094 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -51,6 +51,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auditlog.AuditLog; @@ -144,11 +145,14 @@ public BackendRegistry( this.auditLog = auditLog; this.threadPool = threadPool; this.userInjector = new UserInjector(settings, threadPool, auditLog, xffResolver); + this.restAuthDomains = Collections.emptySortedSet(); + this.ipAuthFailureListeners = Collections.emptyList(); this.ttlInMin = settings.getAsInt(ConfigConstants.SECURITY_CACHE_TTL_MINUTES, 60); // This is going to be defined in the opensearch.yml, so it's best suited to be initialized once. this.injectedUserEnabled = opensearchSettings.getAsBoolean(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, false); + initialized = this.injectedUserEnabled; createCaches(); } @@ -185,7 +189,6 @@ public void onDynamicConfigModelChanged(DynamicConfigModel dcm) { /** * * @param request - * @param channel * @return The authenticated user, null means another roundtrip * @throws OpenSearchSecurityException */ @@ -200,15 +203,17 @@ public boolean authenticate(final SecurityRequestChannel request) { log.debug("Rejecting REST request because of blocked address: {}", request.getRemoteAddress().orElse(null)); } - request.queueForSending(new SecurityResponse(SC_UNAUTHORIZED, null, "Authentication finally failed")); + request.queueForSending(new SecurityResponse(SC_UNAUTHORIZED, new Exception("Authentication finally failed"))); return false; } - final String sslPrincipal = (String) threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_PRINCIPAL); + ThreadContext threadContext = this.threadPool.getThreadContext(); + + final String sslPrincipal = (String) threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_PRINCIPAL); if (adminDns.isAdminDN(sslPrincipal)) { // PKI authenticated REST call - threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, new User(sslPrincipal)); + threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, new User(sslPrincipal)); auditLog.logSucceededLogin(sslPrincipal, true, null, request); return true; } @@ -220,7 +225,7 @@ public boolean authenticate(final SecurityRequestChannel request) { if (!isInitialized()) { log.error("Not yet initialized (you may need to run securityadmin)"); - request.queueForSending(new SecurityResponse(SC_SERVICE_UNAVAILABLE, null, "OpenSearch Security not initialized.")); + request.queueForSending(new SecurityResponse(SC_SERVICE_UNAVAILABLE, new Exception("OpenSearch Security not initialized."))); return false; } diff --git a/src/main/java/org/opensearch/security/filter/NettyAttribute.java b/src/main/java/org/opensearch/security/filter/NettyAttribute.java new file mode 100644 index 0000000000..685e94e199 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/NettyAttribute.java @@ -0,0 +1,49 @@ +package org.opensearch.security.filter; + +import java.util.Optional; + +import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.rest.RestRequest; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.AttributeKey; + +public class NettyAttribute { + + /** + * Gets an attribute value from the request context and clears it from that context + */ + public static Optional popFrom(final RestRequest request, final AttributeKey attribute) { + if (request.getHttpChannel() instanceof Netty4HttpChannel) { + Channel nettyChannel = ((Netty4HttpChannel) request.getHttpChannel()).getNettyChannel(); + return Optional.ofNullable(nettyChannel.attr(attribute).getAndSet(null)); + } + return Optional.empty(); + } + + /** + * Gets an attribute value from the channel handler context and clears it from that context + */ + public static Optional popFrom(final ChannelHandlerContext ctx, final AttributeKey attribute) { + return Optional.ofNullable(ctx.channel().attr(attribute).getAndSet(null)); + } + + /** + * Gets an attribute value from the channel handler context + */ + public static Optional peekFrom(final ChannelHandlerContext ctx, final AttributeKey attribute) { + return Optional.ofNullable(ctx.channel().attr(attribute).get()); + } + + /** + * Clears an attribute value from the channel handler context + */ + public static void clearAttribute(final RestRequest request, final AttributeKey attribute) { + if (request.getHttpChannel() instanceof Netty4HttpChannel) { + Channel nettyChannel = ((Netty4HttpChannel) request.getHttpChannel()).getNettyChannel(); + nettyChannel.attr(attribute).set(null); + } + } + +} diff --git a/src/main/java/org/opensearch/security/filter/NettyRequest.java b/src/main/java/org/opensearch/security/filter/NettyRequest.java new file mode 100644 index 0000000000..4ef17b9dc7 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/NettyRequest.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; + +import javax.net.ssl.SSLEngine; + +import io.netty.handler.ssl.SslHandler; +import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.rest.RestRequest.Method; + +import io.netty.handler.codec.http.HttpRequest; +import org.opensearch.rest.RestUtils; + +/** + * Wraps the functionality of HttpRequest for use in the security plugin + */ +public class NettyRequest implements SecurityRequest { + + protected final HttpRequest underlyingRequest; + protected final Netty4HttpChannel underlyingChannel; + + NettyRequest(final HttpRequest request, final Netty4HttpChannel channel) { + this.underlyingRequest = request; + this.underlyingChannel = channel; + } + + @Override + public Map> getHeaders() { + final Map> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + underlyingRequest.headers().forEach(h -> headers.put(h.getKey(), List.of(h.getValue()))); + return headers; + } + + @Override + public SSLEngine getSSLEngine() { + // We look for Ssl_handler called `ssl_http` in the outbound pipeline of Netty channel first, and if its not + // present we look for it in inbound channel. If its present in neither we return null, else we return the sslHandler. + SslHandler sslhandler = (SslHandler) underlyingChannel.getNettyChannel().pipeline().get("ssl_http"); + return sslhandler != null ? sslhandler.engine() : null; + } + + @Override + public String path() { + String rawPath = SecurityRestUtils.path(underlyingRequest.uri()); + return RestUtils.decodeComponent(rawPath); + } + + @Override + public Method method() { + return Method.valueOf(underlyingRequest.method().name()); + } + + @Override + public Optional getRemoteAddress() { + return Optional.ofNullable(this.underlyingChannel.getRemoteAddress()); + } + + @Override + public String uri() { + return underlyingRequest.uri(); + } + + @Override + public Map params() { + return params(underlyingRequest.uri()); + } + + private static Map params(String uri) { + // Sourced from + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java#L419-L422 + final Map params = new HashMap<>(); + final int index = uri.indexOf(63); + if (index >= 0) { + try { + RestUtils.decodeQueryString(uri, index + 1, params); + } catch (IllegalArgumentException var4) { + return Collections.emptyMap(); + } + } + + return params; + } +} diff --git a/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java b/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java new file mode 100644 index 0000000000..a83ecdea8a --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import io.netty.handler.codec.http.HttpRequest; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.http.netty4.Netty4HttpChannel; + +public class NettyRequestChannel extends NettyRequest implements SecurityRequestChannel { + private final Logger log = LogManager.getLogger(NettyRequestChannel.class); + + private AtomicBoolean hasCompleted = new AtomicBoolean(false); + private final AtomicReference responseRef = new AtomicReference(null); + + NettyRequestChannel(final HttpRequest request, Netty4HttpChannel channel) { + super(request, channel); + } + + @Override + public void queueForSending(SecurityResponse response) { + if (underlyingChannel == null) { + throw new UnsupportedOperationException("Channel was not defined"); + } + + if (hasCompleted.get()) { + throw new UnsupportedOperationException("This channel has already completed"); + } + + if (getQueuedResponse().isPresent()) { + throw new UnsupportedOperationException("Another response was already queued"); + } + + responseRef.set(response); + } + + @Override + public Optional getQueuedResponse() { + return Optional.ofNullable(responseRef.get()); + } +} diff --git a/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java b/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java index 45035e0d83..24c90488cb 100644 --- a/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java +++ b/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java @@ -12,22 +12,14 @@ package org.opensearch.security.filter; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; public class OpenSearchRequestChannel extends OpenSearchRequest implements SecurityRequestChannel { - private final Logger log = LogManager.getLogger(OpenSearchRequest.class); - private final AtomicReference responseRef = new AtomicReference(null); - private final AtomicBoolean hasCompleted = new AtomicBoolean(false); private final RestChannel underlyingChannel; OpenSearchRequestChannel(final RestRequest request, final RestChannel channel) { @@ -46,10 +38,6 @@ public void queueForSending(final SecurityResponse response) { throw new UnsupportedOperationException("Channel was not defined"); } - if (hasCompleted.get()) { - throw new UnsupportedOperationException("This channel has already completed"); - } - if (getQueuedResponse().isPresent()) { throw new UnsupportedOperationException("Another response was already queued"); } @@ -61,37 +49,4 @@ public void queueForSending(final SecurityResponse response) { public Optional getQueuedResponse() { return Optional.ofNullable(responseRef.get()); } - - @Override - public boolean sendResponse() { - if (underlyingChannel == null) { - throw new UnsupportedOperationException("Channel was not defined"); - } - - if (hasCompleted.get()) { - throw new UnsupportedOperationException("This channel has already completed"); - } - - if (getQueuedResponse().isEmpty()) { - throw new UnsupportedOperationException("No response has been associated with this channel"); - } - - final SecurityResponse response = responseRef.get(); - - try { - final BytesRestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(response.getStatus()), response.getBody()); - if (response.getHeaders() != null) { - response.getHeaders().forEach(restResponse::addHeader); - } - underlyingChannel.sendResponse(restResponse); - - return true; - } catch (final Exception e) { - log.error("Error when attempting to send response", e); - throw new RuntimeException(e); - } finally { - hasCompleted.set(true); - } - - } } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java b/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java index 1eec754c08..66744d01dd 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java @@ -19,11 +19,8 @@ public interface SecurityRequestChannel extends SecurityRequest { /** Associate a response with this channel */ - public void queueForSending(final SecurityResponse response); + void queueForSending(final SecurityResponse response); /** Acess the queued response */ - public Optional getQueuedResponse(); - - /** Send the response through the channel */ - public boolean sendResponse(); + Optional getQueuedResponse(); } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java index de74df01ff..0b64d0220d 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java @@ -11,6 +11,8 @@ package org.opensearch.security.filter; +import io.netty.handler.codec.http.HttpRequest; +import org.opensearch.http.netty4.Netty4HttpChannel; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; @@ -24,6 +26,11 @@ public static SecurityRequest from(final RestRequest request) { return new OpenSearchRequest(request); } + /** Creates a security request from a netty HttpRequest object */ + public static SecurityRequestChannel from(HttpRequest request, Netty4HttpChannel channel) { + return new NettyRequestChannel(request, channel); + } + /** Creates a security request channel from a RestRequest & RestChannel */ public static SecurityRequestChannel from(final RestRequest request, final RestChannel channel) { return new OpenSearchRequestChannel(request, channel); diff --git a/src/main/java/org/opensearch/security/filter/SecurityResponse.java b/src/main/java/org/opensearch/security/filter/SecurityResponse.java index 8618be3aab..14c21a9385 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityResponse.java +++ b/src/main/java/org/opensearch/security/filter/SecurityResponse.java @@ -11,9 +11,14 @@ package org.opensearch.security.filter; +import java.io.IOException; import java.util.Map; import org.apache.http.HttpHeaders; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestResponse; public class SecurityResponse { @@ -23,6 +28,12 @@ public class SecurityResponse { private final Map headers; private final String body; + public SecurityResponse(final int status, final Exception e) { + this.status = status; + this.headers = CONTENT_TYPE_APP_JSON; + this.body = generateFailureMessage(e); + } + public SecurityResponse(final int status, final Map headers, final String body) { this.status = status; this.headers = headers; @@ -41,4 +52,26 @@ public String getBody() { return body; } + public RestResponse asRestResponse() { + final RestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), getBody()); + if (getHeaders() != null) { + getHeaders().forEach(restResponse::addHeader); + } + return restResponse; + } + + protected String generateFailureMessage(final Exception e) { + try { + return XContentFactory.jsonBuilder() + .startObject() + .startObject("error") + .field("status", "error") + .field("reason", e.getMessage()) + .endObject() + .endObject() + .toString(); + } catch (final IOException ioe) { + throw new RuntimeException(ioe); + } + } } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java index ce80a9e143..b6475e5ec3 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java @@ -63,10 +63,14 @@ import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.HTTPHelper; import org.opensearch.security.user.User; +import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import static org.opensearch.security.OpenSearchSecurityPlugin.LEGACY_OPENDISTRO_PREFIX; import static org.opensearch.security.OpenSearchSecurityPlugin.PLUGINS_PREFIX; +import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE; +import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE; +import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED; public class SecurityRestFilter { @@ -83,11 +87,11 @@ public class SecurityRestFilter { private WhitelistingSettings whitelistingSettings; private AllowlistingSettings allowlistingSettings; - private static final String HEALTH_SUFFIX = "health"; - private static final String WHO_AM_I_SUFFIX = "whoami"; + public static final String HEALTH_SUFFIX = "health"; + public static final String WHO_AM_I_SUFFIX = "whoami"; - private static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)"; - private static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX); + public static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)"; + public static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX); public SecurityRestFilter( final BackendRegistry registry, @@ -127,13 +131,33 @@ public SecurityRestFilter( */ public RestHandler wrap(RestHandler original, AdminDNs adminDNs) { return (request, channel, client) -> { - org.apache.logging.log4j.ThreadContext.clearAll(); + + final Optional maybeSavedResponse = NettyAttribute.popFrom(request, EARLY_RESPONSE); + if (maybeSavedResponse.isPresent()) { + NettyAttribute.clearAttribute(request, CONTEXT_TO_RESTORE); + NettyAttribute.clearAttribute(request, IS_AUTHENTICATED); + channel.sendResponse(maybeSavedResponse.get().asRestResponse()); + return; + } + + NettyAttribute.popFrom(request, CONTEXT_TO_RESTORE).ifPresent(storedContext -> { + // X_OPAQUE_ID will be overritten on restore - save to apply after restoring the saved context + final String xOpaqueId = threadContext.getHeader(Task.X_OPAQUE_ID); + storedContext.restore(); + if (xOpaqueId != null) { + threadContext.putHeader(Task.X_OPAQUE_ID, xOpaqueId); + } + }); + final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(request, channel); // Authenticate request - checkAndAuthenticateRequest(requestChannel); + if (!NettyAttribute.popFrom(request, IS_AUTHENTICATED).orElse(false)) { + // we aren't authenticated so we should skip this step + checkAndAuthenticateRequest(requestChannel); + } if (requestChannel.getQueuedResponse().isPresent()) { - requestChannel.sendResponse(); + channel.sendResponse(requestChannel.getQueuedResponse().get().asRestResponse()); return; } @@ -149,14 +173,13 @@ public RestHandler wrap(RestHandler original, AdminDNs adminDNs) { .or(() -> allowlistingSettings.checkRequestIsAllowed(requestChannel)); if (deniedResponse.isPresent()) { - requestChannel.queueForSending(deniedResponse.orElseThrow()); - requestChannel.sendResponse(); + channel.sendResponse(deniedResponse.get().asRestResponse()); return; } authorizeRequest(original, requestChannel, user); if (requestChannel.getQueuedResponse().isPresent()) { - requestChannel.sendResponse(); + channel.sendResponse(requestChannel.getQueuedResponse().get().asRestResponse()); return; } @@ -168,11 +191,11 @@ public RestHandler wrap(RestHandler original, AdminDNs adminDNs) { /** * Checks if a given user is a SuperAdmin */ - private boolean userIsSuperAdmin(User user, AdminDNs adminDNs) { + boolean userIsSuperAdmin(User user, AdminDNs adminDNs) { return user != null && adminDNs.isAdmin(user); } - private void authorizeRequest(RestHandler original, SecurityRequestChannel request, User user) { + void authorizeRequest(RestHandler original, SecurityRequestChannel request, User user) { List restRoutes = original.routes(); Optional handler = restRoutes.stream() .filter(rh -> rh.getMethod().equals(request.method())) @@ -224,7 +247,7 @@ public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) t log.error(exception.toString()); auditLog.logBadHeaders(requestChannel); - requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, exception.toString())); + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, exception)); return; } @@ -233,7 +256,7 @@ public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) t log.error(exception.toString()); auditLog.logBadHeaders(requestChannel); - requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, exception.toString())); + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, exception)); return; } @@ -253,7 +276,7 @@ public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) t } catch (SSLPeerUnverifiedException e) { log.error("No ssl info", e); auditLog.logSSLException(requestChannel, e); - requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, null)); + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, new Exception("No ssl info"))); return; } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java b/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java new file mode 100644 index 0000000000..1599346b90 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java @@ -0,0 +1,12 @@ +package org.opensearch.security.filter; + +public class SecurityRestUtils { + public static String path(final String uri) { + final int index = uri.indexOf('?'); + if (index >= 0) { + return uri.substring(0, index); + } else { + return uri; + } + } +} diff --git a/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java b/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java index fc36e2411b..3b70a5ebda 100644 --- a/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java @@ -26,11 +26,15 @@ package org.opensearch.security.http; +import io.netty.util.AttributeKey; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.SecurityRestFilter; import org.opensearch.security.ssl.SecurityKeyStore; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport; @@ -41,6 +45,13 @@ public class SecurityHttpServerTransport extends SecuritySSLNettyHttpServerTransport { + public static final AttributeKey EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response"); + public static final AttributeKey CONTEXT_TO_RESTORE = AttributeKey.newInstance( + "opensearch-http-request-thread-context" + ); + public static final AttributeKey SHOULD_DECOMPRESS = AttributeKey.newInstance("opensearch-http-should-decompress"); + public static final AttributeKey IS_AUTHENTICATED = AttributeKey.newInstance("opensearch-http-is-authenticated"); + public SecurityHttpServerTransport( final Settings settings, final NetworkService networkService, @@ -52,7 +63,8 @@ public SecurityHttpServerTransport( final ValidatingDispatcher dispatcher, final ClusterSettings clusterSettings, SharedGroupFactory sharedGroupFactory, - Tracer tracer + Tracer tracer, + SecurityRestFilter restFilter ) { super( settings, @@ -65,7 +77,8 @@ public SecurityHttpServerTransport( sslExceptionHandler, clusterSettings, sharedGroupFactory, - tracer + tracer, + restFilter ); } } diff --git a/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java b/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java index a8e675ec74..71586a2dff 100644 --- a/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java @@ -29,6 +29,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -36,12 +37,18 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.http.HttpHandlingSettings; import org.opensearch.http.netty4.Netty4HttpServerTransport; +import org.opensearch.security.filter.SecurityRestFilter; +import org.opensearch.security.ssl.http.netty.Netty4ConditionalDecompressor; +import org.opensearch.security.ssl.http.netty.Netty4HttpRequestHeaderVerifier; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SharedGroupFactory; public class SecurityNonSslHttpServerTransport extends Netty4HttpServerTransport { + private final ChannelInboundHandlerAdapter headerVerifier; + private final ChannelInboundHandlerAdapter conditionalDecompressor; + public SecurityNonSslHttpServerTransport( final Settings settings, final NetworkService networkService, @@ -49,9 +56,10 @@ public SecurityNonSslHttpServerTransport( final ThreadPool threadPool, final NamedXContentRegistry namedXContentRegistry, final Dispatcher dispatcher, - ClusterSettings clusterSettings, - SharedGroupFactory sharedGroupFactory, - Tracer tracer + final ClusterSettings clusterSettings, + final SharedGroupFactory sharedGroupFactory, + final Tracer tracer, + final SecurityRestFilter restFilter ) { super( settings, @@ -64,6 +72,8 @@ public SecurityNonSslHttpServerTransport( sharedGroupFactory, tracer ); + headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); + conditionalDecompressor = new Netty4ConditionalDecompressor(); } @Override @@ -82,4 +92,14 @@ protected void initChannel(Channel ch) throws Exception { super.initChannel(ch); } } + + @Override + protected ChannelInboundHandlerAdapter createHeaderVerifier() { + return headerVerifier; + } + + @Override + protected ChannelInboundHandlerAdapter createDecompressor() { + return conditionalDecompressor; + } } diff --git a/src/main/java/org/opensearch/security/http/XFFResolver.java b/src/main/java/org/opensearch/security/http/XFFResolver.java index e9ad412831..64c5f4b60c 100644 --- a/src/main/java/org/opensearch/security/http/XFFResolver.java +++ b/src/main/java/org/opensearch/security/http/XFFResolver.java @@ -34,11 +34,8 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.http.netty4.Netty4HttpChannel; -import org.opensearch.rest.RestRequest; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.security.filter.SecurityRequest; -import org.opensearch.security.filter.OpenSearchRequest; import org.opensearch.security.securityconf.DynamicConfigModel; import org.opensearch.security.support.ConfigConstants; import org.opensearch.threadpool.ThreadPool; @@ -61,15 +58,7 @@ public TransportAddress resolve(final SecurityRequest request) throws OpenSearch log.trace("resolve {}", request.getRemoteAddress().orElse(null)); } - boolean requestFromNetty = false; - if (request instanceof OpenSearchRequest) { - final OpenSearchRequest securityRequestChannel = (OpenSearchRequest) request; - final RestRequest restRequest = securityRequestChannel.breakEncapsulationForRequest(); - - requestFromNetty = restRequest.getHttpChannel() instanceof Netty4HttpChannel; - } - - if (enabled && request.getRemoteAddress().isPresent() && requestFromNetty) { + if (enabled && request.getRemoteAddress().isPresent()) { final InetSocketAddress remoteAddress = request.getRemoteAddress().get(); final InetSocketAddress isa = new InetSocketAddress(detector.detect(request, threadContext), remoteAddress.getPort()); diff --git a/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java b/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java index 18ec7457e9..722e55370e 100644 --- a/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java +++ b/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java @@ -71,6 +71,7 @@ import org.opensearch.script.ScriptService; import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.NonValidatingObjectMapper; +import org.opensearch.security.filter.SecurityRestFilter; import org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport; import org.opensearch.security.ssl.http.netty.ValidatingDispatcher; import org.opensearch.security.ssl.rest.SecuritySSLInfoAction; @@ -96,12 +97,13 @@ public class OpenSearchSecuritySSLPlugin extends Plugin implements SystemIndexPl ); public static final boolean OPENSSL_SUPPORTED = (PlatformDependent.javaVersion() < 12) && USE_NETTY_DEFAULT_ALLOCATOR; protected final Logger log = LogManager.getLogger(this.getClass()); - protected static final String CLIENT_TYPE = "client.type"; + public static final String CLIENT_TYPE = "client.type"; protected final boolean client; protected final boolean httpSSLEnabled; protected final boolean transportSSLEnabled; protected final boolean extendedKeyUsageEnabled; protected final Settings settings; + protected volatile SecurityRestFilter securityRestHandler; protected final SharedGroupFactory sharedGroupFactory; protected final SecurityKeyStore sks; protected PrincipalExtractor principalExtractor; @@ -267,7 +269,8 @@ public Map> getHttpTransports( NOOP_SSL_EXCEPTION_HANDLER, clusterSettings, sharedGroupFactory, - tracer + tracer, + securityRestHandler ); return Collections.singletonMap("org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport", () -> sgsnht); diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java new file mode 100644 index 0000000000..c8059fad5d --- /dev/null +++ b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.ssl.http.netty; + +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.HttpContentDecompressor; + +import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE; +import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS; + +import org.opensearch.security.filter.NettyAttribute; + +@Sharable +public class Netty4ConditionalDecompressor extends HttpContentDecompressor { + + @Override + protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception { + final boolean hasAnEarlyReponse = NettyAttribute.peekFrom(ctx, EARLY_RESPONSE).isPresent(); + final boolean shouldDecompress = NettyAttribute.popFrom(ctx, SHOULD_DECOMPRESS).orElse(false); + if (hasAnEarlyReponse || !shouldDecompress) { + // If there was an error prompting an early response,... don't decompress + // If there is no explicit decompress flag,... don't decompress + // If there is a decompress flag and it is false,... don't decompress + return super.newContentDecoder("identity"); + } + + // Decompresses the content based on its encoding + return super.newContentDecoder(contentEncoding); + } +} diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java new file mode 100644 index 0000000000..5112ceced3 --- /dev/null +++ b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java @@ -0,0 +1,139 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.ssl.http.netty; + +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.util.ReferenceCountUtil; +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.util.concurrent.ThreadContext; + +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.http.netty4.Netty4HttpServerTransport; +import org.opensearch.rest.RestUtils; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestChannelUnsupported; +import org.opensearch.security.filter.SecurityRequestFactory; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.SecurityRestFilter; +import org.opensearch.security.filter.SecurityRestUtils; +import org.opensearch.security.ssl.transport.SSLConfig; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.ssl.OpenSearchSecuritySSLPlugin; +import org.opensearch.common.settings.Settings; +import org.opensearch.OpenSearchSecurityException; + +import java.util.regex.Matcher; + +import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.API_AUTHTOKEN_SUFFIX; +import static org.opensearch.security.filter.SecurityRestFilter.HEALTH_SUFFIX; +import static org.opensearch.security.filter.SecurityRestFilter.PATTERN_PATH_PREFIX; +import static org.opensearch.security.filter.SecurityRestFilter.WHO_AM_I_SUFFIX; +import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE; +import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE; +import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS; +import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED; + +@Sharable +public class Netty4HttpRequestHeaderVerifier extends SimpleChannelInboundHandler { + private final SecurityRestFilter restFilter; + private final ThreadPool threadPool; + private final SSLConfig sslConfig; + private final boolean injectUserEnabled; + private final boolean passthrough; + + public Netty4HttpRequestHeaderVerifier(SecurityRestFilter restFilter, ThreadPool threadPool, Settings settings) { + this.restFilter = restFilter; + this.threadPool = threadPool; + + this.injectUserEnabled = settings.getAsBoolean(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, false); + boolean disabled = settings.getAsBoolean(ConfigConstants.SECURITY_DISABLED, false); + if (disabled) { + sslConfig = new SSLConfig(false, false); + } else { + sslConfig = new SSLConfig(settings); + } + boolean client = !"node".equals(settings.get(OpenSearchSecuritySSLPlugin.CLIENT_TYPE)); + this.passthrough = client || disabled || sslConfig.isSslOnlyMode(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) throws Exception { + // DefaultHttpRequest should always be first and contain headers + ReferenceCountUtil.retain(msg); + + if (passthrough) { + ctx.fireChannelRead(msg); + return; + } + + // Start by setting this value to false, only requests that meet all the criteria will be decompressed + ctx.channel().attr(SHOULD_DECOMPRESS).set(Boolean.FALSE); + ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.FALSE); + + final Netty4HttpChannel httpChannel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get(); + String rawPath = SecurityRestUtils.path(msg.uri()); + String path = RestUtils.decodeComponent(rawPath); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(path); + final String suffix = matcher.matches() ? matcher.group(2) : null; + if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { + ctx.fireChannelRead(msg); + return; + } + + final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(msg, httpChannel); + ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { + injectUser(msg, threadContext); + + boolean shouldSkipAuthentication = HttpMethod.OPTIONS.equals(msg.method()) + || HEALTH_SUFFIX.equals(suffix) + || WHO_AM_I_SUFFIX.equals(suffix); + + if (!shouldSkipAuthentication) { + // If request channel is completed and a response is sent, then there was a failure during authentication + restFilter.checkAndAuthenticateRequest(requestChannel); + } + + ThreadContext.StoredContext contextToRestore = threadPool.getThreadContext().newStoredContext(false); + ctx.channel().attr(CONTEXT_TO_RESTORE).set(contextToRestore); + + requestChannel.getQueuedResponse().ifPresent(response -> ctx.channel().attr(EARLY_RESPONSE).set(response)); + + boolean shouldDecompress = !shouldSkipAuthentication && requestChannel.getQueuedResponse().isEmpty(); + + if (requestChannel.getQueuedResponse().isEmpty() || shouldSkipAuthentication) { + // Only allow decompression on authenticated requests that also aren't one of those ^ + ctx.channel().attr(SHOULD_DECOMPRESS).set(Boolean.valueOf(shouldDecompress)); + ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.TRUE); + } + } catch (final OpenSearchSecurityException e) { + final SecurityResponse earlyResponse = new SecurityResponse(ExceptionsHelper.status(e).getStatus(), e); + ctx.channel().attr(EARLY_RESPONSE).set(earlyResponse); + } catch (final SecurityRequestChannelUnsupported srcu) { + // Use defaults for unsupported channels + } finally { + ctx.fireChannelRead(msg); + } + } + + private void injectUser(HttpRequest request, ThreadContext threadContext) { + if (this.injectUserEnabled) { + threadContext.putTransient( + ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, + request.headers().get(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER) + ); + } + } +} diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java b/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java index 04f71485ba..3eee278083 100644 --- a/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java @@ -19,6 +19,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.DecoderException; import io.netty.handler.ssl.SslHandler; import org.apache.logging.log4j.LogManager; @@ -32,6 +33,7 @@ import org.opensearch.http.HttpChannel; import org.opensearch.http.HttpHandlingSettings; import org.opensearch.http.netty4.Netty4HttpServerTransport; +import org.opensearch.security.filter.SecurityRestFilter; import org.opensearch.security.ssl.SecurityKeyStore; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.telemetry.tracing.Tracer; @@ -43,6 +45,8 @@ public class SecuritySSLNettyHttpServerTransport extends Netty4HttpServerTranspo private static final Logger logger = LogManager.getLogger(SecuritySSLNettyHttpServerTransport.class); private final SecurityKeyStore sks; private final SslExceptionHandler errorHandler; + private final ChannelInboundHandlerAdapter headerVerifier; + private final ChannelInboundHandlerAdapter conditionalDecompressor; public SecuritySSLNettyHttpServerTransport( final Settings settings, @@ -55,7 +59,8 @@ public SecuritySSLNettyHttpServerTransport( final SslExceptionHandler errorHandler, ClusterSettings clusterSettings, SharedGroupFactory sharedGroupFactory, - Tracer tracer + Tracer tracer, + SecurityRestFilter restFilter ) { super( settings, @@ -70,6 +75,8 @@ public SecuritySSLNettyHttpServerTransport( ); this.sks = sks; this.errorHandler = errorHandler; + headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); + conditionalDecompressor = new Netty4ConditionalDecompressor(); } @Override @@ -108,4 +115,14 @@ protected void initChannel(Channel ch) throws Exception { ch.pipeline().addFirst("ssl_http", sslHandler); } } + + @Override + protected ChannelInboundHandlerAdapter createHeaderVerifier() { + return headerVerifier; + } + + @Override + protected ChannelInboundHandlerAdapter createDecompressor() { + return conditionalDecompressor; + } } diff --git a/src/test/java/org/opensearch/security/SystemIntegratorsTests.java b/src/test/java/org/opensearch/security/SystemIntegratorsTests.java index 18670e6f4d..265a3ccb6e 100644 --- a/src/test/java/org/opensearch/security/SystemIntegratorsTests.java +++ b/src/test/java/org/opensearch/security/SystemIntegratorsTests.java @@ -44,10 +44,7 @@ public class SystemIntegratorsTests extends SingleClusterTest { @Test public void testInjectedUserMalformed() throws Exception { - final Settings settings = Settings.builder() - .put(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, true) - .put("http.type", "org.opensearch.security.http.UserInjectingServerTransport") - .build(); + final Settings settings = Settings.builder().put(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, true).build(); setup(settings, ClusterConfiguration.USERINJECTOR); @@ -115,10 +112,7 @@ public void testInjectedUserMalformed() throws Exception { @Test public void testInjectedUser() throws Exception { - final Settings settings = Settings.builder() - .put(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, true) - .put("http.type", "org.opensearch.security.http.UserInjectingServerTransport") - .build(); + final Settings settings = Settings.builder().put(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, true).build(); setup(settings, ClusterConfiguration.USERINJECTOR); @@ -250,7 +244,7 @@ public void testInjectedUser() throws Exception { @Test public void testInjectedUserDisabled() throws Exception { - final Settings settings = Settings.builder().put("http.type", "org.opensearch.security.http.UserInjectingServerTransport").build(); + final Settings settings = Settings.builder().build(); setup(settings, ClusterConfiguration.USERINJECTOR); @@ -276,7 +270,6 @@ public void testInjectedAdminUser() throws Exception { ConfigConstants.SECURITY_AUTHCZ_ADMIN_DN, Lists.newArrayList("CN=kirk,OU=client,O=client,L=Test,C=DE", "injectedadmin") ) - .put("http.type", "org.opensearch.security.http.UserInjectingServerTransport") .build(); setup(settings, ClusterConfiguration.USERINJECTOR); @@ -312,7 +305,6 @@ public void testInjectedAdminUserAdminInjectionDisabled() throws Exception { ConfigConstants.SECURITY_AUTHCZ_ADMIN_DN, Lists.newArrayList("CN=kirk,OU=client,O=client,L=Test,C=DE", "injectedadmin") ) - .put("http.type", "org.opensearch.security.http.UserInjectingServerTransport") .build(); setup(settings, ClusterConfiguration.USERINJECTOR); diff --git a/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java b/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java index ef7eadc7da..f98ab04cd1 100644 --- a/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java +++ b/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java @@ -510,12 +510,15 @@ public void testUpdateSettings() throws Exception { + "}" + "}"; + String expectedRequestBodyLog = + "{\\\"persistent_settings\\\":{\\\"indices\\\":{\\\"recovery\\\":{\\\"*\\\":null}}},\\\"transient_settings\\\":{\\\"indices\\\":{\\\"recovery\\\":{\\\"*\\\":null}}}}"; + HttpResponse response = rh.executePutRequest("_cluster/settings", json, encodeBasicHeader("admin", "admin")); Assert.assertEquals(HttpStatus.SC_OK, response.getStatusCode()); String auditLogImpl = TestAuditlogImpl.sb.toString(); Assert.assertTrue(auditLogImpl.contains("AUTHENTICATED")); Assert.assertTrue(auditLogImpl.contains("cluster:admin/settings/update")); - Assert.assertTrue(auditLogImpl.contains("indices.recovery.*")); + Assert.assertTrue(auditLogImpl.contains(expectedRequestBodyLog)); // may vary because we log may hit cluster manager directly or not Assert.assertTrue(TestAuditlogImpl.messages.size() > 1); Assert.assertTrue(validateMsgs(TestAuditlogImpl.messages)); diff --git a/src/test/java/org/opensearch/security/filter/SecurityRestFilterUnitTests.java b/src/test/java/org/opensearch/security/filter/SecurityRestFilterUnitTests.java new file mode 100644 index 0000000000..96d08659c4 --- /dev/null +++ b/src/test/java/org/opensearch/security/filter/SecurityRestFilterUnitTests.java @@ -0,0 +1,111 @@ +package org.opensearch.security.filter; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.security.auditlog.AuditLog; +import org.opensearch.security.auth.BackendRegistry; +import org.opensearch.security.configuration.AdminDNs; +import org.opensearch.security.configuration.CompatConfig; +import org.opensearch.security.privileges.RestLayerPrivilegesEvaluator; +import org.opensearch.security.ssl.transport.PrincipalExtractor; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.ThreadPool; + +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +public class SecurityRestFilterUnitTests { + + SecurityRestFilter sf; + RestHandler testRestHandler; + + class TestRestHandler implements RestHandler { + + @Override + public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)); + } + } + + @Before + public void setUp() throws NoSuchMethodException { + testRestHandler = new TestRestHandler(); + + ThreadPool tp = spy(new ThreadPool(Settings.builder().put("node.name", "mock").build())); + doReturn(new ThreadContext(Settings.EMPTY)).when(tp).getThreadContext(); + + sf = new SecurityRestFilter( + mock(BackendRegistry.class), + mock(RestLayerPrivilegesEvaluator.class), + mock(AuditLog.class), + tp, + mock(PrincipalExtractor.class), + Settings.EMPTY, + mock(Path.class), + mock(CompatConfig.class) + ); + } + + @Ignore + @Test + public void testDoesCallDelegateOnSuccessfulAuthorization() throws Exception { + SecurityRestFilter filterSpy = spy(sf); + AdminDNs adminDNs = mock(AdminDNs.class); + + RestHandler testRestHandlerSpy = spy(testRestHandler); + RestHandler wrappedRestHandler = filterSpy.wrap(testRestHandlerSpy, adminDNs); + + doReturn(false).when(filterSpy).userIsSuperAdmin(any(), any()); + // doReturn(true).when(filterSpy).authorizeRequest(any(), any(), any()); + + FakeRestRequest fakeRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/test") + .withMethod(RestRequest.Method.POST) + .withHeaders(Map.of("Content-Type", List.of("application/json"))) + .build(); + + wrappedRestHandler.handleRequest(fakeRequest, mock(RestChannel.class), mock(NodeClient.class)); + + verify(testRestHandlerSpy).handleRequest(any(), any(), any()); + } + + @Ignore + @Test + public void testDoesNotCallDelegateOnUnauthorized() throws Exception { + SecurityRestFilter filterSpy = spy(sf); + AdminDNs adminDNs = mock(AdminDNs.class); + + RestHandler testRestHandlerSpy = spy(testRestHandler); + RestHandler wrappedRestHandler = filterSpy.wrap(testRestHandlerSpy, adminDNs); + + doReturn(false).when(filterSpy).userIsSuperAdmin(any(), any()); + // doReturn(false).when(filterSpy).authorizeRequest(any(), any(), any()); + + FakeRestRequest fakeRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/test") + .withMethod(RestRequest.Method.POST) + .withHeaders(Map.of("Content-Type", List.of("application/json"))) + .build(); + + wrappedRestHandler.handleRequest(fakeRequest, mock(RestChannel.class), mock(NodeClient.class)); + + verify(testRestHandlerSpy, never()).handleRequest(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/security/test/helper/cluster/ClusterConfiguration.java b/src/test/java/org/opensearch/security/test/helper/cluster/ClusterConfiguration.java index 456a9f310c..f8bea4c476 100644 --- a/src/test/java/org/opensearch/security/test/helper/cluster/ClusterConfiguration.java +++ b/src/test/java/org/opensearch/security/test/helper/cluster/ClusterConfiguration.java @@ -41,7 +41,6 @@ import org.opensearch.script.mustache.MustachePlugin; import org.opensearch.search.aggregations.matrix.MatrixAggregationPlugin; import org.opensearch.security.OpenSearchSecurityPlugin; -import org.opensearch.security.test.plugin.UserInjectorPlugin; import org.opensearch.transport.Netty4Plugin; public enum ClusterConfiguration { @@ -79,11 +78,7 @@ public enum ClusterConfiguration { CLIENTNODE(new NodeSettings(true, false), new NodeSettings(false, true), new NodeSettings(false, true), new NodeSettings(false, false)), // 3 nodes (1m, 2d) plus additional UserInjectorPlugin - USERINJECTOR( - new NodeSettings(true, false, Lists.newArrayList(UserInjectorPlugin.class)), - new NodeSettings(false, true, Lists.newArrayList(UserInjectorPlugin.class)), - new NodeSettings(false, true, Lists.newArrayList(UserInjectorPlugin.class)) - ); + USERINJECTOR(new NodeSettings(true, false), new NodeSettings(false, true), new NodeSettings(false, true)); private List nodeSettings = new LinkedList<>(); diff --git a/src/test/java/org/opensearch/security/test/plugin/UserInjectorPlugin.java b/src/test/java/org/opensearch/security/test/plugin/UserInjectorPlugin.java deleted file mode 100644 index 73ede93651..0000000000 --- a/src/test/java/org/opensearch/security/test/plugin/UserInjectorPlugin.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright 2015-2018 _floragunn_ GmbH - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.security.test.plugin; - -import java.nio.file.Path; -import java.util.Map; -import java.util.function.Supplier; - -import com.google.common.collect.ImmutableMap; - -import org.opensearch.common.network.NetworkService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.BigArrays; -import org.opensearch.common.util.PageCacheRecycler; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.indices.breaker.CircuitBreakerService; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.http.HttpServerTransport; -import org.opensearch.http.HttpServerTransport.Dispatcher; -import org.opensearch.http.netty4.Netty4HttpServerTransport; -import org.opensearch.plugins.NetworkPlugin; -import org.opensearch.plugins.Plugin; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; -import org.opensearch.security.support.ConfigConstants; -import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.SharedGroupFactory; - -/** - * Mimics the behavior of system integrators that run their own plugins (i.e. server transports) - * in front of OpenSearch Security. This transport just copies the user string from the - * REST headers to the ThreadContext to test user injection. - * @author jkressin - */ -public class UserInjectorPlugin extends Plugin implements NetworkPlugin { - - Settings settings; - private final SharedGroupFactory sharedGroupFactory; - ThreadPool threadPool; - - public UserInjectorPlugin(final Settings settings, final Path configPath) { - this.settings = settings; - sharedGroupFactory = new SharedGroupFactory(settings); - } - - @Override - public Map> getHttpTransports( - Settings settings, - ThreadPool threadPool, - BigArrays bigArrays, - PageCacheRecycler pageCacheRecycler, - CircuitBreakerService circuitBreakerService, - NamedXContentRegistry xContentRegistry, - NetworkService networkService, - Dispatcher dispatcher, - ClusterSettings clusterSettings, - Tracer tracer - ) { - - final UserInjectingDispatcher validatingDispatcher = new UserInjectingDispatcher(dispatcher); - return ImmutableMap.of( - "org.opensearch.security.http.UserInjectingServerTransport", - () -> new UserInjectingServerTransport( - settings, - networkService, - bigArrays, - threadPool, - xContentRegistry, - validatingDispatcher, - clusterSettings, - sharedGroupFactory, - tracer - ) - ); - } - - class UserInjectingServerTransport extends Netty4HttpServerTransport { - - public UserInjectingServerTransport( - final Settings settings, - final NetworkService networkService, - final BigArrays bigArrays, - final ThreadPool threadPool, - final NamedXContentRegistry namedXContentRegistry, - final Dispatcher dispatcher, - ClusterSettings clusterSettings, - SharedGroupFactory sharedGroupFactory, - Tracer tracer - ) { - super( - settings, - networkService, - bigArrays, - threadPool, - namedXContentRegistry, - dispatcher, - clusterSettings, - sharedGroupFactory, - tracer - ); - } - } - - class UserInjectingDispatcher implements Dispatcher { - - private Dispatcher originalDispatcher; - - public UserInjectingDispatcher(final Dispatcher originalDispatcher) { - super(); - this.originalDispatcher = originalDispatcher; - } - - @Override - public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { - threadContext.putTransient( - ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, - request.header(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER) - ); - originalDispatcher.dispatchRequest(request, channel, threadContext); - - } - - @Override - public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, Throwable cause) { - threadContext.putTransient( - ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, - channel.request().header(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER) - ); - originalDispatcher.dispatchBadRequest(channel, threadContext, cause); - } - } - -}