Skip to content

Commit

Permalink
Add early rejection from RestHandler for unauthorized requests (#3418)
Browse files Browse the repository at this point in the history
Previously unauthorized requests were fully processed and rejected once
they reached the RestHandler. This allocations more memory and resources
for these requests that might not be useful if they are already detected
as unauthorized. Using the headerVerifer and decompressor customization
from [1], perform an early authorization check when only the headers are
available, save an 'early response' for transmission and do not perform
the decompression on the request to speed up closing out the connection.

- Resolves opensearch-project/OpenSearch#10260

Signed-off-by: Peter Nied <petern@amazon.com>
Signed-off-by: Craig Perkins <cwperx@amazon.com>
Signed-off-by: Craig Perkins <craig5008@gmail.com>
Co-authored-by: Peter Nied <petern@amazon.com>
  • Loading branch information
cwperks and peternied authored Oct 6, 2023
1 parent 0e4b8a4 commit 6b0b682
Show file tree
Hide file tree
Showing 30 changed files with 947 additions and 275 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ on:

env:
GRADLE_OPTS: -Dhttp.keepAlive=false
CI_ENVIRONMENT: normal

jobs:
generate-test-list:
Expand Down Expand Up @@ -108,6 +109,32 @@ jobs:
arguments: |
integrationTest -Dbuild.snapshot=false
resource-tests:
env:
CI_ENVIRONMENT: resource-test
strategy:
fail-fast: false
matrix:
jdk: [17]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}

steps:
- name: Set up JDK for build and test
uses: actions/setup-java@v3
with:
distribution: temurin # Temurin is a distribution of adoptium
java-version: ${{ matrix.jdk }}

- name: Checkout security
uses: actions/checkout@v4

- name: Build and Test
uses: gradle/gradle-build-action@v2
with:
cache-disabled: true
arguments: |
integrationTest -Dbuild.snapshot=false --tests org.opensearch.security.ResourceFocusedTests
backward-compatibility-build:
runs-on: ubuntu-latest
steps:
Expand Down
21 changes: 16 additions & 5 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,27 @@ sourceSets {

//add new task that runs integration tests
task integrationTest(type: Test) {
doFirst {
// Only run resources tests on resource-test CI environments or locally
if (System.getenv('CI_ENVIRONMENT') == 'resource-test' || System.getenv('CI_ENVIRONMENT') == null) {
include '**/ResourceFocusedTests.class'
} else {
exclude '**/ResourceFocusedTests.class'
}
// Only run with retries while in CI systems
if (System.getenv('CI_ENVIRONMENT') == 'normal') {
retry {
failOnPassedAfterRetry = false
maxRetries = 2
maxFailures = 10
}
}
}
description = 'Run integration tests.'
group = 'verification'
systemProperty "java.util.logging.manager", "org.apache.logging.log4j.jul.LogManager"
testClassesDirs = sourceSets.integrationTest.output.classesDirs
classpath = sourceSets.integrationTest.runtimeClasspath
retry {
failOnPassedAfterRetry = false
maxRetries = 2
maxFailures = 10
}
//run the integrationTest task after the test task
shouldRunAfter test
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
package org.opensearch.security;

import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;
import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryPoolMXBean;
import java.lang.management.MemoryUsage;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPOutputStream;

import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.io.entity.ByteArrayEntity;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.client.Client;
import org.opensearch.test.framework.TestSecurityConfig;
import org.opensearch.test.framework.TestSecurityConfig.User;
import org.opensearch.test.framework.cluster.ClusterManager;
import org.opensearch.test.framework.cluster.LocalCluster;
import org.opensearch.test.framework.cluster.TestRestClient;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class ResourceFocusedTests {
private static final User ADMIN_USER = new User("admin").roles(ALL_ACCESS);
private static final User LIMITED_USER = new User("limited_user").roles(
new TestSecurityConfig.Role("limited-role").clusterPermissions(
"indices:data/read/mget",
"indices:data/read/msearch",
"indices:data/read/scroll",
"cluster:monitor/state",
"cluster:monitor/health"
)
.indexPermissions(
"indices:data/read/search",
"indices:data/read/mget*",
"indices:monitor/settings/get",
"indices:monitor/stats"
)
.on("*")
);

@ClassRule
public static LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.THREE_CLUSTER_MANAGERS)
.authc(AUTHC_HTTPBASIC_INTERNAL)
.users(ADMIN_USER, LIMITED_USER)
.anonymousAuth(false)
.doNotFailOnForbidden(true)
.build();

@BeforeClass
public static void createTestData() {
try (Client client = cluster.getInternalNodeClient()) {
client.index(new IndexRequest().setRefreshPolicy(IMMEDIATE).index("document").source(Map.of("foo", "bar", "abc", "xyz")))
.actionGet();
}
}

@Test
public void testUnauthenticatedFewBig() {
// Tweaks:
final RequestBodySize size = RequestBodySize.XLarge;
final String requestPath = "/*/_search";
final int parrallelism = 5;
final int totalNumberOfRequests = 100;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
}

@Test
public void testUnauthenticatedManyMedium() {
// Tweaks:
final RequestBodySize size = RequestBodySize.Medium;
final String requestPath = "/*/_search";
final int parrallelism = 20;
final int totalNumberOfRequests = 10_000;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
}

@Test
public void testUnauthenticatedTonsSmall() {
// Tweaks:
final RequestBodySize size = RequestBodySize.Small;
final String requestPath = "/*/_search";
final int parrallelism = 100;
final int totalNumberOfRequests = 1_000_000;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
}

private Long runResourceTest(
final RequestBodySize size,
final String requestPath,
final int parrallelism,
final int totalNumberOfRequests,
final boolean statsPrinter
) {
final byte[] compressedRequestBody = createCompressedRequestBody(size);
try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) {

if (statsPrinter) {
printStats();
}
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));

final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism);

final List<CompletableFuture<Void>> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests)
.boxed()
.map(i -> CompletableFuture.runAsync(() -> client.executeRequest(post), forkJoinPool))
.collect(Collectors.toList());
Supplier<Long> getCount = () -> waitingOn.stream().filter(cf -> cf.isDone() && !cf.isCompletedExceptionally()).count();

CompletableFuture<Void> statPrinter = statsPrinter ? CompletableFuture.runAsync(() -> {
while (true) {
printStats();
System.out.println(" & Succesful completions: " + getCount.get());
try {
Thread.sleep(500);
} catch (Exception e) {
break;
}
}
}, forkJoinPool) : CompletableFuture.completedFuture(null);

final CompletableFuture<Void> allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0]));

try {
allOfThem.get(30, TimeUnit.SECONDS);
statPrinter.cancel(true);
} catch (final Exception e) {
// Ignored
}

if (statsPrinter) {
printStats();
System.out.println(" & Succesful completions: " + getCount.get());
}
return getCount.get();
}
}

static enum RequestBodySize {
Small(1),
Medium(1_000),
XLarge(1_000_000);

public final int elementCount;

private RequestBodySize(final int elementCount) {
this.elementCount = elementCount;
}
}

private byte[] createCompressedRequestBody(final RequestBodySize size) {
final int repeatCount = size.elementCount;
final String prefix = "{ \"items\": [";
final String repeatedElement = IntStream.range(0, 20)
.mapToObj(n -> ('a' + n) + "")
.map(n -> '"' + n + '"' + ": 123")
.collect(Collectors.joining(",", "{", "}"));
final String postfix = "]}";
long uncompressedBytesSize = 0;

try (
final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
final GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)
) {

final byte[] prefixBytes = prefix.getBytes(StandardCharsets.UTF_8);
final byte[] repeatedElementBytes = repeatedElement.getBytes(StandardCharsets.UTF_8);
final byte[] postfixBytes = postfix.getBytes(StandardCharsets.UTF_8);

gzipOutputStream.write(prefixBytes);
uncompressedBytesSize = uncompressedBytesSize + prefixBytes.length;
for (int i = 0; i < repeatCount; i++) {
gzipOutputStream.write(repeatedElementBytes);
uncompressedBytesSize = uncompressedBytesSize + repeatedElementBytes.length;
}
gzipOutputStream.write(postfixBytes);
uncompressedBytesSize = uncompressedBytesSize + postfixBytes.length;
gzipOutputStream.finish();

final byte[] compressedRequestBody = byteArrayOutputStream.toByteArray();
System.out.println(
"^^^"
+ String.format(
"Original size was %,d bytes, compressed to %,d bytes, ratio %,.2f",
uncompressedBytesSize,
compressedRequestBody.length,
((double) uncompressedBytesSize / compressedRequestBody.length)
)
);
return compressedRequestBody;
} catch (final IOException ioe) {
throw new RuntimeException(ioe);
}
}

private void printStats() {
System.out.println("** Stats ");
printMemory();
printMemoryPools();
printGCPools();
}

private void printMemory() {
final Runtime runtime = Runtime.getRuntime();

final long totalMemory = runtime.totalMemory(); // Total allocated memory
final long freeMemory = runtime.freeMemory(); // Amount of free memory
final long usedMemory = totalMemory - freeMemory; // Amount of used memory

System.out.println(" Memory Total: " + totalMemory + " Free:" + freeMemory + " Used:" + usedMemory);
}

private void printMemoryPools() {
List<MemoryPoolMXBean> memoryPools = ManagementFactory.getMemoryPoolMXBeans();
for (MemoryPoolMXBean memoryPool : memoryPools) {
MemoryUsage usage = memoryPool.getUsage();
System.out.println(" " + memoryPool.getName() + " USED: " + usage.getUsed() + " MAX: " + usage.getMax());
}
}

private void printGCPools() {
List<GarbageCollectorMXBean> garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans();
for (GarbageCollectorMXBean garbageCollector : garbageCollectors) {
System.out.println(" " + garbageCollector.getName() + " COLLECTION TIME: " + garbageCollector.getCollectionTime());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ public void createRoleMapping(String backendRoleName, String roleName) {
response.assertStatusCode(201);
}

protected final String getHttpServerUri() {
public final String getHttpServerUri() {
return "http" + (enableHTTPClientSSL ? "s" : "") + "://" + nodeHttpAddress.getHostString() + ":" + nodeHttpAddress.getPort();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class HTTPSamlAuthenticator implements HTTPAuthenticator, Destroyable {
public static final String IDP_METADATA_FILE = "idp.metadata_file";
public static final String IDP_METADATA_CONTENT = "idp.metadata_content";

private static final String API_AUTHTOKEN_SUFFIX = "api/authtoken";
public static final String API_AUTHTOKEN_SUFFIX = "api/authtoken";
private static final String AUTHINFO_SUFFIX = "authinfo";
private static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)";
private static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin
public static final String PLUGINS_PREFIX = "_plugins/_security";

private boolean sslCertReloadEnabled;
private volatile SecurityRestFilter securityRestHandler;
private volatile SecurityInterceptor si;
private volatile PrivilegesEvaluator evaluator;
private volatile UserService userService;
Expand Down Expand Up @@ -914,7 +913,8 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
validatingDispatcher,
clusterSettings,
sharedGroupFactory,
tracer
tracer,
securityRestHandler
);

return Collections.singletonMap("org.opensearch.security.http.SecurityHttpServerTransport", () -> odshst);
Expand All @@ -930,7 +930,8 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
dispatcher,
clusterSettings,
sharedGroupFactory,
tracer
tracer,
securityRestHandler
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@
import org.apache.hc.client5.http.ssl.NoopHostnameVerifier;
import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactory;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.HttpStatus;
import org.apache.hc.core5.http.io.SocketConfig;
import org.apache.hc.core5.http.io.entity.StringEntity;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.apache.hc.core5.ssl.TrustStrategy;

import org.apache.http.HttpStatus;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.Strings;
import org.opensearch.security.auditlog.impl.AuditMessage;
Expand Down
Loading

0 comments on commit 6b0b682

Please sign in to comment.