From 7062a48b01663491dfd91ecdb3edfe3dc0985002 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Fri, 22 Nov 2024 20:18:20 +0100 Subject: [PATCH 1/9] Delay creation of the next SearchPhase in executeNextPhase (#116061) (#117369) Delaying the creation of the next phase to only when we actually need it makes this a lot easier to reason about and should set up further simplications. Eager creation of the next phase forced a lot of needlessly complicated safety logic around resources on us. Since we never "close" the `nextPhase` on failure all its resources need to be tracked in via `context.addReleasable`. This isn't as much of an issue with some recent refactorings leaving very little resource creation in the constructors but still, delaying things saves memory and makes reasoning about failure cases far easier. --- .../action/search/AbstractSearchAsyncAction.java | 6 ++++-- .../org/elasticsearch/action/search/DfsQueryPhase.java | 2 +- .../elasticsearch/action/search/ExpandSearchPhase.java | 2 +- .../org/elasticsearch/action/search/FetchSearchPhase.java | 8 +++++--- .../org/elasticsearch/action/search/RankFeaturePhase.java | 2 +- .../elasticsearch/action/search/SearchPhaseContext.java | 3 ++- .../action/search/MockSearchPhaseContext.java | 4 +++- 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index d62f16079088f..beb056d5e43c9 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -56,6 +56,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; @@ -368,7 +369,7 @@ protected abstract void executePhaseOnShard( ); @Override - public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) { + public final void executeNextPhase(SearchPhase currentPhase, Supplier nextPhaseSupplier) { /* This is the main search phase transition where we move to the next phase. If all shards * failed or if there was a failure and partial results are not allowed, then we immediately * fail. Otherwise we continue to the next phase. @@ -412,6 +413,7 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha } return; } + var nextPhase = nextPhaseSupplier.get(); if (logger.isTraceEnabled()) { final String resultsFrom = results.getSuccessfulResults() .map(r -> r.getSearchShardTarget().toString()) @@ -722,7 +724,7 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { * @see #onShardResult(SearchPhaseResult, SearchShardIterator) */ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() - executeNextPhase(this, getNextPhase(results, this)); + executeNextPhase(this, () -> getNextPhase(results, this)); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index e0e240be0377a..aeef5ff7ce94c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -78,7 +78,7 @@ public void run() { final CountedCollector counter = new CountedCollector<>( queryResult, searchResults.size(), - () -> context.executeNextPhase(this, nextPhaseFactory.apply(queryResult)), + () -> context.executeNextPhase(this, () -> nextPhaseFactory.apply(queryResult)), context ); diff --git a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java index 5457ca60d0da4..968d9dac958fa 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java @@ -164,6 +164,6 @@ private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilde } private void onPhaseDone() { - context.executeNextPhase(this, nextPhase.get()); + context.executeNextPhase(this, nextPhase); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 99b24bd483fb4..e3007b1c5b617 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -260,9 +260,11 @@ private void moveToNextPhase( AtomicArray fetchResultsArr, SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { - var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); - context.addReleasable(resp::decRef); - context.executeNextPhase(this, nextPhaseFactory.apply(resp, searchPhaseShardResults)); + context.executeNextPhase(this, () -> { + var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); + context.addReleasable(resp::decRef); + return nextPhaseFactory.apply(resp, searchPhaseShardResults); + }); } private boolean shouldExplainRankScores(SearchRequest request) { diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 81053a70eca9f..31179423fa0dc 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -233,6 +233,6 @@ private float maxScore(ScoreDoc[] scoreDocs) { } void moveToNextPhase(SearchPhaseResults phaseResults, SearchPhaseController.ReducedQueryPhase reducedQueryPhase) { - context.executeNextPhase(this, new FetchSearchPhase(phaseResults, aggregatedDfs, context, reducedQueryPhase)); + context.executeNextPhase(this, () -> new FetchSearchPhase(phaseResults, aggregatedDfs, context, reducedQueryPhase)); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java index 871be0a349a7f..d048887b69c97 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java @@ -19,6 +19,7 @@ import org.elasticsearch.transport.Transport; import java.util.concurrent.Executor; +import java.util.function.Supplier; /** * This class provide contextual state and access to resources across multiple search phases. @@ -120,7 +121,7 @@ default void sendReleaseSearchContext( * of the next phase. If there are no successful operations in the context when this method is executed the search is aborted and * a response is returned to the user indicating that all shards have failed. */ - void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase); + void executeNextPhase(SearchPhase currentPhase, Supplier nextPhaseSupplier); /** * Registers a {@link Releasable} that will be closed when the search request finishes or fails. diff --git a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java index 5d4d60f6805b1..5395e4569901a 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java @@ -30,6 +30,7 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; /** * SearchPhaseContext for tests @@ -132,7 +133,8 @@ public SearchTransportService getSearchTransport() { } @Override - public void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) { + public void executeNextPhase(SearchPhase currentPhase, Supplier nextPhaseSupplier) { + var nextPhase = nextPhaseSupplier.get(); try { nextPhase.run(); } catch (Exception e) { From 4220d0bd5be9b930cd496ac1531c18738ba4f217 Mon Sep 17 00:00:00 2001 From: Oleksandr Kolomiiets Date: Fri, 22 Nov 2024 11:22:42 -0800 Subject: [PATCH 2/9] Fix constand_keyword test run and properly test recent behavior change (#117284) (#117371) --- .../index/mapper/MapperFeatures.java | 7 ++++- .../mapper-constant-keyword/build.gradle | 2 +- .../ConstantKeywordClientYamlTestSuiteIT.java | 10 +++++++ .../test/20_synthetic_source.yml | 26 +++++++++++++++++-- 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index de7eeb4b180aa..365919d7852db 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -51,6 +51,10 @@ public Set getFeatures() { ); } + public static final NodeFeature CONSTANT_KEYWORD_SYNTHETIC_SOURCE_WRITE_FIX = new NodeFeature( + "mapper.constant_keyword.synthetic_source_write_fix" + ); + @Override public Set getTestFeatures() { return Set.of( @@ -59,7 +63,8 @@ public Set getTestFeatures() { SourceFieldMapper.REMOVE_SYNTHETIC_SOURCE_ONLY_VALIDATION, SourceFieldMapper.SOURCE_MODE_FROM_INDEX_SETTING, IgnoredSourceFieldMapper.ALWAYS_STORE_OBJECT_ARRAYS_IN_NESTED_OBJECTS, - MapperService.LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT + MapperService.LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT, + CONSTANT_KEYWORD_SYNTHETIC_SOURCE_WRITE_FIX ); } } diff --git a/x-pack/plugin/mapper-constant-keyword/build.gradle b/x-pack/plugin/mapper-constant-keyword/build.gradle index ad9d3c2f86637..ca7202f2fbbff 100644 --- a/x-pack/plugin/mapper-constant-keyword/build.gradle +++ b/x-pack/plugin/mapper-constant-keyword/build.gradle @@ -1,7 +1,7 @@ import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.internal-es-plugin' -apply plugin: 'elasticsearch.legacy-yaml-rest-test' +apply plugin: 'elasticsearch.internal-yaml-rest-test' esplugin { name 'constant-keyword' diff --git a/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/java/org/elasticsearch/xpack/constantkeyword/ConstantKeywordClientYamlTestSuiteIT.java b/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/java/org/elasticsearch/xpack/constantkeyword/ConstantKeywordClientYamlTestSuiteIT.java index 789059d9e11c0..5b6048b481abf 100644 --- a/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/java/org/elasticsearch/xpack/constantkeyword/ConstantKeywordClientYamlTestSuiteIT.java +++ b/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/java/org/elasticsearch/xpack/constantkeyword/ConstantKeywordClientYamlTestSuiteIT.java @@ -10,8 +10,10 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.ClassRule; /** Runs yaml rest tests */ public class ConstantKeywordClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { @@ -24,4 +26,12 @@ public ConstantKeywordClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidat public static Iterable parameters() throws Exception { return ESClientYamlSuiteTestCase.createParameters(); } + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local().module("constant-keyword").build(); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } } diff --git a/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/resources/rest-api-spec/test/20_synthetic_source.yml b/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/resources/rest-api-spec/test/20_synthetic_source.yml index d40f69f483dbb..012b1006b8d20 100644 --- a/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/resources/rest-api-spec/test/20_synthetic_source.yml +++ b/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/resources/rest-api-spec/test/20_synthetic_source.yml @@ -1,7 +1,7 @@ constant_keyword: - requires: - cluster_features: [ "mapper.source.mode_from_index_setting" ] - reason: "Source mode configured through index setting" + cluster_features: [ "mapper.constant_keyword.synthetic_source_write_fix" ] + reason: "Behavior fix" - do: indices.create: @@ -26,6 +26,15 @@ constant_keyword: body: kwd: foo + - do: + index: + index: test + id: 2 + refresh: true + body: + kwd: foo + const_kwd: bar + - do: search: index: test @@ -33,6 +42,19 @@ constant_keyword: query: ids: values: [1] + + - match: + hits.hits.0._source: + kwd: foo + + - do: + search: + index: test + body: + query: + ids: + values: [2] + - match: hits.hits.0._source: kwd: foo From 3699811cb14bc2b374093804e9043d73bd67abc9 Mon Sep 17 00:00:00 2001 From: Stanislav Malyshev Date: Fri, 22 Nov 2024 12:27:08 -0700 Subject: [PATCH 3/9] FIx async search tests - do not warn on the presence of .async-search (#117301) (#117372) (cherry picked from commit f325c1541088995f35e7d39cf181a9b970d3c90a) # Conflicts: # muted-tests.yml # test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java --- .../test/rest/ESRestTestCase.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java index 5076f97ec96d6..fce766fe070bb 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java @@ -1170,6 +1170,7 @@ protected static void wipeAllIndices(boolean preserveSecurityIndices) throws IOE } final Request deleteRequest = new Request("DELETE", Strings.collectionToCommaDelimitedString(indexPatterns)); deleteRequest.addParameter("expand_wildcards", "open,closed" + (includeHidden ? ",hidden" : "")); + deleteRequest.setOptions(deleteRequest.getOptions().toBuilder().setWarningsHandler(ignoreAsyncSearchWarning()).build()); final Response response = adminClient().performRequest(deleteRequest); try (InputStream is = response.getEntity().getContent()) { assertTrue((boolean) XContentHelper.convertToMap(XContentType.JSON.xContent(), is, true).get("acknowledged")); @@ -1182,6 +1183,30 @@ protected static void wipeAllIndices(boolean preserveSecurityIndices) throws IOE } } + // Make warnings handler that ignores the .async-search warning since .async-search may randomly appear when async requests are slow + // See: https://github.com/elastic/elasticsearch/issues/117099 + protected static WarningsHandler ignoreAsyncSearchWarning() { + return new WarningsHandler() { + @Override + public boolean warningsShouldFailRequest(List warnings) { + if (warnings.isEmpty()) { + return false; + } + return warnings.equals( + List.of( + "this request accesses system indices: [.async-search], " + + "but in a future major version, direct access to system indices will be prevented by default" + ) + ) == false; + } + + @Override + public String toString() { + return "ignore .async-search warning"; + } + }; + } + protected static void wipeDataStreams() throws IOException { try { if (hasXPack()) { From bde7828eb7875139f87035a9acf78dd563a19505 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Fri, 22 Nov 2024 20:33:44 +0100 Subject: [PATCH 4/9] Catch and handle disconnect exceptions in search (#115836) (#117373) Getting a connection can throw an exception for a disconnected node. We failed to handle these in the adjusted spots, leading to a phase failure (and possible memory leaks for outstanding operations) instead of correctly recording a per-shard failure. --- docs/changelog/115836.yaml | 5 ++ .../action/search/DfsQueryPhase.java | 32 +++++++--- .../action/search/FetchSearchPhase.java | 61 +++++++++++-------- .../action/search/RankFeaturePhase.java | 55 ++++++++++------- .../SearchDfsQueryThenFetchAsyncAction.java | 14 +++-- .../SearchQueryThenFetchAsyncAction.java | 9 ++- .../SearchQueryThenFetchAsyncActionTests.java | 16 +++-- 7 files changed, 121 insertions(+), 71 deletions(-) create mode 100644 docs/changelog/115836.yaml diff --git a/docs/changelog/115836.yaml b/docs/changelog/115836.yaml new file mode 100644 index 0000000000000..f6da638f1feff --- /dev/null +++ b/docs/changelog/115836.yaml @@ -0,0 +1,5 @@ +pr: 115836 +summary: Catch and handle disconnect exceptions in search +area: Search +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index aeef5ff7ce94c..b8963df85b8e3 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -84,15 +84,20 @@ public void run() { for (final DfsSearchResult dfsResult : searchResults) { final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget(); - Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); - ShardSearchRequest shardRequest = rewriteShardSearchRequest(dfsResult.getShardSearchRequest()); + final int shardIndex = dfsResult.getShardIndex(); QuerySearchRequest querySearchRequest = new QuerySearchRequest( - context.getOriginalIndices(dfsResult.getShardIndex()), + context.getOriginalIndices(shardIndex), dfsResult.getContextId(), - shardRequest, + rewriteShardSearchRequest(dfsResult.getShardSearchRequest()), dfs ); - final int shardIndex = dfsResult.getShardIndex(); + final Transport.Connection connection; + try { + connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); + } catch (Exception e) { + shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter); + return; + } searchTransportService.sendExecuteQuery( connection, querySearchRequest, @@ -112,10 +117,7 @@ protected void innerOnResponse(QuerySearchResult response) { @Override public void onFailure(Exception exception) { try { - context.getLogger() - .debug(() -> "[" + querySearchRequest.contextId() + "] Failed to execute query phase", exception); - progressListener.notifyQueryFailure(shardIndex, shardTarget, exception); - counter.onFailure(shardIndex, shardTarget, exception); + shardFailure(exception, querySearchRequest, shardIndex, shardTarget, counter); } finally { if (context.isPartOfPointInTime(querySearchRequest.contextId()) == false) { // the query might not have been executed at all (for example because thread pool rejected @@ -134,6 +136,18 @@ public void onFailure(Exception exception) { } } + private void shardFailure( + Exception exception, + QuerySearchRequest querySearchRequest, + int shardIndex, + SearchShardTarget shardTarget, + CountedCollector counter + ) { + context.getLogger().debug(() -> "[" + querySearchRequest.contextId() + "] Failed to execute query phase", exception); + progressListener.notifyQueryFailure(shardIndex, shardTarget, exception); + counter.onFailure(shardIndex, shardTarget, exception); + } + // package private for testing ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { SearchSourceBuilder source = request.source(); diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index e3007b1c5b617..d7b847d835b83 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -21,6 +21,7 @@ import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDocShardInfo; +import org.elasticsearch.transport.Transport; import java.util.ArrayList; import java.util.HashMap; @@ -214,9 +215,41 @@ private void executeFetch( final ShardSearchContextId contextId = shardPhaseResult.queryResult() != null ? shardPhaseResult.queryResult().getContextId() : shardPhaseResult.rankFeatureResult().getContextId(); + var listener = new SearchActionListener(shardTarget, shardIndex) { + @Override + public void innerOnResponse(FetchSearchResult result) { + try { + progressListener.notifyFetchResult(shardIndex); + counter.onResult(result); + } catch (Exception e) { + context.onPhaseFailure(FetchSearchPhase.this, "", e); + } + } + + @Override + public void onFailure(Exception e) { + try { + logger.debug(() -> "[" + contextId + "] Failed to execute fetch phase", e); + progressListener.notifyFetchFailure(shardIndex, shardTarget, e); + counter.onFailure(shardIndex, shardTarget, e); + } finally { + // the search context might not be cleared on the node where the fetch was executed for example + // because the action was rejected by the thread pool. in this case we need to send a dedicated + // request to clear the search context. + releaseIrrelevantSearchContext(shardPhaseResult, context); + } + } + }; + final Transport.Connection connection; + try { + connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); + } catch (Exception e) { + listener.onFailure(e); + return; + } context.getSearchTransport() .sendExecuteFetch( - context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()), + connection, new ShardFetchSearchRequest( context.getOriginalIndices(shardPhaseResult.getShardIndex()), contextId, @@ -228,31 +261,7 @@ private void executeFetch( aggregatedDfs ), context.getTask(), - new SearchActionListener<>(shardTarget, shardIndex) { - @Override - public void innerOnResponse(FetchSearchResult result) { - try { - progressListener.notifyFetchResult(shardIndex); - counter.onResult(result); - } catch (Exception e) { - context.onPhaseFailure(FetchSearchPhase.this, "", e); - } - } - - @Override - public void onFailure(Exception e) { - try { - logger.debug(() -> "[" + contextId + "] Failed to execute fetch phase", e); - progressListener.notifyFetchFailure(shardIndex, shardTarget, e); - counter.onFailure(shardIndex, shardTarget, e); - } finally { - // the search context might not be cleared on the node where the fetch was executed for example - // because the action was rejected by the thread pool. in this case we need to send a dedicated - // request to clear the search context. - releaseIrrelevantSearchContext(shardPhaseResult, context); - } - } - } + listener ); } diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 31179423fa0dc..66faf07dd81a5 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -24,6 +24,7 @@ import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.transport.Transport; import java.util.List; @@ -136,9 +137,38 @@ private void executeRankFeatureShardPhase( final SearchShardTarget shardTarget = queryResult.queryResult().getSearchShardTarget(); final ShardSearchContextId contextId = queryResult.queryResult().getContextId(); final int shardIndex = queryResult.getShardIndex(); + var listener = new SearchActionListener(shardTarget, shardIndex) { + @Override + protected void innerOnResponse(RankFeatureResult response) { + try { + progressListener.notifyRankFeatureResult(shardIndex); + rankRequestCounter.onResult(response); + } catch (Exception e) { + context.onPhaseFailure(RankFeaturePhase.this, "", e); + } + } + + @Override + public void onFailure(Exception e) { + try { + logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e); + progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e); + rankRequestCounter.onFailure(shardIndex, shardTarget, e); + } finally { + releaseIrrelevantSearchContext(queryResult, context); + } + } + }; + final Transport.Connection connection; + try { + connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); + } catch (Exception e) { + listener.onFailure(e); + return; + } context.getSearchTransport() .sendExecuteRankFeature( - context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()), + connection, new RankFeatureShardRequest( context.getOriginalIndices(queryResult.getShardIndex()), queryResult.getContextId(), @@ -146,28 +176,7 @@ private void executeRankFeatureShardPhase( entry ), context.getTask(), - new SearchActionListener<>(shardTarget, shardIndex) { - @Override - protected void innerOnResponse(RankFeatureResult response) { - try { - progressListener.notifyRankFeatureResult(shardIndex); - rankRequestCounter.onResult(response); - } catch (Exception e) { - context.onPhaseFailure(RankFeaturePhase.this, "", e); - } - } - - @Override - public void onFailure(Exception e) { - try { - logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e); - progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e); - rankRequestCounter.onFailure(shardIndex, shardTarget, e); - } finally { - releaseIrrelevantSearchContext(queryResult, context); - } - } - } + listener ); } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 87b16da2bb78a..3a476d8799d0e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -87,12 +87,14 @@ protected void executePhaseOnShard( final SearchShardTarget shard, final SearchActionListener listener ) { - getSearchTransport().sendExecuteDfs( - getConnection(shard.getClusterAlias(), shard.getNodeId()), - buildShardSearchRequest(shardIt, listener.requestIndex), - getTask(), - listener - ); + final Transport.Connection connection; + try { + connection = getConnection(shard.getClusterAlias(), shard.getNodeId()); + } catch (Exception e) { + listener.onFailure(e); + return; + } + getSearchTransport().sendExecuteDfs(connection, buildShardSearchRequest(shardIt, listener.requestIndex), getTask(), listener); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index ecf81980f894a..460739feaf1ef 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -93,8 +93,15 @@ protected void executePhaseOnShard( final SearchShardTarget shard, final SearchActionListener listener ) { + final Transport.Connection connection; + try { + connection = getConnection(shard.getClusterAlias(), shard.getNodeId()); + } catch (Exception e) { + listener.onFailure(e); + return; + } ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex)); - getSearchTransport().sendExecuteQuery(getConnection(shard.getClusterAlias(), shard.getNodeId()), request, getTask(), listener); + getSearchTransport().sendExecuteQuery(connection, request, getTask(), listener); } @Override diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 1a2be26f547f9..6db0f61287722 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -733,17 +734,20 @@ public void run() { assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); SearchShardTarget searchShardTarget = new SearchShardTarget("node3", shardIt.shardId(), null); + final PlainActionFuture f = new PlainActionFuture<>(); SearchActionListener listener = new SearchActionListener(searchShardTarget, 0) { @Override - public void onFailure(Exception e) {} + public void onFailure(Exception e) { + f.onFailure(e); + } @Override - protected void innerOnResponse(SearchPhaseResult response) {} + protected void innerOnResponse(SearchPhaseResult response) { + fail("should not be called"); + } }; - Exception e = expectThrows( - VersionMismatchException.class, - () -> action.executePhaseOnShard(shardIt, searchShardTarget, listener) - ); + action.executePhaseOnShard(shardIt, searchShardTarget, listener); + Exception e = expectThrows(VersionMismatchException.class, f::actionGet); assertThat(e.getMessage(), equalTo("One of the shards is incompatible with the required minimum version [" + minVersion + "]")); } } From 68d9db633bf860bfbc21970e0f8c250fc4e4ed9b Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Fri, 22 Nov 2024 11:44:27 -0800 Subject: [PATCH 5/9] Catch up entitlements backports (#117363) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Entitlements] Consider only system modules in the boot layer (#117017) * [Entitlements] Implement entry point definitions via checker function signature (#116754) * Policy manager for entitlements (#116695) * Add java version variants of entitlements checker (#116878) As each version of Java is released, there may be additional methods we want to instrument for entitlements. Since new methods won't exist in the base version of Java that Elasticsearch is compiled with, we need to hava different classes and compilation for each version. This commit adds a scaffolding for adding the classes for new versions of Java. Unfortunately it requires several classes in different locations. But hopefully these are infrequent enough that the boilerplate is ok. We could consider adding a helper Gradle task to templatize the new classes in the future if it is too cumbersome. Note that the example for Java23 does not have anything meaningful in it yet, it's only meant as an example until we find go through classes and methods that were added after Java 21. * Spotless --------- Co-authored-by: Lorenzo Dematté Co-authored-by: Jack Conradson Co-authored-by: Patrick Doyle --- .../gradle/internal/MrjarPlugin.java | 27 ++ .../impl/InstrumentationServiceImpl.java | 106 ++++++- .../impl/InstrumenterImpl.java | 50 ++-- .../impl/InstrumentationServiceImplTests.java | 262 ++++++++++++++++++ .../impl/InstrumenterTests.java | 215 ++++++++++++-- libs/entitlement/bridge/build.gradle | 17 +- .../bridge/EntitlementChecker.java | 2 +- .../bridge/EntitlementCheckerHandle.java | 25 +- .../entitlement/bridge/HandleLoader.java | 40 +++ .../bridge/Java23EntitlementChecker.java | 12 + .../Java23EntitlementCheckerHandle.java | 27 ++ libs/entitlement/build.gradle | 12 +- .../bootstrap/EntitlementBootstrap.java | 25 +- .../EntitlementInitialization.java | 137 ++++++++- .../instrumentation/CheckerMethod.java | 23 ++ .../InstrumentationService.java | 5 +- .../instrumentation/MethodKey.java | 7 +- .../api/ElasticsearchEntitlementChecker.java | 55 +--- .../runtime/policy/FlagEntitlementType.java | 14 + .../runtime/policy/PolicyManager.java | 116 ++++++++ ...Java23ElasticsearchEntitlementChecker.java | 26 ++ .../bootstrap/Elasticsearch.java | 21 +- .../elasticsearch/plugins/PluginsUtils.java | 4 +- 23 files changed, 1074 insertions(+), 154 deletions(-) create mode 100644 libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java create mode 100644 libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/HandleLoader.java create mode 100644 libs/entitlement/bridge/src/main23/java/org/elasticsearch/entitlement/bridge/Java23EntitlementChecker.java create mode 100644 libs/entitlement/bridge/src/main23/java/org/elasticsearch/entitlement/bridge/Java23EntitlementCheckerHandle.java create mode 100644 libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/CheckerMethod.java create mode 100644 libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FlagEntitlementType.java create mode 100644 libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java create mode 100644 libs/entitlement/src/main23/java/org/elasticsearch/entitlement/runtime/api/Java23ElasticsearchEntitlementChecker.java diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/MrjarPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/MrjarPlugin.java index 51d816259ccac..abf4a321cb038 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/MrjarPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/MrjarPlugin.java @@ -20,9 +20,12 @@ import org.gradle.api.plugins.JavaPluginExtension; import org.gradle.api.tasks.SourceSet; import org.gradle.api.tasks.SourceSetContainer; +import org.gradle.api.tasks.TaskProvider; import org.gradle.api.tasks.compile.CompileOptions; import org.gradle.api.tasks.compile.JavaCompile; +import org.gradle.api.tasks.javadoc.Javadoc; import org.gradle.api.tasks.testing.Test; +import org.gradle.external.javadoc.CoreJavadocOptions; import org.gradle.jvm.tasks.Jar; import org.gradle.jvm.toolchain.JavaLanguageVersion; import org.gradle.jvm.toolchain.JavaToolchainService; @@ -79,6 +82,7 @@ public void apply(Project project) { String mainSourceSetName = SourceSet.MAIN_SOURCE_SET_NAME + javaVersion; SourceSet mainSourceSet = addSourceSet(project, javaExtension, mainSourceSetName, mainSourceSets, javaVersion); configureSourceSetInJar(project, mainSourceSet, javaVersion); + addJar(project, mainSourceSet, javaVersion); mainSourceSets.add(mainSourceSetName); testSourceSets.add(mainSourceSetName); @@ -142,6 +146,29 @@ private SourceSet addSourceSet( return sourceSet; } + private void addJar(Project project, SourceSet sourceSet, int javaVersion) { + project.getConfigurations().register("java" + javaVersion); + TaskProvider jarTask = project.getTasks().register("java" + javaVersion + "Jar", Jar.class, task -> { + task.from(sourceSet.getOutput()); + }); + project.getArtifacts().add("java" + javaVersion, jarTask); + } + + private void configurePreviewFeatures(Project project, SourceSet sourceSet, int javaVersion) { + project.getTasks().withType(JavaCompile.class).named(sourceSet.getCompileJavaTaskName()).configure(compileTask -> { + CompileOptions compileOptions = compileTask.getOptions(); + compileOptions.getCompilerArgs().add("--enable-preview"); + compileOptions.getCompilerArgs().add("-Xlint:-preview"); + + compileTask.doLast(t -> { stripPreviewFromFiles(compileTask.getDestinationDirectory().getAsFile().get().toPath()); }); + }); + project.getTasks().withType(Javadoc.class).named(name -> name.equals(sourceSet.getJavadocTaskName())).configureEach(javadocTask -> { + CoreJavadocOptions options = (CoreJavadocOptions) javadocTask.getOptions(); + options.addBooleanOption("-enable-preview", true); + options.addStringOption("-release", String.valueOf(javaVersion)); + }); + } + private void configureSourceSetInJar(Project project, SourceSet sourceSet, int javaVersion) { var jarTask = project.getTasks().withType(Jar.class).named(JavaPlugin.JAR_TASK_NAME); jarTask.configure(task -> task.into("META-INF/versions/" + javaVersion, copySpec -> copySpec.from(sourceSet.getOutput()))); diff --git a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java index f5fe8d41c2243..a3bbb611f3e68 100644 --- a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java +++ b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java @@ -9,19 +9,29 @@ package org.elasticsearch.entitlement.instrumentation.impl; +import org.elasticsearch.entitlement.instrumentation.CheckerMethod; import org.elasticsearch.entitlement.instrumentation.InstrumentationService; import org.elasticsearch.entitlement.instrumentation.Instrumenter; import org.elasticsearch.entitlement.instrumentation.MethodKey; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; import org.objectweb.asm.Type; +import java.io.IOException; import java.lang.reflect.Method; -import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.stream.Stream; public class InstrumentationServiceImpl implements InstrumentationService { + @Override - public Instrumenter newInstrumenter(String classNameSuffix, Map instrumentationMethods) { + public Instrumenter newInstrumenter(String classNameSuffix, Map instrumentationMethods) { return new InstrumenterImpl(classNameSuffix, instrumentationMethods); } @@ -33,9 +43,97 @@ public MethodKey methodKeyForTarget(Method targetMethod) { return new MethodKey( Type.getInternalName(targetMethod.getDeclaringClass()), targetMethod.getName(), - Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList(), - Modifier.isStatic(targetMethod.getModifiers()) + Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList() ); } + @Override + public Map lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, + IOException { + var methodsToInstrument = new HashMap(); + var checkerClass = Class.forName(entitlementCheckerClassName); + var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass); + ClassReader reader = new ClassReader(classFileInfo.bytecodes()); + ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) { + @Override + public MethodVisitor visitMethod( + int access, + String checkerMethodName, + String checkerMethodDescriptor, + String signature, + String[] exceptions + ) { + var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions); + + var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor); + var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes); + + var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList(); + var checkerMethod = new CheckerMethod(Type.getInternalName(checkerClass), checkerMethodName, checkerParameterDescriptors); + + methodsToInstrument.put(methodToInstrument, checkerMethod); + + return mv; + } + }; + reader.accept(visitor, 0); + return methodsToInstrument; + } + + private static final Type CLASS_TYPE = Type.getType(Class.class); + + static MethodKey parseCheckerMethodSignature(String checkerMethodName, Type[] checkerMethodArgumentTypes) { + var classNameStartIndex = checkerMethodName.indexOf('$'); + var classNameEndIndex = checkerMethodName.lastIndexOf('$'); + + if (classNameStartIndex == -1 || classNameStartIndex >= classNameEndIndex) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Checker method %s has incorrect name format. " + + "It should be either check$$methodName (instance) or check$package_ClassName$methodName (static)", + checkerMethodName + ) + ); + } + + // No "className" (check$$methodName) -> method is static, and we'll get the class from the actual typed argument + final boolean targetMethodIsStatic = classNameStartIndex + 1 != classNameEndIndex; + final String targetMethodName = checkerMethodName.substring(classNameEndIndex + 1); + + final String targetClassName; + final List targetParameterTypes; + if (targetMethodIsStatic) { + if (checkerMethodArgumentTypes.length < 1 || CLASS_TYPE.equals(checkerMethodArgumentTypes[0]) == false) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Checker method %s has incorrect argument types. " + "It must have a first argument of Class type.", + checkerMethodName + ) + ); + } + + targetClassName = checkerMethodName.substring(classNameStartIndex + 1, classNameEndIndex).replace('_', '/'); + targetParameterTypes = Arrays.stream(checkerMethodArgumentTypes).skip(1).map(Type::getInternalName).toList(); + } else { + if (checkerMethodArgumentTypes.length < 2 + || CLASS_TYPE.equals(checkerMethodArgumentTypes[0]) == false + || checkerMethodArgumentTypes[1].getSort() != Type.OBJECT) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Checker method %s has incorrect argument types. " + + "It must have a first argument of Class type, and a second argument of the class containing the method to " + + "instrument", + checkerMethodName + ) + ); + } + var targetClassType = checkerMethodArgumentTypes[1]; + targetClassName = targetClassType.getInternalName(); + targetParameterTypes = Arrays.stream(checkerMethodArgumentTypes).skip(2).map(Type::getInternalName).toList(); + } + return new MethodKey(targetClassName, targetMethodName, targetParameterTypes); + } } diff --git a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java index 63c9ccd80be70..dc20b16400f3d 100644 --- a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java +++ b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java @@ -9,6 +9,7 @@ package org.elasticsearch.entitlement.instrumentation.impl; +import org.elasticsearch.entitlement.instrumentation.CheckerMethod; import org.elasticsearch.entitlement.instrumentation.Instrumenter; import org.elasticsearch.entitlement.instrumentation.MethodKey; import org.objectweb.asm.AnnotationVisitor; @@ -23,7 +24,6 @@ import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.Method; import java.util.Map; import java.util.stream.Stream; @@ -36,13 +36,29 @@ import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL; public class InstrumenterImpl implements Instrumenter { + + private static final String checkerClassDescriptor; + private static final String handleClass; + static { + int javaVersion = Runtime.version().feature(); + final String classNamePrefix; + if (javaVersion >= 23) { + classNamePrefix = "Java23"; + } else { + classNamePrefix = ""; + } + String checkerClass = "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker"; + handleClass = checkerClass + "Handle"; + checkerClassDescriptor = Type.getObjectType(checkerClass).getDescriptor(); + } + /** * To avoid class name collisions during testing without an agent to replace classes in-place. */ private final String classNameSuffix; - private final Map instrumentationMethods; + private final Map instrumentationMethods; - public InstrumenterImpl(String classNameSuffix, Map instrumentationMethods) { + public InstrumenterImpl(String classNameSuffix, Map instrumentationMethods) { this.classNameSuffix = classNameSuffix; this.instrumentationMethods = instrumentationMethods; } @@ -138,12 +154,7 @@ public MethodVisitor visitMethod(int access, String name, String descriptor, Str var mv = super.visitMethod(access, name, descriptor, signature, exceptions); if (isAnnotationPresent == false) { boolean isStatic = (access & ACC_STATIC) != 0; - var key = new MethodKey( - className, - name, - Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList(), - isStatic - ); + var key = new MethodKey(className, name, Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList()); var instrumentationMethod = instrumentationMethods.get(key); if (instrumentationMethod != null) { // LOGGER.debug("Will instrument method {}", key); @@ -177,7 +188,7 @@ private void addClassAnnotationIfNeeded() { class EntitlementMethodVisitor extends MethodVisitor { private final boolean instrumentedMethodIsStatic; private final String instrumentedMethodDescriptor; - private final Method instrumentationMethod; + private final CheckerMethod instrumentationMethod; private boolean hasCallerSensitiveAnnotation = false; EntitlementMethodVisitor( @@ -185,7 +196,7 @@ class EntitlementMethodVisitor extends MethodVisitor { MethodVisitor methodVisitor, boolean instrumentedMethodIsStatic, String instrumentedMethodDescriptor, - Method instrumentationMethod + CheckerMethod instrumentationMethod ) { super(api, methodVisitor); this.instrumentedMethodIsStatic = instrumentedMethodIsStatic; @@ -262,22 +273,19 @@ private void forwardIncomingArguments() { private void invokeInstrumentationMethod() { mv.visitMethodInsn( INVOKEINTERFACE, - Type.getInternalName(instrumentationMethod.getDeclaringClass()), - instrumentationMethod.getName(), - Type.getMethodDescriptor(instrumentationMethod), + instrumentationMethod.className(), + instrumentationMethod.methodName(), + Type.getMethodDescriptor( + Type.VOID_TYPE, + instrumentationMethod.parameterDescriptors().stream().map(Type::getType).toArray(Type[]::new) + ), true ); } } protected void pushEntitlementChecker(MethodVisitor mv) { - mv.visitMethodInsn( - INVOKESTATIC, - "org/elasticsearch/entitlement/bridge/EntitlementCheckerHandle", - "instance", - "()Lorg/elasticsearch/entitlement/bridge/EntitlementChecker;", - false - ); + mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", "()" + checkerClassDescriptor, false); } public record ClassFileInfo(String fileName, byte[] bytecodes) {} diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java new file mode 100644 index 0000000000000..c0ff5d59d3c72 --- /dev/null +++ b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java @@ -0,0 +1,262 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.instrumentation.impl; + +import org.elasticsearch.entitlement.instrumentation.CheckerMethod; +import org.elasticsearch.entitlement.instrumentation.InstrumentationService; +import org.elasticsearch.entitlement.instrumentation.MethodKey; +import org.elasticsearch.test.ESTestCase; +import org.objectweb.asm.Type; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; + +@ESTestCase.WithoutSecurityManager +public class InstrumentationServiceImplTests extends ESTestCase { + + final InstrumentationService instrumentationService = new InstrumentationServiceImpl(); + + static class TestTargetClass {} + + interface TestChecker { + void check$org_example_TestTargetClass$staticMethod(Class clazz, int arg0, String arg1, Object arg2); + + void check$$instanceMethodNoArgs(Class clazz, TestTargetClass that); + + void check$$instanceMethodWithArgs(Class clazz, TestTargetClass that, int x, int y); + } + + interface TestCheckerOverloads { + void check$org_example_TestTargetClass$staticMethodWithOverload(Class clazz, int x, int y); + + void check$org_example_TestTargetClass$staticMethodWithOverload(Class clazz, int x, String y); + } + + public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException { + Map methodsMap = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName()); + + assertThat(methodsMap, aMapWithSize(3)); + assertThat( + methodsMap, + hasEntry( + equalTo(new MethodKey("org/example/TestTargetClass", "staticMethod", List.of("I", "java/lang/String", "java/lang/Object"))), + equalTo( + new CheckerMethod( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker", + "check$org_example_TestTargetClass$staticMethod", + List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;", "Ljava/lang/Object;") + ) + ) + ) + ); + assertThat( + methodsMap, + hasEntry( + equalTo( + new MethodKey( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass", + "instanceMethodNoArgs", + List.of() + ) + ), + equalTo( + new CheckerMethod( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker", + "check$$instanceMethodNoArgs", + List.of( + "Ljava/lang/Class;", + "Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass;" + ) + ) + ) + ) + ); + assertThat( + methodsMap, + hasEntry( + equalTo( + new MethodKey( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass", + "instanceMethodWithArgs", + List.of("I", "I") + ) + ), + equalTo( + new CheckerMethod( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker", + "check$$instanceMethodWithArgs", + List.of( + "Ljava/lang/Class;", + "Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass;", + "I", + "I" + ) + ) + ) + ) + ); + } + + public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException { + Map methodsMap = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName()); + + assertThat(methodsMap, aMapWithSize(2)); + assertThat( + methodsMap, + hasEntry( + equalTo(new MethodKey("org/example/TestTargetClass", "staticMethodWithOverload", List.of("I", "java/lang/String"))), + equalTo( + new CheckerMethod( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerOverloads", + "check$org_example_TestTargetClass$staticMethodWithOverload", + List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;") + ) + ) + ) + ); + assertThat( + methodsMap, + hasEntry( + equalTo(new MethodKey("org/example/TestTargetClass", "staticMethodWithOverload", List.of("I", "I"))), + equalTo( + new CheckerMethod( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerOverloads", + "check$org_example_TestTargetClass$staticMethodWithOverload", + List.of("Ljava/lang/Class;", "I", "I") + ) + ) + ) + ); + } + + public void testParseCheckerMethodSignatureStaticMethod() { + var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature( + "check$org_example_TestClass$staticMethod", + new Type[] { Type.getType(Class.class) } + ); + + assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass", "staticMethod", List.of()))); + } + + public void testParseCheckerMethodSignatureStaticMethodWithArgs() { + var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature( + "check$org_example_TestClass$staticMethod", + new Type[] { Type.getType(Class.class), Type.getType("I"), Type.getType(String.class) } + ); + + assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass", "staticMethod", List.of("I", "java/lang/String")))); + } + + public void testParseCheckerMethodSignatureStaticMethodInnerClass() { + var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature( + "check$org_example_TestClass$InnerClass$staticMethod", + new Type[] { Type.getType(Class.class) } + ); + + assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass$InnerClass", "staticMethod", List.of()))); + } + + public void testParseCheckerMethodSignatureIncorrectName() { + var exception = assertThrows( + IllegalArgumentException.class, + () -> InstrumentationServiceImpl.parseCheckerMethodSignature("check$staticMethod", new Type[] { Type.getType(Class.class) }) + ); + + assertThat(exception.getMessage(), containsString("has incorrect name format")); + } + + public void testParseCheckerMethodSignatureStaticMethodIncorrectArgumentCount() { + var exception = assertThrows( + IllegalArgumentException.class, + () -> InstrumentationServiceImpl.parseCheckerMethodSignature("check$ClassName$staticMethod", new Type[] {}) + ); + assertThat(exception.getMessage(), containsString("It must have a first argument of Class type")); + } + + public void testParseCheckerMethodSignatureStaticMethodIncorrectArgumentType() { + var exception = assertThrows( + IllegalArgumentException.class, + () -> InstrumentationServiceImpl.parseCheckerMethodSignature( + "check$ClassName$staticMethod", + new Type[] { Type.getType(String.class) } + ) + ); + assertThat(exception.getMessage(), containsString("It must have a first argument of Class type")); + } + + public void testParseCheckerMethodSignatureInstanceMethod() { + var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature( + "check$$instanceMethod", + new Type[] { Type.getType(Class.class), Type.getType(TestTargetClass.class) } + ); + + assertThat( + methodKey, + equalTo( + new MethodKey( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass", + "instanceMethod", + List.of() + ) + ) + ); + } + + public void testParseCheckerMethodSignatureInstanceMethodWithArgs() { + var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature( + "check$$instanceMethod", + new Type[] { Type.getType(Class.class), Type.getType(TestTargetClass.class), Type.getType("I"), Type.getType(String.class) } + ); + + assertThat( + methodKey, + equalTo( + new MethodKey( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass", + "instanceMethod", + List.of("I", "java/lang/String") + ) + ) + ); + } + + public void testParseCheckerMethodSignatureInstanceMethodIncorrectArgumentTypes() { + var exception = assertThrows( + IllegalArgumentException.class, + () -> InstrumentationServiceImpl.parseCheckerMethodSignature("check$$instanceMethod", new Type[] { Type.getType(String.class) }) + ); + assertThat(exception.getMessage(), containsString("It must have a first argument of Class type")); + } + + public void testParseCheckerMethodSignatureInstanceMethodIncorrectArgumentCount() { + var exception = assertThrows( + IllegalArgumentException.class, + () -> InstrumentationServiceImpl.parseCheckerMethodSignature("check$$instanceMethod", new Type[] { Type.getType(Class.class) }) + ); + assertThat(exception.getMessage(), containsString("a second argument of the class containing the method to instrument")); + } + + public void testParseCheckerMethodSignatureInstanceMethodIncorrectArgumentTypes2() { + var exception = assertThrows( + IllegalArgumentException.class, + () -> InstrumentationServiceImpl.parseCheckerMethodSignature( + "check$$instanceMethod", + new Type[] { Type.getType(Class.class), Type.getType("I") } + ) + ); + assertThat(exception.getMessage(), containsString("a second argument of the class containing the method to instrument")); + } +} diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java index 9a57e199d4907..e3f5539999be5 100644 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java +++ b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java @@ -11,7 +11,9 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.entitlement.bridge.EntitlementChecker; +import org.elasticsearch.entitlement.instrumentation.CheckerMethod; import org.elasticsearch.entitlement.instrumentation.InstrumentationService; +import org.elasticsearch.entitlement.instrumentation.MethodKey; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.test.ESTestCase; @@ -22,11 +24,12 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.Arrays; -import java.util.stream.Collectors; +import java.util.Map; import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text; import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.startsWith; import static org.objectweb.asm.Opcodes.INVOKESTATIC; /** @@ -53,7 +56,12 @@ public void initialize() { * Contains all the virtual methods from {@link ClassToInstrument}, * allowing this test to call them on the dynamically loaded instrumented class. */ - public interface Testable {} + public interface Testable { + // This method is here to demonstrate Instrumenter does not get confused by overloads + void someMethod(int arg); + + void someMethod(int arg, String anotherArg); + } /** * This is a placeholder for real class library methods. @@ -71,10 +79,26 @@ public static void systemExit(int status) { public static void anotherSystemExit(int status) { assertEquals(123, status); } + + public void someMethod(int arg) {} + + public void someMethod(int arg, String anotherArg) {} + + public static void someStaticMethod(int arg) {} + + public static void someStaticMethod(int arg, String anotherArg) {} } static final class TestException extends RuntimeException {} + public interface MockEntitlementChecker extends EntitlementChecker { + void checkSomeStaticMethod(Class clazz, int arg); + + void checkSomeStaticMethod(Class clazz, int arg, String anotherArg); + + void checkSomeInstanceMethod(Class clazz, Testable that, int arg, String anotherArg); + } + /** * We're not testing the permission checking logic here; * only that the instrumented methods are calling the correct check methods with the correct arguments. @@ -82,7 +106,7 @@ static final class TestException extends RuntimeException {} * just to demonstrate that the injected bytecodes succeed in calling these methods. * It also asserts that the arguments are correct. */ - public static class TestEntitlementChecker implements EntitlementChecker { + public static class TestEntitlementChecker implements MockEntitlementChecker { /** * This allows us to test that the instrumentation is correct in both cases: * if the check throws, and if it doesn't. @@ -90,9 +114,12 @@ public static class TestEntitlementChecker implements EntitlementChecker { volatile boolean isActive; int checkSystemExitCallCount = 0; + int checkSomeStaticMethodIntCallCount = 0; + int checkSomeStaticMethodIntStringCallCount = 0; + int checkSomeInstanceMethodCallCount = 0; @Override - public void checkSystemExit(Class callerClass, int status) { + public void check$java_lang_System$exit(Class callerClass, int status) { checkSystemExitCallCount++; assertSame(InstrumenterTests.class, callerClass); assertEquals(123, status); @@ -104,11 +131,48 @@ private void throwIfActive() { throw new TestException(); } } + + @Override + public void checkSomeStaticMethod(Class callerClass, int arg) { + checkSomeStaticMethodIntCallCount++; + assertSame(InstrumenterTests.class, callerClass); + assertEquals(123, arg); + throwIfActive(); + } + + @Override + public void checkSomeStaticMethod(Class callerClass, int arg, String anotherArg) { + checkSomeStaticMethodIntStringCallCount++; + assertSame(InstrumenterTests.class, callerClass); + assertEquals(123, arg); + assertEquals("abc", anotherArg); + throwIfActive(); + } + + @Override + public void checkSomeInstanceMethod(Class callerClass, Testable that, int arg, String anotherArg) { + checkSomeInstanceMethodCallCount++; + assertSame(InstrumenterTests.class, callerClass); + assertThat( + that.getClass().getName(), + startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$ClassToInstrument") + ); + assertEquals(123, arg); + assertEquals("def", anotherArg); + throwIfActive(); + } } public void testClassIsInstrumented() throws Exception { var classToInstrument = ClassToInstrument.class; - var instrumenter = createInstrumenter(classToInstrument, "systemExit"); + + CheckerMethod checkerMethod = getCheckerMethod(EntitlementChecker.class, "check$java_lang_System$exit", Class.class, int.class); + Map methods = Map.of( + instrumentationService.methodKeyForTarget(classToInstrument.getMethod("systemExit", int.class)), + checkerMethod + ); + + var instrumenter = createInstrumenter(methods); byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); @@ -117,7 +181,7 @@ public void testClassIsInstrumented() throws Exception { } Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - ClassToInstrument.class.getName() + "_NEW", + classToInstrument.getName() + "_NEW", newBytecode ); @@ -134,7 +198,14 @@ public void testClassIsInstrumented() throws Exception { public void testClassIsNotInstrumentedTwice() throws Exception { var classToInstrument = ClassToInstrument.class; - var instrumenter = createInstrumenter(classToInstrument, "systemExit"); + + CheckerMethod checkerMethod = getCheckerMethod(EntitlementChecker.class, "check$java_lang_System$exit", Class.class, int.class); + Map methods = Map.of( + instrumentationService.methodKeyForTarget(classToInstrument.getMethod("systemExit", int.class)), + checkerMethod + ); + + var instrumenter = createInstrumenter(methods); InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument); var internalClassName = Type.getInternalName(classToInstrument); @@ -146,7 +217,7 @@ public void testClassIsNotInstrumentedTwice() throws Exception { logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode))); Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - ClassToInstrument.class.getName() + "_NEW_NEW", + classToInstrument.getName() + "_NEW_NEW", instrumentedTwiceBytecode ); @@ -159,7 +230,16 @@ public void testClassIsNotInstrumentedTwice() throws Exception { public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception { var classToInstrument = ClassToInstrument.class; - var instrumenter = createInstrumenter(classToInstrument, "systemExit", "anotherSystemExit"); + + CheckerMethod checkerMethod = getCheckerMethod(EntitlementChecker.class, "check$java_lang_System$exit", Class.class, int.class); + Map methods = Map.of( + instrumentationService.methodKeyForTarget(classToInstrument.getMethod("systemExit", int.class)), + checkerMethod, + instrumentationService.methodKeyForTarget(classToInstrument.getMethod("anotherSystemExit", int.class)), + checkerMethod + ); + + var instrumenter = createInstrumenter(methods); InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument); var internalClassName = Type.getInternalName(classToInstrument); @@ -171,7 +251,7 @@ public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception { logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode))); Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - ClassToInstrument.class.getName() + "_NEW_NEW", + classToInstrument.getName() + "_NEW_NEW", instrumentedTwiceBytecode ); @@ -185,22 +265,78 @@ public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception { assertThat(getTestEntitlementChecker().checkSystemExitCallCount, is(2)); } - /** This test doesn't replace ClassToInstrument in-place but instead loads a separate - * class ClassToInstrument_NEW that contains the instrumentation. Because of this, - * we need to configure the Transformer to use a MethodKey and instrumentationMethod - * with slightly different signatures (using the common interface Testable) which - * is not what would happen when it's run by the agent. - */ - private InstrumenterImpl createInstrumenter(Class classToInstrument, String... methodNames) throws NoSuchMethodException { - Method v1 = EntitlementChecker.class.getMethod("checkSystemExit", Class.class, int.class); - var methods = Arrays.stream(methodNames).map(name -> { - try { - return instrumentationService.methodKeyForTarget(classToInstrument.getMethod(name, int.class)); - } catch (NoSuchMethodException e) { - throw new RuntimeException(e); - } - }).collect(Collectors.toUnmodifiableMap(name -> name, name -> v1)); + public void testInstrumenterWorksWithOverloads() throws Exception { + var classToInstrument = ClassToInstrument.class; + + Map methods = Map.of( + instrumentationService.methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), + getCheckerMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class), + instrumentationService.methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)), + getCheckerMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class) + ); + + var instrumenter = createInstrumenter(methods); + + byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); + if (logger.isTraceEnabled()) { + logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); + } + + Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( + classToInstrument.getName() + "_NEW", + newBytecode + ); + + getTestEntitlementChecker().isActive = true; + + // After checking is activated, everything should throw + assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); + assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc")); + + assertThat(getTestEntitlementChecker().checkSomeStaticMethodIntCallCount, is(1)); + assertThat(getTestEntitlementChecker().checkSomeStaticMethodIntStringCallCount, is(1)); + } + + public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception { + var classToInstrument = ClassToInstrument.class; + + Map methods = Map.of( + instrumentationService.methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)), + getCheckerMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class) + ); + + var instrumenter = createInstrumenter(methods); + + byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); + + if (logger.isTraceEnabled()) { + logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); + } + + Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( + classToInstrument.getName() + "_NEW", + newBytecode + ); + + getTestEntitlementChecker().isActive = true; + + Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance()); + + // This overload is not instrumented, so it will not throw + testTargetClass.someMethod(123); + assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def")); + + assertThat(getTestEntitlementChecker().checkSomeInstanceMethodCallCount, is(1)); + } + + /** This test doesn't replace classToInstrument in-place but instead loads a separate + * class with the same class name plus a "_NEW" suffix (classToInstrument.class.getName() + "_NEW") + * that contains the instrumentation. Because of this, we need to configure the Transformer to use a + * MethodKey and instrumentationMethod with slightly different signatures (using the common interface + * Testable) which is not what would happen when it's run by the agent. + */ + private InstrumenterImpl createInstrumenter(Map methods) throws NoSuchMethodException { Method getter = InstrumenterTests.class.getMethod("getTestEntitlementChecker"); return new InstrumenterImpl("_NEW", methods) { /** @@ -220,13 +356,38 @@ protected void pushEntitlementChecker(MethodVisitor mv) { }; } + private static CheckerMethod getCheckerMethod(Class clazz, String methodName, Class... parameterTypes) + throws NoSuchMethodException { + var method = clazz.getMethod(methodName, parameterTypes); + return new CheckerMethod( + Type.getInternalName(clazz), + method.getName(), + Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList() + ); + } + /** * Calling a static method of a dynamically loaded class is significantly more cumbersome * than calling a virtual method. */ - private static void callStaticMethod(Class c, String methodName, int status) throws NoSuchMethodException, IllegalAccessException { + private static void callStaticMethod(Class c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException { + try { + c.getMethod(methodName, int.class).invoke(null, arg); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof TestException n) { + // Sometimes we're expecting this one! + throw n; + } else { + throw new AssertionError(cause); + } + } + } + + private static void callStaticMethod(Class c, String methodName, int arg1, String arg2) throws NoSuchMethodException, + IllegalAccessException { try { - c.getMethod(methodName, int.class).invoke(null, status); + c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2); } catch (InvocationTargetException e) { Throwable cause = e.getCause(); if (cause instanceof TestException n) { diff --git a/libs/entitlement/bridge/build.gradle b/libs/entitlement/bridge/build.gradle index 3d59dd3eaf33e..a9f8f6e3a3b0a 100644 --- a/libs/entitlement/bridge/build.gradle +++ b/libs/entitlement/bridge/build.gradle @@ -7,19 +7,18 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ +import org.elasticsearch.gradle.internal.precommit.CheckForbiddenApisTask + apply plugin: 'elasticsearch.build' +apply plugin: 'elasticsearch.mrjar' -configurations { - bridgeJar { - canBeConsumed = true - canBeResolved = false +tasks.named('jar').configure { + // guarding for intellij + if (sourceSets.findByName("main23")) { + from sourceSets.main23.output } } -artifacts { - bridgeJar(jar) -} - -tasks.named('forbiddenApisMain').configure { +tasks.withType(CheckForbiddenApisTask).configureEach { replaceSignatureFiles 'jdk-signatures' } diff --git a/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java b/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java index 5ebb7d00e26f5..167c93c90df5c 100644 --- a/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java +++ b/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java @@ -10,5 +10,5 @@ package org.elasticsearch.entitlement.bridge; public interface EntitlementChecker { - void checkSystemExit(Class callerClass, int status); + void check$java_lang_System$exit(Class callerClass, int status); } diff --git a/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementCheckerHandle.java b/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementCheckerHandle.java index 2fe4a163a4136..26c9c83b8eb51 100644 --- a/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementCheckerHandle.java +++ b/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementCheckerHandle.java @@ -9,9 +9,6 @@ package org.elasticsearch.entitlement.bridge; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; - /** * Makes the {@link EntitlementChecker} available to injected bytecode. */ @@ -35,27 +32,7 @@ private static class Holder { * The {@code EntitlementInitialization} class is what actually instantiates it and makes it available; * here, we copy it into a static final variable for maximum performance. */ - private static final EntitlementChecker instance; - static { - String initClazz = "org.elasticsearch.entitlement.initialization.EntitlementInitialization"; - final Class clazz; - try { - clazz = ClassLoader.getSystemClassLoader().loadClass(initClazz); - } catch (ClassNotFoundException e) { - throw new AssertionError("java.base cannot find entitlement initialziation", e); - } - final Method checkerMethod; - try { - checkerMethod = clazz.getMethod("checker"); - } catch (NoSuchMethodException e) { - throw new AssertionError("EntitlementInitialization is missing checker() method", e); - } - try { - instance = (EntitlementChecker) checkerMethod.invoke(null); - } catch (IllegalAccessException | InvocationTargetException e) { - throw new AssertionError(e); - } - } + private static final EntitlementChecker instance = HandleLoader.load(EntitlementChecker.class); } // no construction diff --git a/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/HandleLoader.java b/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/HandleLoader.java new file mode 100644 index 0000000000000..bbfec47884f79 --- /dev/null +++ b/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/HandleLoader.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.bridge; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +class HandleLoader { + + static T load(Class checkerClass) { + String initClassName = "org.elasticsearch.entitlement.initialization.EntitlementInitialization"; + final Class initClazz; + try { + initClazz = ClassLoader.getSystemClassLoader().loadClass(initClassName); + } catch (ClassNotFoundException e) { + throw new AssertionError("java.base cannot find entitlement initialization", e); + } + final Method checkerMethod; + try { + checkerMethod = initClazz.getMethod("checker"); + } catch (NoSuchMethodException e) { + throw new AssertionError("EntitlementInitialization is missing checker() method", e); + } + try { + return checkerClass.cast(checkerMethod.invoke(null)); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new AssertionError(e); + } + } + + // no instance + private HandleLoader() {} +} diff --git a/libs/entitlement/bridge/src/main23/java/org/elasticsearch/entitlement/bridge/Java23EntitlementChecker.java b/libs/entitlement/bridge/src/main23/java/org/elasticsearch/entitlement/bridge/Java23EntitlementChecker.java new file mode 100644 index 0000000000000..244632e80ffa0 --- /dev/null +++ b/libs/entitlement/bridge/src/main23/java/org/elasticsearch/entitlement/bridge/Java23EntitlementChecker.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.bridge; + +public interface Java23EntitlementChecker extends EntitlementChecker {} diff --git a/libs/entitlement/bridge/src/main23/java/org/elasticsearch/entitlement/bridge/Java23EntitlementCheckerHandle.java b/libs/entitlement/bridge/src/main23/java/org/elasticsearch/entitlement/bridge/Java23EntitlementCheckerHandle.java new file mode 100644 index 0000000000000..f41c5dcdf14fd --- /dev/null +++ b/libs/entitlement/bridge/src/main23/java/org/elasticsearch/entitlement/bridge/Java23EntitlementCheckerHandle.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.bridge; + +/** + * Java23 variant of {@link EntitlementChecker} handle holder. + */ +public class Java23EntitlementCheckerHandle { + + public static Java23EntitlementChecker instance() { + return Holder.instance; + } + + private static class Holder { + private static final Java23EntitlementChecker instance = HandleLoader.load(Java23EntitlementChecker.class); + } + + // no construction + private Java23EntitlementCheckerHandle() {} +} diff --git a/libs/entitlement/build.gradle b/libs/entitlement/build.gradle index 12e0bb48a54b7..841591873153c 100644 --- a/libs/entitlement/build.gradle +++ b/libs/entitlement/build.gradle @@ -6,10 +6,13 @@ * your election, the "Elastic License 2.0", the "GNU Affero General Public * License v3.0 only", or the "Server Side Public License, v 1". */ + +import org.elasticsearch.gradle.internal.precommit.CheckForbiddenApisTask + apply plugin: 'elasticsearch.build' apply plugin: 'elasticsearch.publish' - apply plugin: 'elasticsearch.embedded-providers' +apply plugin: 'elasticsearch.mrjar' embeddedProviders { impl 'entitlement', project(':libs:entitlement:asm-provider') @@ -23,8 +26,13 @@ dependencies { testImplementation(project(":test:framework")) { exclude group: 'org.elasticsearch', module: 'entitlement' } + + // guarding for intellij + if (sourceSets.findByName("main23")) { + main23CompileOnly project(path: ':libs:entitlement:bridge', configuration: 'java23') + } } -tasks.named('forbiddenApisMain').configure { +tasks.withType(CheckForbiddenApisTask).configureEach { replaceSignatureFiles 'jdk-signatures' } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/bootstrap/EntitlementBootstrap.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/bootstrap/EntitlementBootstrap.java index 7f68457baea9e..01b8f4d574f90 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/bootstrap/EntitlementBootstrap.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/bootstrap/EntitlementBootstrap.java @@ -15,6 +15,7 @@ import com.sun.tools.attach.VirtualMachine; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.core.Tuple; import org.elasticsearch.entitlement.initialization.EntitlementInitialization; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -22,15 +23,33 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.util.Collection; +import java.util.Objects; +import java.util.function.Function; public class EntitlementBootstrap { + public record BootstrapArgs(Collection> pluginData, Function, String> pluginResolver) {} + + private static BootstrapArgs bootstrapArgs; + + public static BootstrapArgs bootstrapArgs() { + return bootstrapArgs; + } + /** - * Activates entitlement checking. Once this method returns, calls to forbidden methods - * will throw {@link org.elasticsearch.entitlement.runtime.api.NotEntitledException}. + * Activates entitlement checking. Once this method returns, calls to methods protected by Entitlements from classes without a valid + * policy will throw {@link org.elasticsearch.entitlement.runtime.api.NotEntitledException}. + * @param pluginData a collection of (plugin path, boolean), that holds the paths of all the installed Elasticsearch modules and + * plugins, and whether they are Java modular or not. + * @param pluginResolver a functor to map a Java Class to the plugin it belongs to (the plugin name). */ - public static void bootstrap() { + public static void bootstrap(Collection> pluginData, Function, String> pluginResolver) { logger.debug("Loading entitlement agent"); + if (EntitlementBootstrap.bootstrapArgs != null) { + throw new IllegalStateException("plugin data is already set"); + } + EntitlementBootstrap.bootstrapArgs = new BootstrapArgs(Objects.requireNonNull(pluginData), Objects.requireNonNull(pluginResolver)); exportInitializationToAgent(); loadAgent(findAgentJar()); } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java index 155d5a27c606b..ca57e7b255bca 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java @@ -9,17 +9,37 @@ package org.elasticsearch.entitlement.initialization; +import org.elasticsearch.core.Tuple; import org.elasticsearch.core.internal.provider.ProviderLocator; +import org.elasticsearch.entitlement.bootstrap.EntitlementBootstrap; import org.elasticsearch.entitlement.bridge.EntitlementChecker; +import org.elasticsearch.entitlement.instrumentation.CheckerMethod; import org.elasticsearch.entitlement.instrumentation.InstrumentationService; import org.elasticsearch.entitlement.instrumentation.MethodKey; import org.elasticsearch.entitlement.instrumentation.Transformer; import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker; +import org.elasticsearch.entitlement.runtime.policy.Policy; +import org.elasticsearch.entitlement.runtime.policy.PolicyManager; +import org.elasticsearch.entitlement.runtime.policy.PolicyParser; +import org.elasticsearch.entitlement.runtime.policy.Scope; +import java.io.IOException; import java.lang.instrument.Instrumentation; -import java.lang.reflect.Method; +import java.lang.module.ModuleFinder; +import java.lang.module.ModuleReference; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.entitlement.runtime.policy.PolicyManager.ALL_UNNAMED; /** * Called by the agent during {@code agentmain} to configure the entitlement system, @@ -29,6 +49,9 @@ * to begin injecting our instrumentation. */ public class EntitlementInitialization { + + private static final String POLICY_FILE_NAME = "entitlement-policy.yaml"; + private static ElasticsearchEntitlementChecker manager; // Note: referenced by bridge reflectively @@ -38,16 +61,112 @@ public static EntitlementChecker checker() { // Note: referenced by agent reflectively public static void initialize(Instrumentation inst) throws Exception { - manager = new ElasticsearchEntitlementChecker(); + manager = initChecker(); + + Map methodMap = INSTRUMENTER_FACTORY.lookupMethodsToInstrument( + "org.elasticsearch.entitlement.bridge.EntitlementChecker" + ); + + var classesToTransform = methodMap.keySet().stream().map(MethodKey::className).collect(Collectors.toSet()); + + inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter("", methodMap), classesToTransform), true); + // TODO: should we limit this array somehow? + var classesToRetransform = classesToTransform.stream().map(EntitlementInitialization::internalNameToClass).toArray(Class[]::new); + inst.retransformClasses(classesToRetransform); + } + + private static Class internalNameToClass(String internalName) { + try { + return Class.forName(internalName.replace('/', '.'), false, ClassLoader.getPlatformClassLoader()); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private static PolicyManager createPolicyManager() throws IOException { + Map pluginPolicies = createPluginPolicies(EntitlementBootstrap.bootstrapArgs().pluginData()); + + // TODO: What should the name be? + // TODO(ES-10031): Decide what goes in the elasticsearch default policy and extend it + var serverPolicy = new Policy("server", List.of()); + return new PolicyManager(serverPolicy, pluginPolicies, EntitlementBootstrap.bootstrapArgs().pluginResolver()); + } + + private static Map createPluginPolicies(Collection> pluginData) throws IOException { + Map pluginPolicies = new HashMap<>(pluginData.size()); + for (Tuple entry : pluginData) { + Path pluginRoot = entry.v1(); + boolean isModular = entry.v2(); + + String pluginName = pluginRoot.getFileName().toString(); + final Policy policy = loadPluginPolicy(pluginRoot, isModular, pluginName); + + pluginPolicies.put(pluginName, policy); + } + return pluginPolicies; + } + + private static Policy loadPluginPolicy(Path pluginRoot, boolean isModular, String pluginName) throws IOException { + Path policyFile = pluginRoot.resolve(POLICY_FILE_NAME); + + final Set moduleNames = getModuleNames(pluginRoot, isModular); + final Policy policy = parsePolicyIfExists(pluginName, policyFile); + + // TODO: should this check actually be part of the parser? + for (Scope scope : policy.scopes) { + if (moduleNames.contains(scope.name) == false) { + throw new IllegalStateException("policy [" + policyFile + "] contains invalid module [" + scope.name + "]"); + } + } + return policy; + } + + private static Policy parsePolicyIfExists(String pluginName, Path policyFile) throws IOException { + if (Files.exists(policyFile)) { + return new PolicyParser(Files.newInputStream(policyFile, StandardOpenOption.READ), pluginName).parsePolicy(); + } + return new Policy(pluginName, List.of()); + } + + private static Set getModuleNames(Path pluginRoot, boolean isModular) { + if (isModular) { + ModuleFinder moduleFinder = ModuleFinder.of(pluginRoot); + Set moduleReferences = moduleFinder.findAll(); + + return moduleReferences.stream().map(mr -> mr.descriptor().name()).collect(Collectors.toUnmodifiableSet()); + } + // When isModular == false we use the same "ALL-UNNAMED" constant as the JDK to indicate (any) unnamed module for this plugin + return Set.of(ALL_UNNAMED); + } - // TODO: Configure actual entitlement grants instead of this hardcoded one - Method targetMethod = System.class.getMethod("exit", int.class); - Method instrumentationMethod = Class.forName("org.elasticsearch.entitlement.bridge.EntitlementChecker") - .getMethod("checkSystemExit", Class.class, int.class); - Map methodMap = Map.of(INSTRUMENTER_FACTORY.methodKeyForTarget(targetMethod), instrumentationMethod); + private static ElasticsearchEntitlementChecker initChecker() throws IOException { + final PolicyManager policyManager = createPolicyManager(); - inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter("", methodMap), Set.of(internalName(System.class))), true); - inst.retransformClasses(System.class); + int javaVersion = Runtime.version().feature(); + final String classNamePrefix; + if (javaVersion >= 23) { + classNamePrefix = "Java23"; + } else { + classNamePrefix = ""; + } + final String className = "org.elasticsearch.entitlement.runtime.api." + classNamePrefix + "ElasticsearchEntitlementChecker"; + Class clazz; + try { + clazz = Class.forName(className); + } catch (ClassNotFoundException e) { + throw new AssertionError("entitlement lib cannot find entitlement impl", e); + } + Constructor constructor; + try { + constructor = clazz.getConstructor(PolicyManager.class); + } catch (NoSuchMethodException e) { + throw new AssertionError("entitlement impl is missing no arg constructor", e); + } + try { + return (ElasticsearchEntitlementChecker) constructor.newInstance(policyManager); + } catch (IllegalAccessException | InvocationTargetException | InstantiationException e) { + throw new AssertionError(e); + } } private static String internalName(Class c) { diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/CheckerMethod.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/CheckerMethod.java new file mode 100644 index 0000000000000..c20a75a61a608 --- /dev/null +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/CheckerMethod.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.instrumentation; + +import java.util.List; + +/** + * A structure to use as a representation of the checker method the instrumentation will inject. + * + * @param className the "internal name" of the class: includes the package info, but with periods replaced by slashes + * @param methodName the checker method name + * @param parameterDescriptors a list of + * type descriptors) + * for methodName parameters. + */ +public record CheckerMethod(String className, String methodName, List parameterDescriptors) {} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java index 25fa84ec7c4ba..12316bfb043c5 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java @@ -9,6 +9,7 @@ package org.elasticsearch.entitlement.instrumentation; +import java.io.IOException; import java.lang.reflect.Method; import java.util.Map; @@ -16,10 +17,12 @@ * The SPI service entry point for instrumentation. */ public interface InstrumentationService { - Instrumenter newInstrumenter(String classNameSuffix, Map instrumentationMethods); + Instrumenter newInstrumenter(String classNameSuffix, Map instrumentationMethods); /** * @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline */ MethodKey methodKeyForTarget(Method targetMethod); + + Map lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException; } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/MethodKey.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/MethodKey.java index 54e09c10bcc57..256a4d709d9dc 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/MethodKey.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/MethodKey.java @@ -12,7 +12,10 @@ import java.util.List; /** + * A structure to use as a key/lookup for a method target of instrumentation * - * @param className the "internal name" of the class: includes the package info, but with periods replaced by slashes + * @param className the "internal name" of the class: includes the package info, but with periods replaced by slashes + * @param methodName the method name + * @param parameterTypes a list of "internal names" for the parameter types */ -public record MethodKey(String className, String methodName, List parameterTypes, boolean isStatic) {} +public record MethodKey(String className, String methodName, List parameterTypes) {} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java index 6d5dbd4098aa9..790416ca5659a 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java @@ -10,10 +10,8 @@ package org.elasticsearch.entitlement.runtime.api; import org.elasticsearch.entitlement.bridge.EntitlementChecker; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; - -import java.util.Optional; +import org.elasticsearch.entitlement.runtime.policy.FlagEntitlementType; +import org.elasticsearch.entitlement.runtime.policy.PolicyManager; /** * Implementation of the {@link EntitlementChecker} interface, providing additional @@ -21,51 +19,14 @@ * The trampoline module loads this object via SPI. */ public class ElasticsearchEntitlementChecker implements EntitlementChecker { - private static final Logger logger = LogManager.getLogger(ElasticsearchEntitlementChecker.class); + private final PolicyManager policyManager; - @Override - public void checkSystemExit(Class callerClass, int status) { - var requestingModule = requestingModule(callerClass); - if (isTriviallyAllowed(requestingModule)) { - return; - } - // Hard-forbidden until we develop the permission granting scheme - throw new NotEntitledException("Missing entitlement for " + requestingModule); + public ElasticsearchEntitlementChecker(PolicyManager policyManager) { + this.policyManager = policyManager; } - private static Module requestingModule(Class callerClass) { - if (callerClass != null) { - Module callerModule = callerClass.getModule(); - if (callerModule.getLayer() != ModuleLayer.boot()) { - // fast path - return callerModule; - } - } - int framesToSkip = 1 // getCallingClass (this method) - + 1 // the checkXxx method - + 1 // the runtime config method - + 1 // the instrumented method - ; - Optional module = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE) - .walk( - s -> s.skip(framesToSkip) - .map(f -> f.getDeclaringClass().getModule()) - .filter(m -> m.getLayer() != ModuleLayer.boot()) - .findFirst() - ); - return module.orElse(null); - } - - private static boolean isTriviallyAllowed(Module requestingModule) { - if (requestingModule == null) { - logger.debug("Trivially allowed: Entire call stack is in the boot module layer"); - return true; - } - if (requestingModule == System.class.getModule()) { - logger.debug("Trivially allowed: Caller is in {}", System.class.getModule().getName()); - return true; - } - logger.trace("Not trivially allowed"); - return false; + @Override + public void check$java_lang_System$exit(Class callerClass, int status) { + policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.SYSTEM_EXIT); } } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FlagEntitlementType.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FlagEntitlementType.java new file mode 100644 index 0000000000000..60490baf41a10 --- /dev/null +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FlagEntitlementType.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.runtime.policy; + +public enum FlagEntitlementType { + SYSTEM_EXIT; +} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java new file mode 100644 index 0000000000000..c06dc09758de5 --- /dev/null +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.runtime.policy; + +import org.elasticsearch.core.Strings; +import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker; +import org.elasticsearch.entitlement.runtime.api.NotEntitledException; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; + +import java.lang.module.ModuleFinder; +import java.lang.module.ModuleReference; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class PolicyManager { + private static final Logger logger = LogManager.getLogger(ElasticsearchEntitlementChecker.class); + + protected final Policy serverPolicy; + protected final Map pluginPolicies; + private final Function, String> pluginResolver; + + public static final String ALL_UNNAMED = "ALL-UNNAMED"; + + private static final Set systemModules = findSystemModules(); + + private static Set findSystemModules() { + var systemModulesDescriptors = ModuleFinder.ofSystem() + .findAll() + .stream() + .map(ModuleReference::descriptor) + .collect(Collectors.toUnmodifiableSet()); + + return ModuleLayer.boot() + .modules() + .stream() + .filter(m -> systemModulesDescriptors.contains(m.getDescriptor())) + .collect(Collectors.toUnmodifiableSet()); + } + + public PolicyManager(Policy defaultPolicy, Map pluginPolicies, Function, String> pluginResolver) { + this.serverPolicy = Objects.requireNonNull(defaultPolicy); + this.pluginPolicies = Collections.unmodifiableMap(Objects.requireNonNull(pluginPolicies)); + this.pluginResolver = pluginResolver; + } + + public void checkFlagEntitlement(Class callerClass, FlagEntitlementType type) { + var requestingModule = requestingModule(callerClass); + if (isTriviallyAllowed(requestingModule)) { + return; + } + + // TODO: real policy check. For now, we only allow our hardcoded System.exit policy for server. + // TODO: this will be checked using policies + if (requestingModule.isNamed() + && requestingModule.getName().equals("org.elasticsearch.server") + && type == FlagEntitlementType.SYSTEM_EXIT) { + logger.debug("Allowed: caller [{}] in module [{}] has entitlement [{}]", callerClass, requestingModule.getName(), type); + return; + } + + // TODO: plugins policy check using pluginResolver and pluginPolicies + throw new NotEntitledException( + Strings.format("Missing entitlement [%s] for caller [%s] in module [%s]", type, callerClass, requestingModule.getName()) + ); + } + + private static Module requestingModule(Class callerClass) { + if (callerClass != null) { + Module callerModule = callerClass.getModule(); + if (systemModules.contains(callerModule) == false) { + // fast path + return callerModule; + } + } + int framesToSkip = 1 // getCallingClass (this method) + + 1 // the checkXxx method + + 1 // the runtime config method + + 1 // the instrumented method + ; + Optional module = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE) + .walk( + s -> s.skip(framesToSkip) + .map(f -> f.getDeclaringClass().getModule()) + .filter(m -> systemModules.contains(m) == false) + .findFirst() + ); + return module.orElse(null); + } + + private static boolean isTriviallyAllowed(Module requestingModule) { + if (requestingModule == null) { + logger.debug("Trivially allowed: entire call stack is in composed of classes in system modules"); + return true; + } + logger.trace("Not trivially allowed"); + return false; + } + + @Override + public String toString() { + return "PolicyManager{" + "serverPolicy=" + serverPolicy + ", pluginPolicies=" + pluginPolicies + '}'; + } +} diff --git a/libs/entitlement/src/main23/java/org/elasticsearch/entitlement/runtime/api/Java23ElasticsearchEntitlementChecker.java b/libs/entitlement/src/main23/java/org/elasticsearch/entitlement/runtime/api/Java23ElasticsearchEntitlementChecker.java new file mode 100644 index 0000000000000..d0f9f4f48609c --- /dev/null +++ b/libs/entitlement/src/main23/java/org/elasticsearch/entitlement/runtime/api/Java23ElasticsearchEntitlementChecker.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.runtime.api; + +import org.elasticsearch.entitlement.bridge.Java23EntitlementChecker; +import org.elasticsearch.entitlement.runtime.policy.PolicyManager; + +public class Java23ElasticsearchEntitlementChecker extends ElasticsearchEntitlementChecker implements Java23EntitlementChecker { + + public Java23ElasticsearchEntitlementChecker(PolicyManager policyManager) { + super(policyManager); + } + + @Override + public void check$java_lang_System$exit(Class callerClass, int status) { + // TODO: this is just an example, we shouldn't really override a method implemented in the superclass + super.check$java_lang_System$exit(callerClass, status); + } +} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java index 77875e65ab9b8..95e5b00a2805f 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java @@ -30,6 +30,7 @@ import org.elasticsearch.core.AbstractRefCounted; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.core.Tuple; import org.elasticsearch.entitlement.bootstrap.EntitlementBootstrap; import org.elasticsearch.env.Environment; import org.elasticsearch.index.IndexVersion; @@ -41,7 +42,9 @@ import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; +import org.elasticsearch.plugins.PluginBundle; import org.elasticsearch.plugins.PluginsLoader; +import org.elasticsearch.plugins.PluginsUtils; import java.io.IOException; import java.io.InputStream; @@ -51,8 +54,10 @@ import java.nio.file.Path; import java.security.Permission; import java.security.Security; +import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -201,11 +206,23 @@ private static void initPhase2(Bootstrap bootstrap) throws IOException { ); // load the plugin Java modules and layers now for use in entitlements - bootstrap.setPluginsLoader(new PluginsLoader(nodeEnv.modulesFile(), nodeEnv.pluginsFile())); + var pluginsLoader = new PluginsLoader(nodeEnv.modulesFile(), nodeEnv.pluginsFile()); + bootstrap.setPluginsLoader(pluginsLoader); if (Boolean.parseBoolean(System.getProperty("es.entitlements.enabled"))) { logger.info("Bootstrapping Entitlements"); - EntitlementBootstrap.bootstrap(); + + List> pluginData = new ArrayList<>(); + Set moduleBundles = PluginsUtils.getModuleBundles(nodeEnv.modulesFile()); + for (PluginBundle moduleBundle : moduleBundles) { + pluginData.add(Tuple.tuple(moduleBundle.getDir(), moduleBundle.pluginDescriptor().isModular())); + } + Set pluginBundles = PluginsUtils.getPluginBundles(nodeEnv.pluginsFile()); + for (PluginBundle pluginBundle : pluginBundles) { + pluginData.add(Tuple.tuple(pluginBundle.getDir(), pluginBundle.pluginDescriptor().isModular())); + } + // TODO: add a functor to map module to plugin name + EntitlementBootstrap.bootstrap(pluginData, callerClass -> null); } else { // install SM after natives, shutdown hooks, etc. logger.info("Bootstrapping java SecurityManager"); diff --git a/server/src/main/java/org/elasticsearch/plugins/PluginsUtils.java b/server/src/main/java/org/elasticsearch/plugins/PluginsUtils.java index 44fb531f8610e..155cff57a0ebf 100644 --- a/server/src/main/java/org/elasticsearch/plugins/PluginsUtils.java +++ b/server/src/main/java/org/elasticsearch/plugins/PluginsUtils.java @@ -210,12 +210,12 @@ public static void checkForFailedPluginRemovals(final Path pluginsDirectory) thr } /** Get bundles for plugins installed in the given modules directory. */ - static Set getModuleBundles(Path modulesDirectory) throws IOException { + public static Set getModuleBundles(Path modulesDirectory) throws IOException { return findBundles(modulesDirectory, "module"); } /** Get bundles for plugins installed in the given plugins directory. */ - static Set getPluginBundles(final Path pluginsDirectory) throws IOException { + public static Set getPluginBundles(final Path pluginsDirectory) throws IOException { return findBundles(pluginsDirectory, "plugin"); } From 20e02fab75b22362b8050204d7ab2587757d4cea Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Fri, 22 Nov 2024 15:06:31 -0500 Subject: [PATCH 6/9] ESQL: Add docs for MV_PERCENTILE (#117377) (#117381) We built this a while back. Let's document it. --- docs/reference/esql/functions/mv-functions.asciidoc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/reference/esql/functions/mv-functions.asciidoc b/docs/reference/esql/functions/mv-functions.asciidoc index 4093e44c16911..3da0249c9c0db 100644 --- a/docs/reference/esql/functions/mv-functions.asciidoc +++ b/docs/reference/esql/functions/mv-functions.asciidoc @@ -19,6 +19,7 @@ * <> * <> * <> +* <> * <> * <> * <> @@ -37,6 +38,7 @@ include::layout/mv_max.asciidoc[] include::layout/mv_median.asciidoc[] include::layout/mv_median_absolute_deviation.asciidoc[] include::layout/mv_min.asciidoc[] +include::layout/mv_percentile.asciidoc[] include::layout/mv_pseries_weighted_sum.asciidoc[] include::layout/mv_slice.asciidoc[] include::layout/mv_sort.asciidoc[] From b9940e02ee32122ea8b24c8e60857a1e3cc53936 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Fri, 22 Nov 2024 22:12:19 +0100 Subject: [PATCH 7/9] Fix leak in DfsQueryPhase and introduce search disconnect stress test (#116060) (#117384) Fixing an obvious leak and finally adding a stress test for search disconnects. --- docs/changelog/116060.yaml | 6 + .../basic/SearchWithRandomDisconnectsIT.java | 103 ++++++++++++++++++ .../action/search/DfsQueryPhase.java | 2 +- .../discovery/AbstractDisruptionTestCase.java | 4 +- 4 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 docs/changelog/116060.yaml create mode 100644 server/src/internalClusterTest/java/org/elasticsearch/search/basic/SearchWithRandomDisconnectsIT.java diff --git a/docs/changelog/116060.yaml b/docs/changelog/116060.yaml new file mode 100644 index 0000000000000..b067677ed41d9 --- /dev/null +++ b/docs/changelog/116060.yaml @@ -0,0 +1,6 @@ +pr: 116060 +summary: Fix leak in `DfsQueryPhase` and introduce search disconnect stress test +area: Search +type: bug +issues: + - 115056 diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/basic/SearchWithRandomDisconnectsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/basic/SearchWithRandomDisconnectsIT.java new file mode 100644 index 0000000000000..d2c7e10f8aa62 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/basic/SearchWithRandomDisconnectsIT.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.search.basic; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.discovery.AbstractDisruptionTestCase; +import org.elasticsearch.index.IndexModule; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.test.disruption.NetworkDisruption; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.IntStream; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; + +public class SearchWithRandomDisconnectsIT extends AbstractDisruptionTestCase { + + public void testSearchWithRandomDisconnects() throws InterruptedException, ExecutionException { + // make sure we have a couple data nodes + int minDataNodes = randomIntBetween(3, 7); + internalCluster().ensureAtLeastNumDataNodes(minDataNodes); + final int indexCount = randomIntBetween(minDataNodes, 10 * minDataNodes); + final String[] indexNames = IntStream.range(0, indexCount).mapToObj(i -> "test-" + i).toArray(String[]::new); + final Settings indexSettings = indexSettings(1, 0).put(IndexModule.INDEX_QUERY_CACHE_ENABLED_SETTING.getKey(), false) + .put(IndexSettings.INDEX_CHECK_ON_STARTUP.getKey(), false) + .build(); + for (String indexName : indexNames) { + createIndex(indexName, indexSettings); + } + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + for (String indexName : indexNames) { + for (int i = 0; i < randomIntBetween(1, 10); i++) { + bulkRequestBuilder = bulkRequestBuilder.add(prepareIndex(indexName).setCreate(false).setSource("foo", "bar-" + i)); + } + } + assertFalse(bulkRequestBuilder.get().hasFailures()); + final AtomicBoolean done = new AtomicBoolean(); + final int concurrentSearches = randomIntBetween(2, 5); + final List> futures = new ArrayList<>(concurrentSearches); + for (int i = 0; i < concurrentSearches; i++) { + final PlainActionFuture finishFuture = new PlainActionFuture<>(); + futures.add(finishFuture); + prepareRandomSearch().execute(new ActionListener<>() { + @Override + public void onResponse(SearchResponse searchResponse) { + runMoreSearches(); + } + + @Override + public void onFailure(Exception e) { + runMoreSearches(); + } + + private void runMoreSearches() { + if (done.get() == false) { + prepareRandomSearch().execute(this); + } else { + finishFuture.onResponse(null); + } + } + }); + } + for (int i = 0, n = randomIntBetween(50, 100); i < n; i++) { + NetworkDisruption networkDisruption = new NetworkDisruption( + isolateNode(internalCluster().getRandomNodeName()), + NetworkDisruption.DISCONNECT + ); + setDisruptionScheme(networkDisruption); + networkDisruption.startDisrupting(); + networkDisruption.stopDisrupting(); + internalCluster().clearDisruptionScheme(); + ensureFullyConnectedCluster(); + } + done.set(true); + for (PlainActionFuture future : futures) { + future.get(); + } + ensureGreen(DISRUPTION_HEALING_OVERHEAD, indexNames); + assertAcked(indicesAdmin().prepareDelete(indexNames)); + } + + private static SearchRequestBuilder prepareRandomSearch() { + return prepareSearch("*").setQuery(new MatchAllQueryBuilder()) + .setSize(9999) + .setFetchSource(true) + .setAllowPartialSearchResults(randomBoolean()); + } +} diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index b8963df85b8e3..0b587e72141ff 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -96,7 +96,7 @@ public void run() { connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); } catch (Exception e) { shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter); - return; + continue; } searchTransportService.sendExecuteQuery( connection, diff --git a/server/src/test/java/org/elasticsearch/discovery/AbstractDisruptionTestCase.java b/server/src/test/java/org/elasticsearch/discovery/AbstractDisruptionTestCase.java index 62ffc069b155b..1c3f237f852e5 100644 --- a/server/src/test/java/org/elasticsearch/discovery/AbstractDisruptionTestCase.java +++ b/server/src/test/java/org/elasticsearch/discovery/AbstractDisruptionTestCase.java @@ -48,7 +48,7 @@ public abstract class AbstractDisruptionTestCase extends ESIntegTestCase { - static final TimeValue DISRUPTION_HEALING_OVERHEAD = TimeValue.timeValueSeconds(40); // we use 30s as timeout in many places. + public static final TimeValue DISRUPTION_HEALING_OVERHEAD = TimeValue.timeValueSeconds(40); // we use 30s as timeout in many places. @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { @@ -220,7 +220,7 @@ NetworkDisruption addRandomDisruptionType(TwoPartitions partitions) { return partition; } - TwoPartitions isolateNode(String isolatedNode) { + protected TwoPartitions isolateNode(String isolatedNode) { Set side1 = new HashSet<>(); Set side2 = new HashSet<>(Arrays.asList(internalCluster().getNodeNames())); side1.add(isolatedNode); From 3fbf04eb5679371ece5f1e8da4ebdc3b6d4c9789 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Fri, 22 Nov 2024 13:20:15 -0800 Subject: [PATCH 8/9] Fix CCS exchange when multi cluster aliases point to same cluster (#117297) (#117389) [esql] > Unexpected error from Elasticsearch: illegal_state_exception - sink exchanger for id [ruxoDDxXTGW55oIPHoCT-g:964613010] already exists. This issue occurs when two or more clusterAliases point to the same physical remote cluster. The exchange service assumes the destination is unique, which is not true in this topology. This PR addresses the problem by appending a suffix using a monotonic increasing number, ensuring that different exchanges are created in such cases. Another issue arising from this behavior is that data on a remote cluster is processed multiple times, leading to incorrect results. I can work on the fix for this once we agree that this is an issue. --- docs/changelog/117297.yaml | 5 ++ .../test/AbstractMultiClustersTestCase.java | 29 ++++++++---- .../operator/exchange/ExchangeService.java | 5 ++ .../action/CrossClustersCancellationIT.java | 46 +++++++++++++++++++ .../xpack/esql/action/EsqlActionTaskIT.java | 3 +- .../xpack/esql/plugin/ComputeService.java | 22 ++++++--- 6 files changed, 93 insertions(+), 17 deletions(-) create mode 100644 docs/changelog/117297.yaml diff --git a/docs/changelog/117297.yaml b/docs/changelog/117297.yaml new file mode 100644 index 0000000000000..4a0051bbae644 --- /dev/null +++ b/docs/changelog/117297.yaml @@ -0,0 +1,5 @@ +pr: 117297 +summary: Fix CCS exchange when multi cluster aliases point to same cluster +area: ES|QL +type: bug +issues: [] diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java index 7b18cf575f190..ea82c9d21ab89 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java @@ -17,6 +17,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.network.NetworkModule; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Strings; import org.elasticsearch.plugins.Plugin; @@ -44,6 +45,7 @@ import static org.elasticsearch.discovery.DiscoveryModule.DISCOVERY_SEED_PROVIDERS_SETTING; import static org.elasticsearch.discovery.SettingsBasedSeedHostsProvider.DISCOVERY_SEED_HOSTS_SETTING; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.not; @@ -149,19 +151,23 @@ public static void stopClusters() throws IOException { } protected void disconnectFromRemoteClusters() throws Exception { - Settings.Builder settings = Settings.builder(); final Set clusterAliases = clusterGroup.clusterAliases(); for (String clusterAlias : clusterAliases) { if (clusterAlias.equals(LOCAL_CLUSTER) == false) { - settings.putNull("cluster.remote." + clusterAlias + ".seeds"); - settings.putNull("cluster.remote." + clusterAlias + ".mode"); - settings.putNull("cluster.remote." + clusterAlias + ".proxy_address"); + removeRemoteCluster(clusterAlias); } } + } + + protected void removeRemoteCluster(String clusterAlias) throws Exception { + Settings.Builder settings = Settings.builder(); + settings.putNull("cluster.remote." + clusterAlias + ".seeds"); + settings.putNull("cluster.remote." + clusterAlias + ".mode"); + settings.putNull("cluster.remote." + clusterAlias + ".proxy_address"); client().admin().cluster().prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT).setPersistentSettings(settings).get(); assertBusy(() -> { for (TransportService transportService : cluster(LOCAL_CLUSTER).getInstances(TransportService.class)) { - assertThat(transportService.getRemoteClusterService().getRegisteredRemoteClusterNames(), empty()); + assertThat(transportService.getRemoteClusterService().getRegisteredRemoteClusterNames(), not(contains(clusterAlias))); } }); } @@ -178,12 +184,17 @@ protected void configureAndConnectsToRemoteClusters() throws Exception { } protected void configureRemoteCluster(String clusterAlias, Collection seedNodes) throws Exception { - final String remoteClusterSettingPrefix = "cluster.remote." + clusterAlias + "."; - Settings.Builder settings = Settings.builder(); - final List seedAddresses = seedNodes.stream().map(node -> { + final var seedAddresses = seedNodes.stream().map(node -> { final TransportService transportService = cluster(clusterAlias).getInstance(TransportService.class, node); - return transportService.boundAddress().publishAddress().toString(); + return transportService.boundAddress().publishAddress(); }).toList(); + configureRemoteClusterWithSeedAddresses(clusterAlias, seedAddresses); + } + + protected void configureRemoteClusterWithSeedAddresses(String clusterAlias, Collection seedNodes) throws Exception { + final String remoteClusterSettingPrefix = "cluster.remote." + clusterAlias + "."; + Settings.Builder settings = Settings.builder(); + final List seedAddresses = seedNodes.stream().map(TransportAddress::toString).toList(); boolean skipUnavailable = skipUnavailableForRemoteClusters().containsKey(clusterAlias) ? skipUnavailableForRemoteClusters().get(clusterAlias) : DEFAULT_SKIP_UNAVAILABLE; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index 06059944f1310..e6bae7ba385e6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -40,6 +40,7 @@ import java.io.IOException; import java.util.Map; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicLong; @@ -339,6 +340,10 @@ public boolean isEmpty() { return sinks.isEmpty(); } + public Set sinkKeys() { + return sinks.keySet(); + } + @Override protected void doStart() { diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java index df6a1e00b0212..c426e0f528eab 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.action; +import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -15,6 +16,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.compute.operator.DriverTaskRunner; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.core.TimeValue; @@ -27,8 +29,10 @@ import org.elasticsearch.search.lookup.SearchLookup; import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.esql.plugin.ComputeService; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.Before; @@ -40,8 +44,10 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList; import static org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase.randomPragmas; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; @@ -189,4 +195,44 @@ public void testCancel() throws Exception { Exception error = expectThrows(Exception.class, requestFuture::actionGet); assertThat(error.getMessage(), containsString("proxy timeout")); } + + public void testSameRemoteClusters() throws Exception { + TransportAddress address = cluster(REMOTE_CLUSTER).getInstance(TransportService.class).getLocalNode().getAddress(); + int moreClusters = between(1, 5); + for (int i = 0; i < moreClusters; i++) { + String clusterAlias = REMOTE_CLUSTER + "-" + i; + configureRemoteClusterWithSeedAddresses(clusterAlias, List.of(address)); + } + int numDocs = between(10, 100); + createRemoteIndex(numDocs); + EsqlQueryRequest request = EsqlQueryRequest.syncEsqlQueryRequest(); + request.query("FROM *:test | STATS total=sum(const) | LIMIT 1"); + request.pragmas(randomPragmas()); + ActionFuture future = client().execute(EsqlQueryAction.INSTANCE, request); + try { + try { + assertBusy(() -> { + List tasks = client(REMOTE_CLUSTER).admin() + .cluster() + .prepareListTasks() + .setActions(ComputeService.CLUSTER_ACTION_NAME) + .get() + .getTasks(); + assertThat(tasks, hasSize(moreClusters + 1)); + }); + } finally { + PauseFieldPlugin.allowEmitting.countDown(); + } + try (EsqlQueryResponse resp = future.actionGet(30, TimeUnit.SECONDS)) { + // TODO: This produces incorrect results because data on the remote cluster is processed multiple times. + long expectedCount = numDocs * (moreClusters + 1L); + assertThat(getValuesList(resp), equalTo(List.of(List.of(expectedCount)))); + } + } finally { + for (int i = 0; i < moreClusters; i++) { + String clusterAlias = REMOTE_CLUSTER + "-" + i; + removeRemoteCluster(clusterAlias); + } + } + } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java index cde4f10ef556c..5f299fdca4d31 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java @@ -401,7 +401,8 @@ protected void doRun() throws Exception { }); sessionId = foundTasks.get(0).taskId().toString(); assertTrue(fetchingStarted.await(1, TimeUnit.MINUTES)); - ExchangeSinkHandler exchangeSink = exchangeService.getSinkHandler(sessionId); + String exchangeId = exchangeService.sinkKeys().stream().filter(s -> s.startsWith(sessionId)).findFirst().get(); + ExchangeSinkHandler exchangeSink = exchangeService.getSinkHandler(exchangeId); waitedForPages = randomBoolean(); if (waitedForPages) { // do not fail exchange requests until we have some pages diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index fc4c057e52ab6..eeed811674f60 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -82,6 +82,7 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; @@ -101,6 +102,7 @@ public class ComputeService { private final EnrichLookupService enrichLookupService; private final LookupFromIndexService lookupFromIndexService; private final ClusterService clusterService; + private final AtomicLong childSessionIdGenerator = new AtomicLong(); public ComputeService( SearchService searchService, @@ -167,7 +169,7 @@ public void execute( return; } var computeContext = new ComputeContext( - sessionId, + newChildSession(sessionId), RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, List.of(), configuration, @@ -330,14 +332,15 @@ private void startComputeOnDataNodes( // the new remote exchange sink, and initialize the computation on the target node via data-node-request. for (DataNode node : dataNodeResult.dataNodes()) { var queryPragmas = configuration.pragmas(); + var childSessionId = newChildSession(sessionId); ExchangeService.openExchange( transportService, node.connection, - sessionId, + childSessionId, queryPragmas.exchangeBufferSize(), esqlExecutor, refs.acquire().delegateFailureAndWrap((l, unused) -> { - var remoteSink = exchangeService.newRemoteSink(parentTask, sessionId, transportService, node.connection); + var remoteSink = exchangeService.newRemoteSink(parentTask, childSessionId, transportService, node.connection); exchangeSource.addRemoteSink(remoteSink, queryPragmas.concurrentExchangeClients()); ActionListener computeResponseListener = computeListener.acquireCompute(clusterAlias); var dataNodeListener = ActionListener.runBefore(computeResponseListener, () -> l.onResponse(null)); @@ -345,7 +348,7 @@ private void startComputeOnDataNodes( node.connection, DATA_ACTION_NAME, new DataNodeRequest( - sessionId, + childSessionId, configuration, clusterAlias, node.shardIds, @@ -378,17 +381,18 @@ private void startComputeOnRemoteClusters( var linkExchangeListeners = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); try (RefCountingListener refs = new RefCountingListener(linkExchangeListeners)) { for (RemoteCluster cluster : clusters) { + final var childSessionId = newChildSession(sessionId); ExchangeService.openExchange( transportService, cluster.connection, - sessionId, + childSessionId, queryPragmas.exchangeBufferSize(), esqlExecutor, refs.acquire().delegateFailureAndWrap((l, unused) -> { - var remoteSink = exchangeService.newRemoteSink(rootTask, sessionId, transportService, cluster.connection); + var remoteSink = exchangeService.newRemoteSink(rootTask, childSessionId, transportService, cluster.connection); exchangeSource.addRemoteSink(remoteSink, queryPragmas.concurrentExchangeClients()); var remotePlan = new RemoteClusterPlan(plan, cluster.concreteIndices, cluster.originalIndices); - var clusterRequest = new ClusterComputeRequest(cluster.clusterAlias, sessionId, configuration, remotePlan); + var clusterRequest = new ClusterComputeRequest(cluster.clusterAlias, childSessionId, configuration, remotePlan); var clusterListener = ActionListener.runBefore( computeListener.acquireCompute(cluster.clusterAlias()), () -> l.onResponse(null) @@ -912,4 +916,8 @@ public List searchExecutionContexts() { return searchContexts.stream().map(ctx -> ctx.getSearchExecutionContext()).toList(); } } + + private String newChildSession(String session) { + return session + "/" + childSessionIdGenerator.incrementAndGet(); + } } From 519057d917f33dd0cd411db4ab3a66b827860bf0 Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Fri, 22 Nov 2024 13:21:14 -0800 Subject: [PATCH 9/9] Fix SecureSM to allow innocuous threads and threadgroups for parallel streams (#117277) (#117292) When a parallel stream is opened, the jdk uses an internal fork join pool to do work on processing the stream. This pool is internal to the jdk, and so it should always be allowed to create threads. This commit modifies SecureSM to account for this innocuous thread group and threads. Co-authored-by: Elastic Machine --- .../org/elasticsearch/secure_sm/SecureSM.java | 18 +++++++++++++++--- .../elasticsearch/secure_sm/SecureSMTests.java | 11 +++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SecureSM.java b/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SecureSM.java index 4fd471c529e75..02d0491118dc7 100644 --- a/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SecureSM.java +++ b/libs/secure-sm/src/main/java/org/elasticsearch/secure_sm/SecureSM.java @@ -157,7 +157,9 @@ private static void debugThreadGroups(final ThreadGroup caller, final ThreadGrou // Returns true if the given thread is an instance of the JDK's InnocuousThread. private static boolean isInnocuousThread(Thread t) { final Class c = t.getClass(); - return c.getModule() == Object.class.getModule() && c.getName().equals("jdk.internal.misc.InnocuousThread"); + return c.getModule() == Object.class.getModule() + && (c.getName().equals("jdk.internal.misc.InnocuousThread") + || c.getName().equals("java.util.concurrent.ForkJoinWorkerThread$InnocuousForkJoinWorkerThread")); } protected void checkThreadAccess(Thread t) { @@ -184,11 +186,21 @@ protected void checkThreadAccess(Thread t) { private static final Permission MODIFY_THREADGROUP_PERMISSION = new RuntimePermission("modifyThreadGroup"); private static final Permission MODIFY_ARBITRARY_THREADGROUP_PERMISSION = new ThreadPermission("modifyArbitraryThreadGroup"); + // Returns true if the given thread is an instance of the JDK's InnocuousThread. + private static boolean isInnocuousThreadGroup(ThreadGroup t) { + final Class c = t.getClass(); + return c.getModule() == Object.class.getModule() && t.getName().equals("InnocuousForkJoinWorkerThreadGroup"); + } + protected void checkThreadGroupAccess(ThreadGroup g) { Objects.requireNonNull(g); + boolean targetThreadGroupIsInnocuous = isInnocuousThreadGroup(g); + // first, check if we can modify thread groups at all. - checkPermission(MODIFY_THREADGROUP_PERMISSION); + if (targetThreadGroupIsInnocuous == false) { + checkPermission(MODIFY_THREADGROUP_PERMISSION); + } // check the threadgroup, if its our thread group or an ancestor, its fine. final ThreadGroup source = Thread.currentThread().getThreadGroup(); @@ -196,7 +208,7 @@ protected void checkThreadGroupAccess(ThreadGroup g) { if (source == null) { return; // we are a dead thread, do nothing - } else if (source.parentOf(target) == false) { + } else if (source.parentOf(target) == false && targetThreadGroupIsInnocuous == false) { checkPermission(MODIFY_ARBITRARY_THREADGROUP_PERMISSION); } } diff --git a/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/SecureSMTests.java b/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/SecureSMTests.java index b94639414ffe5..69c6973f57cdf 100644 --- a/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/SecureSMTests.java +++ b/libs/secure-sm/src/test/java/org/elasticsearch/secure_sm/SecureSMTests.java @@ -14,7 +14,10 @@ import java.security.Permission; import java.security.Policy; import java.security.ProtectionDomain; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; /** Simple tests for SecureSM */ public class SecureSMTests extends TestCase { @@ -128,4 +131,12 @@ public void run() { t1.join(); assertTrue(interrupted1.get()); } + + public void testParallelStreamThreadGroup() throws Exception { + List list = new ArrayList<>(); + for (int i = 0; i < 100; ++i) { + list.add(i); + } + list.parallelStream().collect(Collectors.toSet()); + } }