From ef95cdd4cce3cdb3c788dd6c2de122dcc7f82d4a Mon Sep 17 00:00:00 2001
From: Ryan Ernst
Date: Mon, 26 Aug 2024 18:51:12 -0700
Subject: [PATCH 01/46] Fix native library loading zstd with jna (#112221)
Recent refactoring of native library paths broke jna loading zstd. This
commit fixes jna to set the jna.library.path during init so that jna
calls to load libraries still work.
---
.../nativeaccess/jna/JnaNativeLibraryProvider.java | 11 +++++++++++
.../elasticsearch/nativeaccess/lib/LoaderHelper.java | 2 +-
2 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java
index 79caf04c97246..e0233187425ea 100644
--- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java
+++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java
@@ -8,9 +8,11 @@
package org.elasticsearch.nativeaccess.jna;
+import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.nativeaccess.lib.JavaLibrary;
import org.elasticsearch.nativeaccess.lib.Kernel32Library;
import org.elasticsearch.nativeaccess.lib.LinuxCLibrary;
+import org.elasticsearch.nativeaccess.lib.LoaderHelper;
import org.elasticsearch.nativeaccess.lib.MacCLibrary;
import org.elasticsearch.nativeaccess.lib.NativeLibrary;
import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider;
@@ -23,6 +25,10 @@
public class JnaNativeLibraryProvider extends NativeLibraryProvider {
+ static {
+ setJnaLibraryPath();
+ }
+
public JnaNativeLibraryProvider() {
super(
"jna",
@@ -45,6 +51,11 @@ public JnaNativeLibraryProvider() {
);
}
+ @SuppressForbidden(reason = "jna library path must be set for load library to work with our own libs")
+ private static void setJnaLibraryPath() {
+ System.setProperty("jna.library.path", LoaderHelper.platformLibDir.toString());
+ }
+
private static Supplier notImplemented() {
return () -> { throw new AssertionError(); };
}
diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LoaderHelper.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LoaderHelper.java
index 4da52c415c040..42ca60b81a027 100644
--- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LoaderHelper.java
+++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LoaderHelper.java
@@ -16,7 +16,7 @@
* A utility for loading libraries from Elasticsearch's platform specific lib dir.
*/
public class LoaderHelper {
- private static final Path platformLibDir = findPlatformLibDir();
+ public static final Path platformLibDir = findPlatformLibDir();
private static Path findPlatformLibDir() {
// tests don't have an ES install, so the platform dir must be passed in explicitly
From 535e9edced9995e8411b46622e29f8ae006ab4f1 Mon Sep 17 00:00:00 2001
From: Quentin Pradet
Date: Tue, 27 Aug 2024 06:38:11 +0400
Subject: [PATCH 02/46] Add ingest-geoip module to rest-resources-zip (#112216)
---
modules/ingest-geoip/build.gradle | 4 ++++
x-pack/rest-resources-zip/build.gradle | 1 +
2 files changed, 5 insertions(+)
diff --git a/modules/ingest-geoip/build.gradle b/modules/ingest-geoip/build.gradle
index 5bdb6da5c7b29..bc5bb165cd0d2 100644
--- a/modules/ingest-geoip/build.gradle
+++ b/modules/ingest-geoip/build.gradle
@@ -88,3 +88,7 @@ tasks.named("yamlRestTestV7CompatTransform").configure { task ->
task.skipTestsByFilePattern("**/ingest_geoip/20_geoip_processor.yml", "from 8.0 yaml rest tests use geoip test fixture and default geoip are no longer packaged. In 7.x yaml tests used default databases which makes tests results very different, so skipping these tests")
// task.skipTest("lang_mustache/50_multi_search_template/Multi-search template with errors", "xxx")
}
+
+artifacts {
+ restTests(new File(projectDir, "src/yamlRestTest/resources/rest-api-spec/test"))
+}
diff --git a/x-pack/rest-resources-zip/build.gradle b/x-pack/rest-resources-zip/build.gradle
index cc5bddf12d801..0133ff80dfadf 100644
--- a/x-pack/rest-resources-zip/build.gradle
+++ b/x-pack/rest-resources-zip/build.gradle
@@ -21,6 +21,7 @@ dependencies {
freeTests project(path: ':rest-api-spec', configuration: 'restTests')
freeTests project(path: ':modules:aggregations', configuration: 'restTests')
freeTests project(path: ':modules:analysis-common', configuration: 'restTests')
+ freeTests project(path: ':modules:ingest-geoip', configuration: 'restTests')
compatApis project(path: ':rest-api-spec', configuration: 'restCompatSpecs')
compatApis project(path: ':x-pack:plugin', configuration: 'restCompatSpecs')
freeCompatTests project(path: ':rest-api-spec', configuration: 'restCompatTests')
From d14fe7733b2ce361e08c05624668fddbf2763a86 Mon Sep 17 00:00:00 2001
From: Yang Wang
Date: Tue, 27 Aug 2024 17:03:01 +1000
Subject: [PATCH 03/46] Expand RecordingInstrucments to support collection of
observers (#112195)
The support is needed for RecordingInstruments to be used in tests for
guages with a collection of observers.
Relates: #110630
---
.../telemetry/RecordingInstruments.java | 29 ++++++++-----
.../telemetry/RecordingMeterRegistry.java | 42 +++++++++++--------
2 files changed, 43 insertions(+), 28 deletions(-)
diff --git a/test/framework/src/main/java/org/elasticsearch/telemetry/RecordingInstruments.java b/test/framework/src/main/java/org/elasticsearch/telemetry/RecordingInstruments.java
index 35417c16e7e1c..49e667bb74e5b 100644
--- a/test/framework/src/main/java/org/elasticsearch/telemetry/RecordingInstruments.java
+++ b/test/framework/src/main/java/org/elasticsearch/telemetry/RecordingInstruments.java
@@ -24,6 +24,7 @@
import org.elasticsearch.telemetry.metric.LongUpDownCounter;
import org.elasticsearch.telemetry.metric.LongWithAttributes;
+import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
@@ -53,7 +54,7 @@ public String getName() {
}
}
- protected interface NumberWithAttributesObserver extends Supplier>> {
+ protected interface NumberWithAttributesObserver extends Supplier>>> {
}
@@ -74,7 +75,7 @@ public void run() {
return;
}
var observation = observer.get();
- call(observation.v1(), observation.v2());
+ observation.forEach(o -> call(o.v1(), o.v2()));
}
}
@@ -109,10 +110,10 @@ public void incrementBy(double inc, Map attributes) {
}
public static class RecordingDoubleGauge extends CallbackRecordingInstrument implements DoubleGauge {
- public RecordingDoubleGauge(String name, Supplier observer, MetricRecorder recorder) {
+ public RecordingDoubleGauge(String name, Supplier> observer, MetricRecorder recorder) {
super(name, () -> {
var observation = observer.get();
- return new Tuple<>(observation.value(), observation.attributes());
+ return observation.stream().map(o -> new Tuple<>((Number) o.value(), o.attributes())).toList();
}, recorder);
}
}
@@ -172,10 +173,14 @@ public void incrementBy(long inc, Map attributes) {
public static class RecordingAsyncLongCounter extends CallbackRecordingInstrument implements LongAsyncCounter {
- public RecordingAsyncLongCounter(String name, Supplier observer, MetricRecorder recorder) {
+ public RecordingAsyncLongCounter(
+ String name,
+ Supplier> observer,
+ MetricRecorder recorder
+ ) {
super(name, () -> {
var observation = observer.get();
- return new Tuple<>(observation.value(), observation.attributes());
+ return observation.stream().map(o -> new Tuple<>((Number) o.value(), o.attributes())).toList();
}, recorder);
}
@@ -183,10 +188,14 @@ public RecordingAsyncLongCounter(String name, Supplier obser
public static class RecordingAsyncDoubleCounter extends CallbackRecordingInstrument implements DoubleAsyncCounter {
- public RecordingAsyncDoubleCounter(String name, Supplier observer, MetricRecorder recorder) {
+ public RecordingAsyncDoubleCounter(
+ String name,
+ Supplier> observer,
+ MetricRecorder recorder
+ ) {
super(name, () -> {
var observation = observer.get();
- return new Tuple<>(observation.value(), observation.attributes());
+ return observation.stream().map(o -> new Tuple<>((Number) o.value(), o.attributes())).toList();
}, recorder);
}
@@ -194,10 +203,10 @@ public RecordingAsyncDoubleCounter(String name, Supplier o
public static class RecordingLongGauge extends CallbackRecordingInstrument implements LongGauge {
- public RecordingLongGauge(String name, Supplier observer, MetricRecorder recorder) {
+ public RecordingLongGauge(String name, Supplier> observer, MetricRecorder recorder) {
super(name, () -> {
var observation = observer.get();
- return new Tuple<>(observation.value(), observation.attributes());
+ return observation.stream().map(o -> new Tuple<>((Number) o.value(), o.attributes())).toList();
}, recorder);
}
}
diff --git a/test/framework/src/main/java/org/elasticsearch/telemetry/RecordingMeterRegistry.java b/test/framework/src/main/java/org/elasticsearch/telemetry/RecordingMeterRegistry.java
index 97fe0ad1370ef..392445aa77a8f 100644
--- a/test/framework/src/main/java/org/elasticsearch/telemetry/RecordingMeterRegistry.java
+++ b/test/framework/src/main/java/org/elasticsearch/telemetry/RecordingMeterRegistry.java
@@ -24,6 +24,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import java.util.Collection;
+import java.util.Collections;
import java.util.function.Supplier;
/**
@@ -72,9 +73,7 @@ protected DoubleUpDownCounter buildDoubleUpDownCounter(String name, String descr
@Override
public DoubleGauge registerDoubleGauge(String name, String description, String unit, Supplier observer) {
- DoubleGauge instrument = buildDoubleGauge(name, description, unit, observer);
- recorder.register(instrument, InstrumentType.fromInstrument(instrument), name, description, unit);
- return instrument;
+ return registerDoublesGauge(name, description, unit, () -> Collections.singleton(observer.get()));
}
@Override
@@ -84,7 +83,9 @@ public DoubleGauge registerDoublesGauge(
String unit,
Supplier> observer
) {
- throw new UnsupportedOperationException("not implemented");
+ DoubleGauge instrument = buildDoubleGauge(name, description, unit, observer);
+ recorder.register(instrument, InstrumentType.fromInstrument(instrument), name, description, unit);
+ return instrument;
}
@Override
@@ -92,7 +93,12 @@ public DoubleGauge getDoubleGauge(String name) {
return (DoubleGauge) recorder.getInstrument(InstrumentType.DOUBLE_GAUGE, name);
}
- protected DoubleGauge buildDoubleGauge(String name, String description, String unit, Supplier observer) {
+ protected DoubleGauge buildDoubleGauge(
+ String name,
+ String description,
+ String unit,
+ Supplier> observer
+ ) {
return new RecordingInstruments.RecordingDoubleGauge(name, observer, recorder);
}
@@ -121,9 +127,7 @@ public LongCounter registerLongCounter(String name, String description, String u
@Override
public LongAsyncCounter registerLongAsyncCounter(String name, String description, String unit, Supplier observer) {
- LongAsyncCounter instrument = new RecordingInstruments.RecordingAsyncLongCounter(name, observer, recorder);
- recorder.register(instrument, InstrumentType.fromInstrument(instrument), name, description, unit);
- return instrument;
+ return registerLongsAsyncCounter(name, description, unit, () -> Collections.singleton(observer.get()));
}
@Override
@@ -133,7 +137,9 @@ public LongAsyncCounter registerLongsAsyncCounter(
String unit,
Supplier> observer
) {
- throw new UnsupportedOperationException("not implemented");
+ LongAsyncCounter instrument = new RecordingInstruments.RecordingAsyncLongCounter(name, observer, recorder);
+ recorder.register(instrument, InstrumentType.fromInstrument(instrument), name, description, unit);
+ return instrument;
}
@Override
@@ -148,9 +154,7 @@ public DoubleAsyncCounter registerDoubleAsyncCounter(
String unit,
Supplier observer
) {
- DoubleAsyncCounter instrument = new RecordingInstruments.RecordingAsyncDoubleCounter(name, observer, recorder);
- recorder.register(instrument, InstrumentType.fromInstrument(instrument), name, description, unit);
- return instrument;
+ return registerDoublesAsyncCounter(name, description, unit, () -> Collections.singleton(observer.get()));
}
@Override
@@ -160,7 +164,9 @@ public DoubleAsyncCounter registerDoublesAsyncCounter(
String unit,
Supplier> observer
) {
- throw new UnsupportedOperationException("not implemented");
+ DoubleAsyncCounter instrument = new RecordingInstruments.RecordingAsyncDoubleCounter(name, observer, recorder);
+ recorder.register(instrument, InstrumentType.fromInstrument(instrument), name, description, unit);
+ return instrument;
}
@Override
@@ -196,14 +202,14 @@ protected LongUpDownCounter buildLongUpDownCounter(String name, String descripti
@Override
public LongGauge registerLongGauge(String name, String description, String unit, Supplier observer) {
- LongGauge instrument = buildLongGauge(name, description, unit, observer);
- recorder.register(instrument, InstrumentType.fromInstrument(instrument), name, description, unit);
- return instrument;
+ return registerLongsGauge(name, description, unit, () -> Collections.singleton(observer.get()));
}
@Override
public LongGauge registerLongsGauge(String name, String description, String unit, Supplier> observer) {
- throw new UnsupportedOperationException("not implemented");
+ LongGauge instrument = buildLongGauge(name, description, unit, observer);
+ recorder.register(instrument, InstrumentType.fromInstrument(instrument), name, description, unit);
+ return instrument;
}
@Override
@@ -211,7 +217,7 @@ public LongGauge getLongGauge(String name) {
return (LongGauge) recorder.getInstrument(InstrumentType.LONG_GAUGE, name);
}
- protected LongGauge buildLongGauge(String name, String description, String unit, Supplier observer) {
+ protected LongGauge buildLongGauge(String name, String description, String unit, Supplier> observer) {
return new RecordingInstruments.RecordingLongGauge(name, observer, recorder);
}
From 303b2274766595c2bbbd2b339345cfa6b6a2009e Mon Sep 17 00:00:00 2001
From: David Turner
Date: Tue, 27 Aug 2024 08:05:46 +0100
Subject: [PATCH 04/46] Add link to warning re. single-node clusters (#112114)
Expands the message added in #88013 to include a link to the relevant
docs.
---
.../cluster/coordination/Coordinator.java | 7 +++++--
.../java/org/elasticsearch/common/ReferenceDocs.java | 1 +
.../elasticsearch/common/reference-docs-links.json | 3 ++-
.../cluster/coordination/CoordinatorTests.java | 11 ++++++++++-
4 files changed, 18 insertions(+), 4 deletions(-)
diff --git a/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java b/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java
index 437219b312045..e922d130d7f83 100644
--- a/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java
+++ b/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java
@@ -41,6 +41,7 @@
import org.elasticsearch.cluster.service.MasterServiceTaskQueue;
import org.elasticsearch.cluster.version.CompatibilityVersions;
import org.elasticsearch.common.Priority;
+import org.elasticsearch.common.ReferenceDocs;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@@ -831,10 +832,12 @@ public void run() {
discover other nodes and form a multi-node cluster via the [{}={}] setting. Fully-formed clusters do \
not attempt to discover other nodes, and nodes with different cluster UUIDs cannot belong to the same \
cluster. The cluster UUID persists across restarts and can only be changed by deleting the contents of \
- the node's data path(s). Remove the discovery configuration to suppress this message.""",
+ the node's data path(s). Remove the discovery configuration to suppress this message. See [{}] for \
+ more information.""",
applierState.metadata().clusterUUID(),
DISCOVERY_SEED_HOSTS_SETTING.getKey(),
- DISCOVERY_SEED_HOSTS_SETTING.get(settings)
+ DISCOVERY_SEED_HOSTS_SETTING.get(settings),
+ ReferenceDocs.FORMING_SINGLE_NODE_CLUSTERS
);
}
}
diff --git a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java
index f710ae7c3b84a..59c55fb7b624a 100644
--- a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java
+++ b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java
@@ -81,6 +81,7 @@ public enum ReferenceDocs {
MAX_SHARDS_PER_NODE,
FLOOD_STAGE_WATERMARK,
X_OPAQUE_ID,
+ FORMING_SINGLE_NODE_CLUSTERS,
// this comment keeps the ';' on the next line so every entry above has a trailing ',' which makes the diff for adding new links cleaner
;
diff --git a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json
index 8288ca792b0f1..3eb8939c22a65 100644
--- a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json
+++ b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json
@@ -41,5 +41,6 @@
"LUCENE_MAX_DOCS_LIMIT": "size-your-shards.html#troubleshooting-max-docs-limit",
"MAX_SHARDS_PER_NODE": "size-your-shards.html#troubleshooting-max-shards-open",
"FLOOD_STAGE_WATERMARK": "fix-watermark-errors.html",
- "X_OPAQUE_ID": "api-conventions.html#x-opaque-id"
+ "X_OPAQUE_ID": "api-conventions.html#x-opaque-id",
+ "FORMING_SINGLE_NODE_CLUSTERS": "modules-discovery-bootstrap-cluster.html#modules-discovery-bootstrap-cluster-joining"
}
diff --git a/server/src/test/java/org/elasticsearch/cluster/coordination/CoordinatorTests.java b/server/src/test/java/org/elasticsearch/cluster/coordination/CoordinatorTests.java
index b57badb3a180f..bf64b29d364e0 100644
--- a/server/src/test/java/org/elasticsearch/cluster/coordination/CoordinatorTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/coordination/CoordinatorTests.java
@@ -25,6 +25,7 @@
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterStateUpdateStats;
+import org.elasticsearch.common.ReferenceDocs;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -79,6 +80,8 @@
import static org.elasticsearch.discovery.SettingsBasedSeedHostsProvider.DISCOVERY_SEED_HOSTS_SETTING;
import static org.elasticsearch.monitor.StatusInfo.Status.HEALTHY;
import static org.elasticsearch.monitor.StatusInfo.Status.UNHEALTHY;
+import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -1762,7 +1765,13 @@ public void testLogsWarningPeriodicallyIfSingleNodeClusterHasSeedHosts() {
@Override
public void match(LogEvent event) {
final String message = event.getMessage().getFormattedMessage();
- assertThat(message, startsWith("This node is a fully-formed single-node cluster with cluster UUID"));
+ assertThat(
+ message,
+ allOf(
+ startsWith("This node is a fully-formed single-node cluster with cluster UUID"),
+ containsString(ReferenceDocs.FORMING_SINGLE_NODE_CLUSTERS.toString())
+ )
+ );
loggedClusterUuid = (String) event.getMessage().getParameters()[0];
}
From ec90d2c1239bf848914dc4411c676a1f05f2777a Mon Sep 17 00:00:00 2001
From: David Turner
Date: Tue, 27 Aug 2024 08:06:05 +0100
Subject: [PATCH 05/46] Reduce nesting in restore-snapshot path (#112107)
Also cleans up the exception-handling a little to ensure that all
failures are logged.
---
.../snapshots/RestoreService.java | 114 +++++++++---------
1 file changed, 59 insertions(+), 55 deletions(-)
diff --git a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java
index 0f03cfab4ad2e..d8987495f9035 100644
--- a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java
+++ b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java
@@ -15,6 +15,7 @@
import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.RefCountingRunnable;
+import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateApplier;
@@ -56,7 +57,6 @@
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.Maps;
-import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Nullable;
@@ -92,9 +92,9 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -248,62 +248,66 @@ public void restoreSnapshot(
final BiConsumer updater
) {
assert Repository.assertSnapshotMetaThread();
- try {
- // Try and fill in any missing repository UUIDs in case they're needed during the restore
- final var repositoryUuidRefreshStep = new ListenableFuture();
- refreshRepositoryUuids(
- refreshRepositoryUuidOnRestore,
- repositoriesService,
- () -> repositoryUuidRefreshStep.onResponse(null),
- snapshotMetaExecutor
- );
- // Read snapshot info and metadata from the repository
- final String repositoryName = request.repository();
- Repository repository = repositoriesService.repository(repositoryName);
- final ListenableFuture repositoryDataListener = new ListenableFuture<>();
- repository.getRepositoryData(snapshotMetaExecutor, repositoryDataListener);
-
- repositoryDataListener.addListener(
- listener.delegateFailureAndWrap(
- (delegate, repositoryData) -> repositoryUuidRefreshStep.addListener(
- delegate.delegateFailureAndWrap((subDelegate, ignored) -> {
- assert Repository.assertSnapshotMetaThread();
- final String snapshotName = request.snapshot();
- final Optional matchingSnapshotId = repositoryData.getSnapshotIds()
- .stream()
- .filter(s -> snapshotName.equals(s.getName()))
- .findFirst();
- if (matchingSnapshotId.isPresent() == false) {
- throw new SnapshotRestoreException(repositoryName, snapshotName, "snapshot does not exist");
- }
+ // Try and fill in any missing repository UUIDs in case they're needed during the restore
+ final var repositoryUuidRefreshStep = SubscribableListener.newForked(
+ l -> refreshRepositoryUuids(refreshRepositoryUuidOnRestore, repositoriesService, () -> l.onResponse(null), snapshotMetaExecutor)
+ );
- final SnapshotId snapshotId = matchingSnapshotId.get();
- if (request.snapshotUuid() != null && request.snapshotUuid().equals(snapshotId.getUUID()) == false) {
- throw new SnapshotRestoreException(
- repositoryName,
- snapshotName,
- "snapshot UUID mismatch: expected ["
- + request.snapshotUuid()
- + "] but got ["
- + snapshotId.getUUID()
- + "]"
- );
- }
- repository.getSnapshotInfo(
- snapshotId,
- subDelegate.delegateFailureAndWrap(
- (l, snapshotInfo) -> startRestore(snapshotInfo, repository, request, repositoryData, updater, l)
- )
- );
- })
- )
+ // AtomicReference just so we have somewhere to hold these objects, there's no interesting concurrency here
+ final AtomicReference repositoryRef = new AtomicReference<>();
+ final AtomicReference repositoryDataRef = new AtomicReference<>();
+
+ SubscribableListener
+
+ .newForked(repositorySetListener -> {
+ // do this within newForked for exception handling
+ repositoryRef.set(repositoriesService.repository(request.repository()));
+ repositorySetListener.onResponse(null);
+ })
+
+ .andThen(
+ repositoryDataListener -> repositoryRef.get().getRepositoryData(snapshotMetaExecutor, repositoryDataListener)
+ )
+ .andThenAccept(repositoryDataRef::set)
+ .andThen(repositoryUuidRefreshStep::addListener)
+
+ .andThen(snapshotInfoListener -> {
+ assert Repository.assertSnapshotMetaThread();
+ final String snapshotName = request.snapshot();
+ final SnapshotId snapshotId = repositoryDataRef.get()
+ .getSnapshotIds()
+ .stream()
+ .filter(s -> snapshotName.equals(s.getName()))
+ .findFirst()
+ .orElseThrow(() -> new SnapshotRestoreException(request.repository(), snapshotName, "snapshot does not exist"));
+
+ if (request.snapshotUuid() != null && request.snapshotUuid().equals(snapshotId.getUUID()) == false) {
+ throw new SnapshotRestoreException(
+ request.repository(),
+ snapshotName,
+ "snapshot UUID mismatch: expected [" + request.snapshotUuid() + "] but got [" + snapshotId.getUUID() + "]"
+ );
+ }
+
+ repositoryRef.get().getSnapshotInfo(snapshotId, snapshotInfoListener);
+ })
+
+ .andThen(
+ (responseListener, snapshotInfo) -> startRestore(
+ snapshotInfo,
+ repositoryRef.get(),
+ request,
+ repositoryDataRef.get(),
+ updater,
+ responseListener
)
- );
- } catch (Exception e) {
- logger.warn(() -> "[" + request.repository() + ":" + request.snapshot() + "] failed to restore snapshot", e);
- listener.onFailure(e);
- }
+ )
+
+ .addListener(listener.delegateResponse((delegate, e) -> {
+ logger.warn(() -> "[" + request.repository() + ":" + request.snapshot() + "] failed to restore snapshot", e);
+ delegate.onFailure(e);
+ }));
}
/**
From bff45aaa8a2d53d3de44c66a2c692664fa3b3d46 Mon Sep 17 00:00:00 2001
From: David Turner
Date: Tue, 27 Aug 2024 08:06:20 +0100
Subject: [PATCH 06/46] Reduce `CompletableFuture` usage in tests (#111848)
Fixes some spots in tests where we use `CompletableFuture` instead of
one of the preferred alternatives.
---
.../grok/MatcherWatchdogTests.java | 9 +-
.../action/bulk/BulkOperationTests.java | 136 +++++-------------
.../ingest/ConditionalProcessorTests.java | 8 +-
.../ingest/PipelineProcessorTests.java | 10 +-
.../security/authc/ApiKeyServiceTests.java | 16 +--
5 files changed, 53 insertions(+), 126 deletions(-)
diff --git a/libs/grok/src/test/java/org/elasticsearch/grok/MatcherWatchdogTests.java b/libs/grok/src/test/java/org/elasticsearch/grok/MatcherWatchdogTests.java
index b66778743aec0..5ed1a7d13b80a 100644
--- a/libs/grok/src/test/java/org/elasticsearch/grok/MatcherWatchdogTests.java
+++ b/libs/grok/src/test/java/org/elasticsearch/grok/MatcherWatchdogTests.java
@@ -7,12 +7,12 @@
*/
package org.elasticsearch.grok;
+import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.test.ESTestCase;
import org.joni.Matcher;
import org.mockito.Mockito;
import java.util.Map;
-import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -77,16 +77,17 @@ public void testIdleIfNothingRegistered() throws Exception {
);
// Periodic action is not scheduled because no thread is registered
verifyNoMoreInteractions(threadPool);
- CompletableFuture commandFuture = new CompletableFuture<>();
+
+ PlainActionFuture commandFuture = new PlainActionFuture<>();
// Periodic action is scheduled because a thread is registered
doAnswer(invocationOnMock -> {
- commandFuture.complete((Runnable) invocationOnMock.getArguments()[0]);
+ commandFuture.onResponse(invocationOnMock.getArgument(0));
return null;
}).when(threadPool).schedule(any(Runnable.class), eq(interval), eq(TimeUnit.MILLISECONDS));
Matcher matcher = mock(Matcher.class);
watchdog.register(matcher);
// Registering the first thread should have caused the command to get scheduled again
- Runnable command = commandFuture.get(1L, TimeUnit.MILLISECONDS);
+ Runnable command = safeGet(commandFuture);
Mockito.reset(threadPool);
watchdog.unregister(matcher);
command.run();
diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java
index e950901a538b4..0c0e1de74a3e7 100644
--- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java
+++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java
@@ -20,6 +20,7 @@
import org.elasticsearch.action.delete.DeleteResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
+import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.ClusterName;
@@ -60,9 +61,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
-import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
@@ -201,9 +200,6 @@ public void tearDownThreadpool() {
public void testClusterBlockedFailsBulk() {
NodeClient client = getNodeClient(assertNoClientInteraction());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
// Not retryable
ClusterState state = ClusterState.builder(DEFAULT_STATE)
.blocks(ClusterBlocks.builder().addGlobalBlock(Metadata.CLUSTER_READ_ONLY_BLOCK).build())
@@ -215,9 +211,10 @@ public void testClusterBlockedFailsBulk() {
when(observer.isTimedOut()).thenReturn(false);
doThrow(new AssertionError("Should not wait")).when(observer).waitForNextChange(any());
- newBulkOperation(client, new BulkRequest(), state, observer, listener).run();
-
- expectThrows(ExecutionException.class, ClusterBlockException.class, future::get);
+ assertThat(
+ safeAwaitFailure(BulkResponse.class, l -> newBulkOperation(client, new BulkRequest(), state, observer, l).run()),
+ instanceOf(ClusterBlockException.class)
+ );
}
/**
@@ -226,9 +223,6 @@ public void testClusterBlockedFailsBulk() {
public void testTimeoutOnRetryableClusterBlockedFailsBulk() {
NodeClient client = getNodeClient(assertNoClientInteraction());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
// Retryable
final ClusterState state = ClusterState.builder(DEFAULT_STATE)
.blocks(ClusterBlocks.builder().addGlobalBlock(NoMasterBlockService.NO_MASTER_BLOCK_WRITES).build())
@@ -248,9 +242,11 @@ public void testTimeoutOnRetryableClusterBlockedFailsBulk() {
return null;
}).doThrow(new AssertionError("Should not wait")).when(observer).waitForNextChange(any());
- newBulkOperation(client, new BulkRequest(), state, observer, listener).run();
+ assertThat(
+ safeAwaitFailure(BulkResponse.class, l -> newBulkOperation(client, new BulkRequest(), state, observer, l).run()),
+ instanceOf(ClusterBlockException.class)
+ );
- expectThrows(ExecutionException.class, ClusterBlockException.class, future::get);
verify(observer, times(2)).isTimedOut();
verify(observer, times(1)).waitForNextChange(any());
}
@@ -261,9 +257,6 @@ public void testTimeoutOnRetryableClusterBlockedFailsBulk() {
public void testNodeClosedOnRetryableClusterBlockedFailsBulk() {
NodeClient client = getNodeClient(assertNoClientInteraction());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
// Retryable
final ClusterState state = ClusterState.builder(DEFAULT_STATE)
.blocks(ClusterBlocks.builder().addGlobalBlock(NoMasterBlockService.NO_MASTER_BLOCK_WRITES).build())
@@ -278,9 +271,10 @@ public void testNodeClosedOnRetryableClusterBlockedFailsBulk() {
return null;
}).doThrow(new AssertionError("Should not wait")).when(observer).waitForNextChange(any());
- newBulkOperation(client, new BulkRequest(), state, observer, listener).run();
-
- expectThrows(ExecutionException.class, NodeClosedException.class, future::get);
+ assertThat(
+ safeAwaitFailure(BulkResponse.class, l -> newBulkOperation(client, new BulkRequest(), state, observer, l).run()),
+ instanceOf(NodeClosedException.class)
+ );
verify(observer, times(1)).isTimedOut();
verify(observer, times(1)).waitForNextChange(any());
}
@@ -296,12 +290,7 @@ public void testBulkToIndex() throws Exception {
NodeClient client = getNodeClient(acceptAllShardWrites());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, l).run());
assertThat(bulkItemResponses.hasFailures(), is(false));
}
@@ -318,12 +307,7 @@ public void testBulkToIndexFailingEntireShard() throws Exception {
shardSpecificResponse(Map.of(new ShardId(indexMetadata.getIndex(), 0), failWithException(() -> new MapperException("test"))))
);
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, l).run());
assertThat(bulkItemResponses.hasFailures(), is(true));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(BulkItemResponse::isFailed)
@@ -344,12 +328,7 @@ public void testBulkToDataStream() throws Exception {
NodeClient client = getNodeClient(acceptAllShardWrites());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, l).run());
assertThat(bulkItemResponses.hasFailures(), is(false));
}
@@ -366,12 +345,7 @@ public void testBulkToDataStreamFailingEntireShard() throws Exception {
shardSpecificResponse(Map.of(new ShardId(ds1BackingIndex2.getIndex(), 0), failWithException(() -> new MapperException("test"))))
);
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, l).run());
assertThat(bulkItemResponses.hasFailures(), is(true));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(BulkItemResponse::isFailed)
@@ -396,12 +370,7 @@ public void testFailingEntireShardRedirectsToFailureStore() throws Exception {
shardSpecificResponse(Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), failWithException(() -> new MapperException("test"))))
);
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, l).run());
assertThat(bulkItemResponses.hasFailures(), is(false));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(item -> item.getIndex().equals(ds2FailureStore1.getIndex().getName()))
@@ -426,12 +395,7 @@ public void testFailingDocumentRedirectsToFailureStore() throws Exception {
thatFailsDocuments(Map.of(new IndexAndId(ds2BackingIndex1.getIndex().getName(), "3"), () -> new MapperException("test")))
);
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, l).run());
assertThat(bulkItemResponses.hasFailures(), is(false));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(item -> item.getIndex().equals(ds2FailureStore1.getIndex().getName()))
@@ -465,12 +429,7 @@ public void testFailureStoreShardFailureRejectsDocument() throws Exception {
)
);
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, l).run());
assertThat(bulkItemResponses.hasFailures(), is(true));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(BulkItemResponse::isFailed)
@@ -500,16 +459,12 @@ public void testFailedDocumentCanNotBeConvertedFails() throws Exception {
thatFailsDocuments(Map.of(new IndexAndId(ds2BackingIndex1.getIndex().getName(), "3"), () -> new MapperException("root cause")))
);
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
// Mock a failure store document converter that always fails
FailureStoreDocumentConverter mockConverter = mock(FailureStoreDocumentConverter.class);
when(mockConverter.transformFailedRequest(any(), any(), any(), any())).thenThrow(new IOException("Could not serialize json"));
- newBulkOperation(client, bulkRequest, mockConverter, listener).run();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, mockConverter, l).run());
- BulkResponse bulkItemResponses = future.get();
assertThat(bulkItemResponses.hasFailures(), is(true));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(BulkItemResponse::isFailed)
@@ -579,13 +534,10 @@ public void testRetryableBlockAcceptsFailureStoreDocument() throws Exception {
return null;
}).when(observer).waitForNextChange(any());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.notifyOnce(
- ActionListener.wrap(future::complete, future::completeExceptionally)
+ final SubscribableListener responseListener = SubscribableListener.newForked(
+ l -> newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, l).run()
);
- newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, listener).run();
-
// The operation will attempt to write the documents in the request, receive a failure, wait for a stable cluster state, and then
// redirect the failed documents to the failure store. Wait for that failure store write to start:
if (readyToPerformFailureStoreWrite.await(30, TimeUnit.SECONDS) == false) {
@@ -595,7 +547,7 @@ public void testRetryableBlockAcceptsFailureStoreDocument() throws Exception {
}
// Check to make sure there is no response yet
- if (future.isDone()) {
+ if (responseListener.isDone()) {
// we're going to fail the test, but be a good citizen and unblock the other thread first
beginFailureStoreWrite.countDown();
fail("bulk operation completed prematurely");
@@ -605,7 +557,7 @@ public void testRetryableBlockAcceptsFailureStoreDocument() throws Exception {
beginFailureStoreWrite.countDown();
// Await final result and verify
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(responseListener);
assertThat(bulkItemResponses.hasFailures(), is(false));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(item -> item.getIndex().equals(ds2FailureStore1.getIndex().getName()))
@@ -650,12 +602,7 @@ public void testBlockedClusterRejectsFailureStoreDocument() throws Exception {
when(observer.isTimedOut()).thenReturn(false);
doThrow(new AssertionError("Should not wait on non retryable block")).when(observer).waitForNextChange(any());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, l).run());
assertThat(bulkItemResponses.hasFailures(), is(true));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(BulkItemResponse::isFailed)
@@ -715,12 +662,7 @@ public void testOperationTimeoutRejectsFailureStoreDocument() throws Exception {
return null;
}).doThrow(new AssertionError("Should not wait any longer")).when(observer).waitForNextChange(any());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, l).run());
assertThat(bulkItemResponses.hasFailures(), is(true));
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(BulkItemResponse::isFailed)
@@ -775,12 +717,10 @@ public void testNodeClosureRejectsFailureStoreDocument() {
return null;
}).doThrow(new AssertionError("Should not wait any longer")).when(observer).waitForNextChange(any());
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, listener).run();
-
- expectThrows(ExecutionException.class, NodeClosedException.class, future::get);
+ assertThat(
+ safeAwaitFailure(BulkResponse.class, l -> newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, l).run()),
+ instanceOf(NodeClosedException.class)
+ );
verify(observer, times(1)).isTimedOut();
verify(observer, times(1)).waitForNextChange(any());
@@ -832,12 +772,7 @@ public void testLazilyRollingOverFailureStore() throws Exception {
ClusterState rolledOverState = ClusterState.builder(DEFAULT_STATE).metadata(metadata).build();
ClusterStateObserver observer = mockObserver(DEFAULT_STATE, DEFAULT_STATE, rolledOverState);
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, l).run());
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(item -> item.getIndex().equals(ds3FailureStore2.getIndex().getName()))
.findFirst()
@@ -880,12 +815,7 @@ public void testFailureWhileRollingOverFailureStore() throws Exception {
ClusterState rolledOverState = ClusterState.builder(DEFAULT_STATE).metadata(metadata).build();
ClusterStateObserver observer = mockObserver(DEFAULT_STATE, DEFAULT_STATE, rolledOverState);
- CompletableFuture future = new CompletableFuture<>();
- ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally);
-
- newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, listener).run();
-
- BulkResponse bulkItemResponses = future.get();
+ BulkResponse bulkItemResponses = safeAwait(l -> newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, l).run());
BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
.filter(BulkItemResponse::isFailed)
.findFirst()
diff --git a/server/src/test/java/org/elasticsearch/ingest/ConditionalProcessorTests.java b/server/src/test/java/org/elasticsearch/ingest/ConditionalProcessorTests.java
index 3a6de10b5901d..546b252615b28 100644
--- a/server/src/test/java/org/elasticsearch/ingest/ConditionalProcessorTests.java
+++ b/server/src/test/java/org/elasticsearch/ingest/ConditionalProcessorTests.java
@@ -8,6 +8,7 @@
package org.elasticsearch.ingest;
+import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.script.IngestConditionalScript;
import org.elasticsearch.script.MockScriptEngine;
@@ -25,7 +26,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
@@ -242,14 +242,14 @@ public boolean execute(Map ctx) {
private static void assertMutatingCtxThrows(Consumer
*/
private void testManyTypeConflicts(boolean withParent, ByteSizeValue expected) throws IOException {
- try (BytesStreamOutput out = new BytesStreamOutput()) {
- indexWithManyConflicts(withParent).writeTo(out);
+ try (BytesStreamOutput out = new BytesStreamOutput(); var pso = new PlanStreamOutput(out, new PlanNameRegistry(), null)) {
+ indexWithManyConflicts(withParent).writeTo(pso);
assertThat(ByteSizeValue.ofBytes(out.bytes().length()), byteSizeEquals(expected));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java
index a5f2adbc1fc29..e5f195b053349 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java
@@ -269,7 +269,7 @@ static Nullability randomNullability() {
};
}
- static EsField randomEsField() {
+ public static EsField randomEsField() {
return randomEsField(0);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutputTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutputTests.java
index d169cdb5742af..cdb6c5384e16a 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutputTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutputTests.java
@@ -259,6 +259,42 @@ public void testWriteDifferentAttributesSameID() throws IOException {
}
}
+ public void testWriteMultipleEsFields() throws IOException {
+ Configuration configuration = randomConfiguration();
+ try (
+ BytesStreamOutput out = new BytesStreamOutput();
+ PlanStreamOutput planStream = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, configuration)
+ ) {
+ List fields = new ArrayList<>();
+ int occurrences = randomIntBetween(2, 300);
+ for (int i = 0; i < occurrences; i++) {
+ fields.add(PlanNamedTypesTests.randomEsField());
+ }
+
+ // send all the EsFields, three times
+ for (int i = 0; i < 3; i++) {
+ for (EsField attr : fields) {
+ attr.writeTo(planStream);
+ }
+ }
+
+ try (PlanStreamInput in = new PlanStreamInput(out.bytes().streamInput(), PlanNameRegistry.INSTANCE, REGISTRY, configuration)) {
+ List readFields = new ArrayList<>();
+ for (int i = 0; i < occurrences; i++) {
+ readFields.add(EsField.readFrom(in));
+ assertThat(readFields.get(i), equalTo(fields.get(i)));
+ }
+ // two more times
+ for (int i = 0; i < 2; i++) {
+ for (int j = 0; j < occurrences; j++) {
+ EsField attr = EsField.readFrom(in);
+ assertThat(attr, sameInstance(readFields.get(j)));
+ }
+ }
+ }
+ }
+ }
+
private static Attribute randomAttribute() {
return switch (randomInt(3)) {
case 0 -> PlanNamedTypesTests.randomFieldAttribute();
@@ -293,7 +329,6 @@ private Column randomColumn() {
writeables.addAll(Block.getNamedWriteables());
writeables.addAll(Attribute.getNamedWriteables());
writeables.add(UnsupportedAttribute.ENTRY);
- writeables.addAll(EsField.getNamedWriteables());
REGISTRY = new NamedWriteableRegistry(new ArrayList<>(new HashSet<>(writeables)));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AbstractLogicalPlanSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AbstractLogicalPlanSerializationTests.java
index 8562391b2e3b0..1b9df46a1c842 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AbstractLogicalPlanSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AbstractLogicalPlanSerializationTests.java
@@ -13,7 +13,6 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.Node;
-import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.plan.AbstractNodeSerializationTests;
@@ -42,7 +41,6 @@ protected final NamedWriteableRegistry getNamedWriteableRegistry() {
entries.addAll(AggregateFunction.getNamedWriteables());
entries.addAll(Expression.getNamedWriteables());
entries.addAll(Attribute.getNamedWriteables());
- entries.addAll(EsField.getNamedWriteables());
entries.addAll(Block.getNamedWriteables());
entries.addAll(NamedExpression.getNamedWriteables());
return new NamedWriteableRegistry(entries);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/AbstractPhysicalPlanSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/AbstractPhysicalPlanSerializationTests.java
index b7b321a022b87..7a0d125ad85ba 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/AbstractPhysicalPlanSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/AbstractPhysicalPlanSerializationTests.java
@@ -15,7 +15,6 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.Node;
-import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.plan.AbstractNodeSerializationTests;
@@ -46,7 +45,6 @@ protected final NamedWriteableRegistry getNamedWriteableRegistry() {
entries.addAll(AggregateFunction.getNamedWriteables());
entries.addAll(Expression.getNamedWriteables());
entries.addAll(Attribute.getNamedWriteables());
- entries.addAll(EsField.getNamedWriteables());
entries.addAll(Block.getNamedWriteables());
entries.addAll(NamedExpression.getNamedWriteables());
entries.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedWriteables());
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java
index 237f8d6a9c580..ae58c49eade17 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java
@@ -22,7 +22,6 @@
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
-import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.EsIndexSerializationTests;
@@ -63,7 +62,12 @@ public static Source randomSource() {
* See {@link #testManyTypeConflicts(boolean, ByteSizeValue)} for more.
*/
public void testManyTypeConflicts() throws IOException {
- testManyTypeConflicts(false, ByteSizeValue.ofBytes(2444252));
+ testManyTypeConflicts(false, ByteSizeValue.ofBytes(1897374));
+ /*
+ * History:
+ * 2.3mb - shorten error messages for UnsupportedAttributes #111973
+ * 1.8mb - cache EsFields #112008
+ */
}
/**
@@ -71,12 +75,13 @@ public void testManyTypeConflicts() throws IOException {
* See {@link #testManyTypeConflicts(boolean, ByteSizeValue)} for more.
*/
public void testManyTypeConflictsWithParent() throws IOException {
- testManyTypeConflicts(true, ByteSizeValue.ofBytes(5885765));
+ testManyTypeConflicts(true, ByteSizeValue.ofBytes(3271487));
/*
* History:
* 2 gb+ - start
* 43.3mb - Cache attribute subclasses #111447
* 5.6mb - shorten error messages for UnsupportedAttributes #111973
+ * 3.1mb - cache EsFields #112008
*/
}
@@ -131,7 +136,6 @@ private NamedWriteableRegistry getNamedWriteableRegistry() {
entries.addAll(AggregateFunction.getNamedWriteables());
entries.addAll(Expression.getNamedWriteables());
entries.addAll(Attribute.getNamedWriteables());
- entries.addAll(EsField.getNamedWriteables());
entries.addAll(Block.getNamedWriteables());
entries.addAll(NamedExpression.getNamedWriteables());
entries.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedWriteables());
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/AbstractEsFieldTypeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/AbstractEsFieldTypeTests.java
similarity index 57%
rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/AbstractEsFieldTypeTests.java
rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/AbstractEsFieldTypeTests.java
index a415c529894c3..9b2bf03b5c8aa 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/AbstractEsFieldTypeTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/AbstractEsFieldTypeTests.java
@@ -5,16 +5,26 @@
* 2.0.
*/
-package org.elasticsearch.xpack.esql.core.type;
+package org.elasticsearch.xpack.esql.type;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
-import org.elasticsearch.test.AbstractNamedWriteableTestCase;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.test.AbstractWireTestCase;
+import org.elasticsearch.xpack.esql.EsqlTestUtils;
+import org.elasticsearch.xpack.esql.core.type.EsField;
+import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
import java.io.IOException;
+import java.util.List;
import java.util.Map;
import java.util.TreeMap;
-public abstract class AbstractEsFieldTypeTests extends AbstractNamedWriteableTestCase {
+public abstract class AbstractEsFieldTypeTests extends AbstractWireTestCase {
public static EsField randomAnyEsField(int maxDepth) {
return switch (between(0, 5)) {
case 0 -> EsFieldTests.randomEsField(maxDepth);
@@ -32,6 +42,25 @@ public static EsField randomAnyEsField(int maxDepth) {
protected abstract T mutate(T instance);
+ @Override
+ protected EsField copyInstance(EsField instance, TransportVersion version) throws IOException {
+ NamedWriteableRegistry namedWriteableRegistry = getNamedWriteableRegistry();
+ try (
+ BytesStreamOutput output = new BytesStreamOutput();
+ var pso = new PlanStreamOutput(output, new PlanNameRegistry(), EsqlTestUtils.TEST_CFG)
+ ) {
+ pso.setTransportVersion(version);
+ instance.writeTo(pso);
+ try (
+ StreamInput in1 = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry);
+ var psi = new PlanStreamInput(in1, new PlanNameRegistry(), in1.namedWriteableRegistry(), EsqlTestUtils.TEST_CFG)
+ ) {
+ psi.setTransportVersion(version);
+ return EsField.readFrom(psi);
+ }
+ }
+ }
+
/**
* Generate sub-properties.
* @param maxDepth the maximum number of levels of properties to make
@@ -59,11 +88,6 @@ protected final T mutateInstance(EsField instance) throws IOException {
@Override
protected final NamedWriteableRegistry getNamedWriteableRegistry() {
- return new NamedWriteableRegistry(EsField.getNamedWriteables());
- }
-
- @Override
- protected final Class categoryClass() {
- return EsField.class;
+ return new NamedWriteableRegistry(List.of());
}
}
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DataTypeConversionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/DataTypeConversionTests.java
similarity index 99%
rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DataTypeConversionTests.java
rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/DataTypeConversionTests.java
index 929aa1c0eab49..9f8c8f91b7037 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DataTypeConversionTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/DataTypeConversionTests.java
@@ -4,13 +4,16 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
-package org.elasticsearch.xpack.esql.core.type;
+package org.elasticsearch.xpack.esql.type;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Location;
import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.Converter;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.core.type.DataTypeConverter;
import org.elasticsearch.xpack.versionfield.Version;
import java.math.BigDecimal;
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DateEsFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/DateEsFieldTests.java
similarity index 89%
rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DateEsFieldTests.java
rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/DateEsFieldTests.java
index dea03ee8a8cdf..bf0494d5fd043 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DateEsFieldTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/DateEsFieldTests.java
@@ -5,7 +5,10 @@
* 2.0.
*/
-package org.elasticsearch.xpack.esql.core.type;
+package org.elasticsearch.xpack.esql.type;
+
+import org.elasticsearch.xpack.esql.core.type.DateEsField;
+import org.elasticsearch.xpack.esql.core.type.EsField;
import java.util.Map;
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/EsFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsFieldTests.java
similarity index 91%
rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/EsFieldTests.java
rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsFieldTests.java
index e72ae0c5c0cda..e824b4de03e26 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/EsFieldTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsFieldTests.java
@@ -5,7 +5,10 @@
* 2.0.
*/
-package org.elasticsearch.xpack.esql.core.type;
+package org.elasticsearch.xpack.esql.type;
+
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.core.type.EsField;
import java.util.Map;
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/InvalidMappedFieldTests.java
similarity index 90%
rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedFieldTests.java
rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/InvalidMappedFieldTests.java
index 47a99329d0222..c66088b0695d4 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedFieldTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/InvalidMappedFieldTests.java
@@ -5,7 +5,10 @@
* 2.0.
*/
-package org.elasticsearch.xpack.esql.core.type;
+package org.elasticsearch.xpack.esql.type;
+
+import org.elasticsearch.xpack.esql.core.type.EsField;
+import org.elasticsearch.xpack.esql.core.type.InvalidMappedField;
import java.util.Map;
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/KeywordEsFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/KeywordEsFieldTests.java
similarity index 92%
rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/KeywordEsFieldTests.java
rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/KeywordEsFieldTests.java
index a5d3b8329b2df..ef04f0e27c096 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/KeywordEsFieldTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/KeywordEsFieldTests.java
@@ -5,9 +5,11 @@
* 2.0.
*/
-package org.elasticsearch.xpack.esql.core.type;
+package org.elasticsearch.xpack.esql.type;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.esql.core.type.EsField;
+import org.elasticsearch.xpack.esql.core.type.KeywordEsField;
import java.util.Map;
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/MultiTypeEsFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/MultiTypeEsFieldTests.java
index 618ca812005f8..d4ca40b75d2f3 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/MultiTypeEsFieldTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/MultiTypeEsFieldTests.java
@@ -9,13 +9,14 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
-import org.elasticsearch.test.AbstractNamedWriteableTestCase;
+import org.elasticsearch.test.AbstractWireTestCase;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.EsField;
+import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField;
import org.elasticsearch.xpack.esql.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToBoolean;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToCartesianPoint;
@@ -57,7 +58,7 @@
* These differences can be minimized once Expression is fully supported in the new serialization approach, and the esql and esql.core
* modules are merged, or at least the relevant classes are moved.
*/
-public class MultiTypeEsFieldTests extends AbstractNamedWriteableTestCase {
+public class MultiTypeEsFieldTests extends AbstractWireTestCase {
private Configuration config;
@@ -94,26 +95,19 @@ protected MultiTypeEsField mutateInstance(MultiTypeEsField instance) throws IOEx
protected final NamedWriteableRegistry getNamedWriteableRegistry() {
List entries = new ArrayList<>(UnaryScalarFunction.getNamedWriteables());
entries.addAll(Attribute.getNamedWriteables());
- entries.addAll(EsField.getNamedWriteables());
- entries.add(MultiTypeEsField.ENTRY);
entries.addAll(Expression.getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
- @Override
- protected final Class categoryClass() {
- return MultiTypeEsField.class;
- }
-
@Override
protected final MultiTypeEsField copyInstance(MultiTypeEsField instance, TransportVersion version) throws IOException {
return copyInstance(
instance,
getNamedWriteableRegistry(),
- (out, v) -> new PlanStreamOutput(out, new PlanNameRegistry(), config).writeNamedWriteable(v),
+ (out, v) -> v.writeTo(new PlanStreamOutput(out, new PlanNameRegistry(), config)),
in -> {
PlanStreamInput pin = new PlanStreamInput(in, new PlanNameRegistry(), in.namedWriteableRegistry(), config);
- return (MultiTypeEsField) pin.readNamedWriteable(EsField.class);
+ return EsField.readFrom(pin);
},
version
);
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TextEsFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/TextEsFieldTests.java
similarity index 90%
rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TextEsFieldTests.java
rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/TextEsFieldTests.java
index 817dd7cd27094..9af3b7376f2b2 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TextEsFieldTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/TextEsFieldTests.java
@@ -5,7 +5,10 @@
* 2.0.
*/
-package org.elasticsearch.xpack.esql.core.type;
+package org.elasticsearch.xpack.esql.type;
+
+import org.elasticsearch.xpack.esql.core.type.EsField;
+import org.elasticsearch.xpack.esql.core.type.TextEsField;
import java.util.Map;
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsFieldTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/UnsupportedEsFieldTests.java
similarity index 91%
rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsFieldTests.java
rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/UnsupportedEsFieldTests.java
index e05d8ca10425e..a89ca9481b7e1 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsFieldTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/UnsupportedEsFieldTests.java
@@ -5,7 +5,10 @@
* 2.0.
*/
-package org.elasticsearch.xpack.esql.core.type;
+package org.elasticsearch.xpack.esql.type;
+
+import org.elasticsearch.xpack.esql.core.type.EsField;
+import org.elasticsearch.xpack.esql.core.type.UnsupportedEsField;
import java.util.Map;
From 73c5c1e1c587cc7ec7ce1f0d10fea49ecfd39002 Mon Sep 17 00:00:00 2001
From: Chris Berkhout
Date: Tue, 27 Aug 2024 11:35:53 +0200
Subject: [PATCH 11/46] ByteArrayStreamInput: Return -1 when there are no more
bytes to read (#112214)
---
docs/changelog/112214.yaml | 5 +++++
.../common/io/stream/ByteArrayStreamInput.java | 6 +++++-
.../elasticsearch/common/io/stream/AbstractStreamTests.java | 1 +
3 files changed, 11 insertions(+), 1 deletion(-)
create mode 100644 docs/changelog/112214.yaml
diff --git a/docs/changelog/112214.yaml b/docs/changelog/112214.yaml
new file mode 100644
index 0000000000000..430f95a72bb3f
--- /dev/null
+++ b/docs/changelog/112214.yaml
@@ -0,0 +1,5 @@
+pr: 112214
+summary: '`ByteArrayStreamInput:` Return -1 when there are no more bytes to read'
+area: Infra/Core
+type: bug
+issues: []
diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/ByteArrayStreamInput.java b/server/src/main/java/org/elasticsearch/common/io/stream/ByteArrayStreamInput.java
index 838f2998d339f..a27eec4c12061 100644
--- a/server/src/main/java/org/elasticsearch/common/io/stream/ByteArrayStreamInput.java
+++ b/server/src/main/java/org/elasticsearch/common/io/stream/ByteArrayStreamInput.java
@@ -120,7 +120,11 @@ public void readBytes(byte[] b, int offset, int len) {
@Override
public int read(byte[] b, int off, int len) throws IOException {
- int toRead = Math.min(len, available());
+ final int available = limit - pos;
+ if (available <= 0) {
+ return -1;
+ }
+ int toRead = Math.min(len, available);
readBytes(b, off, toRead);
return toRead;
}
diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java
index 8451d2fd64b9c..b1104a72400ea 100644
--- a/server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java
+++ b/server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java
@@ -723,6 +723,7 @@ public void testReadAfterReachingEndOfStream() throws IOException {
input.readBytes(new byte[len], 0, len);
assertEquals(-1, input.read());
+ assertEquals(-1, input.read(new byte[2], 0, 2));
}
}
From fb32adcb174a7f32338b55737c8273fd962fefdd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Johannes=20Fred=C3=A9n?=
<109296772+jfreden@users.noreply.github.com>
Date: Tue, 27 Aug 2024 14:10:05 +0200
Subject: [PATCH 12/46] Add manage roles privilege (#110633)
This PR adds functionality to limit the resources and privileges an
Elasticsearch user can grant permissions to when creating a role. This
is achieved using a new
[global](https://www.elastic.co/guide/en/elasticsearch/reference/current/defining-roles.html)
(configurable/request aware) cluster privilege , named `role`, with a
sub-key called `manage/indices` which is an array where each entry is a
pair of [index
patterns](https://docs.google.com/document/d/1VN73C2KpmvvOW85-XGUqMmnMwXrfK4aoxRtG8tPqk7Y/edit#heading=h.z74zwo30t0pf)
and [index
privileges](https://www.elastic.co/guide/en/elasticsearch/reference/current/security-privileges.html#privileges-list-indices).
## Definition - Using a role with this privilege to create, update or
delete roles with privileges on indices outside of the indices matched
by the [index
pattern](https://docs.google.com/document/d/1VN73C2KpmvvOW85-XGUqMmnMwXrfK4aoxRtG8tPqk7Y/edit#heading=h.z74zwo30t0pf)
in the indices array, will fail. - Using a role with this privilege to
try to create, update or delete roles with cluster, run_as, etc.
privileges will fail. - Using a role with this privilege with
restricted indices will fail. - Other broader privileges (such as
manage_security) will nullify this privilege.
## Example Create `test-manage` role:
```
POST _security/role/test-manage
{
"global": {
"role": {
"manage": {
"indices": [
{
"names": ["allowed-index-prefix-*"],
"privileges":["read"]
}
]
}
}
}
}
```
And then a user with that role creates a role:
```
POST _security/role/a-test-role
{
"indices": [
{
"names": [
"allowed-index-prefix-some-index"
],
"privileges": [
"read"
]}]
}
```
But this would fail for:
```
POST _security/role/a-test-role
{
"indices": [
{
"names": [
"not-allowed-index-prefix-some-index"
],
"privileges": [
"read"
]}]
}
```
## Backwards compatibility and mixed cluster concerns - A new mapping
version has been added to the security index to store the new privilege.
- If the new mapping version is not applied and a role descriptor with
the new global privilege is written, the write will fail causing an
exception. - When sending role descriptors over the transport layer in a
mixed cluster, the new global privilege needs to be excluded for older
versions. This is hanled with a new transport version. - If a role
descriptor is serialized for API keys on one node in a mixed cluster and
read from another, an older node might not be able to deserialize it, so
it needs to be removed before being written in mixed cluster with old
nodes. This is handled in the API key service. - If a role descriptor
containing a global privilege is in a put role request in a mixed
cluster where it's not supported on all nodes, fail request to create
role. - RCS is not applicable here since RCS only considers cluster
privileges and index privileges (not global cluster privileges). - This
doesn't include remote privileges, since the current use case with
connectors doesn't need roles to be created on a cluster separate from
the cluster where the search data resides.
## Follow up work - Create a docs PR - Error handling for actions that
use manage roles. Should configurable cluster privileges that grant
restricted usage of actions be listed in error authorization error
messages?
---
docs/changelog/110633.yaml | 5 +
.../org/elasticsearch/TransportVersions.java | 1 +
.../xpack/core/XPackClientPlugin.java | 7 +-
.../authz/permission/ClusterPermission.java | 22 ++
.../authz/permission/IndicesPermission.java | 87 ++++-
.../core/security/authz/permission/Role.java | 2 +-
.../ConfigurableClusterPrivilege.java | 3 +-
.../ConfigurableClusterPrivileges.java | 319 +++++++++++++++-
.../authz/RoleDescriptorTestHelper.java | 35 +-
.../RoleDescriptorsIntersectionTests.java | 5 +
.../ConfigurableClusterPrivilegesTests.java | 8 +-
.../privilege/ManageRolesPrivilegesTests.java | 351 ++++++++++++++++++
.../security/ManageRolesPrivilegeIT.java | 211 +++++++++++
.../xpack/security/apikey/ApiKeyRestIT.java | 67 ++++
.../xpack/security/authc/ApiKeyService.java | 125 ++++---
.../authz/store/NativeRolesStore.java | 11 +-
.../support/SecuritySystemIndices.java | 40 ++
.../audit/logfile/LoggingAuditTrailTests.java | 10 +-
.../security/audit/logfile/audited_roles.txt | 4 +-
.../RolesBackwardsCompatibilityIT.java | 186 ++++++++--
20 files changed, 1397 insertions(+), 102 deletions(-)
create mode 100644 docs/changelog/110633.yaml
create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/ManageRolesPrivilegesTests.java
create mode 100644 x-pack/plugin/security/qa/security-basic/src/javaRestTest/java/org/elasticsearch/xpack/security/ManageRolesPrivilegeIT.java
diff --git a/docs/changelog/110633.yaml b/docs/changelog/110633.yaml
new file mode 100644
index 0000000000000..d4d1dc68cdbcc
--- /dev/null
+++ b/docs/changelog/110633.yaml
@@ -0,0 +1,5 @@
+pr: 110633
+summary: Add manage roles privilege
+area: Authorization
+type: enhancement
+issues: []
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index 33f483c57b54e..582c618216999 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -198,6 +198,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT = def(8_728_00_0);
public static final TransportVersion RANK_DOCS_RETRIEVER = def(8_729_00_0);
public static final TransportVersion ESQL_ES_FIELD_CACHED_SERIALIZATION = def(8_730_00_0);
+ public static final TransportVersion ADD_MANAGE_ROLES_PRIVILEGE = def(8_731_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
index a2c3e40c76ae4..2e806a24ad469 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
@@ -149,7 +149,7 @@ public List getNamedWriteables() {
new NamedWriteableRegistry.Entry(ClusterState.Custom.class, TokenMetadata.TYPE, TokenMetadata::new),
new NamedWriteableRegistry.Entry(NamedDiff.class, TokenMetadata.TYPE, TokenMetadata::readDiffFrom),
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.SECURITY, SecurityFeatureSetUsage::new),
- // security : conditional privileges
+ // security : configurable cluster privileges
new NamedWriteableRegistry.Entry(
ConfigurableClusterPrivilege.class,
ConfigurableClusterPrivileges.ManageApplicationPrivileges.WRITEABLE_NAME,
@@ -160,6 +160,11 @@ public List getNamedWriteables() {
ConfigurableClusterPrivileges.WriteProfileDataPrivileges.WRITEABLE_NAME,
ConfigurableClusterPrivileges.WriteProfileDataPrivileges::createFrom
),
+ new NamedWriteableRegistry.Entry(
+ ConfigurableClusterPrivilege.class,
+ ConfigurableClusterPrivileges.ManageRolesPrivilege.WRITEABLE_NAME,
+ ConfigurableClusterPrivileges.ManageRolesPrivilege::createFrom
+ ),
// security : role-mappings
new NamedWriteableRegistry.Entry(Metadata.Custom.class, RoleMappingMetadata.TYPE, RoleMappingMetadata::new),
new NamedWriteableRegistry.Entry(NamedDiff.class, RoleMappingMetadata.TYPE, RoleMappingMetadata::readDiffFrom),
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/ClusterPermission.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/ClusterPermission.java
index c70f2a05bfe93..9c41786f39eeb 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/ClusterPermission.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/ClusterPermission.java
@@ -10,6 +10,7 @@
import org.apache.lucene.util.automaton.Operations;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.xpack.core.security.authc.Authentication;
+import org.elasticsearch.xpack.core.security.authz.RestrictedIndices;
import org.elasticsearch.xpack.core.security.authz.privilege.ClusterPrivilege;
import org.elasticsearch.xpack.core.security.support.Automatons;
@@ -17,6 +18,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
+import java.util.function.Function;
import java.util.function.Predicate;
/**
@@ -84,6 +86,16 @@ public static class Builder {
private final List actionAutomatons = new ArrayList<>();
private final List permissionChecks = new ArrayList<>();
+ private final RestrictedIndices restrictedIndices;
+
+ public Builder(RestrictedIndices restrictedIndices) {
+ this.restrictedIndices = restrictedIndices;
+ }
+
+ public Builder() {
+ this.restrictedIndices = null;
+ }
+
public Builder add(
final ClusterPrivilege clusterPrivilege,
final Set allowedActionPatterns,
@@ -110,6 +122,16 @@ public Builder add(final ClusterPrivilege clusterPrivilege, final PermissionChec
return this;
}
+ public Builder addWithPredicateSupplier(
+ final ClusterPrivilege clusterPrivilege,
+ final Set allowedActionPatterns,
+ final Function> requestPredicateSupplier
+ ) {
+ final Automaton actionAutomaton = createAutomaton(allowedActionPatterns, Set.of());
+ Predicate requestPredicate = requestPredicateSupplier.apply(restrictedIndices);
+ return add(clusterPrivilege, new ActionRequestBasedPermissionCheck(clusterPrivilege, actionAutomaton, requestPredicate));
+ }
+
public ClusterPermission build() {
if (clusterPrivileges.isEmpty()) {
return NONE;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/IndicesPermission.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/IndicesPermission.java
index d29b1dd67757a..e1b72cc43b38e 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/IndicesPermission.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/IndicesPermission.java
@@ -20,6 +20,7 @@
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.Index;
import org.elasticsearch.xpack.core.security.authz.RestrictedIndices;
import org.elasticsearch.xpack.core.security.authz.accesscontrol.IndicesAccessControl;
@@ -86,6 +87,7 @@ public Builder addGroup(
public IndicesPermission build() {
return new IndicesPermission(restrictedIndices, groups.toArray(Group.EMPTY_ARRAY));
}
+
}
private IndicesPermission(RestrictedIndices restrictedIndices, Group[] groups) {
@@ -238,6 +240,21 @@ public boolean check(String action) {
return false;
}
+ public boolean checkResourcePrivileges(
+ Set checkForIndexPatterns,
+ boolean allowRestrictedIndices,
+ Set checkForPrivileges,
+ @Nullable ResourcePrivilegesMap.Builder resourcePrivilegesMapBuilder
+ ) {
+ return checkResourcePrivileges(
+ checkForIndexPatterns,
+ allowRestrictedIndices,
+ checkForPrivileges,
+ false,
+ resourcePrivilegesMapBuilder
+ );
+ }
+
/**
* For given index patterns and index privileges determines allowed privileges and creates an instance of {@link ResourcePrivilegesMap}
* holding a map of resource to {@link ResourcePrivileges} where resource is index pattern and the map of index privilege to whether it
@@ -246,6 +263,7 @@ public boolean check(String action) {
* @param checkForIndexPatterns check permission grants for the set of index patterns
* @param allowRestrictedIndices if {@code true} then checks permission grants even for restricted indices by index matching
* @param checkForPrivileges check permission grants for the set of index privileges
+ * @param combineIndexGroups combine index groups to enable checking against regular expressions
* @param resourcePrivilegesMapBuilder out-parameter for returning the details on which privilege over which resource is granted or not.
* Can be {@code null} when no such details are needed so the method can return early, after
* encountering the first privilege that is not granted over some resource.
@@ -255,10 +273,13 @@ public boolean checkResourcePrivileges(
Set checkForIndexPatterns,
boolean allowRestrictedIndices,
Set checkForPrivileges,
+ boolean combineIndexGroups,
@Nullable ResourcePrivilegesMap.Builder resourcePrivilegesMapBuilder
) {
- final Map predicateCache = new HashMap<>();
boolean allMatch = true;
+ Map indexGroupAutomatons = indexGroupAutomatons(
+ combineIndexGroups && checkForIndexPatterns.stream().anyMatch(Automatons::isLuceneRegex)
+ );
for (String forIndexPattern : checkForIndexPatterns) {
Automaton checkIndexAutomaton = Automatons.patterns(forIndexPattern);
if (false == allowRestrictedIndices && false == isConcreteRestrictedIndex(forIndexPattern)) {
@@ -266,15 +287,14 @@ public boolean checkResourcePrivileges(
}
if (false == Operations.isEmpty(checkIndexAutomaton)) {
Automaton allowedIndexPrivilegesAutomaton = null;
- for (Group group : groups) {
- final Automaton groupIndexAutomaton = predicateCache.computeIfAbsent(group, Group::getIndexMatcherAutomaton);
- if (Operations.subsetOf(checkIndexAutomaton, groupIndexAutomaton)) {
+ for (var indexAndPrivilegeAutomaton : indexGroupAutomatons.entrySet()) {
+ if (Operations.subsetOf(checkIndexAutomaton, indexAndPrivilegeAutomaton.getValue())) {
if (allowedIndexPrivilegesAutomaton != null) {
allowedIndexPrivilegesAutomaton = Automatons.unionAndMinimize(
- Arrays.asList(allowedIndexPrivilegesAutomaton, group.privilege().getAutomaton())
+ Arrays.asList(allowedIndexPrivilegesAutomaton, indexAndPrivilegeAutomaton.getKey())
);
} else {
- allowedIndexPrivilegesAutomaton = group.privilege().getAutomaton();
+ allowedIndexPrivilegesAutomaton = indexAndPrivilegeAutomaton.getKey();
}
}
}
@@ -656,6 +676,61 @@ private static boolean containsPrivilegeThatGrantsMappingUpdatesForBwc(Group gro
return group.privilege().name().stream().anyMatch(PRIVILEGE_NAME_SET_BWC_ALLOW_MAPPING_UPDATE::contains);
}
+ /**
+ * Get all automatons for the index groups in this permission and optionally combine the index groups to enable checking if a set of
+ * index patterns specified using a regular expression grants a set of index privileges.
+ *
+ *
An index group is defined as a set of index patterns and a set of privileges (excluding field permissions and DLS queries).
+ * {@link IndicesPermission} consist of a set of index groups. For non-regular expression privilege checks, an index pattern is checked
+ * against each index group, to see if it's a sub-pattern of the index pattern for the group and then if that group grants some or all
+ * of the privileges requested. For regular expressions it's not sufficient to check per group since the index patterns covered by a
+ * group can be distinct sets and a regular expression can cover several distinct sets.
+ *
+ *
For example the two index groups: {"names": ["a"], "privileges": ["read", "create"]} and {"names": ["b"],
+ * "privileges": ["read","delete"]} will not match on ["\[ab]\"], while a single index group:
+ * {"names": ["a", "b"], "privileges": ["read"]} will. This happens because the index groups are evaluated against a request index
+ * pattern without first being combined. In the example above, the two index patterns should be combined to:
+ * {"names": ["a", "b"], "privileges": ["read"]} before being checked.
+ *
+ *
+ * @param combine combine index groups to allow for checking against regular expressions
+ *
+ * @return a map of all index and privilege pattern automatons
+ */
+ private Map indexGroupAutomatons(boolean combine) {
+ // Map of privilege automaton object references (cached by IndexPrivilege::CACHE)
+ Map allAutomatons = new HashMap<>();
+ for (Group group : groups) {
+ Automaton indexAutomaton = group.getIndexMatcherAutomaton();
+ allAutomatons.compute(
+ group.privilege().getAutomaton(),
+ (key, value) -> value == null ? indexAutomaton : Automatons.unionAndMinimize(List.of(value, indexAutomaton))
+ );
+ if (combine) {
+ List> combinedAutomatons = new ArrayList<>();
+ for (var indexAndPrivilegeAutomatons : allAutomatons.entrySet()) {
+ Automaton intersectingPrivileges = Operations.intersection(
+ indexAndPrivilegeAutomatons.getKey(),
+ group.privilege().getAutomaton()
+ );
+ if (Operations.isEmpty(intersectingPrivileges) == false) {
+ Automaton indexPatternAutomaton = Automatons.unionAndMinimize(
+ List.of(indexAndPrivilegeAutomatons.getValue(), indexAutomaton)
+ );
+ combinedAutomatons.add(new Tuple<>(intersectingPrivileges, indexPatternAutomaton));
+ }
+ }
+ combinedAutomatons.forEach(
+ automatons -> allAutomatons.compute(
+ automatons.v1(),
+ (key, value) -> value == null ? automatons.v2() : Automatons.unionAndMinimize(List.of(value, automatons.v2()))
+ )
+ );
+ }
+ }
+ return allAutomatons;
+ }
+
public static class Group {
public static final Group[] EMPTY_ARRAY = new Group[0];
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/Role.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/Role.java
index 0fc04e8cc9a52..d8d56a4fbb247 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/Role.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/Role.java
@@ -233,7 +233,7 @@ private Builder(RestrictedIndices restrictedIndices, String[] names) {
}
public Builder cluster(Set privilegeNames, Iterable configurableClusterPrivileges) {
- ClusterPermission.Builder builder = ClusterPermission.builder();
+ ClusterPermission.Builder builder = new ClusterPermission.Builder(restrictedIndices);
if (privilegeNames.isEmpty() == false) {
for (String name : privilegeNames) {
builder = ClusterPrivilegeResolver.resolve(name).buildPermission(builder);
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivilege.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivilege.java
index f9722ca42f20d..edb0cb8f9e79d 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivilege.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivilege.java
@@ -41,7 +41,8 @@ public interface ConfigurableClusterPrivilege extends NamedWriteable, ToXContent
*/
enum Category {
APPLICATION(new ParseField("application")),
- PROFILE(new ParseField("profile"));
+ PROFILE(new ParseField("profile")),
+ ROLE(new ParseField("role"));
public final ParseField field;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivileges.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivileges.java
index fed8b7e0d7a1c..b93aa079a28d2 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivileges.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivileges.java
@@ -7,6 +7,9 @@
package org.elasticsearch.xpack.core.security.authz.privilege;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -17,10 +20,21 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.security.action.ActionTypes;
import org.elasticsearch.xpack.core.security.action.privilege.ApplicationPrivilegesRequest;
import org.elasticsearch.xpack.core.security.action.profile.UpdateProfileDataAction;
import org.elasticsearch.xpack.core.security.action.profile.UpdateProfileDataRequest;
+import org.elasticsearch.xpack.core.security.action.role.BulkDeleteRolesRequest;
+import org.elasticsearch.xpack.core.security.action.role.BulkPutRolesRequest;
+import org.elasticsearch.xpack.core.security.action.role.DeleteRoleAction;
+import org.elasticsearch.xpack.core.security.action.role.DeleteRoleRequest;
+import org.elasticsearch.xpack.core.security.action.role.PutRoleAction;
+import org.elasticsearch.xpack.core.security.action.role.PutRoleRequest;
+import org.elasticsearch.xpack.core.security.authz.RestrictedIndices;
+import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.authz.permission.ClusterPermission;
+import org.elasticsearch.xpack.core.security.authz.permission.FieldPermissions;
+import org.elasticsearch.xpack.core.security.authz.permission.IndicesPermission;
import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivilege.Category;
import org.elasticsearch.xpack.core.security.support.StringMatcher;
import org.elasticsearch.xpack.core.security.xcontent.XContentUtils;
@@ -30,12 +44,18 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
import java.util.function.Predicate;
+import static org.elasticsearch.xpack.core.security.authz.privilege.IndexPrivilege.DELETE_INDEX;
+
/**
* Static utility class for working with {@link ConfigurableClusterPrivilege} instances
*/
@@ -43,6 +63,7 @@ public final class ConfigurableClusterPrivileges {
public static final ConfigurableClusterPrivilege[] EMPTY_ARRAY = new ConfigurableClusterPrivilege[0];
+ private static final Logger logger = LogManager.getLogger(ConfigurableClusterPrivileges.class);
public static final Writeable.Reader READER = in1 -> in1.readNamedWriteable(
ConfigurableClusterPrivilege.class
);
@@ -61,7 +82,16 @@ public static ConfigurableClusterPrivilege[] readArray(StreamInput in) throws IO
* Utility method to write an array of {@link ConfigurableClusterPrivilege} objects to a {@link StreamOutput}
*/
public static void writeArray(StreamOutput out, ConfigurableClusterPrivilege[] privileges) throws IOException {
- out.writeArray(WRITER, privileges);
+ if (out.getTransportVersion().onOrAfter(TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE)) {
+ out.writeArray(WRITER, privileges);
+ } else {
+ out.writeArray(
+ WRITER,
+ Arrays.stream(privileges)
+ .filter(privilege -> privilege instanceof ManageRolesPrivilege == false)
+ .toArray(ConfigurableClusterPrivilege[]::new)
+ );
+ }
}
/**
@@ -97,7 +127,7 @@ public static List parse(XContentParser parser) th
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
expectedToken(parser.currentToken(), parser, XContentParser.Token.FIELD_NAME);
- expectFieldName(parser, Category.APPLICATION.field, Category.PROFILE.field);
+ expectFieldName(parser, Category.APPLICATION.field, Category.PROFILE.field, Category.ROLE.field);
if (Category.APPLICATION.field.match(parser.currentName(), parser.getDeprecationHandler())) {
expectedToken(parser.nextToken(), parser, XContentParser.Token.START_OBJECT);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -106,8 +136,7 @@ public static List parse(XContentParser parser) th
expectFieldName(parser, ManageApplicationPrivileges.Fields.MANAGE);
privileges.add(ManageApplicationPrivileges.parse(parser));
}
- } else {
- assert Category.PROFILE.field.match(parser.currentName(), parser.getDeprecationHandler());
+ } else if (Category.PROFILE.field.match(parser.currentName(), parser.getDeprecationHandler())) {
expectedToken(parser.nextToken(), parser, XContentParser.Token.START_OBJECT);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
expectedToken(parser.currentToken(), parser, XContentParser.Token.FIELD_NAME);
@@ -115,9 +144,16 @@ public static List parse(XContentParser parser) th
expectFieldName(parser, WriteProfileDataPrivileges.Fields.WRITE);
privileges.add(WriteProfileDataPrivileges.parse(parser));
}
+ } else if (Category.ROLE.field.match(parser.currentName(), parser.getDeprecationHandler())) {
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.START_OBJECT);
+ while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
+ expectedToken(parser.currentToken(), parser, XContentParser.Token.FIELD_NAME);
+
+ expectFieldName(parser, ManageRolesPrivilege.Fields.MANAGE);
+ privileges.add(ManageRolesPrivilege.parse(parser));
+ }
}
}
-
return privileges;
}
@@ -362,4 +398,277 @@ private interface Fields {
ParseField APPLICATIONS = new ParseField("applications");
}
}
+
+ public static class ManageRolesPrivilege implements ConfigurableClusterPrivilege {
+ public static final String WRITEABLE_NAME = "manage-roles-privilege";
+ private final List indexPermissionGroups;
+ private final Function> requestPredicateSupplier;
+
+ private static final Set EXPECTED_INDEX_GROUP_FIELDS = Set.of(
+ Fields.NAMES.getPreferredName(),
+ Fields.PRIVILEGES.getPreferredName()
+ );
+
+ public ManageRolesPrivilege(List manageRolesIndexPermissionGroups) {
+ this.indexPermissionGroups = manageRolesIndexPermissionGroups;
+ this.requestPredicateSupplier = (restrictedIndices) -> {
+ IndicesPermission.Builder indicesPermissionBuilder = new IndicesPermission.Builder(restrictedIndices);
+ for (ManageRolesIndexPermissionGroup indexPatternPrivilege : manageRolesIndexPermissionGroups) {
+ indicesPermissionBuilder.addGroup(
+ IndexPrivilege.get(Set.of(indexPatternPrivilege.privileges())),
+ FieldPermissions.DEFAULT,
+ null,
+ false,
+ indexPatternPrivilege.indexPatterns()
+ );
+ }
+ final IndicesPermission indicesPermission = indicesPermissionBuilder.build();
+
+ return (TransportRequest request) -> {
+ if (request instanceof final PutRoleRequest putRoleRequest) {
+ return hasNonIndexPrivileges(putRoleRequest.roleDescriptor()) == false
+ && Arrays.stream(putRoleRequest.indices())
+ .noneMatch(
+ indexPrivilege -> requestIndexPatternsAllowed(
+ indicesPermission,
+ indexPrivilege.getIndices(),
+ indexPrivilege.getPrivileges()
+ ) == false
+ );
+ } else if (request instanceof final BulkPutRolesRequest bulkPutRoleRequest) {
+ return bulkPutRoleRequest.getRoles().stream().noneMatch(ManageRolesPrivilege::hasNonIndexPrivileges)
+ && bulkPutRoleRequest.getRoles()
+ .stream()
+ .allMatch(
+ roleDescriptor -> Arrays.stream(roleDescriptor.getIndicesPrivileges())
+ .noneMatch(
+ indexPrivilege -> requestIndexPatternsAllowed(
+ indicesPermission,
+ indexPrivilege.getIndices(),
+ indexPrivilege.getPrivileges()
+ ) == false
+ )
+ );
+ } else if (request instanceof final DeleteRoleRequest deleteRoleRequest) {
+ return requestIndexPatternsAllowed(
+ indicesPermission,
+ new String[] { deleteRoleRequest.name() },
+ DELETE_INDEX.name().toArray(String[]::new)
+ );
+ } else if (request instanceof final BulkDeleteRolesRequest bulkDeleteRoleRequest) {
+ return requestIndexPatternsAllowed(
+ indicesPermission,
+ bulkDeleteRoleRequest.getRoleNames().toArray(String[]::new),
+ DELETE_INDEX.name().toArray(String[]::new)
+ );
+ }
+ throw new IllegalArgumentException("Unsupported request type [" + request.getClass() + "]");
+ };
+ };
+ }
+
+ @Override
+ public Category getCategory() {
+ return Category.ROLE;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return WRITEABLE_NAME;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeCollection(indexPermissionGroups);
+ }
+
+ public static ManageRolesPrivilege createFrom(StreamInput in) throws IOException {
+ final List indexPatternPrivileges = in.readCollectionAsList(
+ ManageRolesIndexPermissionGroup::createFrom
+ );
+ return new ManageRolesPrivilege(indexPatternPrivileges);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ return builder.field(
+ Fields.MANAGE.getPreferredName(),
+ Map.of(Fields.INDICES.getPreferredName(), indexPermissionGroups.stream().map(indexPatternPrivilege -> {
+ Map sortedMap = new TreeMap<>();
+ sortedMap.put(Fields.NAMES.getPreferredName(), indexPatternPrivilege.indexPatterns());
+ sortedMap.put(Fields.PRIVILEGES.getPreferredName(), indexPatternPrivilege.privileges());
+ return sortedMap;
+ }).toList())
+ );
+ }
+
+ private static void expectedIndexGroupFields(String fieldName, XContentParser parser) {
+ if (EXPECTED_INDEX_GROUP_FIELDS.contains(fieldName) == false) {
+ throw new XContentParseException(
+ parser.getTokenLocation(),
+ "failed to parse privilege. expected one of "
+ + Arrays.toString(EXPECTED_INDEX_GROUP_FIELDS.toArray(String[]::new))
+ + " but found ["
+ + fieldName
+ + "] instead"
+ );
+ }
+ }
+
+ public static ManageRolesPrivilege parse(XContentParser parser) throws IOException {
+ expectedToken(parser.currentToken(), parser, XContentParser.Token.FIELD_NAME);
+ expectFieldName(parser, Fields.MANAGE);
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.START_OBJECT);
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.FIELD_NAME);
+ expectFieldName(parser, Fields.INDICES);
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.START_ARRAY);
+ List indexPrivileges = new ArrayList<>();
+ Map parsedArraysByFieldName = new HashMap<>();
+
+ XContentParser.Token token;
+ while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
+ expectedToken(token, parser, XContentParser.Token.START_OBJECT);
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.FIELD_NAME);
+ String currentFieldName = parser.currentName();
+ expectedIndexGroupFields(currentFieldName, parser);
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.START_ARRAY);
+ parsedArraysByFieldName.put(currentFieldName, XContentUtils.readStringArray(parser, false));
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.FIELD_NAME);
+ currentFieldName = parser.currentName();
+ expectedIndexGroupFields(currentFieldName, parser);
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.START_ARRAY);
+ parsedArraysByFieldName.put(currentFieldName, XContentUtils.readStringArray(parser, false));
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.END_OBJECT);
+ indexPrivileges.add(
+ new ManageRolesIndexPermissionGroup(
+ parsedArraysByFieldName.get(Fields.NAMES.getPreferredName()),
+ parsedArraysByFieldName.get(Fields.PRIVILEGES.getPreferredName())
+ )
+ );
+ }
+ expectedToken(parser.nextToken(), parser, XContentParser.Token.END_OBJECT);
+
+ for (var indexPrivilege : indexPrivileges) {
+ if (indexPrivilege.indexPatterns == null || indexPrivilege.indexPatterns.length == 0) {
+ throw new IllegalArgumentException("Indices privileges must refer to at least one index name or index name pattern");
+ }
+ if (indexPrivilege.privileges == null || indexPrivilege.privileges.length == 0) {
+ throw new IllegalArgumentException("Indices privileges must define at least one privilege");
+ }
+ }
+ return new ManageRolesPrivilege(indexPrivileges);
+ }
+
+ public record ManageRolesIndexPermissionGroup(String[] indexPatterns, String[] privileges) implements Writeable {
+ public static ManageRolesIndexPermissionGroup createFrom(StreamInput in) throws IOException {
+ return new ManageRolesIndexPermissionGroup(in.readStringArray(), in.readStringArray());
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeStringArray(indexPatterns);
+ out.writeStringArray(privileges);
+ }
+
+ @Override
+ public String toString() {
+ return "{"
+ + Fields.NAMES
+ + ":"
+ + Arrays.toString(indexPatterns())
+ + ":"
+ + Fields.PRIVILEGES
+ + ":"
+ + Arrays.toString(privileges())
+ + "}";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ ManageRolesIndexPermissionGroup that = (ManageRolesIndexPermissionGroup) o;
+ return Arrays.equals(indexPatterns, that.indexPatterns) && Arrays.equals(privileges, that.privileges);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(Arrays.hashCode(indexPatterns), Arrays.hashCode(privileges));
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "{"
+ + getCategory()
+ + ":"
+ + Fields.MANAGE.getPreferredName()
+ + ":"
+ + Fields.INDICES.getPreferredName()
+ + "=["
+ + Strings.collectionToDelimitedString(indexPermissionGroups, ",")
+ + "]}";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final ManageRolesPrivilege that = (ManageRolesPrivilege) o;
+
+ if (this.indexPermissionGroups.size() != that.indexPermissionGroups.size()) {
+ return false;
+ }
+
+ for (int i = 0; i < this.indexPermissionGroups.size(); i++) {
+ if (Objects.equals(this.indexPermissionGroups.get(i), that.indexPermissionGroups.get(i)) == false) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(indexPermissionGroups.hashCode());
+ }
+
+ @Override
+ public ClusterPermission.Builder buildPermission(final ClusterPermission.Builder builder) {
+ return builder.addWithPredicateSupplier(
+ this,
+ Set.of(PutRoleAction.NAME, ActionTypes.BULK_PUT_ROLES.name(), ActionTypes.BULK_DELETE_ROLES.name(), DeleteRoleAction.NAME),
+ requestPredicateSupplier
+ );
+ }
+
+ private static boolean requestIndexPatternsAllowed(
+ IndicesPermission indicesPermission,
+ String[] requestIndexPatterns,
+ String[] privileges
+ ) {
+ return indicesPermission.checkResourcePrivileges(Set.of(requestIndexPatterns), false, Set.of(privileges), true, null);
+ }
+
+ private static boolean hasNonIndexPrivileges(RoleDescriptor roleDescriptor) {
+ return roleDescriptor.hasApplicationPrivileges()
+ || roleDescriptor.hasClusterPrivileges()
+ || roleDescriptor.hasConfigurableClusterPrivileges()
+ || roleDescriptor.hasRemoteIndicesPrivileges()
+ || roleDescriptor.hasRemoteClusterPermissions()
+ || roleDescriptor.hasRunAs()
+ || roleDescriptor.hasWorkflowsRestriction();
+ }
+
+ private interface Fields {
+ ParseField MANAGE = new ParseField("manage");
+ ParseField INDICES = new ParseField("indices");
+ ParseField PRIVILEGES = new ParseField("privileges");
+ ParseField NAMES = new ParseField("names");
+ }
+ }
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorTestHelper.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorTestHelper.java
index 2d8b62335f4ef..77a37cec45b25 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorTestHelper.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorTestHelper.java
@@ -26,6 +26,7 @@
import static org.elasticsearch.test.ESTestCase.generateRandomStringArray;
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLengthBetween;
+import static org.elasticsearch.test.ESTestCase.randomArray;
import static org.elasticsearch.test.ESTestCase.randomBoolean;
import static org.elasticsearch.test.ESTestCase.randomInt;
import static org.elasticsearch.test.ESTestCase.randomIntBetween;
@@ -52,6 +53,7 @@ public static RoleDescriptor randomRoleDescriptor() {
.allowRestriction(randomBoolean())
.allowDescription(randomBoolean())
.allowRemoteClusters(randomBoolean())
+ .allowConfigurableClusterPrivileges(randomBoolean())
.build();
}
@@ -69,7 +71,7 @@ public static Map randomRoleDescriptorMetadata(boolean allowRese
}
public static ConfigurableClusterPrivilege[] randomClusterPrivileges() {
- final ConfigurableClusterPrivilege[] configurableClusterPrivileges = switch (randomIntBetween(0, 4)) {
+ return switch (randomIntBetween(0, 5)) {
case 0 -> new ConfigurableClusterPrivilege[0];
case 1 -> new ConfigurableClusterPrivilege[] {
new ConfigurableClusterPrivileges.ManageApplicationPrivileges(
@@ -93,9 +95,9 @@ public static ConfigurableClusterPrivilege[] randomClusterPrivileges() {
new ConfigurableClusterPrivileges.WriteProfileDataPrivileges(
Sets.newHashSet(generateRandomStringArray(3, randomIntBetween(4, 12), false, false))
) };
+ case 5 -> randomManageRolesPrivileges();
default -> throw new IllegalStateException("Unexpected value");
};
- return configurableClusterPrivileges;
}
public static RoleDescriptor.ApplicationResourcePrivileges[] randomApplicationPrivileges() {
@@ -119,6 +121,27 @@ public static RoleDescriptor.ApplicationResourcePrivileges[] randomApplicationPr
return applicationPrivileges;
}
+ public static ConfigurableClusterPrivilege[] randomManageRolesPrivileges() {
+ List indexPatternPrivileges = randomList(
+ 1,
+ 10,
+ () -> {
+ String[] indexPatterns = randomArray(1, 5, String[]::new, () -> randomAlphaOfLengthBetween(5, 100));
+
+ int startIndex = randomIntBetween(0, IndexPrivilege.names().size() - 2);
+ int endIndex = randomIntBetween(startIndex + 1, IndexPrivilege.names().size());
+
+ String[] indexPrivileges = IndexPrivilege.names().stream().toList().subList(startIndex, endIndex).toArray(String[]::new);
+ return new ConfigurableClusterPrivileges.ManageRolesPrivilege.ManageRolesIndexPermissionGroup(
+ indexPatterns,
+ indexPrivileges
+ );
+ }
+ );
+
+ return new ConfigurableClusterPrivilege[] { new ConfigurableClusterPrivileges.ManageRolesPrivilege(indexPatternPrivileges) };
+ }
+
public static RoleDescriptor.RemoteIndicesPrivileges[] randomRemoteIndicesPrivileges(int min, int max) {
return randomRemoteIndicesPrivileges(min, max, Set.of());
}
@@ -251,6 +274,7 @@ public static class Builder {
private boolean allowRestriction = false;
private boolean allowDescription = false;
private boolean allowRemoteClusters = false;
+ private boolean allowConfigurableClusterPrivileges = false;
public Builder() {}
@@ -259,6 +283,11 @@ public Builder allowReservedMetadata(boolean allowReservedMetadata) {
return this;
}
+ public Builder allowConfigurableClusterPrivileges(boolean allowConfigurableClusterPrivileges) {
+ this.allowConfigurableClusterPrivileges = allowConfigurableClusterPrivileges;
+ return this;
+ }
+
public Builder alwaysIncludeRemoteIndices() {
this.alwaysIncludeRemoteIndices = true;
return this;
@@ -302,7 +331,7 @@ public RoleDescriptor build() {
randomSubsetOf(ClusterPrivilegeResolver.names()).toArray(String[]::new),
randomIndicesPrivileges(0, 3),
randomApplicationPrivileges(),
- randomClusterPrivileges(),
+ allowConfigurableClusterPrivileges ? randomClusterPrivileges() : null,
generateRandomStringArray(5, randomIntBetween(2, 8), false, true),
randomRoleDescriptorMetadata(allowReservedMetadata),
Map.of(),
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorsIntersectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorsIntersectionTests.java
index a892e8b864e6e..b67292e76961f 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorsIntersectionTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorsIntersectionTests.java
@@ -48,6 +48,11 @@ public void testSerialization() throws IOException {
ConfigurableClusterPrivilege.class,
ConfigurableClusterPrivileges.WriteProfileDataPrivileges.WRITEABLE_NAME,
ConfigurableClusterPrivileges.WriteProfileDataPrivileges::createFrom
+ ),
+ new NamedWriteableRegistry.Entry(
+ ConfigurableClusterPrivilege.class,
+ ConfigurableClusterPrivileges.ManageRolesPrivilege.WRITEABLE_NAME,
+ ConfigurableClusterPrivileges.ManageRolesPrivilege::createFrom
)
)
);
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivilegesTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivilegesTests.java
index c6fac77ea26e6..5599b33fbcfe7 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivilegesTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivilegesTests.java
@@ -61,13 +61,15 @@ public void testGenerateAndParseXContent() throws Exception {
}
private ConfigurableClusterPrivilege[] buildSecurityPrivileges() {
- return switch (randomIntBetween(0, 3)) {
+ return switch (randomIntBetween(0, 4)) {
case 0 -> new ConfigurableClusterPrivilege[0];
case 1 -> new ConfigurableClusterPrivilege[] { ManageApplicationPrivilegesTests.buildPrivileges() };
case 2 -> new ConfigurableClusterPrivilege[] { WriteProfileDataPrivilegesTests.buildPrivileges() };
- case 3 -> new ConfigurableClusterPrivilege[] {
+ case 3 -> new ConfigurableClusterPrivilege[] { ManageRolesPrivilegesTests.buildPrivileges() };
+ case 4 -> new ConfigurableClusterPrivilege[] {
ManageApplicationPrivilegesTests.buildPrivileges(),
- WriteProfileDataPrivilegesTests.buildPrivileges() };
+ WriteProfileDataPrivilegesTests.buildPrivileges(),
+ ManageRolesPrivilegesTests.buildPrivileges() };
default -> throw new IllegalStateException("Unexpected value");
};
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/ManageRolesPrivilegesTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/ManageRolesPrivilegesTests.java
new file mode 100644
index 0000000000000..2d47752063d9d
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/ManageRolesPrivilegesTests.java
@@ -0,0 +1,351 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.security.authz.privilege;
+
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.test.AbstractNamedWriteableTestCase;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.XPackClientPlugin;
+import org.elasticsearch.xpack.core.security.action.role.BulkDeleteRolesRequest;
+import org.elasticsearch.xpack.core.security.action.role.BulkPutRolesRequest;
+import org.elasticsearch.xpack.core.security.action.role.DeleteRoleRequest;
+import org.elasticsearch.xpack.core.security.action.role.PutRoleRequest;
+import org.elasticsearch.xpack.core.security.authc.Authentication;
+import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper;
+import org.elasticsearch.xpack.core.security.authz.RestrictedIndices;
+import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
+import org.elasticsearch.xpack.core.security.authz.permission.ClusterPermission;
+import org.elasticsearch.xpack.core.security.authz.permission.RemoteClusterPermissionGroup;
+import org.elasticsearch.xpack.core.security.authz.permission.RemoteClusterPermissions;
+import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivileges.ManageRolesPrivilege;
+import org.elasticsearch.xpack.core.security.authz.store.ReservedRolesStore;
+import org.elasticsearch.xpack.core.security.test.TestRestrictedIndices;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.hamcrest.Matchers.nullValue;
+import static org.hamcrest.core.Is.is;
+import static org.hamcrest.core.IsEqual.equalTo;
+
+public class ManageRolesPrivilegesTests extends AbstractNamedWriteableTestCase {
+
+ private static final int MIN_INDEX_NAME_LENGTH = 4;
+
+ public void testSimplePutRoleRequest() {
+ new ReservedRolesStore();
+ final ManageRolesPrivilege privilege = new ManageRolesPrivilege(
+ List.of(new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "allowed*" }, new String[] { "all" }))
+ );
+ final ClusterPermission permission = privilege.buildPermission(
+ new ClusterPermission.Builder(new RestrictedIndices(TestRestrictedIndices.RESTRICTED_INDICES.getAutomaton()))
+ ).build();
+
+ assertAllowedIndexPatterns(permission, randomArray(1, 10, String[]::new, () -> "allowed-" + randomAlphaOfLength(5)), true);
+ assertAllowedIndexPatterns(permission, randomArray(1, 10, String[]::new, () -> "not-allowed-" + randomAlphaOfLength(5)), false);
+ assertAllowedIndexPatterns(
+ permission,
+ new String[] { "allowed-" + randomAlphaOfLength(5), "not-allowed-" + randomAlphaOfLength(5) },
+ false
+ );
+ }
+
+ public void testDeleteRoleRequest() {
+ new ReservedRolesStore();
+ {
+ final ManageRolesPrivilege privilege = new ManageRolesPrivilege(
+ List.of(new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "allowed*" }, new String[] { "manage" }))
+ );
+ final ClusterPermission permission = privilege.buildPermission(
+ new ClusterPermission.Builder(new RestrictedIndices(TestRestrictedIndices.RESTRICTED_INDICES.getAutomaton()))
+ ).build();
+
+ assertAllowedDeleteIndex(permission, randomArray(1, 10, String[]::new, () -> "allowed-" + randomAlphaOfLength(5)), true);
+ assertAllowedDeleteIndex(permission, randomArray(1, 10, String[]::new, () -> "not-allowed-" + randomAlphaOfLength(5)), false);
+ assertAllowedDeleteIndex(
+ permission,
+ new String[] { "allowed-" + randomAlphaOfLength(5), "not-allowed-" + randomAlphaOfLength(5) },
+ false
+ );
+ }
+ {
+ final ManageRolesPrivilege privilege = new ManageRolesPrivilege(
+ List.of(new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "allowed*" }, new String[] { "read" }))
+ );
+ final ClusterPermission permission = privilege.buildPermission(
+ new ClusterPermission.Builder(new RestrictedIndices(TestRestrictedIndices.RESTRICTED_INDICES.getAutomaton()))
+ ).build();
+ assertAllowedDeleteIndex(permission, randomArray(1, 10, String[]::new, () -> "allowed-" + randomAlphaOfLength(5)), false);
+ }
+ }
+
+ public void testSeveralIndexGroupsPutRoleRequest() {
+ new ReservedRolesStore();
+
+ final ManageRolesPrivilege privilege = new ManageRolesPrivilege(
+ List.of(
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "a", "b" }, new String[] { "read" }),
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "c" }, new String[] { "read" }),
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "d" }, new String[] { "read" })
+ )
+ );
+
+ final ClusterPermission permission = privilege.buildPermission(
+ new ClusterPermission.Builder(new RestrictedIndices(TestRestrictedIndices.RESTRICTED_INDICES.getAutomaton()))
+ ).build();
+
+ assertAllowedIndexPatterns(permission, new String[] { "/[ab]/" }, new String[] { "read" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[cd]/" }, new String[] { "read" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[acd]/" }, new String[] { "read" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[ef]/" }, new String[] { "read" }, false);
+ }
+
+ public void testPrivilegeIntersectionPutRoleRequest() {
+ new ReservedRolesStore();
+
+ final ManageRolesPrivilege privilege = new ManageRolesPrivilege(
+ List.of(
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "a", "b" }, new String[] { "all" }),
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "c" }, new String[] { "create" }),
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "d" }, new String[] { "delete" }),
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "e" }, new String[] { "create_doc" }),
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "f" }, new String[] { "read", "manage" })
+ )
+ );
+
+ final ClusterPermission permission = privilege.buildPermission(
+ new ClusterPermission.Builder(new RestrictedIndices(TestRestrictedIndices.RESTRICTED_INDICES.getAutomaton()))
+ ).build();
+
+ assertAllowedIndexPatterns(permission, new String[] { "/[ab]/" }, new String[] { "all" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[abc]/" }, new String[] { "all" }, false);
+ assertAllowedIndexPatterns(permission, new String[] { "/[ab]/" }, new String[] { "read", "manage" }, true);
+
+ assertAllowedIndexPatterns(permission, new String[] { "/[ac]/" }, new String[] { "create" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[ac]/" }, new String[] { "create", "create_doc" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[ce]/" }, new String[] { "create_doc" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[abce]/" }, new String[] { "create_doc" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[abcde]/" }, new String[] { "create_doc" }, false);
+ assertAllowedIndexPatterns(permission, new String[] { "/[ce]/" }, new String[] { "create_doc" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[eb]/" }, new String[] { "create_doc" }, true);
+
+ assertAllowedIndexPatterns(permission, new String[] { "/[d]/" }, new String[] { "delete" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[ad]/" }, new String[] { "delete" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[de]/" }, new String[] { "delete" }, false);
+
+ assertAllowedIndexPatterns(permission, new String[] { "/[f]/" }, new String[] { "read", "manage" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { "/[f]/" }, new String[] { "read", "write" }, false);
+ assertAllowedIndexPatterns(permission, new String[] { "/[f]/" }, new String[] { "read", "manage" }, true);
+ }
+
+ public void testEmptyPrivileges() {
+ new ReservedRolesStore();
+
+ final ManageRolesPrivilege privilege = new ManageRolesPrivilege(List.of());
+
+ final ClusterPermission permission = privilege.buildPermission(
+ new ClusterPermission.Builder(new RestrictedIndices(TestRestrictedIndices.RESTRICTED_INDICES.getAutomaton()))
+ ).build();
+
+ assertAllowedIndexPatterns(permission, new String[] { "test" }, new String[] { "all" }, false);
+ }
+
+ public void testRestrictedIndexPutRoleRequest() {
+ new ReservedRolesStore();
+
+ final ManageRolesPrivilege privilege = new ManageRolesPrivilege(
+ List.of(new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "*" }, new String[] { "all" }))
+ );
+ final ClusterPermission permission = privilege.buildPermission(
+ new ClusterPermission.Builder(new RestrictedIndices(TestRestrictedIndices.RESTRICTED_INDICES.getAutomaton()))
+ ).build();
+
+ assertAllowedIndexPatterns(permission, new String[] { "security" }, true);
+ assertAllowedIndexPatterns(permission, new String[] { ".security" }, false);
+ assertAllowedIndexPatterns(permission, new String[] { "security", ".security-7" }, false);
+ }
+
+ public void testGenerateAndParseXContent() throws Exception {
+ final XContent xContent = randomFrom(XContentType.values()).xContent();
+ try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
+ final XContentBuilder builder = new XContentBuilder(xContent, out);
+
+ final ManageRolesPrivilege original = buildPrivileges();
+ builder.startObject();
+ original.toXContent(builder, ToXContent.EMPTY_PARAMS);
+ builder.endObject();
+ builder.flush();
+
+ final byte[] bytes = out.toByteArray();
+ try (XContentParser parser = xContent.createParser(XContentParserConfiguration.EMPTY, bytes)) {
+ assertThat(parser.nextToken(), equalTo(XContentParser.Token.START_OBJECT));
+ assertThat(parser.nextToken(), equalTo(XContentParser.Token.FIELD_NAME));
+ final ManageRolesPrivilege clone = ManageRolesPrivilege.parse(parser);
+ assertThat(parser.nextToken(), equalTo(XContentParser.Token.END_OBJECT));
+
+ assertThat(clone, equalTo(original));
+ assertThat(original, equalTo(clone));
+ }
+ }
+ }
+
+ public void testPutRoleRequestContainsNonIndexPrivileges() {
+ new ReservedRolesStore();
+ final ManageRolesPrivilege privilege = new ManageRolesPrivilege(
+ List.of(new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "allowed*" }, new String[] { "all" }))
+ );
+ final ClusterPermission permission = privilege.buildPermission(
+ new ClusterPermission.Builder(new RestrictedIndices(TestRestrictedIndices.RESTRICTED_INDICES.getAutomaton()))
+ ).build();
+
+ final PutRoleRequest putRoleRequest = new PutRoleRequest();
+
+ switch (randomIntBetween(0, 5)) {
+ case 0:
+ putRoleRequest.cluster("all");
+ break;
+ case 1:
+ putRoleRequest.runAs("test");
+ break;
+ case 2:
+ putRoleRequest.addApplicationPrivileges(
+ RoleDescriptor.ApplicationResourcePrivileges.builder()
+ .privileges("all")
+ .application("test-app")
+ .resources("test-resource")
+ .build()
+ );
+ break;
+ case 3:
+ putRoleRequest.addRemoteIndex(
+ new RoleDescriptor.RemoteIndicesPrivileges.Builder("test-cluster").privileges("all").indices("test*").build()
+ );
+ break;
+ case 4:
+ putRoleRequest.putRemoteCluster(
+ new RemoteClusterPermissions().addGroup(
+ new RemoteClusterPermissionGroup(new String[] { "monitor_enrich" }, new String[] { "test" })
+ )
+ );
+ break;
+ case 5:
+ putRoleRequest.conditionalCluster(
+ new ConfigurableClusterPrivileges.ManageRolesPrivilege(
+ List.of(
+ new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(new String[] { "test-*" }, new String[] { "read" })
+ )
+ )
+ );
+ break;
+ }
+
+ putRoleRequest.name(randomAlphaOfLength(4));
+ assertThat(permissionCheck(permission, "cluster:admin/xpack/security/role/put", putRoleRequest), is(false));
+ }
+
+ private static boolean permissionCheck(ClusterPermission permission, String action, ActionRequest request) {
+ final Authentication authentication = AuthenticationTestHelper.builder().build();
+ assertThat(request.validate(), nullValue());
+ return permission.check(action, request, authentication);
+ }
+
+ private static void assertAllowedIndexPatterns(ClusterPermission permission, String[] indexPatterns, boolean expected) {
+ assertAllowedIndexPatterns(permission, indexPatterns, new String[] { "index", "write", "indices:data/read" }, expected);
+ }
+
+ private static void assertAllowedIndexPatterns(
+ ClusterPermission permission,
+ String[] indexPatterns,
+ String[] privileges,
+ boolean expected
+ ) {
+ {
+ final PutRoleRequest putRoleRequest = new PutRoleRequest();
+ putRoleRequest.name(randomAlphaOfLength(3));
+ putRoleRequest.addIndex(indexPatterns, privileges, null, null, null, false);
+ assertThat(permissionCheck(permission, "cluster:admin/xpack/security/role/put", putRoleRequest), is(expected));
+ }
+ {
+ final BulkPutRolesRequest bulkPutRolesRequest = new BulkPutRolesRequest(
+ List.of(
+ new RoleDescriptor(
+ randomAlphaOfLength(3),
+ new String[] {},
+ new RoleDescriptor.IndicesPrivileges[] {
+ RoleDescriptor.IndicesPrivileges.builder().indices(indexPatterns).privileges(privileges).build() },
+ new String[] {}
+ )
+ )
+ );
+ assertThat(permissionCheck(permission, "cluster:admin/xpack/security/role/bulk_put", bulkPutRolesRequest), is(expected));
+ }
+ }
+
+ private static void assertAllowedDeleteIndex(ClusterPermission permission, String[] indices, boolean expected) {
+ {
+ final BulkDeleteRolesRequest bulkDeleteRolesRequest = new BulkDeleteRolesRequest(List.of(indices));
+ assertThat(permissionCheck(permission, "cluster:admin/xpack/security/role/bulk_delete", bulkDeleteRolesRequest), is(expected));
+ }
+ {
+ assertThat(Arrays.stream(indices).allMatch(pattern -> {
+ final DeleteRoleRequest deleteRolesRequest = new DeleteRoleRequest();
+ deleteRolesRequest.name(pattern);
+ return permissionCheck(permission, "cluster:admin/xpack/security/role/delete", deleteRolesRequest);
+ }), is(expected));
+ }
+ }
+
+ public static ManageRolesPrivilege buildPrivileges() {
+ return buildPrivileges(randomIntBetween(MIN_INDEX_NAME_LENGTH, 7));
+ }
+
+ private static ManageRolesPrivilege buildPrivileges(int indexNameLength) {
+ String[] indexNames = Objects.requireNonNull(generateRandomStringArray(5, indexNameLength, false, false));
+
+ return new ManageRolesPrivilege(
+ List.of(new ManageRolesPrivilege.ManageRolesIndexPermissionGroup(indexNames, IndexPrivilege.READ.name().toArray(String[]::new)))
+ );
+ }
+
+ @Override
+ protected NamedWriteableRegistry getNamedWriteableRegistry() {
+ try (var xClientPlugin = new XPackClientPlugin()) {
+ return new NamedWriteableRegistry(xClientPlugin.getNamedWriteables());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ protected Class categoryClass() {
+ return ConfigurableClusterPrivilege.class;
+ }
+
+ @Override
+ protected ConfigurableClusterPrivilege createTestInstance() {
+ return buildPrivileges();
+ }
+
+ @Override
+ protected ConfigurableClusterPrivilege mutateInstance(ConfigurableClusterPrivilege instance) throws IOException {
+ if (instance instanceof ManageRolesPrivilege) {
+ return buildPrivileges(MIN_INDEX_NAME_LENGTH - 1);
+ }
+ fail();
+ return null;
+ }
+}
diff --git a/x-pack/plugin/security/qa/security-basic/src/javaRestTest/java/org/elasticsearch/xpack/security/ManageRolesPrivilegeIT.java b/x-pack/plugin/security/qa/security-basic/src/javaRestTest/java/org/elasticsearch/xpack/security/ManageRolesPrivilegeIT.java
new file mode 100644
index 0000000000000..728f068adcae4
--- /dev/null
+++ b/x-pack/plugin/security/qa/security-basic/src/javaRestTest/java/org/elasticsearch/xpack/security/ManageRolesPrivilegeIT.java
@@ -0,0 +1,211 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.security;
+
+import org.elasticsearch.client.RequestOptions;
+import org.elasticsearch.client.ResponseException;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.test.TestSecurityClient;
+import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
+import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivilege;
+import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivileges;
+import org.elasticsearch.xpack.core.security.user.User;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.hamcrest.core.StringContains.containsString;
+
+public class ManageRolesPrivilegeIT extends SecurityInBasicRestTestCase {
+
+ private TestSecurityClient adminSecurityClient;
+ private static final SecureString TEST_PASSWORD = new SecureString("100%-secure-password".toCharArray());
+
+ @Before
+ public void setupClient() {
+ adminSecurityClient = new TestSecurityClient(adminClient());
+ }
+
+ public void testManageRoles() throws Exception {
+ createManageRolesRole("manage-roles-role", new String[0], Set.of("*-allowed-suffix"), Set.of("read", "write"));
+ createUser("test-user", Set.of("manage-roles-role"));
+
+ String authHeader = basicAuthHeaderValue("test-user", TEST_PASSWORD);
+
+ createRole(
+ authHeader,
+ new RoleDescriptor(
+ "manage-roles-role",
+ new String[0],
+ new RoleDescriptor.IndicesPrivileges[] {
+ RoleDescriptor.IndicesPrivileges.builder().indices("test-allowed-suffix").privileges(Set.of("read", "write")).build() },
+ new RoleDescriptor.ApplicationResourcePrivileges[0],
+ new ConfigurableClusterPrivilege[0],
+ new String[0],
+ Map.of(),
+ Map.of()
+ )
+ );
+
+ {
+ ResponseException responseException = assertThrows(
+ ResponseException.class,
+ () -> createRole(
+ authHeader,
+ new RoleDescriptor(
+ "manage-roles-role",
+ new String[0],
+ new RoleDescriptor.IndicesPrivileges[] {
+ RoleDescriptor.IndicesPrivileges.builder().indices("test-suffix-not-allowed").privileges("write").build() },
+ new RoleDescriptor.ApplicationResourcePrivileges[0],
+ new ConfigurableClusterPrivilege[0],
+ new String[0],
+ Map.of(),
+ Map.of()
+ )
+ )
+ );
+
+ assertThat(
+ responseException.getMessage(),
+ containsString("this action is granted by the cluster privileges [manage_security,all]")
+ );
+ }
+
+ {
+ ResponseException responseException = assertThrows(
+ ResponseException.class,
+ () -> createRole(
+ authHeader,
+ new RoleDescriptor(
+ "manage-roles-role",
+ new String[0],
+ new RoleDescriptor.IndicesPrivileges[] {
+ RoleDescriptor.IndicesPrivileges.builder().indices("test-allowed-suffix").privileges("manage").build() },
+ new RoleDescriptor.ApplicationResourcePrivileges[0],
+ new ConfigurableClusterPrivilege[0],
+ new String[0],
+ Map.of(),
+ Map.of()
+ )
+ )
+ );
+ assertThat(
+ responseException.getMessage(),
+ containsString("this action is granted by the cluster privileges [manage_security,all]")
+ );
+ }
+ }
+
+ public void testManageSecurityNullifiesManageRoles() throws Exception {
+ createManageRolesRole("manage-roles-no-manage-security", new String[0], Set.of("allowed"));
+ createManageRolesRole("manage-roles-manage-security", new String[] { "manage_security" }, Set.of("allowed"));
+
+ createUser("test-user-no-manage-security", Set.of("manage-roles-no-manage-security"));
+ createUser("test-user-manage-security", Set.of("manage-roles-manage-security"));
+
+ String authHeaderNoManageSecurity = basicAuthHeaderValue("test-user-no-manage-security", TEST_PASSWORD);
+ String authHeaderManageSecurity = basicAuthHeaderValue("test-user-manage-security", TEST_PASSWORD);
+
+ createRole(
+ authHeaderNoManageSecurity,
+ new RoleDescriptor(
+ "test-role-allowed-by-manage-roles",
+ new String[0],
+ new RoleDescriptor.IndicesPrivileges[] {
+ RoleDescriptor.IndicesPrivileges.builder().indices("allowed").privileges("read").build() },
+ new RoleDescriptor.ApplicationResourcePrivileges[0],
+ new ConfigurableClusterPrivilege[0],
+ new String[0],
+ Map.of(),
+ Map.of()
+ )
+ );
+
+ ResponseException responseException = assertThrows(
+ ResponseException.class,
+ () -> createRole(
+ authHeaderNoManageSecurity,
+ new RoleDescriptor(
+ "test-role-not-allowed-by-manage-roles",
+ new String[0],
+ new RoleDescriptor.IndicesPrivileges[] {
+ RoleDescriptor.IndicesPrivileges.builder().indices("not-allowed").privileges("read").build() },
+ new RoleDescriptor.ApplicationResourcePrivileges[0],
+ new ConfigurableClusterPrivilege[0],
+ new String[0],
+ Map.of(),
+ Map.of()
+ )
+ )
+ );
+
+ assertThat(
+ responseException.getMessage(),
+ // TODO Should the new global role/manage privilege be listed here? Probably not because it's not documented
+ containsString("this action is granted by the cluster privileges [manage_security,all]")
+ );
+
+ createRole(
+ authHeaderManageSecurity,
+ new RoleDescriptor(
+ "test-role-not-allowed-by-manage-roles",
+ new String[0],
+ new RoleDescriptor.IndicesPrivileges[] {
+ RoleDescriptor.IndicesPrivileges.builder().indices("not-allowed").privileges("read").build() },
+ new RoleDescriptor.ApplicationResourcePrivileges[0],
+ new ConfigurableClusterPrivilege[0],
+ new String[0],
+ Map.of(),
+ Map.of()
+ )
+ );
+ }
+
+ private void createRole(String authHeader, RoleDescriptor descriptor) throws IOException {
+ TestSecurityClient userAuthSecurityClient = new TestSecurityClient(
+ adminClient(),
+ RequestOptions.DEFAULT.toBuilder().addHeader("Authorization", authHeader).build()
+ );
+ userAuthSecurityClient.putRole(descriptor);
+ }
+
+ private void createUser(String username, Set roles) throws IOException {
+ adminSecurityClient.putUser(new User(username, roles.toArray(String[]::new)), TEST_PASSWORD);
+ }
+
+ private void createManageRolesRole(String roleName, String[] clusterPrivileges, Set indexPatterns) throws IOException {
+ createManageRolesRole(roleName, clusterPrivileges, indexPatterns, Set.of("read"));
+ }
+
+ private void createManageRolesRole(String roleName, String[] clusterPrivileges, Set indexPatterns, Set privileges)
+ throws IOException {
+ adminSecurityClient.putRole(
+ new RoleDescriptor(
+ roleName,
+ clusterPrivileges,
+ new RoleDescriptor.IndicesPrivileges[0],
+ new RoleDescriptor.ApplicationResourcePrivileges[0],
+ new ConfigurableClusterPrivilege[] {
+ new ConfigurableClusterPrivileges.ManageRolesPrivilege(
+ List.of(
+ new ConfigurableClusterPrivileges.ManageRolesPrivilege.ManageRolesIndexPermissionGroup(
+ indexPatterns.toArray(String[]::new),
+ privileges.toArray(String[]::new)
+ )
+ )
+ ) },
+ new String[0],
+ Map.of(),
+ Map.of()
+ )
+ );
+ }
+}
diff --git a/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/apikey/ApiKeyRestIT.java b/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/apikey/ApiKeyRestIT.java
index 5ae84517202d4..667140b849951 100644
--- a/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/apikey/ApiKeyRestIT.java
+++ b/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/apikey/ApiKeyRestIT.java
@@ -31,6 +31,8 @@
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.authz.permission.RemoteClusterPermissionGroup;
import org.elasticsearch.xpack.core.security.authz.permission.RemoteClusterPermissions;
+import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivilege;
+import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivileges;
import org.elasticsearch.xpack.security.SecurityOnTrialLicenseRestTestCase;
import org.junit.After;
import org.junit.Before;
@@ -385,6 +387,50 @@ public void testGrantApiKeyWithOnlyManageOwnApiKeyPrivilegeFails() throws IOExce
assertThat(e.getMessage(), containsString("action [" + GrantApiKeyAction.NAME + "] is unauthorized for user"));
}
+ public void testApiKeyWithManageRoles() throws IOException {
+ RoleDescriptor role = roleWithManageRoles("manage-roles-role", new String[] { "manage_own_api_key" }, "allowed-prefix*");
+ getSecurityClient().putRole(role);
+ createUser("test-user", END_USER_PASSWORD, List.of("manage-roles-role"));
+
+ final Request createApiKeyrequest = new Request("POST", "_security/api_key");
+ createApiKeyrequest.setOptions(
+ RequestOptions.DEFAULT.toBuilder().addHeader("Authorization", basicAuthHeaderValue("test-user", END_USER_PASSWORD))
+ );
+ final Map requestBody = Map.of(
+ "name",
+ "test-api-key",
+ "role_descriptors",
+ Map.of(
+ "test-role",
+ XContentTestUtils.convertToMap(roleWithManageRoles("test-role", new String[0], "allowed-prefix*")),
+ "another-test-role",
+ // This is not allowed by the limited-by-role (creator of the api key), so should not grant access to not-allowed=prefix*
+ XContentTestUtils.convertToMap(roleWithManageRoles("another-test-role", new String[0], "not-allowed-prefix*"))
+ )
+ );
+
+ createApiKeyrequest.setJsonEntity(XContentTestUtils.convertToXContent(requestBody, XContentType.JSON).utf8ToString());
+ Map responseMap = responseAsMap(client().performRequest(createApiKeyrequest));
+ String encodedApiKey = responseMap.get("encoded").toString();
+
+ final Request createRoleRequest = new Request("POST", "_security/role/test-role");
+ createRoleRequest.setOptions(RequestOptions.DEFAULT.toBuilder().addHeader("Authorization", "ApiKey " + encodedApiKey));
+ // Allowed role by manage roles permission
+ {
+ createRoleRequest.setJsonEntity("""
+ {"indices": [{"names": ["allowed-prefix-test"],"privileges": ["read"]}]}""");
+ assertOK(client().performRequest(createRoleRequest));
+ }
+ // Not allowed role by manage roles permission
+ {
+ createRoleRequest.setJsonEntity("""
+ {"indices": [{"names": ["not-allowed-prefix-test"],"privileges": ["read"]}]}""");
+ final ResponseException e = expectThrows(ResponseException.class, () -> client().performRequest(createRoleRequest));
+ assertEquals(403, e.getResponse().getStatusLine().getStatusCode());
+ assertThat(e.getMessage(), containsString("this action is granted by the cluster privileges [manage_security,all]"));
+ }
+ }
+
public void testUpdateApiKey() throws IOException {
final var apiKeyName = "my-api-key-name";
final Map apiKeyMetadata = Map.of("not", "returned");
@@ -2393,6 +2439,27 @@ private void createRole(String name, Collection localClusterPrivileges,
getSecurityClient().putRole(role);
}
+ private RoleDescriptor roleWithManageRoles(String name, String[] clusterPrivileges, String indexPattern) {
+ return new RoleDescriptor(
+ name,
+ clusterPrivileges,
+ null,
+ null,
+ new ConfigurableClusterPrivilege[] {
+ new ConfigurableClusterPrivileges.ManageRolesPrivilege(
+ List.of(
+ new ConfigurableClusterPrivileges.ManageRolesPrivilege.ManageRolesIndexPermissionGroup(
+ new String[] { indexPattern },
+ new String[] { "read" }
+ )
+ )
+ ) },
+ null,
+ null,
+ null
+ );
+ }
+
protected void createRoleWithDescription(String name, Collection clusterPrivileges, String description) throws IOException {
final RoleDescriptor role = new RoleDescriptor(
name,
diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ApiKeyService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ApiKeyService.java
index d88577f905e96..90566e25b4ea5 100644
--- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ApiKeyService.java
+++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ApiKeyService.java
@@ -100,6 +100,7 @@
import org.elasticsearch.xpack.core.security.authc.support.Hasher;
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.authz.privilege.ClusterPrivilegeResolver;
+import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivileges;
import org.elasticsearch.xpack.core.security.authz.store.ReservedRolesStore;
import org.elasticsearch.xpack.core.security.authz.store.RoleReference;
import org.elasticsearch.xpack.core.security.support.MetadataUtils;
@@ -137,6 +138,7 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;
+import static org.elasticsearch.TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE;
import static org.elasticsearch.TransportVersions.ROLE_REMOTE_CLUSTER_PRIVS;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.search.SearchService.DEFAULT_KEEPALIVE_SETTING;
@@ -363,29 +365,10 @@ public void createApiKey(
listener.onFailure(new IllegalArgumentException("authentication must be provided"));
} else {
final TransportVersion transportVersion = getMinTransportVersion();
- if (transportVersion.before(TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY)
- && hasRemoteIndices(request.getRoleDescriptors())) {
- // Creating API keys with roles which define remote indices privileges is not allowed in a mixed cluster.
- listener.onFailure(
- new IllegalArgumentException(
- "all nodes must have version ["
- + TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY.toReleaseVersion()
- + "] or higher to support remote indices privileges for API keys"
- )
- );
- return;
- }
- if (transportVersion.before(ROLE_REMOTE_CLUSTER_PRIVS) && hasRemoteCluster(request.getRoleDescriptors())) {
- // Creating API keys with roles which define remote cluster privileges is not allowed in a mixed cluster.
- listener.onFailure(
- new IllegalArgumentException(
- "all nodes must have version ["
- + ROLE_REMOTE_CLUSTER_PRIVS
- + "] or higher to support remote cluster privileges for API keys"
- )
- );
+ if (validateRoleDescriptorsForMixedCluster(listener, request.getRoleDescriptors(), transportVersion) == false) {
return;
}
+
if (transportVersion.before(TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY)
&& request.getType() == ApiKey.Type.CROSS_CLUSTER) {
listener.onFailure(
@@ -407,15 +390,63 @@ && hasRemoteIndices(request.getRoleDescriptors())) {
return;
}
- final Set userRolesWithoutDescription = removeUserRoleDescriptorDescriptions(userRoleDescriptors);
- final Set filteredUserRoleDescriptors = maybeRemoveRemotePrivileges(
- userRolesWithoutDescription,
+ Set filteredRoleDescriptors = filterRoleDescriptorsForMixedCluster(
+ userRoleDescriptors,
transportVersion,
request.getId()
);
- createApiKeyAndIndexIt(authentication, request, filteredUserRoleDescriptors, listener);
+ createApiKeyAndIndexIt(authentication, request, filteredRoleDescriptors, listener);
+ }
+ }
+
+ private Set filterRoleDescriptorsForMixedCluster(
+ final Set userRoleDescriptors,
+ final TransportVersion transportVersion,
+ final String... apiKeyIds
+ ) {
+ final Set userRolesWithoutDescription = removeUserRoleDescriptorDescriptions(userRoleDescriptors);
+ return maybeRemoveRemotePrivileges(userRolesWithoutDescription, transportVersion, apiKeyIds);
+ }
+
+ private boolean validateRoleDescriptorsForMixedCluster(
+ final ActionListener> listener,
+ final List roleDescriptors,
+ final TransportVersion transportVersion
+ ) {
+ if (transportVersion.before(TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY) && hasRemoteIndices(roleDescriptors)) {
+ // API keys with roles which define remote indices privileges is not allowed in a mixed cluster.
+ listener.onFailure(
+ new IllegalArgumentException(
+ "all nodes must have version ["
+ + TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY.toReleaseVersion()
+ + "] or higher to support remote indices privileges for API keys"
+ )
+ );
+ return false;
+ }
+ if (transportVersion.before(ROLE_REMOTE_CLUSTER_PRIVS) && hasRemoteCluster(roleDescriptors)) {
+ // API keys with roles which define remote cluster privileges is not allowed in a mixed cluster.
+ listener.onFailure(
+ new IllegalArgumentException(
+ "all nodes must have version ["
+ + ROLE_REMOTE_CLUSTER_PRIVS
+ + "] or higher to support remote cluster privileges for API keys"
+ )
+ );
+ return false;
+ }
+ if (transportVersion.before(ADD_MANAGE_ROLES_PRIVILEGE) && hasGlobalManageRolesPrivilege(roleDescriptors)) {
+ listener.onFailure(
+ new IllegalArgumentException(
+ "all nodes must have version ["
+ + ADD_MANAGE_ROLES_PRIVILEGE
+ + "] or higher to support the manage roles privilege for API keys"
+ )
+ );
+ return false;
}
+ return true;
}
/**
@@ -458,6 +489,13 @@ private static boolean hasRemoteCluster(Collection roleDescripto
return roleDescriptors != null && roleDescriptors.stream().anyMatch(RoleDescriptor::hasRemoteClusterPermissions);
}
+ private static boolean hasGlobalManageRolesPrivilege(Collection roleDescriptors) {
+ return roleDescriptors != null
+ && roleDescriptors.stream()
+ .flatMap(roleDescriptor -> Arrays.stream(roleDescriptor.getConditionalClusterPrivileges()))
+ .anyMatch(privilege -> privilege instanceof ConfigurableClusterPrivileges.ManageRolesPrivilege);
+ }
+
private static IllegalArgumentException validateWorkflowsRestrictionConstraints(
TransportVersion transportVersion,
List requestRoleDescriptors,
@@ -594,28 +632,11 @@ public void updateApiKeys(
}
final TransportVersion transportVersion = getMinTransportVersion();
- if (transportVersion.before(TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY) && hasRemoteIndices(request.getRoleDescriptors())) {
- // Updating API keys with roles which define remote indices privileges is not allowed in a mixed cluster.
- listener.onFailure(
- new IllegalArgumentException(
- "all nodes must have version ["
- + TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY.toReleaseVersion()
- + "] or higher to support remote indices privileges for API keys"
- )
- );
- return;
- }
- if (transportVersion.before(ROLE_REMOTE_CLUSTER_PRIVS) && hasRemoteCluster(request.getRoleDescriptors())) {
- // Updating API keys with roles which define remote cluster privileges is not allowed in a mixed cluster.
- listener.onFailure(
- new IllegalArgumentException(
- "all nodes must have version ["
- + ROLE_REMOTE_CLUSTER_PRIVS
- + "] or higher to support remote indices privileges for API keys"
- )
- );
+
+ if (validateRoleDescriptorsForMixedCluster(listener, request.getRoleDescriptors(), transportVersion) == false) {
return;
}
+
final Exception workflowsValidationException = validateWorkflowsRestrictionConstraints(
transportVersion,
request.getRoleDescriptors(),
@@ -627,22 +648,22 @@ public void updateApiKeys(
}
final String[] apiKeyIds = request.getIds().toArray(String[]::new);
- final Set userRolesWithoutDescription = removeUserRoleDescriptorDescriptions(userRoleDescriptors);
- final Set filteredUserRoleDescriptors = maybeRemoveRemotePrivileges(
- userRolesWithoutDescription,
- transportVersion,
- apiKeyIds
- );
if (logger.isDebugEnabled()) {
logger.debug("Updating [{}] API keys", buildDelimitedStringWithLimit(10, apiKeyIds));
}
+ Set filteredRoleDescriptors = filterRoleDescriptorsForMixedCluster(
+ userRoleDescriptors,
+ transportVersion,
+ apiKeyIds
+ );
+
findVersionedApiKeyDocsForSubject(
authentication,
apiKeyIds,
ActionListener.wrap(
- versionedDocs -> updateApiKeys(authentication, request, filteredUserRoleDescriptors, versionedDocs, listener),
+ versionedDocs -> updateApiKeys(authentication, request, filteredRoleDescriptors, versionedDocs, listener),
ex -> listener.onFailure(traceLog("bulk update", ex))
)
);
diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java
index a2d2b21b489ea..9ddda193dba39 100644
--- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java
+++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java
@@ -60,6 +60,7 @@
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor.IndicesPrivileges;
import org.elasticsearch.xpack.core.security.authz.permission.RemoteClusterPermissions;
+import org.elasticsearch.xpack.core.security.authz.privilege.ConfigurableClusterPrivileges;
import org.elasticsearch.xpack.core.security.authz.store.RoleRetrievalResult;
import org.elasticsearch.xpack.core.security.authz.support.DLSRoleQueryValidator;
import org.elasticsearch.xpack.core.security.support.NativeRealmValidationUtil;
@@ -476,7 +477,15 @@ private Exception validateRoleDescriptor(RoleDescriptor role) {
+ TransportVersions.SECURITY_ROLE_DESCRIPTION.toReleaseVersion()
+ "] or higher to support specifying role description"
);
- }
+ } else if (Arrays.stream(role.getConditionalClusterPrivileges())
+ .anyMatch(privilege -> privilege instanceof ConfigurableClusterPrivileges.ManageRolesPrivilege)
+ && clusterService.state().getMinTransportVersion().before(TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE)) {
+ return new IllegalStateException(
+ "all nodes must have version ["
+ + TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE.toReleaseVersion()
+ + "] or higher to support the manage roles privilege"
+ );
+ }
try {
DLSRoleQueryValidator.validateQueryField(role.getIndicesPrivileges(), xContentRegistry);
} catch (ElasticsearchException | IllegalArgumentException e) {
diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/SecuritySystemIndices.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/SecuritySystemIndices.java
index 4c5ce703f48ad..9541dd9dc470d 100644
--- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/SecuritySystemIndices.java
+++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/SecuritySystemIndices.java
@@ -36,6 +36,7 @@
import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_PROFILE_ORIGIN;
import static org.elasticsearch.xpack.security.support.SecurityIndexManager.SECURITY_VERSION_STRING;
+import static org.elasticsearch.xpack.security.support.SecuritySystemIndices.SecurityMainIndexMappingVersion.ADD_MANAGE_ROLES_PRIVILEGE;
/**
* Responsible for handling system indices for the Security plugin
@@ -409,6 +410,40 @@ private XContentBuilder getMainIndexMappings(SecurityMainIndexMappingVersion map
builder.endObject();
}
builder.endObject();
+ if (mappingVersion.onOrAfter(ADD_MANAGE_ROLES_PRIVILEGE)) {
+ builder.startObject("role");
+ {
+ builder.field("type", "object");
+ builder.startObject("properties");
+ {
+ builder.startObject("manage");
+ {
+ builder.field("type", "object");
+ builder.startObject("properties");
+ {
+ builder.startObject("indices");
+ {
+ builder.startObject("properties");
+ {
+ builder.startObject("names");
+ builder.field("type", "keyword");
+ builder.endObject();
+ builder.startObject("privileges");
+ builder.field("type", "keyword");
+ builder.endObject();
+ }
+ builder.endObject();
+ }
+ builder.endObject();
+ }
+ builder.endObject();
+ }
+ builder.endObject();
+ }
+ builder.endObject();
+ }
+ builder.endObject();
+ }
}
builder.endObject();
}
@@ -1050,6 +1085,11 @@ public enum SecurityMainIndexMappingVersion implements VersionId(Arrays.asList("", "\""))),
- new ConfigurableClusterPrivileges.ManageApplicationPrivileges(Set.of("\"")) },
+ new ConfigurableClusterPrivileges.ManageApplicationPrivileges(Set.of("\"")),
+ new ConfigurableClusterPrivileges.ManageRolesPrivilege(
+ List.of(
+ new ConfigurableClusterPrivileges.ManageRolesPrivilege.ManageRolesIndexPermissionGroup(
+ new String[] { "test*" },
+ new String[] { "read", "write" }
+ )
+ )
+ ) },
new String[] { "\"[a]/" },
Map.of(),
Map.of()
diff --git a/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/audit/logfile/audited_roles.txt b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/audit/logfile/audited_roles.txt
index 7b5e24c97d65a..f913c8608960b 100644
--- a/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/audit/logfile/audited_roles.txt
+++ b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/audit/logfile/audited_roles.txt
@@ -7,6 +7,6 @@ role_descriptor2
role_descriptor3
{"cluster":[],"indices":[],"applications":[{"application":"maps","privileges":["{","}","\n","\\","\""],"resources":["raster:*"]},{"application":"maps","privileges":["*:*"],"resources":["noooooo!!\n\n\f\\\\r","{"]}],"run_as":["jack","nich*","//\""],"metadata":{"some meta":42}}
role_descriptor4
-{"cluster":["manage_ml","grant_api_key","manage_rollup"],"global":{"application":{"manage":{"applications":["a+b+|b+a+"]}},"profile":{}},"indices":[{"names":["/. ? + * | { } [ ] ( ) \" \\/","*"],"privileges":["read","read_cross_cluster"],"field_security":{"grant":["almost","all*"],"except":["denied*"]}}],"applications":[],"run_as":["//+a+\"[a]/"],"metadata":{"?list":["e1","e2","*"],"some other meta":{"r":"t"}}}
+{"cluster":["manage_ml","grant_api_key","manage_rollup"],"global":{"application":{"manage":{"applications":["a+b+|b+a+"]}},"profile":{},"role":{}},"indices":[{"names":["/. ? + * | { } [ ] ( ) \" \\/","*"],"privileges":["read","read_cross_cluster"],"field_security":{"grant":["almost","all*"],"except":["denied*"]}}],"applications":[],"run_as":["//+a+\"[a]/"],"metadata":{"?list":["e1","e2","*"],"some other meta":{"r":"t"}}}
role_descriptor5
-{"cluster":["all"],"global":{"application":{"manage":{"applications":["\""]}},"profile":{"write":{"applications":["","\""]}}},"indices":[],"applications":[],"run_as":["\"[a]/"]}
+{"cluster":["all"],"global":{"application":{"manage":{"applications":["\""]}},"profile":{"write":{"applications":["","\""]}},"role":{"manage":{"indices":[{"names":["test*"],"privileges":["read","write"]}]}}},"indices":[],"applications":[],"run_as":["\"[a]/"]}
diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/RolesBackwardsCompatibilityIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/RolesBackwardsCompatibilityIT.java
index 4f4ff1d5743ee..650779cfbc85d 100644
--- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/RolesBackwardsCompatibilityIT.java
+++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/RolesBackwardsCompatibilityIT.java
@@ -29,6 +29,7 @@
import static org.elasticsearch.xpack.core.security.authz.RoleDescriptorTestHelper.randomApplicationPrivileges;
import static org.elasticsearch.xpack.core.security.authz.RoleDescriptorTestHelper.randomIndicesPrivileges;
+import static org.elasticsearch.xpack.core.security.authz.RoleDescriptorTestHelper.randomManageRolesPrivileges;
import static org.elasticsearch.xpack.core.security.authz.RoleDescriptorTestHelper.randomRoleDescriptorMetadata;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
@@ -40,7 +41,7 @@ public class RolesBackwardsCompatibilityIT extends AbstractUpgradeTestCase {
private RestClient oldVersionClient = null;
private RestClient newVersionClient = null;
- public void testCreatingAndUpdatingRoles() throws Exception {
+ public void testRolesWithDescription() throws Exception {
assumeTrue(
"The role description is supported after transport version: " + TransportVersions.SECURITY_ROLE_DESCRIPTION,
minimumTransportVersion().before(TransportVersions.SECURITY_ROLE_DESCRIPTION)
@@ -48,14 +49,14 @@ public void testCreatingAndUpdatingRoles() throws Exception {
switch (CLUSTER_TYPE) {
case OLD -> {
// Creating role in "old" cluster should succeed when description is not provided
- final String initialRole = randomRoleDescriptorSerialized(false);
+ final String initialRole = randomRoleDescriptorSerialized();
createRole(client(), "my-old-role", initialRole);
- updateRole("my-old-role", randomValueOtherThan(initialRole, () -> randomRoleDescriptorSerialized(false)));
+ updateRole("my-old-role", randomValueOtherThan(initialRole, RolesBackwardsCompatibilityIT::randomRoleDescriptorSerialized));
// and fail if we include description
var createException = expectThrows(
Exception.class,
- () -> createRole(client(), "my-invalid-old-role", randomRoleDescriptorSerialized(true))
+ () -> createRole(client(), "my-invalid-old-role", randomRoleDescriptorWithDescriptionSerialized())
);
assertThat(
createException.getMessage(),
@@ -65,7 +66,7 @@ public void testCreatingAndUpdatingRoles() throws Exception {
RestClient client = client();
var updateException = expectThrows(
Exception.class,
- () -> updateRole(client, "my-old-role", randomRoleDescriptorSerialized(true))
+ () -> updateRole(client, "my-old-role", randomRoleDescriptorWithDescriptionSerialized())
);
assertThat(
updateException.getMessage(),
@@ -74,17 +75,20 @@ public void testCreatingAndUpdatingRoles() throws Exception {
}
case MIXED -> {
try {
- this.createClientsByVersion();
+ this.createClientsByVersion(TransportVersions.SECURITY_ROLE_DESCRIPTION);
// succeed when role description is not provided
- final String initialRole = randomRoleDescriptorSerialized(false);
+ final String initialRole = randomRoleDescriptorSerialized();
createRole(client(), "my-valid-mixed-role", initialRole);
- updateRole("my-valid-mixed-role", randomValueOtherThan(initialRole, () -> randomRoleDescriptorSerialized(false)));
+ updateRole(
+ "my-valid-mixed-role",
+ randomValueOtherThan(initialRole, RolesBackwardsCompatibilityIT::randomRoleDescriptorSerialized)
+ );
// against old node, fail when description is provided either in update or create request
{
Exception e = expectThrows(
Exception.class,
- () -> updateRole(oldVersionClient, "my-valid-mixed-role", randomRoleDescriptorSerialized(true))
+ () -> updateRole(oldVersionClient, "my-valid-mixed-role", randomRoleDescriptorWithDescriptionSerialized())
);
assertThat(
e.getMessage(),
@@ -94,7 +98,7 @@ public void testCreatingAndUpdatingRoles() throws Exception {
{
Exception e = expectThrows(
Exception.class,
- () -> createRole(oldVersionClient, "my-invalid-mixed-role", randomRoleDescriptorSerialized(true))
+ () -> createRole(oldVersionClient, "my-invalid-mixed-role", randomRoleDescriptorWithDescriptionSerialized())
);
assertThat(
e.getMessage(),
@@ -106,7 +110,7 @@ public void testCreatingAndUpdatingRoles() throws Exception {
{
Exception e = expectThrows(
Exception.class,
- () -> createRole(newVersionClient, "my-invalid-mixed-role", randomRoleDescriptorSerialized(true))
+ () -> createRole(newVersionClient, "my-invalid-mixed-role", randomRoleDescriptorWithDescriptionSerialized())
);
assertThat(
e.getMessage(),
@@ -120,7 +124,7 @@ public void testCreatingAndUpdatingRoles() throws Exception {
{
Exception e = expectThrows(
Exception.class,
- () -> updateRole(newVersionClient, "my-valid-mixed-role", randomRoleDescriptorSerialized(true))
+ () -> updateRole(newVersionClient, "my-valid-mixed-role", randomRoleDescriptorWithDescriptionSerialized())
);
assertThat(
e.getMessage(),
@@ -138,11 +142,129 @@ public void testCreatingAndUpdatingRoles() throws Exception {
case UPGRADED -> {
// on upgraded cluster which supports new description field
// create/update requests should succeed either way (with or without description)
- final String initialRole = randomRoleDescriptorSerialized(randomBoolean());
+ final String initialRole = randomFrom(randomRoleDescriptorSerialized(), randomRoleDescriptorWithDescriptionSerialized());
createRole(client(), "my-valid-upgraded-role", initialRole);
updateRole(
"my-valid-upgraded-role",
- randomValueOtherThan(initialRole, () -> randomRoleDescriptorSerialized(randomBoolean()))
+ randomValueOtherThan(
+ initialRole,
+ () -> randomFrom(randomRoleDescriptorSerialized(), randomRoleDescriptorWithDescriptionSerialized())
+ )
+ );
+ }
+ }
+ }
+
+ public void testRolesWithManageRoles() throws Exception {
+ assumeTrue(
+ "The manage roles privilege is supported after transport version: " + TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE,
+ minimumTransportVersion().before(TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE)
+ );
+ switch (CLUSTER_TYPE) {
+ case OLD -> {
+ // Creating role in "old" cluster should succeed when manage roles is not provided
+ final String initialRole = randomRoleDescriptorSerialized();
+ createRole(client(), "my-old-role", initialRole);
+ updateRole("my-old-role", randomValueOtherThan(initialRole, RolesBackwardsCompatibilityIT::randomRoleDescriptorSerialized));
+
+ // and fail if we include manage roles
+ var createException = expectThrows(
+ Exception.class,
+ () -> createRole(client(), "my-invalid-old-role", randomRoleDescriptorWithManageRolesSerialized())
+ );
+ assertThat(
+ createException.getMessage(),
+ allOf(containsString("failed to parse privilege"), containsString("but found [role] instead"))
+ );
+
+ RestClient client = client();
+ var updateException = expectThrows(
+ Exception.class,
+ () -> updateRole(client, "my-old-role", randomRoleDescriptorWithManageRolesSerialized())
+ );
+ assertThat(
+ updateException.getMessage(),
+ allOf(containsString("failed to parse privilege"), containsString("but found [role] instead"))
+ );
+ }
+ case MIXED -> {
+ try {
+ this.createClientsByVersion(TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE);
+ // succeed when role manage roles is not provided
+ final String initialRole = randomRoleDescriptorSerialized();
+ createRole(client(), "my-valid-mixed-role", initialRole);
+ updateRole(
+ "my-valid-mixed-role",
+ randomValueOtherThan(initialRole, RolesBackwardsCompatibilityIT::randomRoleDescriptorSerialized)
+ );
+
+ // against old node, fail when manage roles is provided either in update or create request
+ {
+ Exception e = expectThrows(
+ Exception.class,
+ () -> updateRole(oldVersionClient, "my-valid-mixed-role", randomRoleDescriptorWithManageRolesSerialized())
+ );
+ assertThat(
+ e.getMessage(),
+ allOf(containsString("failed to parse privilege"), containsString("but found [role] instead"))
+ );
+ }
+ {
+ Exception e = expectThrows(
+ Exception.class,
+ () -> createRole(oldVersionClient, "my-invalid-mixed-role", randomRoleDescriptorWithManageRolesSerialized())
+ );
+ assertThat(
+ e.getMessage(),
+ allOf(containsString("failed to parse privilege"), containsString("but found [role] instead"))
+ );
+ }
+
+ // and against new node in a mixed cluster we should fail
+ {
+ Exception e = expectThrows(
+ Exception.class,
+ () -> createRole(newVersionClient, "my-invalid-mixed-role", randomRoleDescriptorWithManageRolesSerialized())
+ );
+
+ assertThat(
+ e.getMessage(),
+ containsString(
+ "all nodes must have version ["
+ + TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE.toReleaseVersion()
+ + "] or higher to support the manage roles privilege"
+ )
+ );
+ }
+ {
+ Exception e = expectThrows(
+ Exception.class,
+ () -> updateRole(newVersionClient, "my-valid-mixed-role", randomRoleDescriptorWithManageRolesSerialized())
+ );
+ assertThat(
+ e.getMessage(),
+ containsString(
+ "all nodes must have version ["
+ + TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE.toReleaseVersion()
+ + "] or higher to support the manage roles privilege"
+ )
+ );
+ }
+ } finally {
+ this.closeClientsByVersion();
+ }
+ }
+ case UPGRADED -> {
+ // on upgraded cluster which supports new description field
+ // create/update requests should succeed either way (with or without description)
+ final String initialRole = randomFrom(randomRoleDescriptorSerialized(), randomRoleDescriptorWithManageRolesSerialized());
+ createRole(client(), "my-valid-upgraded-role", initialRole);
+ updateRole(
+ "my-valid-upgraded-role",
+ randomValueOtherThan(
+ initialRole,
+ () -> randomFrom(randomRoleDescriptorSerialized(), randomRoleDescriptorWithManageRolesSerialized())
+ )
);
}
}
@@ -166,10 +288,22 @@ private void updateRole(RestClient client, String roleName, String payload) thro
assertThat(created, equalTo(false));
}
- private static String randomRoleDescriptorSerialized(boolean includeDescription) {
+ private static String randomRoleDescriptorSerialized() {
+ return randomRoleDescriptorSerialized(false, false);
+ }
+
+ private static String randomRoleDescriptorWithDescriptionSerialized() {
+ return randomRoleDescriptorSerialized(true, false);
+ }
+
+ private static String randomRoleDescriptorWithManageRolesSerialized() {
+ return randomRoleDescriptorSerialized(false, true);
+ }
+
+ private static String randomRoleDescriptorSerialized(boolean includeDescription, boolean includeManageRoles) {
try {
return XContentTestUtils.convertToXContent(
- XContentTestUtils.convertToMap(randomRoleDescriptor(includeDescription)),
+ XContentTestUtils.convertToMap(randomRoleDescriptor(includeDescription, includeManageRoles)),
XContentType.JSON
).utf8ToString();
} catch (IOException e) {
@@ -177,26 +311,26 @@ private static String randomRoleDescriptorSerialized(boolean includeDescription)
}
}
- private boolean nodeSupportRoleDescription(Map nodeDetails) {
+ private boolean nodeSupportTransportVersion(Map nodeDetails, TransportVersion transportVersion) {
String nodeVersionString = (String) nodeDetails.get("version");
- TransportVersion transportVersion = getTransportVersionWithFallback(
+ TransportVersion nodeTransportVersion = getTransportVersionWithFallback(
nodeVersionString,
nodeDetails.get("transport_version"),
() -> TransportVersions.ZERO
);
- if (transportVersion.equals(TransportVersions.ZERO)) {
+ if (nodeTransportVersion.equals(TransportVersions.ZERO)) {
// In cases where we were not able to find a TransportVersion, a pre-8.8.0 node answered about a newer (upgraded) node.
// In that case, the node will be current (upgraded), and remote indices are supported for sure.
var nodeIsCurrent = nodeVersionString.equals(Build.current().version());
assertTrue(nodeIsCurrent);
return true;
}
- return transportVersion.onOrAfter(TransportVersions.SECURITY_ROLE_DESCRIPTION);
+ return nodeTransportVersion.onOrAfter(transportVersion);
}
- private void createClientsByVersion() throws IOException {
- var clientsByCapability = getRestClientByCapability();
+ private void createClientsByVersion(TransportVersion transportVersion) throws IOException {
+ var clientsByCapability = getRestClientByCapability(transportVersion);
if (clientsByCapability.size() == 2) {
for (Map.Entry client : clientsByCapability.entrySet()) {
if (client.getKey() == false) {
@@ -224,7 +358,7 @@ private void closeClientsByVersion() throws IOException {
}
@SuppressWarnings("unchecked")
- private Map getRestClientByCapability() throws IOException {
+ private Map getRestClientByCapability(TransportVersion transportVersion) throws IOException {
Response response = client().performRequest(new Request("GET", "_nodes"));
assertOK(response);
ObjectPath objectPath = ObjectPath.createFromResponse(response);
@@ -232,7 +366,7 @@ private Map getRestClientByCapability() throws IOException
Map> hostsByCapability = new HashMap<>();
for (Map.Entry entry : nodesAsMap.entrySet()) {
Map nodeDetails = (Map) entry.getValue();
- var capabilitySupported = nodeSupportRoleDescription(nodeDetails);
+ var capabilitySupported = nodeSupportTransportVersion(nodeDetails, transportVersion);
Map httpInfo = (Map) nodeDetails.get("http");
hostsByCapability.computeIfAbsent(capabilitySupported, k -> new ArrayList<>())
.add(HttpHost.create((String) httpInfo.get("publish_address")));
@@ -244,7 +378,7 @@ private Map getRestClientByCapability() throws IOException
return clientsByCapability;
}
- private static RoleDescriptor randomRoleDescriptor(boolean includeDescription) {
+ private static RoleDescriptor randomRoleDescriptor(boolean includeDescription, boolean includeManageRoles) {
final Set excludedPrivileges = Set.of(
"cross_cluster_replication",
"cross_cluster_replication_internal",
@@ -255,7 +389,7 @@ private static RoleDescriptor randomRoleDescriptor(boolean includeDescription) {
randomSubsetOf(Set.of("all", "monitor", "none")).toArray(String[]::new),
randomIndicesPrivileges(0, 3, excludedPrivileges),
randomApplicationPrivileges(),
- null,
+ includeManageRoles ? randomManageRolesPrivileges() : null,
generateRandomStringArray(5, randomIntBetween(2, 8), false, true),
randomRoleDescriptorMetadata(false),
Map.of(),
From f150e2c11df0fe3bef298c55bd867437e50f5f73 Mon Sep 17 00:00:00 2001
From: David Turner
Date: Tue, 27 Aug 2024 14:34:02 +0100
Subject: [PATCH 13/46] Add telemetry for repository usage (#112133)
Adds to the `GET _cluster/stats` endpoint information about the snapshot
repositories in use, including their types, whether they are read-only
or read-write, and for Azure repositories the kind of credentials in
use.
---
docs/changelog/112133.yaml | 5 ++
docs/reference/cluster/stats.asciidoc | 31 +++++++++-
.../repositories/azure/AzureRepository.java | 6 ++
.../azure/AzureStorageService.java | 12 ++++
.../azure/AzureStorageSettings.java | 12 ++++
.../test/repository_azure/20_repository.yml | 13 ++++
.../test/repository_gcs/20_repository.yml | 13 ++++
.../20_repository_permanent_credentials.yml | 13 ++++
.../30_repository_temporary_credentials.yml | 13 ++++
.../40_repository_ec2_credentials.yml | 13 ++++
.../50_repository_ecs_credentials.yml | 13 ++++
.../60_repository_sts_credentials.yml | 13 ++++
server/src/main/java/module-info.java | 1 +
.../org/elasticsearch/TransportVersions.java | 2 +
.../stats/ClusterStatsNodeResponse.java | 36 ++++++-----
.../cluster/stats/ClusterStatsResponse.java | 12 ++++
.../cluster/stats/RepositoryUsageStats.java | 59 +++++++++++++++++++
.../stats/TransportClusterStatsAction.java | 19 ++++--
.../cluster/health/ClusterHealthStatus.java | 2 +-
.../repositories/RepositoriesFeatures.java | 23 ++++++++
.../repositories/RepositoriesService.java | 27 +++++++--
.../repositories/Repository.java | 8 +++
.../blobstore/BlobStoreRepository.java | 25 ++++++++
...lasticsearch.features.FeatureSpecification | 1 +
.../cluster/stats/VersionStatsTests.java | 3 +-
.../ClusterStatsMonitoringDocTests.java | 25 ++++----
.../AzureRepositoryAnalysisRestIT.java | 37 ++++++++++++
27 files changed, 400 insertions(+), 37 deletions(-)
create mode 100644 docs/changelog/112133.yaml
create mode 100644 server/src/main/java/org/elasticsearch/action/admin/cluster/stats/RepositoryUsageStats.java
create mode 100644 server/src/main/java/org/elasticsearch/repositories/RepositoriesFeatures.java
diff --git a/docs/changelog/112133.yaml b/docs/changelog/112133.yaml
new file mode 100644
index 0000000000000..11109402b7373
--- /dev/null
+++ b/docs/changelog/112133.yaml
@@ -0,0 +1,5 @@
+pr: 112133
+summary: Add telemetry for repository usage
+area: Snapshot/Restore
+type: enhancement
+issues: []
diff --git a/docs/reference/cluster/stats.asciidoc b/docs/reference/cluster/stats.asciidoc
index 3b429ef427071..c39bc0dcd2878 100644
--- a/docs/reference/cluster/stats.asciidoc
+++ b/docs/reference/cluster/stats.asciidoc
@@ -1282,6 +1282,31 @@ They are included here for expert users, but should otherwise be ignored.
=====
+====
+
+`repositories`::
+(object) Contains statistics about the <> repositories defined in the cluster, broken down
+by repository type.
++
+.Properties of `repositories`
+[%collapsible%open]
+=====
+
+`count`:::
+(integer) The number of repositories of this type in the cluster.
+
+`read_only`:::
+(integer) The number of repositories of this type in the cluster which are registered read-only.
+
+`read_write`:::
+(integer) The number of repositories of this type in the cluster which are not registered as read-only.
+
+Each repository type may also include other statistics about the repositories of that type here.
+
+=====
+
+====
+
[[cluster-stats-api-example]]
==== {api-examples-title}
@@ -1579,6 +1604,9 @@ The API returns the following response:
},
"snapshots": {
...
+ },
+ "repositories": {
+ ...
}
}
--------------------------------------------------
@@ -1589,6 +1617,7 @@ The API returns the following response:
// TESTRESPONSE[s/"count": \{[^\}]*\}/"count": $body.$_path/]
// TESTRESPONSE[s/"packaging_types": \[[^\]]*\]/"packaging_types": $body.$_path/]
// TESTRESPONSE[s/"snapshots": \{[^\}]*\}/"snapshots": $body.$_path/]
+// TESTRESPONSE[s/"repositories": \{[^\}]*\}/"repositories": $body.$_path/]
// TESTRESPONSE[s/"field_types": \[[^\]]*\]/"field_types": $body.$_path/]
// TESTRESPONSE[s/"runtime_field_types": \[[^\]]*\]/"runtime_field_types": $body.$_path/]
// TESTRESPONSE[s/"search": \{[^\}]*\}/"search": $body.$_path/]
@@ -1600,7 +1629,7 @@ The API returns the following response:
// the plugins that will be in it. And because we figure folks don't need to
// see an exhaustive list anyway.
// 2. Similarly, ignore the contents of `network_types`, `discovery_types`,
-// `packaging_types` and `snapshots`.
+// `packaging_types`, `snapshots` and `repositories`.
// 3. Ignore the contents of the (nodes) count object, as what's shown here
// depends on the license. Voting-only nodes are e.g. only shown when this
// test runs with a basic license.
diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepository.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepository.java
index 388474acc75ea..c8c0b15db5ebe 100644
--- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepository.java
+++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepository.java
@@ -26,6 +26,7 @@
import java.util.Locale;
import java.util.Map;
+import java.util.Set;
import java.util.function.Function;
import static org.elasticsearch.core.Strings.format;
@@ -175,4 +176,9 @@ protected ByteSizeValue chunkSize() {
public boolean isReadOnly() {
return readonly;
}
+
+ @Override
+ protected Set getExtraUsageFeatures() {
+ return storageService.getExtraUsageFeatures(Repository.CLIENT_NAME.get(getMetadata().settings()));
+ }
}
diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java
index 0d6cd7bf3d246..09088004759a8 100644
--- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java
+++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java
@@ -24,6 +24,7 @@
import java.net.Proxy;
import java.net.URL;
import java.util.Map;
+import java.util.Set;
import java.util.function.BiConsumer;
import static java.util.Collections.emptyMap;
@@ -165,4 +166,15 @@ public void refreshSettings(Map clientsSettings) {
this.storageSettings = Map.copyOf(clientsSettings);
// clients are built lazily by {@link client(String, LocationMode)}
}
+
+ /**
+ * For Azure repositories, we report the different kinds of credentials in use in the telemetry.
+ */
+ public Set getExtraUsageFeatures(String clientName) {
+ try {
+ return getClientSettings(clientName).credentialsUsageFeatures();
+ } catch (Exception e) {
+ return Set.of();
+ }
+ }
}
diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageSettings.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageSettings.java
index b3e8dd8898bea..2333a1fdb9e93 100644
--- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageSettings.java
+++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageSettings.java
@@ -29,6 +29,7 @@
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
+import java.util.Set;
final class AzureStorageSettings {
@@ -130,6 +131,7 @@ final class AzureStorageSettings {
private final int maxRetries;
private final Proxy proxy;
private final boolean hasCredentials;
+ private final Set credentialsUsageFeatures;
private AzureStorageSettings(
String account,
@@ -150,6 +152,12 @@ private AzureStorageSettings(
this.endpointSuffix = endpointSuffix;
this.timeout = timeout;
this.maxRetries = maxRetries;
+ this.credentialsUsageFeatures = Strings.hasText(key) ? Set.of("uses_key_credentials")
+ : Strings.hasText(sasToken) ? Set.of("uses_sas_token")
+ : SocketAccess.doPrivilegedException(() -> System.getenv("AZURE_FEDERATED_TOKEN_FILE")) == null
+ ? Set.of("uses_default_credentials", "uses_managed_identity")
+ : Set.of("uses_default_credentials", "uses_workload_identity");
+
// Register the proxy if we have any
// Validate proxy settings
if (proxyType.equals(Proxy.Type.DIRECT) && ((proxyPort != 0) || Strings.hasText(proxyHost))) {
@@ -366,4 +374,8 @@ private String deriveURIFromSettings(boolean isPrimary) {
throw new IllegalArgumentException(e);
}
}
+
+ public Set credentialsUsageFeatures() {
+ return credentialsUsageFeatures;
+ }
}
diff --git a/modules/repository-azure/src/yamlRestTest/resources/rest-api-spec/test/repository_azure/20_repository.yml b/modules/repository-azure/src/yamlRestTest/resources/rest-api-spec/test/repository_azure/20_repository.yml
index 299183f26d9dc..a4a7d0b22a0ed 100644
--- a/modules/repository-azure/src/yamlRestTest/resources/rest-api-spec/test/repository_azure/20_repository.yml
+++ b/modules/repository-azure/src/yamlRestTest/resources/rest-api-spec/test/repository_azure/20_repository.yml
@@ -235,6 +235,19 @@ setup:
snapshot: missing
wait_for_completion: true
+---
+"Usage stats":
+ - requires:
+ cluster_features:
+ - repositories.supports_usage_stats
+ reason: requires this feature
+
+ - do:
+ cluster.stats: {}
+
+ - gte: { repositories.azure.count: 1 }
+ - gte: { repositories.azure.read_write: 1 }
+
---
teardown:
diff --git a/modules/repository-gcs/src/yamlRestTest/resources/rest-api-spec/test/repository_gcs/20_repository.yml b/modules/repository-gcs/src/yamlRestTest/resources/rest-api-spec/test/repository_gcs/20_repository.yml
index 68d61be4983c5..e8c34a4b6a20b 100644
--- a/modules/repository-gcs/src/yamlRestTest/resources/rest-api-spec/test/repository_gcs/20_repository.yml
+++ b/modules/repository-gcs/src/yamlRestTest/resources/rest-api-spec/test/repository_gcs/20_repository.yml
@@ -232,6 +232,19 @@ setup:
snapshot: missing
wait_for_completion: true
+---
+"Usage stats":
+ - requires:
+ cluster_features:
+ - repositories.supports_usage_stats
+ reason: requires this feature
+
+ - do:
+ cluster.stats: {}
+
+ - gte: { repositories.gcs.count: 1 }
+ - gte: { repositories.gcs.read_write: 1 }
+
---
teardown:
diff --git a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/20_repository_permanent_credentials.yml b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/20_repository_permanent_credentials.yml
index 77870697f93ae..e88a0861ec01c 100644
--- a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/20_repository_permanent_credentials.yml
+++ b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/20_repository_permanent_credentials.yml
@@ -345,6 +345,19 @@ setup:
snapshot: missing
wait_for_completion: true
+---
+"Usage stats":
+ - requires:
+ cluster_features:
+ - repositories.supports_usage_stats
+ reason: requires this feature
+
+ - do:
+ cluster.stats: {}
+
+ - gte: { repositories.s3.count: 1 }
+ - gte: { repositories.s3.read_write: 1 }
+
---
teardown:
diff --git a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/30_repository_temporary_credentials.yml b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/30_repository_temporary_credentials.yml
index 4a62d6183470d..501af980e17e3 100644
--- a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/30_repository_temporary_credentials.yml
+++ b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/30_repository_temporary_credentials.yml
@@ -256,6 +256,19 @@ setup:
snapshot: missing
wait_for_completion: true
+---
+"Usage stats":
+ - requires:
+ cluster_features:
+ - repositories.supports_usage_stats
+ reason: requires this feature
+
+ - do:
+ cluster.stats: {}
+
+ - gte: { repositories.s3.count: 1 }
+ - gte: { repositories.s3.read_write: 1 }
+
---
teardown:
diff --git a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/40_repository_ec2_credentials.yml b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/40_repository_ec2_credentials.yml
index e24ff1ad0e559..129f0ba5d7588 100644
--- a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/40_repository_ec2_credentials.yml
+++ b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/40_repository_ec2_credentials.yml
@@ -256,6 +256,19 @@ setup:
snapshot: missing
wait_for_completion: true
+---
+"Usage stats":
+ - requires:
+ cluster_features:
+ - repositories.supports_usage_stats
+ reason: requires this feature
+
+ - do:
+ cluster.stats: {}
+
+ - gte: { repositories.s3.count: 1 }
+ - gte: { repositories.s3.read_write: 1 }
+
---
teardown:
diff --git a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/50_repository_ecs_credentials.yml b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/50_repository_ecs_credentials.yml
index 9c332cc7d9301..de334b4b3df96 100644
--- a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/50_repository_ecs_credentials.yml
+++ b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/50_repository_ecs_credentials.yml
@@ -256,6 +256,19 @@ setup:
snapshot: missing
wait_for_completion: true
+---
+"Usage stats":
+ - requires:
+ cluster_features:
+ - repositories.supports_usage_stats
+ reason: requires this feature
+
+ - do:
+ cluster.stats: {}
+
+ - gte: { repositories.s3.count: 1 }
+ - gte: { repositories.s3.read_write: 1 }
+
---
teardown:
diff --git a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/60_repository_sts_credentials.yml b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/60_repository_sts_credentials.yml
index 24c2b2b1741d6..09a8526017960 100644
--- a/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/60_repository_sts_credentials.yml
+++ b/modules/repository-s3/src/yamlRestTest/resources/rest-api-spec/test/repository_s3/60_repository_sts_credentials.yml
@@ -257,6 +257,19 @@ setup:
snapshot: missing
wait_for_completion: true
+---
+"Usage stats":
+ - requires:
+ cluster_features:
+ - repositories.supports_usage_stats
+ reason: requires this feature
+
+ - do:
+ cluster.stats: {}
+
+ - gte: { repositories.s3.count: 1 }
+ - gte: { repositories.s3.read_write: 1 }
+
---
teardown:
diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java
index c223db531e688..d412748ed4e57 100644
--- a/server/src/main/java/module-info.java
+++ b/server/src/main/java/module-info.java
@@ -429,6 +429,7 @@
org.elasticsearch.cluster.metadata.MetadataFeatures,
org.elasticsearch.rest.RestFeatures,
org.elasticsearch.indices.IndicesFeatures,
+ org.elasticsearch.repositories.RepositoriesFeatures,
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures,
org.elasticsearch.index.mapper.MapperFeatures,
org.elasticsearch.ingest.IngestGeoIpFeatures,
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index 582c618216999..41fa34bb5a4a3 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -199,6 +199,8 @@ static TransportVersion def(int id) {
public static final TransportVersion RANK_DOCS_RETRIEVER = def(8_729_00_0);
public static final TransportVersion ESQL_ES_FIELD_CACHED_SERIALIZATION = def(8_730_00_0);
public static final TransportVersion ADD_MANAGE_ROLES_PRIVILEGE = def(8_731_00_0);
+ public static final TransportVersion REPOSITORIES_TELEMETRY = def(8_732_00_0);
+
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsNodeResponse.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsNodeResponse.java
index d74889b623589..b48295dc8b3eb 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsNodeResponse.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsNodeResponse.java
@@ -20,29 +20,33 @@
import org.elasticsearch.core.Nullable;
import java.io.IOException;
+import java.util.Objects;
public class ClusterStatsNodeResponse extends BaseNodeResponse {
private final NodeInfo nodeInfo;
private final NodeStats nodeStats;
private final ShardStats[] shardsStats;
- private ClusterHealthStatus clusterStatus;
+ private final ClusterHealthStatus clusterStatus;
private final SearchUsageStats searchUsageStats;
+ private final RepositoryUsageStats repositoryUsageStats;
public ClusterStatsNodeResponse(StreamInput in) throws IOException {
super(in);
- clusterStatus = null;
- if (in.readBoolean()) {
- clusterStatus = ClusterHealthStatus.readFrom(in);
- }
+ this.clusterStatus = in.readOptionalWriteable(ClusterHealthStatus::readFrom);
this.nodeInfo = new NodeInfo(in);
this.nodeStats = new NodeStats(in);
- shardsStats = in.readArray(ShardStats::new, ShardStats[]::new);
+ this.shardsStats = in.readArray(ShardStats::new, ShardStats[]::new);
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_6_0)) {
searchUsageStats = new SearchUsageStats(in);
} else {
searchUsageStats = new SearchUsageStats();
}
+ if (in.getTransportVersion().onOrAfter(TransportVersions.REPOSITORIES_TELEMETRY)) {
+ repositoryUsageStats = RepositoryUsageStats.readFrom(in);
+ } else {
+ repositoryUsageStats = RepositoryUsageStats.EMPTY;
+ }
}
public ClusterStatsNodeResponse(
@@ -51,14 +55,16 @@ public ClusterStatsNodeResponse(
NodeInfo nodeInfo,
NodeStats nodeStats,
ShardStats[] shardsStats,
- SearchUsageStats searchUsageStats
+ SearchUsageStats searchUsageStats,
+ RepositoryUsageStats repositoryUsageStats
) {
super(node);
this.nodeInfo = nodeInfo;
this.nodeStats = nodeStats;
this.shardsStats = shardsStats;
this.clusterStatus = clusterStatus;
- this.searchUsageStats = searchUsageStats;
+ this.searchUsageStats = Objects.requireNonNull(searchUsageStats);
+ this.repositoryUsageStats = Objects.requireNonNull(repositoryUsageStats);
}
public NodeInfo nodeInfo() {
@@ -85,20 +91,22 @@ public SearchUsageStats searchUsageStats() {
return searchUsageStats;
}
+ public RepositoryUsageStats repositoryUsageStats() {
+ return repositoryUsageStats;
+ }
+
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
- if (clusterStatus == null) {
- out.writeBoolean(false);
- } else {
- out.writeBoolean(true);
- out.writeByte(clusterStatus.value());
- }
+ out.writeOptionalWriteable(clusterStatus);
nodeInfo.writeTo(out);
nodeStats.writeTo(out);
out.writeArray(shardsStats);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_6_0)) {
searchUsageStats.writeTo(out);
}
+ if (out.getTransportVersion().onOrAfter(TransportVersions.REPOSITORIES_TELEMETRY)) {
+ repositoryUsageStats.writeTo(out);
+ } // else just drop these stats, ok for bwc
}
}
diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsResponse.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsResponse.java
index 36e7b247befac..b6dd40e8c8b79 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsResponse.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsResponse.java
@@ -30,6 +30,7 @@ public class ClusterStatsResponse extends BaseNodesResponse r.isEmpty() == false)
+ // stats should be the same on every node so just pick one of them
+ .findAny()
+ .orElse(RepositoryUsageStats.EMPTY);
}
public String getClusterUUID() {
@@ -113,6 +122,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field("snapshots");
clusterSnapshotStats.toXContent(builder, params);
+ builder.field("repositories");
+ repositoryUsageStats.toXContent(builder, params);
+
return builder;
}
diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/RepositoryUsageStats.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/RepositoryUsageStats.java
new file mode 100644
index 0000000000000..771aa0fbef842
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/RepositoryUsageStats.java
@@ -0,0 +1,59 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.admin.cluster.stats;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * Stats on repository feature usage exposed in cluster stats for telemetry.
+ *
+ * @param statsByType a count of the repositories using various named features, keyed by repository type and then by feature name.
+ */
+public record RepositoryUsageStats(Map> statsByType) implements Writeable, ToXContentObject {
+
+ public static final RepositoryUsageStats EMPTY = new RepositoryUsageStats(Map.of());
+
+ public static RepositoryUsageStats readFrom(StreamInput in) throws IOException {
+ final var statsByType = in.readMap(i -> i.readMap(StreamInput::readVLong));
+ if (statsByType.isEmpty()) {
+ return EMPTY;
+ } else {
+ return new RepositoryUsageStats(statsByType);
+ }
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeMap(statsByType, (o, m) -> o.writeMap(m, StreamOutput::writeVLong));
+ }
+
+ public boolean isEmpty() {
+ return statsByType.isEmpty();
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ for (Map.Entry> typeAndStats : statsByType.entrySet()) {
+ builder.startObject(typeAndStats.getKey());
+ for (Map.Entry statAndValue : typeAndStats.getValue().entrySet()) {
+ builder.field(statAndValue.getKey(), statAndValue.getValue());
+ }
+ builder.endObject();
+ }
+ return builder.endObject();
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java
index bcf49bca421f6..1912de3cfa4d2 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java
@@ -41,6 +41,7 @@
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.node.NodeService;
+import org.elasticsearch.repositories.RepositoriesService;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
@@ -78,6 +79,7 @@ public class TransportClusterStatsAction extends TransportNodesAction<
private final NodeService nodeService;
private final IndicesService indicesService;
+ private final RepositoriesService repositoriesService;
private final SearchUsageHolder searchUsageHolder;
private final MetadataStatsCache mappingStatsCache;
@@ -90,6 +92,7 @@ public TransportClusterStatsAction(
TransportService transportService,
NodeService nodeService,
IndicesService indicesService,
+ RepositoriesService repositoriesService,
UsageService usageService,
ActionFilters actionFilters
) {
@@ -103,6 +106,7 @@ public TransportClusterStatsAction(
);
this.nodeService = nodeService;
this.indicesService = indicesService;
+ this.repositoriesService = repositoriesService;
this.searchUsageHolder = usageService.getSearchUsageHolder();
this.mappingStatsCache = new MetadataStatsCache<>(threadPool.getThreadContext(), MappingStats::of);
this.analysisStatsCache = new MetadataStatsCache<>(threadPool.getThreadContext(), AnalysisStats::of);
@@ -237,12 +241,14 @@ protected ClusterStatsNodeResponse nodeOperation(ClusterStatsNodeRequest nodeReq
}
}
- ClusterHealthStatus clusterStatus = null;
- if (clusterService.state().nodes().isLocalNodeElectedMaster()) {
- clusterStatus = new ClusterStateHealth(clusterService.state()).getStatus();
- }
+ final ClusterState clusterState = clusterService.state();
+ final ClusterHealthStatus clusterStatus = clusterState.nodes().isLocalNodeElectedMaster()
+ ? new ClusterStateHealth(clusterState).getStatus()
+ : null;
+
+ final SearchUsageStats searchUsageStats = searchUsageHolder.getSearchUsageStats();
- SearchUsageStats searchUsageStats = searchUsageHolder.getSearchUsageStats();
+ final RepositoryUsageStats repositoryUsageStats = repositoriesService.getUsageStats();
return new ClusterStatsNodeResponse(
nodeInfo.getNode(),
@@ -250,7 +256,8 @@ protected ClusterStatsNodeResponse nodeOperation(ClusterStatsNodeRequest nodeReq
nodeInfo,
nodeStats,
shardsStats.toArray(new ShardStats[shardsStats.size()]),
- searchUsageStats
+ searchUsageStats,
+ repositoryUsageStats
);
}
diff --git a/server/src/main/java/org/elasticsearch/cluster/health/ClusterHealthStatus.java b/server/src/main/java/org/elasticsearch/cluster/health/ClusterHealthStatus.java
index d025ddab26af6..c53395b5d76c1 100644
--- a/server/src/main/java/org/elasticsearch/cluster/health/ClusterHealthStatus.java
+++ b/server/src/main/java/org/elasticsearch/cluster/health/ClusterHealthStatus.java
@@ -19,7 +19,7 @@ public enum ClusterHealthStatus implements Writeable {
YELLOW((byte) 1),
RED((byte) 2);
- private byte value;
+ private final byte value;
ClusterHealthStatus(byte value) {
this.value = value;
diff --git a/server/src/main/java/org/elasticsearch/repositories/RepositoriesFeatures.java b/server/src/main/java/org/elasticsearch/repositories/RepositoriesFeatures.java
new file mode 100644
index 0000000000000..141dac0c5c430
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/repositories/RepositoriesFeatures.java
@@ -0,0 +1,23 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.repositories;
+
+import org.elasticsearch.features.FeatureSpecification;
+import org.elasticsearch.features.NodeFeature;
+
+import java.util.Set;
+
+public class RepositoriesFeatures implements FeatureSpecification {
+ public static final NodeFeature SUPPORTS_REPOSITORIES_USAGE_STATS = new NodeFeature("repositories.supports_usage_stats");
+
+ @Override
+ public Set getFeatures() {
+ return Set.of(SUPPORTS_REPOSITORIES_USAGE_STATS);
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/repositories/RepositoriesService.java b/server/src/main/java/org/elasticsearch/repositories/RepositoriesService.java
index de4ae1051ba62..732a18dffe233 100644
--- a/server/src/main/java/org/elasticsearch/repositories/RepositoriesService.java
+++ b/server/src/main/java/org/elasticsearch/repositories/RepositoriesService.java
@@ -14,6 +14,7 @@
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.admin.cluster.repositories.delete.DeleteRepositoryRequest;
import org.elasticsearch.action.admin.cluster.repositories.put.PutRepositoryRequest;
+import org.elasticsearch.action.admin.cluster.stats.RepositoryUsageStats;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.node.NodeClient;
@@ -944,15 +945,33 @@ public List> getPreRestoreVersionChecks() {
return preRestoreChecks;
}
- @Override
- protected void doStart() {
+ public static String COUNT_USAGE_STATS_NAME = "count";
+ public RepositoryUsageStats getUsageStats() {
+ if (repositories.isEmpty()) {
+ return RepositoryUsageStats.EMPTY;
+ }
+ final var statsByType = new HashMap>();
+ for (final var repository : repositories.values()) {
+ final var repositoryType = repository.getMetadata().type();
+ final var typeStats = statsByType.computeIfAbsent(repositoryType, ignored -> new HashMap<>());
+ typeStats.compute(COUNT_USAGE_STATS_NAME, (k, count) -> (count == null ? 0L : count) + 1);
+ final var repositoryUsageTags = repository.getUsageFeatures();
+ assert repositoryUsageTags.contains(COUNT_USAGE_STATS_NAME) == false : repositoryUsageTags;
+ for (final var repositoryUsageTag : repositoryUsageTags) {
+ typeStats.compute(repositoryUsageTag, (k, count) -> (count == null ? 0L : count) + 1);
+ }
+ }
+ return new RepositoryUsageStats(
+ statsByType.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> Map.copyOf(e.getValue())))
+ );
}
@Override
- protected void doStop() {
+ protected void doStart() {}
- }
+ @Override
+ protected void doStop() {}
@Override
protected void doClose() throws IOException {
diff --git a/server/src/main/java/org/elasticsearch/repositories/Repository.java b/server/src/main/java/org/elasticsearch/repositories/Repository.java
index fd52c21cad3f8..09f4782b6e5fa 100644
--- a/server/src/main/java/org/elasticsearch/repositories/Repository.java
+++ b/server/src/main/java/org/elasticsearch/repositories/Repository.java
@@ -312,6 +312,14 @@ void cloneShardSnapshot(
*/
void awaitIdle();
+ /**
+ * @return a set of the names of the features that this repository instance uses, for reporting in the cluster stats for telemetry
+ * collection.
+ */
+ default Set getUsageFeatures() {
+ return Set.of();
+ }
+
static boolean assertSnapshotMetaThread() {
return ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SNAPSHOT_META);
}
diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java
index e8af752bec179..cc56e940530e8 100644
--- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java
+++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java
@@ -3943,4 +3943,29 @@ public String getAnalysisFailureExtraDetail() {
ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS
);
}
+
+ public static final String READ_ONLY_USAGE_STATS_NAME = "read_only";
+ public static final String READ_WRITE_USAGE_STATS_NAME = "read_write";
+
+ @Override
+ public final Set getUsageFeatures() {
+ final var extraUsageFeatures = getExtraUsageFeatures();
+ assert extraUsageFeatures.contains(READ_ONLY_USAGE_STATS_NAME) == false : extraUsageFeatures;
+ assert extraUsageFeatures.contains(READ_WRITE_USAGE_STATS_NAME) == false : extraUsageFeatures;
+ return Set.copyOf(
+ Stream.concat(Stream.of(isReadOnly() ? READ_ONLY_USAGE_STATS_NAME : READ_WRITE_USAGE_STATS_NAME), extraUsageFeatures.stream())
+ .toList()
+ );
+ }
+
+ /**
+ * All blob-store repositories include the counts of read-only and read-write repositories in their telemetry. This method returns other
+ * features of the repositories in use.
+ *
+ * @return a set of the names of the extra features that this repository instance uses, for reporting in the cluster stats for telemetry
+ * collection.
+ */
+ protected Set getExtraUsageFeatures() {
+ return Set.of();
+ }
}
diff --git a/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification b/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification
index baf7e53345944..90a1c29972ff3 100644
--- a/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification
+++ b/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification
@@ -13,6 +13,7 @@ org.elasticsearch.cluster.service.TransportFeatures
org.elasticsearch.cluster.metadata.MetadataFeatures
org.elasticsearch.rest.RestFeatures
org.elasticsearch.indices.IndicesFeatures
+org.elasticsearch.repositories.RepositoriesFeatures
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures
org.elasticsearch.index.mapper.MapperFeatures
org.elasticsearch.ingest.IngestGeoIpFeatures
diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/VersionStatsTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/VersionStatsTests.java
index 49528c204b042..20eae9833e4b0 100644
--- a/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/VersionStatsTests.java
+++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/VersionStatsTests.java
@@ -127,7 +127,8 @@ public void testCreation() {
null,
null,
new ShardStats[] { shardStats },
- null
+ new SearchUsageStats(),
+ RepositoryUsageStats.EMPTY
);
stats = VersionStats.of(metadata, Collections.singletonList(nodeResponse));
diff --git a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/collector/cluster/ClusterStatsMonitoringDocTests.java b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/collector/cluster/ClusterStatsMonitoringDocTests.java
index c89638045a5a8..4a695f7c51e4c 100644
--- a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/collector/cluster/ClusterStatsMonitoringDocTests.java
+++ b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/collector/cluster/ClusterStatsMonitoringDocTests.java
@@ -15,6 +15,7 @@
import org.elasticsearch.action.admin.cluster.stats.ClusterStatsNodeResponse;
import org.elasticsearch.action.admin.cluster.stats.ClusterStatsResponse;
import org.elasticsearch.action.admin.cluster.stats.MappingStats;
+import org.elasticsearch.action.admin.cluster.stats.RepositoryUsageStats;
import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats;
import org.elasticsearch.action.admin.cluster.stats.VersionStats;
import org.elasticsearch.action.admin.indices.stats.CommonStats;
@@ -420,6 +421,7 @@ public void testToXContent() throws IOException {
when(mockNodeResponse.nodeStats()).thenReturn(mockNodeStats);
when(mockNodeResponse.shardsStats()).thenReturn(new ShardStats[] { mockShardStats });
when(mockNodeResponse.searchUsageStats()).thenReturn(new SearchUsageStats());
+ when(mockNodeResponse.repositoryUsageStats()).thenReturn(RepositoryUsageStats.EMPTY);
final Metadata metadata = testClusterState.metadata();
final ClusterStatsResponse clusterStatsResponse = new ClusterStatsResponse(
@@ -533,7 +535,9 @@ public void testToXContent() throws IOException {
"fielddata": {
"memory_size_in_bytes": 1,
"evictions": 0,
- "global_ordinals":{"build_time_in_millis":1}
+ "global_ordinals": {
+ "build_time_in_millis": 1
+ }
},
"query_cache": {
"memory_size_in_bytes": 0,
@@ -563,9 +567,9 @@ public void testToXContent() throws IOException {
"file_sizes": {}
},
"mappings": {
- "total_field_count" : 0,
- "total_deduplicated_field_count" : 0,
- "total_deduplicated_mapping_size_in_bytes" : 0,
+ "total_field_count": 0,
+ "total_deduplicated_field_count": 0,
+ "total_deduplicated_mapping_size_in_bytes": 0,
"field_types": [],
"runtime_field_types": []
},
@@ -581,11 +585,11 @@ public void testToXContent() throws IOException {
"synonyms": {}
},
"versions": [],
- "search" : {
- "total" : 0,
- "queries" : {},
- "rescorers" : {},
- "sections" : {}
+ "search": {
+ "total": 0,
+ "queries": {},
+ "rescorers": {},
+ "sections": {}
},
"dense_vector": {
"value_count": 0
@@ -749,7 +753,8 @@ public void testToXContent() throws IOException {
"cleanups": 0
},
"repositories": {}
- }
+ },
+ "repositories": {}
},
"cluster_state": {
"nodes_hash": 1314980060,
diff --git a/x-pack/plugin/snapshot-repo-test-kit/qa/azure/src/javaRestTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/AzureRepositoryAnalysisRestIT.java b/x-pack/plugin/snapshot-repo-test-kit/qa/azure/src/javaRestTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/AzureRepositoryAnalysisRestIT.java
index ecc8401e1d79a..a9b8fe51c01cc 100644
--- a/x-pack/plugin/snapshot-repo-test-kit/qa/azure/src/javaRestTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/AzureRepositoryAnalysisRestIT.java
+++ b/x-pack/plugin/snapshot-repo-test-kit/qa/azure/src/javaRestTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/AzureRepositoryAnalysisRestIT.java
@@ -8,6 +8,8 @@
import fixture.azure.AzureHttpFixture;
+import org.apache.http.client.methods.HttpGet;
+import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Booleans;
@@ -15,15 +17,20 @@
import org.elasticsearch.test.TestTrustStore;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.util.resource.Resource;
+import org.elasticsearch.test.rest.ObjectPath;
+import org.hamcrest.Matcher;
import org.junit.ClassRule;
import org.junit.rules.RuleChain;
import org.junit.rules.TestRule;
+import java.io.IOException;
import java.util.Map;
import java.util.function.Predicate;
import static org.hamcrest.Matchers.blankOrNullString;
+import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
public class AzureRepositoryAnalysisRestIT extends AbstractRepositoryAnalysisRestTestCase {
private static final boolean USE_FIXTURE = Booleans.parseBoolean(System.getProperty("test.azure.fixture", "true"));
@@ -119,4 +126,34 @@ protected Settings repositorySettings() {
return Settings.builder().put("client", "repository_test_kit").put("container", container).put("base_path", basePath).build();
}
+
+ public void testClusterStats() throws IOException {
+ registerRepository(randomIdentifier(), repositoryType(), true, repositorySettings());
+
+ final var request = new Request(HttpGet.METHOD_NAME, "/_cluster/stats");
+ final var response = client().performRequest(request);
+ assertOK(response);
+
+ final var objectPath = ObjectPath.createFromResponse(response);
+ assertThat(objectPath.evaluate("repositories.azure.count"), isSetIff(true));
+ assertThat(objectPath.evaluate("repositories.azure.read_write"), isSetIff(true));
+
+ assertThat(objectPath.evaluate("repositories.azure.uses_key_credentials"), isSetIff(Strings.hasText(AZURE_TEST_KEY)));
+ assertThat(objectPath.evaluate("repositories.azure.uses_sas_token"), isSetIff(Strings.hasText(AZURE_TEST_SASTOKEN)));
+ assertThat(
+ objectPath.evaluate("repositories.azure.uses_default_credentials"),
+ isSetIff((Strings.hasText(AZURE_TEST_SASTOKEN) || Strings.hasText(AZURE_TEST_KEY)) == false)
+ );
+ assertThat(
+ objectPath.evaluate("repositories.azure.uses_managed_identity"),
+ isSetIff(
+ (Strings.hasText(AZURE_TEST_SASTOKEN) || Strings.hasText(AZURE_TEST_KEY) || Strings.hasText(AZURE_TEST_CLIENT_ID)) == false
+ )
+ );
+ assertThat(objectPath.evaluate("repositories.azure.uses_workload_identity"), isSetIff(Strings.hasText(AZURE_TEST_CLIENT_ID)));
+ }
+
+ private static Matcher isSetIff(boolean predicate) {
+ return predicate ? equalTo(1) : nullValue(Integer.class);
+ }
}
From b7e1d5593b42f03aecc387160af6f452c4d25351 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Johannes=20Fred=C3=A9n?=
<109296772+jfreden@users.noreply.github.com>
Date: Tue, 27 Aug 2024 15:45:53 +0200
Subject: [PATCH 14/46] Fix connection timeout for OpenIdConnectAuthenticator
get Userinfo (#112230)
* Fix connection timeout for OpenIdConnectAuthenticator get Userinfo
* Update docs/changelog/112230.yaml
---
docs/changelog/112230.yaml | 5 +++++
.../security/authc/oidc/OpenIdConnectAuthenticator.java | 2 +-
2 files changed, 6 insertions(+), 1 deletion(-)
create mode 100644 docs/changelog/112230.yaml
diff --git a/docs/changelog/112230.yaml b/docs/changelog/112230.yaml
new file mode 100644
index 0000000000000..ef12dc3f78267
--- /dev/null
+++ b/docs/changelog/112230.yaml
@@ -0,0 +1,5 @@
+pr: 112230
+summary: Fix connection timeout for `OpenIdConnectAuthenticator` get Userinfo
+area: Security
+type: bug
+issues: []
diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java
index 0f34850b861b7..c2e0caf7234cb 100644
--- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java
+++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java
@@ -718,7 +718,7 @@ private CloseableHttpAsyncClient createHttpClient() {
connectionManager.setMaxTotal(realmConfig.getSetting(HTTP_MAX_CONNECTIONS));
final RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(Math.toIntExact(realmConfig.getSetting(HTTP_CONNECT_TIMEOUT).getMillis()))
- .setConnectionRequestTimeout(Math.toIntExact(realmConfig.getSetting(HTTP_CONNECTION_READ_TIMEOUT).getSeconds()))
+ .setConnectionRequestTimeout(Math.toIntExact(realmConfig.getSetting(HTTP_CONNECTION_READ_TIMEOUT).getMillis()))
.setSocketTimeout(Math.toIntExact(realmConfig.getSetting(HTTP_SOCKET_TIMEOUT).getMillis()))
.build();
From b14bada16f3c66598e18393d8d30271a81096ec3 Mon Sep 17 00:00:00 2001
From: Pat Whelan
Date: Tue, 27 Aug 2024 10:44:29 -0400
Subject: [PATCH 15/46] [ML] Update inference interfaces for streaming
(#112234)
Using InferenceServiceResults and InferenceAction to stream
ChunkedToXContent through to the Rest handler.
---
.../inference/InferenceServiceResults.java | 24 ++++++++++++++++---
.../inference/action/InferenceAction.java | 20 ++++++++++++++++
2 files changed, 41 insertions(+), 3 deletions(-)
diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
index f8330404c1538..0000e0ddc9af9 100644
--- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
+++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
@@ -13,17 +13,18 @@
import java.util.List;
import java.util.Map;
+import java.util.concurrent.Flow;
public interface InferenceServiceResults extends NamedWriteable, ChunkedToXContent {
/**
- * Transform the result to match the format required for the TransportCoordinatedInferenceAction.
+ *
Transform the result to match the format required for the TransportCoordinatedInferenceAction.
* For the inference plugin TextEmbeddingResults, the {@link #transformToLegacyFormat()} transforms the
* results into an intermediate format only used by the plugin's return value. It doesn't align with what the
* TransportCoordinatedInferenceAction expects. TransportCoordinatedInferenceAction expects an ml plugin
- * TextEmbeddingResults.
+ * TextEmbeddingResults.
*
- * For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.
+ *
For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.
*/
List extends InferenceResults> transformToCoordinationFormat();
@@ -37,4 +38,21 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
* Convert the result to a map to aid with test assertions
*/
Map asMap();
+
+ /**
+ * Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload.
+ * Defaults to {@code false}.
+ */
+ default boolean isStreaming() {
+ return false;
+ }
+
+ /**
+ * When {@link #isStreaming()} is {@code true}, the InferenceAction.Results will subscribe to this publisher.
+ * Implementations should follow the {@link java.util.concurrent.Flow.Publisher} spec to stream the chunks.
+ */
+ default Flow.Publisher publisher() {
+ assert isStreaming() == false : "This must be implemented when isStreaming() == true";
+ throw new UnsupportedOperationException("This must be implemented when isStreaming() == true");
+ }
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
index 7ecb5aef4ce8d..c38f508db1b6a 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
@@ -17,6 +17,7 @@
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.core.TimeValue;
@@ -40,6 +41,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.concurrent.Flow;
import static org.elasticsearch.core.Strings.format;
@@ -391,6 +393,24 @@ public InferenceServiceResults getResults() {
return results;
}
+ /**
+ * Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload.
+ * Currently set to false while it is being implemented.
+ */
+ public boolean isStreaming() {
+ return false;
+ }
+
+ /**
+ * When {@link #isStreaming()} is {@code true}, the RestHandler will subscribe to this publisher.
+ * When the RestResponse is finished with the current chunk, it will request the next chunk using the subscription.
+ * If the RestResponse is closed, it will cancel the subscription.
+ */
+ public Flow.Publisher publisher() {
+ assert isStreaming() == false : "This must be implemented when isStreaming() == true";
+ throw new UnsupportedOperationException("This must be implemented when isStreaming() == true");
+ }
+
@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
From b43470feeb82d602f549b6dfee9243d9afa6ce25 Mon Sep 17 00:00:00 2001
From: Oleksandr Kolomiiets
Date: Tue, 27 Aug 2024 07:55:50 -0700
Subject: [PATCH 16/46] Fix nested field generation in
StandardVersusLogsIndexModeRandomDataChallengeRestIT (#112223)
---
.../logsdb/datageneration/fields/Context.java | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/fields/Context.java b/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/fields/Context.java
index 647d5bff152d1..62130967508f6 100644
--- a/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/fields/Context.java
+++ b/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/fields/Context.java
@@ -13,6 +13,7 @@
import org.elasticsearch.logsdb.datageneration.datasource.DataSourceResponse;
import java.util.Optional;
+import java.util.concurrent.atomic.AtomicInteger;
class Context {
private final DataGeneratorSpecification specification;
@@ -21,13 +22,14 @@ class Context {
private final DataSourceResponse.FieldTypeGenerator fieldTypeGenerator;
private final DataSourceResponse.ObjectArrayGenerator objectArrayGenerator;
private final int objectDepth;
- private final int nestedFieldsCount;
+ // We don't need atomicity, but we need to pass counter by reference to accumulate total value from sub-objects.
+ private final AtomicInteger nestedFieldsCount;
Context(DataGeneratorSpecification specification) {
- this(specification, 0, 0);
+ this(specification, 0, new AtomicInteger(0));
}
- private Context(DataGeneratorSpecification specification, int objectDepth, int nestedFieldsCount) {
+ private Context(DataGeneratorSpecification specification, int objectDepth, AtomicInteger nestedFieldsCount) {
this.specification = specification;
this.childFieldGenerator = specification.dataSource().get(new DataSourceRequest.ChildFieldGenerator(specification));
this.fieldTypeGenerator = specification.dataSource().get(new DataSourceRequest.FieldTypeGenerator());
@@ -53,7 +55,8 @@ public Context subObject() {
}
public Context nestedObject() {
- return new Context(specification, objectDepth + 1, nestedFieldsCount + 1);
+ nestedFieldsCount.incrementAndGet();
+ return new Context(specification, objectDepth + 1, nestedFieldsCount);
}
public boolean shouldAddObjectField() {
@@ -63,7 +66,7 @@ public boolean shouldAddObjectField() {
public boolean shouldAddNestedField() {
return childFieldGenerator.generateNestedSubObject()
&& objectDepth < specification.maxObjectDepth()
- && nestedFieldsCount < specification.nestedFieldsLimit();
+ && nestedFieldsCount.get() < specification.nestedFieldsLimit();
}
public Optional generateObjectArray() {
From ed515138160da2b2431fd93462d3f3b7178e2e1b Mon Sep 17 00:00:00 2001
From: Nik Everett
Date: Tue, 27 Aug 2024 10:57:17 -0400
Subject: [PATCH 17/46] ESQL: Remove `LogicalPlan` from old serialization
(#112237)
This removes `LogicalPlan` subclasses from `PlanNamedTypes` because it
is no longer used.
---
.../xpack/esql/io/stream/PlanNamedTypes.java | 35 +------------
.../esql/io/stream/PlanNamedTypesTests.java | 52 -------------------
2 files changed, 1 insertion(+), 86 deletions(-)
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java
index 180ba8c028e6a..77d982453203c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java
@@ -23,24 +23,9 @@
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.index.EsIndex;
-import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
-import org.elasticsearch.xpack.esql.plan.logical.Dissect;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
-import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
-import org.elasticsearch.xpack.esql.plan.logical.Eval;
-import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.Grok;
-import org.elasticsearch.xpack.esql.plan.logical.InlineStats;
-import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
-import org.elasticsearch.xpack.esql.plan.logical.Lookup;
-import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
-import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
-import org.elasticsearch.xpack.esql.plan.logical.Project;
-import org.elasticsearch.xpack.esql.plan.logical.TopN;
-import org.elasticsearch.xpack.esql.plan.logical.join.Join;
-import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
-import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.DissectExec;
import org.elasticsearch.xpack.esql.plan.physical.EnrichExec;
@@ -132,25 +117,7 @@ public static List namedTypeEntries() {
of(PhysicalPlan.class, ProjectExec.class, PlanNamedTypes::writeProjectExec, PlanNamedTypes::readProjectExec),
of(PhysicalPlan.class, RowExec.class, PlanNamedTypes::writeRowExec, PlanNamedTypes::readRowExec),
of(PhysicalPlan.class, ShowExec.class, PlanNamedTypes::writeShowExec, PlanNamedTypes::readShowExec),
- of(PhysicalPlan.class, TopNExec.class, PlanNamedTypes::writeTopNExec, PlanNamedTypes::readTopNExec),
- // Logical Plan Nodes - a subset of plans that end up being actually serialized
- of(LogicalPlan.class, Aggregate.ENTRY),
- of(LogicalPlan.class, Dissect.ENTRY),
- of(LogicalPlan.class, EsRelation.ENTRY),
- of(LogicalPlan.class, Eval.ENTRY),
- of(LogicalPlan.class, Enrich.ENTRY),
- of(LogicalPlan.class, EsqlProject.ENTRY),
- of(LogicalPlan.class, Filter.ENTRY),
- of(LogicalPlan.class, Grok.ENTRY),
- of(LogicalPlan.class, InlineStats.ENTRY),
- of(LogicalPlan.class, Join.ENTRY),
- of(LogicalPlan.class, Limit.ENTRY),
- of(LogicalPlan.class, LocalRelation.ENTRY),
- of(LogicalPlan.class, Lookup.ENTRY),
- of(LogicalPlan.class, MvExpand.ENTRY),
- of(LogicalPlan.class, OrderBy.ENTRY),
- of(LogicalPlan.class, Project.ENTRY),
- of(LogicalPlan.class, TopN.ENTRY)
+ of(PhysicalPlan.class, TopNExec.class, PlanNamedTypes::writeTopNExec, PlanNamedTypes::readTopNExec)
);
return declared;
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java
index e5f195b053349..56ab1bd41693e 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java
@@ -38,24 +38,6 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
-import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
-import org.elasticsearch.xpack.esql.plan.logical.Dissect;
-import org.elasticsearch.xpack.esql.plan.logical.Enrich;
-import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
-import org.elasticsearch.xpack.esql.plan.logical.Eval;
-import org.elasticsearch.xpack.esql.plan.logical.Filter;
-import org.elasticsearch.xpack.esql.plan.logical.Grok;
-import org.elasticsearch.xpack.esql.plan.logical.InlineStats;
-import org.elasticsearch.xpack.esql.plan.logical.Limit;
-import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
-import org.elasticsearch.xpack.esql.plan.logical.Lookup;
-import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
-import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
-import org.elasticsearch.xpack.esql.plan.logical.Project;
-import org.elasticsearch.xpack.esql.plan.logical.TopN;
-import org.elasticsearch.xpack.esql.plan.logical.join.Join;
-import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
-import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.DissectExec;
import org.elasticsearch.xpack.esql.plan.physical.EnrichExec;
@@ -130,40 +112,6 @@ public void testPhysicalPlanEntries() {
assertMap(actual, matchesList(expected));
}
- // List of known serializable logical plan nodes - this should be kept up to date or retrieved
- // programmatically.
- public static final List> LOGICAL_PLAN_NODE_CLS = List.of(
- Aggregate.class,
- Dissect.class,
- Enrich.class,
- EsRelation.class,
- EsqlProject.class,
- Eval.class,
- Filter.class,
- Grok.class,
- InlineStats.class,
- Join.class,
- Limit.class,
- LocalRelation.class,
- Lookup.class,
- MvExpand.class,
- OrderBy.class,
- Project.class,
- TopN.class
- );
-
- // Tests that all logical plan nodes have a suitably named serialization entry.
- public void testLogicalPlanEntries() {
- var expected = LOGICAL_PLAN_NODE_CLS.stream().map(Class::getSimpleName).toList();
- var actual = PlanNamedTypes.namedTypeEntries()
- .stream()
- .filter(e -> e.categoryClass().isAssignableFrom(LogicalPlan.class))
- .map(PlanNameRegistry.Entry::name)
- .sorted()
- .toList();
- assertMap(actual, matchesList(expected));
- }
-
// Tests that all names are unique - there should be a good reason if this is not the case.
public void testUniqueNames() {
var actual = PlanNamedTypes.namedTypeEntries().stream().map(PlanNameRegistry.Entry::name).distinct().toList();
From bd2d6aa55fdf839ca42ebf04a6493732b6c94b24 Mon Sep 17 00:00:00 2001
From: Lee Hinman
Date: Tue, 27 Aug 2024 09:14:49 -0600
Subject: [PATCH 18/46] Fix template alias parsing livelock (#112217)
* Fix template alias parsing livelock
This commit fixes an issue with templates parsing alias definitions that can cause the ES thread to
hang indefinitely. Due to the malformed alias definition, the parsing gets into a loop which never
exits. In this commit a null check in both the component template and alias parsing code is added,
which prevents the looping.
---
docs/changelog/112217.yaml | 5 +++++
.../cluster/metadata/AliasMetadata.java | 2 ++
.../cluster/metadata/Template.java | 6 +++++-
.../metadata/ComponentTemplateTests.java | 19 +++++++++++++++++++
4 files changed, 31 insertions(+), 1 deletion(-)
create mode 100644 docs/changelog/112217.yaml
diff --git a/docs/changelog/112217.yaml b/docs/changelog/112217.yaml
new file mode 100644
index 0000000000000..bb367d6128001
--- /dev/null
+++ b/docs/changelog/112217.yaml
@@ -0,0 +1,5 @@
+pr: 112217
+summary: Fix template alias parsing livelock
+area: Indices APIs
+type: bug
+issues: []
diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/AliasMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/AliasMetadata.java
index a0f4a929dafdb..ff412d629b3b1 100644
--- a/server/src/main/java/org/elasticsearch/cluster/metadata/AliasMetadata.java
+++ b/server/src/main/java/org/elasticsearch/cluster/metadata/AliasMetadata.java
@@ -396,6 +396,8 @@ public static AliasMetadata fromXContent(XContentParser parser) throws IOExcepti
} else if ("is_hidden".equals(currentFieldName)) {
builder.isHidden(parser.booleanValue());
}
+ } else if (token == null) {
+ throw new IllegalArgumentException("unexpected null token while parsing alias");
}
}
return builder.build();
diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/Template.java b/server/src/main/java/org/elasticsearch/cluster/metadata/Template.java
index 70440adc4ebbe..b044ef6042428 100644
--- a/server/src/main/java/org/elasticsearch/cluster/metadata/Template.java
+++ b/server/src/main/java/org/elasticsearch/cluster/metadata/Template.java
@@ -70,7 +70,11 @@ public class Template implements SimpleDiffable, ToXContentObject {
}, MAPPINGS, ObjectParser.ValueType.VALUE_OBJECT_ARRAY);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
Map aliasMap = new HashMap<>();
- while ((p.nextToken()) != XContentParser.Token.END_OBJECT) {
+ XContentParser.Token token;
+ while ((token = p.nextToken()) != XContentParser.Token.END_OBJECT) {
+ if (token == null) {
+ break;
+ }
AliasMetadata alias = AliasMetadata.Builder.fromXContent(p);
aliasMap.put(alias.alias(), alias);
}
diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/ComponentTemplateTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/ComponentTemplateTests.java
index b93ccb0f978af..6ca267e5c9df2 100644
--- a/server/src/test/java/org/elasticsearch/cluster/metadata/ComponentTemplateTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/metadata/ComponentTemplateTests.java
@@ -24,6 +24,7 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import java.io.IOException;
@@ -307,4 +308,22 @@ public void testXContentSerializationWithRolloverAndEffectiveRetention() throws
assertThat(serialized, not(containsString("effective_retention")));
}
}
+
+ public void testHangingParsing() throws IOException {
+ String cutDown = """
+ {
+ "template": {
+ "aliases": {
+ "foo": "bar"
+ },
+ "food": "eggplant"
+ },
+ "potato": true
+ }
+ """;
+
+ try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, cutDown)) {
+ expectThrows(Exception.class, () -> ComponentTemplate.parse(parser));
+ }
+ }
}
From ae41e9ab65ec4b662f21810e16292742897ee674 Mon Sep 17 00:00:00 2001
From: Patrick Doyle <810052+prdoyle@users.noreply.github.com>
Date: Tue, 27 Aug 2024 11:22:54 -0400
Subject: [PATCH 19/46] Pluggable BuiltInExecutorBuilders (#111939)
* Refactor: move static calculations to Util
* BuiltInExecutorBuilders
* Spotless
* Change to getBuilders
* Move helper functions back into ThreadPool
---
...sAvailabilityHealthIndicatorBenchmark.java | 3 +-
.../threadpool/ThreadPoolBridge.java | 3 +-
.../geoip/EnterpriseGeoIpDownloaderTests.java | 7 +-
.../ingest/geoip/GeoIpDownloaderTests.java | 7 +-
.../Netty4SizeHeaderFrameDecoderTests.java | 3 +-
.../ingest/LogstashInternalBridge.java | 5 +-
.../elasticsearch/node/NodeConstruction.java | 3 +
.../DefaultBuiltInExecutorBuilders.java | 215 ++++++++++++++++++
.../elasticsearch/threadpool/ThreadPool.java | 124 +---------
.../internal/BuiltInExecutorBuilders.java | 19 ++
.../threadpool/internal/package-info.java | 13 ++
.../TransportMultiSearchActionTests.java | 5 +-
.../search/TransportSearchActionTests.java | 3 +-
.../TransportActionFilterChainTests.java | 4 +-
.../action/support/TransportActionTests.java | 4 +-
.../AbstractClientHeadersTestCase.java | 3 +-
.../http/HttpClientStatsTrackerTests.java | 7 +-
.../threadpool/FixedThreadPoolTests.java | 2 +-
.../threadpool/ScalingThreadPoolTests.java | 2 +-
.../ScheduleWithFixedDelayTests.java | 6 +-
.../ThreadPoolSerializationTests.java | 2 +-
.../UpdateThreadPoolSettingsTests.java | 13 +-
.../ClusterConnectionManagerTests.java | 3 +-
.../threadpool/TestThreadPool.java | 7 +-
...seGeoIpDownloaderLicenseListenerTests.java | 7 +-
.../authc/AuthenticationServiceTests.java | 2 +
.../security/authc/TokenServiceTests.java | 2 +
...InternalEnrollmentTokenGeneratorTests.java | 2 +
.../apikey/RestCreateApiKeyActionTests.java | 3 +-
.../apikey/RestGetApiKeyActionTests.java | 3 +-
.../RestInvalidateApiKeyActionTests.java | 3 +-
.../apikey/RestQueryApiKeyActionTests.java | 3 +-
.../SecurityNetty4HeaderSizeLimitTests.java | 3 +-
33 files changed, 345 insertions(+), 146 deletions(-)
create mode 100644 server/src/main/java/org/elasticsearch/threadpool/DefaultBuiltInExecutorBuilders.java
create mode 100644 server/src/main/java/org/elasticsearch/threadpool/internal/BuiltInExecutorBuilders.java
create mode 100644 server/src/main/java/org/elasticsearch/threadpool/internal/package-info.java
diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/ShardsAvailabilityHealthIndicatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/ShardsAvailabilityHealthIndicatorBenchmark.java
index 8c5de05a01648..d7a72615f4b93 100644
--- a/benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/ShardsAvailabilityHealthIndicatorBenchmark.java
+++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/ShardsAvailabilityHealthIndicatorBenchmark.java
@@ -32,6 +32,7 @@
import org.elasticsearch.indices.SystemIndices;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.telemetry.metric.MeterRegistry;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -167,7 +168,7 @@ public void setUp() throws Exception {
.build();
Settings settings = Settings.builder().put("node.name", ShardsAvailabilityHealthIndicatorBenchmark.class.getSimpleName()).build();
- ThreadPool threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ ThreadPool threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
ClusterService clusterService = new ClusterService(
Settings.EMPTY,
diff --git a/libs/logstash-bridge/src/main/java/org/elasticsearch/logstashbridge/threadpool/ThreadPoolBridge.java b/libs/logstash-bridge/src/main/java/org/elasticsearch/logstashbridge/threadpool/ThreadPoolBridge.java
index 13218a9b206a5..30801b4f0b078 100644
--- a/libs/logstash-bridge/src/main/java/org/elasticsearch/logstashbridge/threadpool/ThreadPoolBridge.java
+++ b/libs/logstash-bridge/src/main/java/org/elasticsearch/logstashbridge/threadpool/ThreadPoolBridge.java
@@ -10,6 +10,7 @@
import org.elasticsearch.logstashbridge.StableBridgeAPI;
import org.elasticsearch.logstashbridge.common.SettingsBridge;
import org.elasticsearch.telemetry.metric.MeterRegistry;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import java.util.concurrent.TimeUnit;
@@ -17,7 +18,7 @@
public class ThreadPoolBridge extends StableBridgeAPI.Proxy {
public ThreadPoolBridge(final SettingsBridge settingsBridge) {
- this(new ThreadPool(settingsBridge.unwrap(), MeterRegistry.NOOP));
+ this(new ThreadPool(settingsBridge.unwrap(), MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders()));
}
public ThreadPoolBridge(final ThreadPool delegate) {
diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderTests.java
index 203ecaea72c0e..1676ce14698a9 100644
--- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderTests.java
+++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpDownloaderTests.java
@@ -40,6 +40,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.hamcrest.Matchers;
@@ -86,7 +87,11 @@ public void setup() throws IOException {
"e4a3411cdd7b21eaf18675da5a7f9f360d33c6882363b2c19c38715834c9e836 GeoIP2-City_20240709.tar.gz".getBytes(StandardCharsets.UTF_8)
);
clusterService = mock(ClusterService.class);
- threadPool = new ThreadPool(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(), MeterRegistry.NOOP);
+ threadPool = new ThreadPool(
+ Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(),
+ MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders()
+ );
when(clusterService.getClusterSettings()).thenReturn(
new ClusterSettings(Settings.EMPTY, Set.of(GeoIpDownloaderTaskExecutor.POLL_INTERVAL_SETTING))
);
diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java
index 984bd37181fe7..f213868fb65a1 100644
--- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java
+++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java
@@ -45,6 +45,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
@@ -92,7 +93,11 @@ public void setup() throws IOException {
httpClient = mock(HttpClient.class);
when(httpClient.getBytes(anyString())).thenReturn("[]".getBytes(StandardCharsets.UTF_8));
clusterService = mock(ClusterService.class);
- threadPool = new ThreadPool(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(), MeterRegistry.NOOP);
+ threadPool = new ThreadPool(
+ Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(),
+ MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders()
+ );
when(clusterService.getClusterSettings()).thenReturn(
new ClusterSettings(
Settings.EMPTY,
diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java
index 3e74a74dbd49c..ce7704e6e040c 100644
--- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java
+++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java
@@ -19,6 +19,7 @@
import org.elasticsearch.mocksocket.MockSocket;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportSettings;
import org.junit.After;
@@ -52,7 +53,7 @@ public class Netty4SizeHeaderFrameDecoderTests extends ESTestCase {
@Before
public void startThreadPool() {
- threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
NetworkService networkService = new NetworkService(Collections.emptyList());
PageCacheRecycler recycler = new MockPageCacheRecycler(Settings.EMPTY);
nettyTransport = new Netty4Transport(
diff --git a/server/src/main/java/org/elasticsearch/ingest/LogstashInternalBridge.java b/server/src/main/java/org/elasticsearch/ingest/LogstashInternalBridge.java
index 889a4ffb932f9..af0cb187ba05d 100644
--- a/server/src/main/java/org/elasticsearch/ingest/LogstashInternalBridge.java
+++ b/server/src/main/java/org/elasticsearch/ingest/LogstashInternalBridge.java
@@ -10,6 +10,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.telemetry.metric.MeterRegistry;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
/**
@@ -40,9 +41,9 @@ public static void resetReroute(final IngestDocument ingestDocument) {
/**
* @param settings
- * @return a new {@link ThreadPool} with a noop {@link MeterRegistry}
+ * @return a new {@link ThreadPool} with a noop {@link MeterRegistry} and default executors
*/
public static ThreadPool createThreadPool(final Settings settings) {
- return new ThreadPool(settings, MeterRegistry.NOOP);
+ return new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
}
}
diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java
index 9c5b72a573d44..ec0d293dc0064 100644
--- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java
+++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java
@@ -199,8 +199,10 @@
import org.elasticsearch.telemetry.TelemetryProvider;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.telemetry.tracing.Tracer;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.threadpool.internal.BuiltInExecutorBuilders;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.upgrades.SystemIndexMigrationExecutor;
@@ -480,6 +482,7 @@ private ThreadPool createThreadPool(Settings settings, MeterRegistry meterRegist
ThreadPool threadPool = new ThreadPool(
settings,
meterRegistry,
+ pluginsService.loadSingletonServiceProvider(BuiltInExecutorBuilders.class, DefaultBuiltInExecutorBuilders::new),
pluginsService.flatMap(p -> p.getExecutorBuilders(settings)).toArray(ExecutorBuilder>[]::new)
);
resourcesToClose.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS));
diff --git a/server/src/main/java/org/elasticsearch/threadpool/DefaultBuiltInExecutorBuilders.java b/server/src/main/java/org/elasticsearch/threadpool/DefaultBuiltInExecutorBuilders.java
new file mode 100644
index 0000000000000..bf7e2a7bbdc86
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/threadpool/DefaultBuiltInExecutorBuilders.java
@@ -0,0 +1,215 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.threadpool;
+
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.threadpool.internal.BuiltInExecutorBuilders;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static java.util.Collections.unmodifiableMap;
+import static org.elasticsearch.threadpool.ThreadPool.searchAutoscalingEWMA;
+
+public class DefaultBuiltInExecutorBuilders implements BuiltInExecutorBuilders {
+ @Override
+ @SuppressWarnings("rawtypes")
+ public Map getBuilders(Settings settings, int allocatedProcessors) {
+ final int halfProc = ThreadPool.halfAllocatedProcessors(allocatedProcessors);
+ final int halfProcMaxAt5 = ThreadPool.halfAllocatedProcessorsMaxFive(allocatedProcessors);
+ final int halfProcMaxAt10 = ThreadPool.halfAllocatedProcessorsMaxTen(allocatedProcessors);
+ final int genericThreadPoolMax = ThreadPool.boundedBy(4 * allocatedProcessors, 128, 512);
+
+ Map result = new HashMap<>();
+ result.put(
+ ThreadPool.Names.GENERIC,
+ new ScalingExecutorBuilder(ThreadPool.Names.GENERIC, 4, genericThreadPoolMax, TimeValue.timeValueSeconds(30), false)
+ );
+ result.put(
+ ThreadPool.Names.WRITE,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.WRITE,
+ allocatedProcessors,
+ 10000,
+ new EsExecutors.TaskTrackingConfig(true, 0.1)
+ )
+ );
+ int searchOrGetThreadPoolSize = ThreadPool.searchOrGetThreadPoolSize(allocatedProcessors);
+ result.put(
+ ThreadPool.Names.GET,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.GET,
+ searchOrGetThreadPoolSize,
+ 1000,
+ EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
+ )
+ );
+ result.put(
+ ThreadPool.Names.ANALYZE,
+ new FixedExecutorBuilder(settings, ThreadPool.Names.ANALYZE, 1, 16, EsExecutors.TaskTrackingConfig.DO_NOT_TRACK)
+ );
+ result.put(
+ ThreadPool.Names.SEARCH,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.SEARCH,
+ searchOrGetThreadPoolSize,
+ 1000,
+ new EsExecutors.TaskTrackingConfig(true, searchAutoscalingEWMA)
+ )
+ );
+ result.put(
+ ThreadPool.Names.SEARCH_WORKER,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.SEARCH_WORKER,
+ searchOrGetThreadPoolSize,
+ -1,
+ EsExecutors.TaskTrackingConfig.DEFAULT
+ )
+ );
+ result.put(
+ ThreadPool.Names.SEARCH_COORDINATION,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.SEARCH_COORDINATION,
+ halfProc,
+ 1000,
+ new EsExecutors.TaskTrackingConfig(true, searchAutoscalingEWMA)
+ )
+ );
+ result.put(
+ ThreadPool.Names.AUTO_COMPLETE,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.AUTO_COMPLETE,
+ Math.max(allocatedProcessors / 4, 1),
+ 100,
+ EsExecutors.TaskTrackingConfig.DEFAULT
+ )
+ );
+ result.put(
+ ThreadPool.Names.SEARCH_THROTTLED,
+ new FixedExecutorBuilder(settings, ThreadPool.Names.SEARCH_THROTTLED, 1, 100, EsExecutors.TaskTrackingConfig.DEFAULT)
+ );
+ result.put(
+ ThreadPool.Names.MANAGEMENT,
+ new ScalingExecutorBuilder(
+ ThreadPool.Names.MANAGEMENT,
+ 1,
+ ThreadPool.boundedBy(allocatedProcessors, 1, 5),
+ TimeValue.timeValueMinutes(5),
+ false
+ )
+ );
+ result.put(
+ ThreadPool.Names.FLUSH,
+ new ScalingExecutorBuilder(ThreadPool.Names.FLUSH, 1, halfProcMaxAt5, TimeValue.timeValueMinutes(5), false)
+ );
+ // TODO: remove (or refine) this temporary stateless custom refresh pool sizing once ES-7631 is solved.
+ final int refreshThreads = DiscoveryNode.isStateless(settings) ? allocatedProcessors : halfProcMaxAt10;
+ result.put(
+ ThreadPool.Names.REFRESH,
+ new ScalingExecutorBuilder(ThreadPool.Names.REFRESH, 1, refreshThreads, TimeValue.timeValueMinutes(5), false)
+ );
+ result.put(
+ ThreadPool.Names.WARMER,
+ new ScalingExecutorBuilder(ThreadPool.Names.WARMER, 1, halfProcMaxAt5, TimeValue.timeValueMinutes(5), false)
+ );
+ final int maxSnapshotCores = ThreadPool.getMaxSnapshotThreadPoolSize(allocatedProcessors);
+ result.put(
+ ThreadPool.Names.SNAPSHOT,
+ new ScalingExecutorBuilder(ThreadPool.Names.SNAPSHOT, 1, maxSnapshotCores, TimeValue.timeValueMinutes(5), false)
+ );
+ result.put(
+ ThreadPool.Names.SNAPSHOT_META,
+ new ScalingExecutorBuilder(
+ ThreadPool.Names.SNAPSHOT_META,
+ 1,
+ Math.min(allocatedProcessors * 3, 50),
+ TimeValue.timeValueSeconds(30L),
+ false
+ )
+ );
+ result.put(
+ ThreadPool.Names.FETCH_SHARD_STARTED,
+ new ScalingExecutorBuilder(
+ ThreadPool.Names.FETCH_SHARD_STARTED,
+ 1,
+ 2 * allocatedProcessors,
+ TimeValue.timeValueMinutes(5),
+ false
+ )
+ );
+ result.put(
+ ThreadPool.Names.FORCE_MERGE,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.FORCE_MERGE,
+ ThreadPool.oneEighthAllocatedProcessors(allocatedProcessors),
+ -1,
+ EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
+ )
+ );
+ result.put(
+ ThreadPool.Names.CLUSTER_COORDINATION,
+ new FixedExecutorBuilder(settings, ThreadPool.Names.CLUSTER_COORDINATION, 1, -1, EsExecutors.TaskTrackingConfig.DO_NOT_TRACK)
+ );
+ result.put(
+ ThreadPool.Names.FETCH_SHARD_STORE,
+ new ScalingExecutorBuilder(ThreadPool.Names.FETCH_SHARD_STORE, 1, 2 * allocatedProcessors, TimeValue.timeValueMinutes(5), false)
+ );
+ result.put(
+ ThreadPool.Names.SYSTEM_READ,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.SYSTEM_READ,
+ halfProcMaxAt5,
+ 2000,
+ EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
+ )
+ );
+ result.put(
+ ThreadPool.Names.SYSTEM_WRITE,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.SYSTEM_WRITE,
+ halfProcMaxAt5,
+ 1000,
+ new EsExecutors.TaskTrackingConfig(true, 0.1)
+ )
+ );
+ result.put(
+ ThreadPool.Names.SYSTEM_CRITICAL_READ,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.SYSTEM_CRITICAL_READ,
+ halfProcMaxAt5,
+ 2000,
+ EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
+ )
+ );
+ result.put(
+ ThreadPool.Names.SYSTEM_CRITICAL_WRITE,
+ new FixedExecutorBuilder(
+ settings,
+ ThreadPool.Names.SYSTEM_CRITICAL_WRITE,
+ halfProcMaxAt5,
+ 1500,
+ new EsExecutors.TaskTrackingConfig(true, 0.1)
+ )
+ );
+ return unmodifiableMap(result);
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java
index 29ab3ec7e0848..859c4a3e924c6 100644
--- a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java
+++ b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java
@@ -10,7 +10,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
-import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@@ -20,7 +19,6 @@
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.SizeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
-import org.elasticsearch.common.util.concurrent.EsExecutors.TaskTrackingConfig;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionHandler;
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
@@ -35,6 +33,7 @@
import org.elasticsearch.telemetry.metric.LongGauge;
import org.elasticsearch.telemetry.metric.LongWithAttributes;
import org.elasticsearch.telemetry.metric.MeterRegistry;
+import org.elasticsearch.threadpool.internal.BuiltInExecutorBuilders;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;
@@ -229,125 +228,22 @@ public Collection builders() {
*
* @param settings
* @param meterRegistry
+ * @param builtInExecutorBuilders used to construct builders for the built-in thread pools
* @param customBuilders a list of additional thread pool builders that were defined elsewhere (like a Plugin).
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
- public ThreadPool(final Settings settings, MeterRegistry meterRegistry, final ExecutorBuilder>... customBuilders) {
+ public ThreadPool(
+ final Settings settings,
+ MeterRegistry meterRegistry,
+ BuiltInExecutorBuilders builtInExecutorBuilders,
+ final ExecutorBuilder>... customBuilders
+ ) {
assert Node.NODE_NAME_SETTING.exists(settings);
- final Map builders = new HashMap<>();
- final int allocatedProcessors = EsExecutors.allocatedProcessors(settings);
- final int halfProc = halfAllocatedProcessors(allocatedProcessors);
- final int halfProcMaxAt5 = halfAllocatedProcessorsMaxFive(allocatedProcessors);
- final int halfProcMaxAt10 = halfAllocatedProcessorsMaxTen(allocatedProcessors);
- final int genericThreadPoolMax = boundedBy(4 * allocatedProcessors, 128, 512);
final Map> instruments = new HashMap<>();
+ final int allocatedProcessors = EsExecutors.allocatedProcessors(settings);
- builders.put(
- Names.GENERIC,
- new ScalingExecutorBuilder(Names.GENERIC, 4, genericThreadPoolMax, TimeValue.timeValueSeconds(30), false)
- );
- builders.put(
- Names.WRITE,
- new FixedExecutorBuilder(settings, Names.WRITE, allocatedProcessors, 10000, new TaskTrackingConfig(true, 0.1))
- );
- int searchOrGetThreadPoolSize = searchOrGetThreadPoolSize(allocatedProcessors);
- builders.put(
- Names.GET,
- new FixedExecutorBuilder(settings, Names.GET, searchOrGetThreadPoolSize, 1000, TaskTrackingConfig.DO_NOT_TRACK)
- );
- builders.put(Names.ANALYZE, new FixedExecutorBuilder(settings, Names.ANALYZE, 1, 16, TaskTrackingConfig.DO_NOT_TRACK));
- builders.put(
- Names.SEARCH,
- new FixedExecutorBuilder(
- settings,
- Names.SEARCH,
- searchOrGetThreadPoolSize,
- 1000,
- new TaskTrackingConfig(true, searchAutoscalingEWMA)
- )
- );
- builders.put(
- Names.SEARCH_WORKER,
- new FixedExecutorBuilder(settings, Names.SEARCH_WORKER, searchOrGetThreadPoolSize, -1, TaskTrackingConfig.DEFAULT)
- );
- builders.put(
- Names.SEARCH_COORDINATION,
- new FixedExecutorBuilder(
- settings,
- Names.SEARCH_COORDINATION,
- halfProc,
- 1000,
- new TaskTrackingConfig(true, searchAutoscalingEWMA)
- )
- );
- builders.put(
- Names.AUTO_COMPLETE,
- new FixedExecutorBuilder(settings, Names.AUTO_COMPLETE, Math.max(allocatedProcessors / 4, 1), 100, TaskTrackingConfig.DEFAULT)
- );
- builders.put(
- Names.SEARCH_THROTTLED,
- new FixedExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100, TaskTrackingConfig.DEFAULT)
- );
- builders.put(
- Names.MANAGEMENT,
- new ScalingExecutorBuilder(Names.MANAGEMENT, 1, boundedBy(allocatedProcessors, 1, 5), TimeValue.timeValueMinutes(5), false)
- );
- builders.put(Names.FLUSH, new ScalingExecutorBuilder(Names.FLUSH, 1, halfProcMaxAt5, TimeValue.timeValueMinutes(5), false));
- // TODO: remove (or refine) this temporary stateless custom refresh pool sizing once ES-7631 is solved.
- final int refreshThreads = DiscoveryNode.isStateless(settings) ? allocatedProcessors : halfProcMaxAt10;
- builders.put(Names.REFRESH, new ScalingExecutorBuilder(Names.REFRESH, 1, refreshThreads, TimeValue.timeValueMinutes(5), false));
- builders.put(Names.WARMER, new ScalingExecutorBuilder(Names.WARMER, 1, halfProcMaxAt5, TimeValue.timeValueMinutes(5), false));
- final int maxSnapshotCores = getMaxSnapshotThreadPoolSize(allocatedProcessors);
- builders.put(Names.SNAPSHOT, new ScalingExecutorBuilder(Names.SNAPSHOT, 1, maxSnapshotCores, TimeValue.timeValueMinutes(5), false));
- builders.put(
- Names.SNAPSHOT_META,
- new ScalingExecutorBuilder(
- Names.SNAPSHOT_META,
- 1,
- Math.min(allocatedProcessors * 3, 50),
- TimeValue.timeValueSeconds(30L),
- false
- )
- );
- builders.put(
- Names.FETCH_SHARD_STARTED,
- new ScalingExecutorBuilder(Names.FETCH_SHARD_STARTED, 1, 2 * allocatedProcessors, TimeValue.timeValueMinutes(5), false)
- );
- builders.put(
- Names.FORCE_MERGE,
- new FixedExecutorBuilder(
- settings,
- Names.FORCE_MERGE,
- oneEighthAllocatedProcessors(allocatedProcessors),
- -1,
- TaskTrackingConfig.DO_NOT_TRACK
- )
- );
- builders.put(
- Names.CLUSTER_COORDINATION,
- new FixedExecutorBuilder(settings, Names.CLUSTER_COORDINATION, 1, -1, TaskTrackingConfig.DO_NOT_TRACK)
- );
- builders.put(
- Names.FETCH_SHARD_STORE,
- new ScalingExecutorBuilder(Names.FETCH_SHARD_STORE, 1, 2 * allocatedProcessors, TimeValue.timeValueMinutes(5), false)
- );
- builders.put(
- Names.SYSTEM_READ,
- new FixedExecutorBuilder(settings, Names.SYSTEM_READ, halfProcMaxAt5, 2000, TaskTrackingConfig.DO_NOT_TRACK)
- );
- builders.put(
- Names.SYSTEM_WRITE,
- new FixedExecutorBuilder(settings, Names.SYSTEM_WRITE, halfProcMaxAt5, 1000, new TaskTrackingConfig(true, 0.1))
- );
- builders.put(
- Names.SYSTEM_CRITICAL_READ,
- new FixedExecutorBuilder(settings, Names.SYSTEM_CRITICAL_READ, halfProcMaxAt5, 2000, TaskTrackingConfig.DO_NOT_TRACK)
- );
- builders.put(
- Names.SYSTEM_CRITICAL_WRITE,
- new FixedExecutorBuilder(settings, Names.SYSTEM_CRITICAL_WRITE, halfProcMaxAt5, 1500, new TaskTrackingConfig(true, 0.1))
- );
+ final Map builders = new HashMap<>(builtInExecutorBuilders.getBuilders(settings, allocatedProcessors));
for (final ExecutorBuilder> builder : customBuilders) {
if (builders.containsKey(builder.name())) {
diff --git a/server/src/main/java/org/elasticsearch/threadpool/internal/BuiltInExecutorBuilders.java b/server/src/main/java/org/elasticsearch/threadpool/internal/BuiltInExecutorBuilders.java
new file mode 100644
index 0000000000000..6709685af6a75
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/threadpool/internal/BuiltInExecutorBuilders.java
@@ -0,0 +1,19 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.threadpool.internal;
+
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.threadpool.ExecutorBuilder;
+
+import java.util.Map;
+
+public interface BuiltInExecutorBuilders {
+ @SuppressWarnings("rawtypes")
+ Map getBuilders(Settings settings, int allocatedProcessors);
+}
diff --git a/server/src/main/java/org/elasticsearch/threadpool/internal/package-info.java b/server/src/main/java/org/elasticsearch/threadpool/internal/package-info.java
new file mode 100644
index 0000000000000..821e27a4a854f
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/threadpool/internal/package-info.java
@@ -0,0 +1,13 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+/**
+ * This package can be selectively exported to control access to features
+ * that are not intended to be available to any arbitrary plugin.
+ */
+package org.elasticsearch.threadpool.internal;
diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportMultiSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportMultiSearchActionTests.java
index 4fd3221ccb1b8..7694553f4de21 100644
--- a/server/src/test/java/org/elasticsearch/action/search/TransportMultiSearchActionTests.java
+++ b/server/src/test/java/org/elasticsearch/action/search/TransportMultiSearchActionTests.java
@@ -28,6 +28,7 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;
@@ -54,7 +55,7 @@ public void testParentTaskId() throws Exception {
Settings settings = Settings.builder().put("node.name", TransportMultiSearchActionTests.class.getSimpleName()).build();
ActionFilters actionFilters = mock(ActionFilters.class);
when(actionFilters.filters()).thenReturn(new ActionFilter[0]);
- ThreadPool threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ ThreadPool threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
try {
TransportService transportService = new TransportService(
Settings.EMPTY,
@@ -121,7 +122,7 @@ public void testBatchExecute() throws ExecutionException, InterruptedException {
Settings settings = Settings.builder().put("node.name", TransportMultiSearchActionTests.class.getSimpleName()).build();
ActionFilters actionFilters = mock(ActionFilters.class);
when(actionFilters.filters()).thenReturn(new ActionFilter[0]);
- ThreadPool threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ ThreadPool threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
index 6621f2055968f..487d8c6f3a7ee 100644
--- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
+++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
@@ -85,6 +85,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.TransportVersionUtils;
import org.elasticsearch.test.transport.MockTransportService;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.NodeDisconnectedException;
@@ -1722,7 +1723,7 @@ public void testCCSCompatibilityCheck() throws Exception {
ActionFilters actionFilters = mock(ActionFilters.class);
when(actionFilters.filters()).thenReturn(new ActionFilter[0]);
TransportVersion transportVersion = TransportVersionUtils.getNextVersion(TransportVersions.MINIMUM_CCS_VERSION, true);
- ThreadPool threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ ThreadPool threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
try {
TransportService transportService = MockTransportService.createNewService(
Settings.EMPTY,
diff --git a/server/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java b/server/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java
index f793255f3b98d..29963ac98d957 100644
--- a/server/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java
+++ b/server/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java
@@ -23,6 +23,7 @@
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
@@ -54,7 +55,8 @@ public void init() throws Exception {
counter = new AtomicInteger();
threadPool = new ThreadPool(
Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "TransportActionFilterChainTests").build(),
- MeterRegistry.NOOP
+ MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders()
);
}
diff --git a/server/src/test/java/org/elasticsearch/action/support/TransportActionTests.java b/server/src/test/java/org/elasticsearch/action/support/TransportActionTests.java
index 97fa537874397..2fa91937a7dcc 100644
--- a/server/src/test/java/org/elasticsearch/action/support/TransportActionTests.java
+++ b/server/src/test/java/org/elasticsearch/action/support/TransportActionTests.java
@@ -20,6 +20,7 @@
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
@@ -41,7 +42,8 @@ public class TransportActionTests extends ESTestCase {
public void init() throws Exception {
threadPool = new ThreadPool(
Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "TransportActionTests").build(),
- MeterRegistry.NOOP
+ MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders()
);
}
diff --git a/server/src/test/java/org/elasticsearch/client/internal/AbstractClientHeadersTestCase.java b/server/src/test/java/org/elasticsearch/client/internal/AbstractClientHeadersTestCase.java
index e946d355ebb32..453ca8ab873b6 100644
--- a/server/src/test/java/org/elasticsearch/client/internal/AbstractClientHeadersTestCase.java
+++ b/server/src/test/java/org/elasticsearch/client/internal/AbstractClientHeadersTestCase.java
@@ -30,6 +30,7 @@
import org.elasticsearch.env.Environment;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
@@ -79,7 +80,7 @@ public void setUp() throws Exception {
.put("node.name", "test-" + getTestName())
.put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString())
.build();
- threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
client = buildClient(settings, ACTIONS);
}
diff --git a/server/src/test/java/org/elasticsearch/http/HttpClientStatsTrackerTests.java b/server/src/test/java/org/elasticsearch/http/HttpClientStatsTrackerTests.java
index 2dfaaf34bb1f1..0ae494428939b 100644
--- a/server/src/test/java/org/elasticsearch/http/HttpClientStatsTrackerTests.java
+++ b/server/src/test/java/org/elasticsearch/http/HttpClientStatsTrackerTests.java
@@ -19,6 +19,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import java.net.InetSocketAddress;
@@ -438,7 +439,11 @@ private static class FakeTimeThreadPool extends ThreadPool {
private final long absoluteTimeOffset = randomLong();
FakeTimeThreadPool() {
- super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(), MeterRegistry.NOOP);
+ super(
+ Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(),
+ MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders()
+ );
stopCachedTimeThread();
setRandomTime();
}
diff --git a/server/src/test/java/org/elasticsearch/threadpool/FixedThreadPoolTests.java b/server/src/test/java/org/elasticsearch/threadpool/FixedThreadPoolTests.java
index 6be78f27135a5..7ece91c48da6f 100644
--- a/server/src/test/java/org/elasticsearch/threadpool/FixedThreadPoolTests.java
+++ b/server/src/test/java/org/elasticsearch/threadpool/FixedThreadPoolTests.java
@@ -34,7 +34,7 @@ public void testRejectedExecutionCounter() throws InterruptedException {
.put("thread_pool." + threadPoolName + ".queue_size", queueSize)
.build();
try {
- threadPool = new ThreadPool(nodeSettings, MeterRegistry.NOOP);
+ threadPool = new ThreadPool(nodeSettings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
// these tasks will consume the thread pool causing further
// submissions to queue
diff --git a/server/src/test/java/org/elasticsearch/threadpool/ScalingThreadPoolTests.java b/server/src/test/java/org/elasticsearch/threadpool/ScalingThreadPoolTests.java
index 9a0c5c4b75d54..8c2ed2cc5bf00 100644
--- a/server/src/test/java/org/elasticsearch/threadpool/ScalingThreadPoolTests.java
+++ b/server/src/test/java/org/elasticsearch/threadpool/ScalingThreadPoolTests.java
@@ -425,7 +425,7 @@ public void runScalingThreadPoolTest(final Settings settings, final BiConsumer... customBuilders) {
}
public TestThreadPool(String name, Settings settings, ExecutorBuilder>... customBuilders) {
- super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), MeterRegistry.NOOP, customBuilders);
+ super(
+ Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(),
+ MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders(),
+ customBuilders
+ );
}
@Override
diff --git a/x-pack/plugin/geoip-enterprise-downloader/src/test/java/org/elasticsearch/xpack/geoip/EnterpriseGeoIpDownloaderLicenseListenerTests.java b/x-pack/plugin/geoip-enterprise-downloader/src/test/java/org/elasticsearch/xpack/geoip/EnterpriseGeoIpDownloaderLicenseListenerTests.java
index 5a5aacd392f3c..8b5b2b84c3ca8 100644
--- a/x-pack/plugin/geoip-enterprise-downloader/src/test/java/org/elasticsearch/xpack/geoip/EnterpriseGeoIpDownloaderLicenseListenerTests.java
+++ b/x-pack/plugin/geoip-enterprise-downloader/src/test/java/org/elasticsearch/xpack/geoip/EnterpriseGeoIpDownloaderLicenseListenerTests.java
@@ -33,6 +33,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
@@ -50,7 +51,11 @@ public class EnterpriseGeoIpDownloaderLicenseListenerTests extends ESTestCase {
@Before
public void setup() {
- threadPool = new ThreadPool(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(), MeterRegistry.NOOP);
+ threadPool = new ThreadPool(
+ Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(),
+ MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders()
+ );
}
@After
diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java
index 85a1dc1aa029d..e1c3b936e5a32 100644
--- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java
+++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java
@@ -55,6 +55,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLog;
import org.elasticsearch.test.rest.FakeRestRequest;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
@@ -263,6 +264,7 @@ public void init() throws Exception {
threadPool = new ThreadPool(
settings,
MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders(),
new FixedExecutorBuilder(
settings,
THREAD_POOL_NAME,
diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java
index adf0b44266260..e53fa83b89617 100644
--- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java
+++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java
@@ -67,6 +67,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.EqualsHashCodeTestUtils;
import org.elasticsearch.test.XContentTestUtils;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
@@ -271,6 +272,7 @@ public static void startThreadPool() throws IOException {
threadPool = new ThreadPool(
settings,
MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders(),
new FixedExecutorBuilder(
settings,
TokenService.THREAD_POOL_NAME,
diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/enrollment/InternalEnrollmentTokenGeneratorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/enrollment/InternalEnrollmentTokenGeneratorTests.java
index 888483613a187..0a1f5f801143d 100644
--- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/enrollment/InternalEnrollmentTokenGeneratorTests.java
+++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/enrollment/InternalEnrollmentTokenGeneratorTests.java
@@ -30,6 +30,7 @@
import org.elasticsearch.node.Node;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.security.EnrollmentToken;
@@ -85,6 +86,7 @@ public static void startThreadPool() throws IOException {
threadPool = new ThreadPool(
settings,
MeterRegistry.NOOP,
+ new DefaultBuiltInExecutorBuilders(),
new FixedExecutorBuilder(
settings,
TokenService.THREAD_POOL_NAME,
diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestCreateApiKeyActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestCreateApiKeyActionTests.java
index d487eab9f7887..9a05230d82ae6 100644
--- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestCreateApiKeyActionTests.java
+++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestCreateApiKeyActionTests.java
@@ -26,6 +26,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentType;
@@ -56,7 +57,7 @@ public void setUp() throws Exception {
.put("node.name", "test-" + getTestName())
.put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString())
.build();
- threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
}
@Override
diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGetApiKeyActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGetApiKeyActionTests.java
index 577a8eb9f698e..d88a217cd0949 100644
--- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGetApiKeyActionTests.java
+++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGetApiKeyActionTests.java
@@ -25,6 +25,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentType;
@@ -62,7 +63,7 @@ public void setUp() throws Exception {
.put("node.name", "test-" + getTestName())
.put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString())
.build();
- threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
}
@Override
diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestInvalidateApiKeyActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestInvalidateApiKeyActionTests.java
index 8bbd051c2fc32..ac472378d4874 100644
--- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestInvalidateApiKeyActionTests.java
+++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestInvalidateApiKeyActionTests.java
@@ -26,6 +26,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentType;
@@ -54,7 +55,7 @@ public void setUp() throws Exception {
.put("node.name", "test-" + getTestName())
.put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString())
.build();
- threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
}
@Override
diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestQueryApiKeyActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestQueryApiKeyActionTests.java
index 2240b72c1a963..d5aa249b1d0f5 100644
--- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestQueryApiKeyActionTests.java
+++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestQueryApiKeyActionTests.java
@@ -34,6 +34,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentParseException;
@@ -72,7 +73,7 @@ public void setUp() throws Exception {
.put("node.name", "test-" + getTestName())
.put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString())
.build();
- threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
}
@Override
diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HeaderSizeLimitTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HeaderSizeLimitTests.java
index 8c422342c3640..ba7c2e3844521 100644
--- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HeaderSizeLimitTests.java
+++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HeaderSizeLimitTests.java
@@ -25,6 +25,7 @@
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.DefaultBuiltInExecutorBuilders;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterPortSettings;
import org.elasticsearch.transport.RequestHandlerRegistry;
@@ -78,7 +79,7 @@ public final class SecurityNetty4HeaderSizeLimitTests extends ESTestCase {
@Before
public void startThreadPool() {
- threadPool = new ThreadPool(settings, MeterRegistry.NOOP);
+ threadPool = new ThreadPool(settings, MeterRegistry.NOOP, new DefaultBuiltInExecutorBuilders());
TaskManager taskManager = new TaskManager(settings, threadPool, Collections.emptySet());
NetworkService networkService = new NetworkService(Collections.emptyList());
PageCacheRecycler recycler = new MockPageCacheRecycler(Settings.EMPTY);
From 1b9fa40bdd7de1d40fc475a6a60334148afcf2e0 Mon Sep 17 00:00:00 2001
From: Fang Xing <155562079+fang-xing-esql@users.noreply.github.com>
Date: Tue, 27 Aug 2024 11:30:50 -0400
Subject: [PATCH 20/46] [ES|QL] Cast mixed numeric types to the first not null
numeric type for Coalesce at Analyzer (#111917)
Cast mixed numeric for Coalesce in analyzer
---
docs/changelog/111917.yaml | 7 ++
.../src/main/resources/convert.csv-spec | 34 ++++++
.../src/main/resources/null.csv-spec | 49 ++++++++
.../xpack/esql/action/EsqlCapabilities.java | 7 +-
.../xpack/esql/analysis/Analyzer.java | 110 ++++++++++++++---
.../xpack/esql/analysis/AnalyzerTests.java | 25 ++++
.../xpack/esql/analysis/VerifierTests.java | 112 ++++++++++++++++++
7 files changed, 327 insertions(+), 17 deletions(-)
create mode 100644 docs/changelog/111917.yaml
diff --git a/docs/changelog/111917.yaml b/docs/changelog/111917.yaml
new file mode 100644
index 0000000000000..0dc760d76a698
--- /dev/null
+++ b/docs/changelog/111917.yaml
@@ -0,0 +1,7 @@
+pr: 111917
+summary: "[ES|QL] Cast mixed numeric types to a common numeric type for Coalesce and\
+ \ In at Analyzer"
+area: ES|QL
+type: enhancement
+issues:
+ - 111486
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/convert.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/convert.csv-spec
index 3ef1a322eb94a..42b5c0344b559 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/convert.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/convert.csv-spec
@@ -180,3 +180,37 @@ ver:version
1.2.3
//end::docsCastOperator-result[]
;
+
+mixedNumericTypesInLiterals
+required_capability: mixed_numeric_types_in_coalesce
+from employees
+| where languages.long in (1, 2.0, null)
+| keep emp_no, languages
+| sort emp_no
+| limit 10
+;
+
+emp_no:integer | languages:integer
+10001 | 2
+10005 | 1
+10008 | 2
+10009 | 1
+10013 | 1
+10016 | 2
+10017 | 2
+10018 | 2
+10019 | 1
+10033 | 1
+;
+
+mixedNumericTypesInFields
+required_capability: mixed_numeric_types_in_coalesce
+from employees
+| where languages in (7.0, height)
+| keep emp_no, languages, height
+| sort emp_no
+;
+
+emp_no:integer | languages:integer | height:double
+10037 | 2 | 2.0
+;
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/null.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/null.csv-spec
index 92537ed1221e8..9914d073a589d 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/null.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/null.csv-spec
@@ -142,3 +142,52 @@ emp_no:integer | first_name:keyword | foo:boolean
10002 | Bezalel | true
10003 | Parto | true
;
+
+coalesceMixedNumeric
+required_capability: mixed_numeric_types_in_coalesce
+FROM employees
+| WHERE emp_no <= 10021 and emp_no >= 10018
+| EVAL x = coalesce(languages.long, 0), y = coalesce(height, 0), z = coalesce(languages::double, 0)
+| SORT emp_no ASC
+| KEEP emp_no, languages, x, height, y, z
+;
+
+emp_no:integer | languages:integer | x:long | height:double | y:double | z:double
+ 10018 | 2 | 2 | 1.97 | 1.97 | 2.0
+ 10019 | 1 | 1 | 2.06 | 2.06 | 1.0
+ 10020 | null | 0 | 1.41 | 1.41 | 0.0
+ 10021 | null | 0 | 1.47 | 1.47 | 0.0
+;
+
+coalesceMixedNumericWithNull
+required_capability: mixed_numeric_types_in_coalesce
+FROM employees
+| WHERE emp_no <= 10021 and emp_no >= 10018
+| EVAL x = coalesce(languages.long, null, 0), y = coalesce(null, height, 0), z = coalesce(languages::double, null, 0)
+| SORT emp_no ASC
+| KEEP emp_no, languages.long, x, height, y, z
+;
+
+emp_no:integer | languages.long:long | x:long | height:double | y:double | z:double
+ 10018 | 2 | 2 | 1.97 | 1.97 | 2.0
+ 10019 | 1 | 1 | 2.06 | 2.06 | 1.0
+ 10020 | null | 0 | 1.41 | 1.41 | 0.0
+ 10021 | null | 0 | 1.47 | 1.47 | 0.0
+;
+
+coalesceMixedNumericFields
+required_capability: mixed_numeric_types_in_coalesce
+FROM employees
+| WHERE emp_no <= 10021 and emp_no >= 10018
+| EVAL x = coalesce(height, languages.long, 0), y = coalesce(height, null, languages, 0),
+z = coalesce(languages::double, null, salary, height, 0)
+| SORT emp_no ASC
+| KEEP emp_no, languages, height, x, y, z, salary
+;
+
+emp_no:integer | languages:integer | height:double | x:double | y:double | z:double | salary:integer
+ 10018 | 2 | 1.97 | 1.97 | 1.97 | 2.0 | 56760
+ 10019 | 1 | 2.06 | 2.06 | 2.06 | 1.0 | 73717
+ 10020 | null | 1.41 | 1.41 | 1.41 | 40031.0 | 40031
+ 10021 | null | 1.47 | 1.47 | 1.47 | 60408.0 | 60408
+;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index afa8b6e1d06d7..81b2ba71b8808 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -264,7 +264,12 @@ public enum Cap {
/**
* Support for the whole number spans in BUCKET function.
*/
- BUCKET_WHOLE_NUMBER_AS_SPAN;
+ BUCKET_WHOLE_NUMBER_AS_SPAN,
+
+ /**
+ * Allow mixed numeric types in coalesce
+ */
+ MIXED_NUMERIC_TYPES_IN_COALESCE;
private final boolean snapshotOnly;
private final FeatureFlag featureFlag;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index f88c603b4cacb..5baced5bc93f2 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -56,6 +56,11 @@
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction;
+import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
+import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
+import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
+import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong;
+import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
@@ -112,6 +117,7 @@
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT;
+import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG;
import static org.elasticsearch.xpack.esql.core.type.DataType.VERSION;
import static org.elasticsearch.xpack.esql.core.type.DataType.isTemporalAmount;
import static org.elasticsearch.xpack.esql.stats.FeatureMetric.LIMIT;
@@ -943,6 +949,24 @@ private BitSet gatherPreAnalysisMetrics(LogicalPlan plan, BitSet b) {
return b;
}
+ /**
+ * Cast string literals in ScalarFunction, EsqlArithmeticOperation, BinaryComparison and In to desired data types.
+ * For example, the string literals in the following expressions will be cast implicitly to the field data type on the left hand side.
+ * date > "2024-08-21"
+ * date in ("2024-08-21", "2024-08-22", "2024-08-23")
+ * date = "2024-08-21" + 3 days
+ * ip == "127.0.0.1"
+ * version != "1.0"
+ *
+ * If the inputs to Coalesce are mixed numeric types, cast the rest of the numeric field or value to the first numeric data type if
+ * applicable. For example, implicit casting converts:
+ * Coalesce(Long, Int) to Coalesce(Long, Long)
+ * Coalesce(null, Long, Int) to Coalesce(null, Long, Long)
+ * Coalesce(Double, Long, Int) to Coalesce(Double, Double, Double)
+ * Coalesce(null, Double, Long, Int) to Coalesce(null, Double, Double, Double)
+ *
+ * Coalesce(Int, Long) will NOT be converted to Coalesce(Long, Long) or Coalesce(Int, Int).
+ */
private static class ImplicitCasting extends ParameterizedRule {
@Override
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
@@ -972,22 +996,38 @@ private static Expression processScalarFunction(EsqlScalarFunction f, EsqlFuncti
boolean childrenChanged = false;
DataType targetDataType = DataType.NULL;
Expression arg;
+ DataType targetNumericType = null;
+ boolean castNumericArgs = true;
for (int i = 0; i < args.size(); i++) {
arg = args.get(i);
- if (arg.resolved() && arg.dataType() == KEYWORD && arg.foldable() && ((arg instanceof EsqlScalarFunction) == false)) {
- if (i < targetDataTypes.size()) {
- targetDataType = targetDataTypes.get(i);
- }
- if (targetDataType != DataType.NULL && targetDataType != DataType.UNSUPPORTED) {
- Expression e = castStringLiteral(arg, targetDataType);
- childrenChanged = true;
- newChildren.add(e);
- continue;
+ if (arg.resolved()) {
+ var dataType = arg.dataType();
+ if (dataType == KEYWORD) {
+ if (arg.foldable() && ((arg instanceof EsqlScalarFunction) == false)) {
+ if (i < targetDataTypes.size()) {
+ targetDataType = targetDataTypes.get(i);
+ }
+ if (targetDataType != DataType.NULL && targetDataType != DataType.UNSUPPORTED) {
+ Expression e = castStringLiteral(arg, targetDataType);
+ childrenChanged = true;
+ newChildren.add(e);
+ continue;
+ }
+ }
+ } else if (dataType.isNumeric() && canCastMixedNumericTypes(f) && castNumericArgs) {
+ if (targetNumericType == null) {
+ targetNumericType = dataType; // target data type is the first numeric data type
+ } else if (dataType != targetNumericType) {
+ castNumericArgs = canCastNumeric(dataType, targetNumericType);
+ }
}
}
newChildren.add(args.get(i));
}
- return childrenChanged ? f.replaceChildren(newChildren) : f;
+ Expression resultF = childrenChanged ? f.replaceChildren(newChildren) : f;
+ return targetNumericType != null && castNumericArgs
+ ? castMixedNumericTypes((EsqlScalarFunction) resultF, targetNumericType)
+ : resultF;
}
private static Expression processBinaryOperator(BinaryOperator, ?, ?, ?> o) {
@@ -1002,7 +1042,7 @@ private static Expression processBinaryOperator(BinaryOperator, ?, ?, ?> o) {
Expression from = Literal.NULL;
if (left.dataType() == KEYWORD && left.foldable() && (left instanceof EsqlScalarFunction == false)) {
- if (supportsImplicitCasting(right.dataType())) {
+ if (supportsStringImplicitCasting(right.dataType())) {
targetDataType = right.dataType();
from = left;
} else if (supportsImplicitTemporalCasting(right, o)) {
@@ -1011,7 +1051,7 @@ private static Expression processBinaryOperator(BinaryOperator, ?, ?, ?> o) {
}
}
if (right.dataType() == KEYWORD && right.foldable() && (right instanceof EsqlScalarFunction == false)) {
- if (supportsImplicitCasting(left.dataType())) {
+ if (supportsStringImplicitCasting(left.dataType())) {
targetDataType = left.dataType();
from = right;
} else if (supportsImplicitTemporalCasting(left, o)) {
@@ -1031,16 +1071,18 @@ private static Expression processBinaryOperator(BinaryOperator, ?, ?, ?> o) {
private static Expression processIn(In in) {
Expression left = in.value();
List right = in.list();
+ DataType targetDataType = left.dataType();
- if (left.resolved() == false || supportsImplicitCasting(left.dataType()) == false) {
+ if (left.resolved() == false || supportsStringImplicitCasting(targetDataType) == false) {
return in;
}
+
List newChildren = new ArrayList<>(right.size() + 1);
boolean childrenChanged = false;
for (Expression value : right) {
- if (value.dataType() == KEYWORD && value.foldable()) {
- Expression e = castStringLiteral(value, left.dataType());
+ if (value.resolved() && value.dataType() == KEYWORD && value.foldable()) {
+ Expression e = castStringLiteral(value, targetDataType);
newChildren.add(e);
childrenChanged = true;
} else {
@@ -1051,11 +1093,47 @@ private static Expression processIn(In in) {
return childrenChanged ? in.replaceChildren(newChildren) : in;
}
+ private static boolean canCastMixedNumericTypes(EsqlScalarFunction f) {
+ return f instanceof Coalesce;
+ }
+
+ private static boolean canCastNumeric(DataType from, DataType to) {
+ DataType commonType = EsqlDataTypeConverter.commonType(from, to);
+ return commonType == to;
+ }
+
+ private static Expression castMixedNumericTypes(EsqlScalarFunction f, DataType targetNumericType) {
+ List newChildren = new ArrayList<>(f.children().size());
+ boolean childrenChanged = false;
+ DataType childDataType;
+
+ for (Expression e : f.children()) {
+ childDataType = e.dataType();
+ if (childDataType.isNumeric() == false
+ || childDataType == targetNumericType
+ || canCastNumeric(childDataType, targetNumericType) == false) {
+ newChildren.add(e);
+ continue;
+ }
+ childrenChanged = true;
+ // add a casting function
+ switch (targetNumericType) {
+ case INTEGER -> newChildren.add(new ToInteger(e.source(), e));
+ case LONG -> newChildren.add(new ToLong(e.source(), e));
+ case DOUBLE -> newChildren.add(new ToDouble(e.source(), e));
+ case UNSIGNED_LONG -> newChildren.add(new ToUnsignedLong(e.source(), e));
+ default -> throw new EsqlIllegalArgumentException("unexpected data type: " + targetNumericType);
+ }
+
+ }
+ return childrenChanged ? f.replaceChildren(newChildren) : f;
+ }
+
private static boolean supportsImplicitTemporalCasting(Expression e, BinaryOperator, ?, ?, ?> o) {
return isTemporalAmount(e.dataType()) && (o instanceof DateTimeArithmeticOperation);
}
- private static boolean supportsImplicitCasting(DataType type) {
+ private static boolean supportsStringImplicitCasting(DataType type) {
return type == DATETIME || type == IP || type == VERSION || type == BOOLEAN;
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index 3fb4b80d3974e..72a905f4b37a4 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -2061,6 +2061,31 @@ public void testRateRequiresCounterTypes() {
);
}
+ public void testCoalesceWithMixedNumericTypes() {
+ LogicalPlan plan = analyze("""
+ from test
+ | eval x = coalesce(salary_change, null, 0), y = coalesce(languages, null, 0), z = coalesce(languages.long, null, 0)
+ , w = coalesce(salary_change, null, 0::long)
+ | keep x, y, z, w
+ """, "mapping-default.json");
+ var limit = as(plan, Limit.class);
+ var esqlProject = as(limit.child(), EsqlProject.class);
+ List> projections = esqlProject.projections();
+ var projection = as(projections.get(0), ReferenceAttribute.class);
+ assertEquals(projection.name(), "x");
+ assertEquals(projection.dataType(), DataType.DOUBLE);
+ projection = as(projections.get(1), ReferenceAttribute.class);
+ assertEquals(projection.name(), "y");
+ assertEquals(projection.dataType(), DataType.INTEGER);
+ projection = as(projections.get(2), ReferenceAttribute.class);
+ assertEquals(projection.name(), "z");
+ assertEquals(projection.dataType(), DataType.LONG);
+ projection = as(projections.get(3), ReferenceAttribute.class);
+ assertEquals(projection.name(), "w");
+ assertEquals(projection.dataType(), DataType.DOUBLE);
+ assertThat(limit.limit().fold(), equalTo(1000));
+ }
+
private void verifyUnsupported(String query, String errorMessage) {
verifyUnsupported(query, errorMessage, "mapping-multi-field-variation.json");
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
index e2403505921a9..b50b801785a9f 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
@@ -950,6 +950,118 @@ public void testMatchCommand() throws Exception {
// TODO Keep adding tests for all unsupported commands
}
+ public void testCoalesceWithMixedNumericTypes() {
+ assertEquals(
+ "1:22: second argument of [coalesce(languages, height)] must be [integer], found value [height] type [double]",
+ error("from test | eval x = coalesce(languages, height)")
+ );
+ assertEquals(
+ "1:22: second argument of [coalesce(languages.long, height)] must be [long], found value [height] type [double]",
+ error("from test | eval x = coalesce(languages.long, height)")
+ );
+ assertEquals(
+ "1:22: second argument of [coalesce(salary, languages.long)] must be [integer], found value [languages.long] type [long]",
+ error("from test | eval x = coalesce(salary, languages.long)")
+ );
+ assertEquals(
+ "1:22: second argument of [coalesce(languages.short, height)] must be [integer], found value [height] type [double]",
+ error("from test | eval x = coalesce(languages.short, height)")
+ );
+ assertEquals(
+ "1:22: second argument of [coalesce(languages.byte, height)] must be [integer], found value [height] type [double]",
+ error("from test | eval x = coalesce(languages.byte, height)")
+ );
+ assertEquals(
+ "1:22: second argument of [coalesce(languages, height.float)] must be [integer], found value [height.float] type [double]",
+ error("from test | eval x = coalesce(languages, height.float)")
+ );
+ assertEquals(
+ "1:22: second argument of [coalesce(languages, height.scaled_float)] must be [integer], "
+ + "found value [height.scaled_float] type [double]",
+ error("from test | eval x = coalesce(languages, height.scaled_float)")
+ );
+ assertEquals(
+ "1:22: second argument of [coalesce(languages, height.half_float)] must be [integer], "
+ + "found value [height.half_float] type [double]",
+ error("from test | eval x = coalesce(languages, height.half_float)")
+ );
+
+ assertEquals(
+ "1:22: third argument of [coalesce(null, languages, height)] must be [integer], found value [height] type [double]",
+ error("from test | eval x = coalesce(null, languages, height)")
+ );
+ assertEquals(
+ "1:22: third argument of [coalesce(null, languages.long, height)] must be [long], found value [height] type [double]",
+ error("from test | eval x = coalesce(null, languages.long, height)")
+ );
+ assertEquals(
+ "1:22: third argument of [coalesce(null, salary, languages.long)] must be [integer], "
+ + "found value [languages.long] type [long]",
+ error("from test | eval x = coalesce(null, salary, languages.long)")
+ );
+ assertEquals(
+ "1:22: third argument of [coalesce(null, languages.short, height)] must be [integer], found value [height] type [double]",
+ error("from test | eval x = coalesce(null, languages.short, height)")
+ );
+ assertEquals(
+ "1:22: third argument of [coalesce(null, languages.byte, height)] must be [integer], found value [height] type [double]",
+ error("from test | eval x = coalesce(null, languages.byte, height)")
+ );
+ assertEquals(
+ "1:22: third argument of [coalesce(null, languages, height.float)] must be [integer], "
+ + "found value [height.float] type [double]",
+ error("from test | eval x = coalesce(null, languages, height.float)")
+ );
+ assertEquals(
+ "1:22: third argument of [coalesce(null, languages, height.scaled_float)] must be [integer], "
+ + "found value [height.scaled_float] type [double]",
+ error("from test | eval x = coalesce(null, languages, height.scaled_float)")
+ );
+ assertEquals(
+ "1:22: third argument of [coalesce(null, languages, height.half_float)] must be [integer], "
+ + "found value [height.half_float] type [double]",
+ error("from test | eval x = coalesce(null, languages, height.half_float)")
+ );
+
+ // counter
+ assertEquals(
+ "1:23: second argument of [coalesce(network.bytes_in, 0)] must be [counter_long], found value [0] type [integer]",
+ error("FROM tests | eval x = coalesce(network.bytes_in, 0)", tsdb)
+ );
+
+ assertEquals(
+ "1:23: second argument of [coalesce(network.bytes_in, to_long(0))] must be [counter_long], "
+ + "found value [to_long(0)] type [long]",
+ error("FROM tests | eval x = coalesce(network.bytes_in, to_long(0))", tsdb)
+ );
+ assertEquals(
+ "1:23: second argument of [coalesce(network.bytes_in, 0.0)] must be [counter_long], found value [0.0] type [double]",
+ error("FROM tests | eval x = coalesce(network.bytes_in, 0.0)", tsdb)
+ );
+
+ assertEquals(
+ "1:23: third argument of [coalesce(null, network.bytes_in, 0)] must be [counter_long], found value [0] type [integer]",
+ error("FROM tests | eval x = coalesce(null, network.bytes_in, 0)", tsdb)
+ );
+
+ assertEquals(
+ "1:23: third argument of [coalesce(null, network.bytes_in, to_long(0))] must be [counter_long], "
+ + "found value [to_long(0)] type [long]",
+ error("FROM tests | eval x = coalesce(null, network.bytes_in, to_long(0))", tsdb)
+ );
+ assertEquals(
+ "1:23: third argument of [coalesce(null, network.bytes_in, 0.0)] must be [counter_long], found value [0.0] type [double]",
+ error("FROM tests | eval x = coalesce(null, network.bytes_in, 0.0)", tsdb)
+ );
+ }
+
+ public void test() {
+ assertEquals(
+ "1:23: second argument of [coalesce(network.bytes_in, 0)] must be [counter_long], found value [0] type [integer]",
+ error("FROM tests | eval x = coalesce(network.bytes_in, 0)", tsdb)
+ );
+ }
+
private String error(String query) {
return error(query, defaultAnalyzer);
}
From 29121fdf8f492885f7fda8141b17d54d8d53ad74 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lien=20FOUCRET?=
Date: Tue, 27 Aug 2024 18:12:46 +0200
Subject: [PATCH 21/46] New version of the script_score term stats helpers.
(#108634)
---
.../script/ScriptScoreBenchmark.java | 5 +
.../expression/ExpressionScoreScript.java | 6 +
.../org.elasticsearch.script.score.txt | 17 +
.../painless/ScoreScriptTests.java | 60 ++
.../190_term_statistics_script_score.yml | 612 ++++++++++++++++++
.../functionscore/ExplainableScriptIT.java | 5 +
.../search/function/ScriptScoreQuery.java | 18 +-
.../org/elasticsearch/script/ScoreScript.java | 29 +-
.../elasticsearch/script/ScriptFeatures.java | 2 +-
.../elasticsearch/script/ScriptTermStats.java | 234 +++++++
.../elasticsearch/script/StatsSummary.java | 123 ++++
.../search/internal/ContextIndexSearcher.java | 37 +-
.../mapper/BooleanScriptFieldTypeTests.java | 10 +
.../mapper/DateScriptFieldTypeTests.java | 5 +
.../mapper/DoubleScriptFieldTypeTests.java | 5 +
.../mapper/GeoPointScriptFieldTypeTests.java | 5 +
.../index/mapper/IpScriptFieldTypeTests.java | 5 +
.../mapper/KeywordScriptFieldTypeTests.java | 5 +
.../mapper/LongScriptFieldTypeTests.java | 5 +
.../script/ScriptTermStatsTests.java | 358 ++++++++++
.../script/StatsSummaryTests.java | 83 +++
.../highlight/PlainHighlighterTests.java | 5 +
.../search/query/ScriptScoreQueryTests.java | 131 ++--
.../script/MockScriptEngine.java | 5 +
.../AggregateDoubleMetricFieldTypeTests.java | 5 +
.../mapper/GeoShapeScriptFieldTypeTests.java | 5 +
26 files changed, 1718 insertions(+), 62 deletions(-)
create mode 100644 modules/lang-painless/src/test/java/org/elasticsearch/painless/ScoreScriptTests.java
create mode 100644 modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/190_term_statistics_script_score.yml
create mode 100644 server/src/main/java/org/elasticsearch/script/ScriptTermStats.java
create mode 100644 server/src/main/java/org/elasticsearch/script/StatsSummary.java
create mode 100644 server/src/test/java/org/elasticsearch/script/ScriptTermStatsTests.java
create mode 100644 server/src/test/java/org/elasticsearch/script/StatsSummaryTests.java
diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/script/ScriptScoreBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/script/ScriptScoreBenchmark.java
index 5a27abe8be2a4..fe221ec980dc3 100644
--- a/benchmarks/src/main/java/org/elasticsearch/benchmark/script/ScriptScoreBenchmark.java
+++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/script/ScriptScoreBenchmark.java
@@ -186,6 +186,11 @@ public void setDocument(int docid) {
public boolean needs_score() {
return false;
}
+
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
};
};
}
diff --git a/modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScoreScript.java b/modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScoreScript.java
index 159851affd004..622a1bd4afd25 100644
--- a/modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScoreScript.java
+++ b/modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScoreScript.java
@@ -42,6 +42,12 @@ public boolean needs_score() {
return needsScores;
}
+ @Override
+ public boolean needs_termStats() {
+ // _termStats is not available for expressions
+ return false;
+ }
+
@Override
public ScoreScript newInstance(final DocReader reader) throws IOException {
// Use DocReader to get the leaf context while transitioning to DocReader for Painless. DocReader for expressions should follow.
diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt
index 5082d5f1c7bdb..0dab7dcbadfb5 100644
--- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt
+++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt
@@ -13,6 +13,23 @@ class org.elasticsearch.script.ScoreScript @no_import {
class org.elasticsearch.script.ScoreScript$Factory @no_import {
}
+class org.elasticsearch.script.StatsSummary {
+ double getMin()
+ double getMax()
+ double getAverage()
+ double getSum()
+ long getCount()
+}
+
+class org.elasticsearch.script.ScriptTermStats {
+ int uniqueTermsCount()
+ int matchedTermsCount()
+ StatsSummary docFreq()
+ StatsSummary totalTermFreq()
+ StatsSummary termFreq()
+ StatsSummary termPositions()
+}
+
static_import {
double saturation(double, double) from_class org.elasticsearch.script.ScoreScriptUtils
double sigmoid(double, double, double) from_class org.elasticsearch.script.ScoreScriptUtils
diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScoreScriptTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScoreScriptTests.java
new file mode 100644
index 0000000000000..08b55fdf3bcc3
--- /dev/null
+++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScoreScriptTests.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.painless;
+
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.IndexService;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.painless.spi.Whitelist;
+import org.elasticsearch.painless.spi.WhitelistLoader;
+import org.elasticsearch.script.ScoreScript;
+import org.elasticsearch.script.ScriptContext;
+import org.elasticsearch.test.ESSingleNodeTestCase;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static java.util.Collections.emptyMap;
+import static org.elasticsearch.painless.ScriptTestCase.PAINLESS_BASE_WHITELIST;
+
+public class ScoreScriptTests extends ESSingleNodeTestCase {
+ /**
+ * Test that needTermStats() is reported correctly depending on whether _termStats is used
+ */
+ public void testNeedsTermStats() {
+ IndexService index = createIndex("test", Settings.EMPTY, "type", "d", "type=double");
+
+ Map, List> contexts = new HashMap<>();
+ List whitelists = new ArrayList<>(PAINLESS_BASE_WHITELIST);
+ whitelists.add(WhitelistLoader.loadFromResourceFiles(PainlessPlugin.class, "org.elasticsearch.script.score.txt"));
+ contexts.put(ScoreScript.CONTEXT, whitelists);
+ PainlessScriptEngine service = new PainlessScriptEngine(Settings.EMPTY, contexts);
+
+ SearchExecutionContext searchExecutionContext = index.newSearchExecutionContext(0, 0, null, () -> 0, null, emptyMap());
+
+ ScoreScript.Factory factory = service.compile(null, "1.2", ScoreScript.CONTEXT, Collections.emptyMap());
+ ScoreScript.LeafFactory ss = factory.newFactory(Collections.emptyMap(), searchExecutionContext.lookup());
+ assertFalse(ss.needs_termStats());
+
+ factory = service.compile(null, "doc['d'].value", ScoreScript.CONTEXT, Collections.emptyMap());
+ ss = factory.newFactory(Collections.emptyMap(), searchExecutionContext.lookup());
+ assertFalse(ss.needs_termStats());
+
+ factory = service.compile(null, "1/_termStats.totalTermFreq().getAverage()", ScoreScript.CONTEXT, Collections.emptyMap());
+ ss = factory.newFactory(Collections.emptyMap(), searchExecutionContext.lookup());
+ assertTrue(ss.needs_termStats());
+
+ factory = service.compile(null, "doc['d'].value * _termStats.docFreq().getSum()", ScoreScript.CONTEXT, Collections.emptyMap());
+ ss = factory.newFactory(Collections.emptyMap(), searchExecutionContext.lookup());
+ assertTrue(ss.needs_termStats());
+ }
+}
diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/190_term_statistics_script_score.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/190_term_statistics_script_score.yml
new file mode 100644
index 0000000000000..f82b844f01588
--- /dev/null
+++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/190_term_statistics_script_score.yml
@@ -0,0 +1,612 @@
+setup:
+ - requires:
+ cluster_features: ["script.term_stats"]
+ reason: "support for term stats has been added in 8.16"
+
+ - do:
+ indices.create:
+ index: test-index
+ body:
+ settings:
+ number_of_shards: "2"
+ mappings:
+ properties:
+ title:
+ type: text
+ genre:
+ type: text
+ fields:
+ keyword:
+ type: keyword
+
+ - do:
+ index: { refresh: true, index: test-index, id: "1", routing: 0, body: {"title": "Star wars", "genre": "Sci-fi"} }
+ - do:
+ index: { refresh: true, index: test-index, id: "2", routing: 1, body: {"title": "Star trek", "genre": "Sci-fi"} }
+ - do:
+ index: { refresh: true, index: test-index, id: "3", routing: 1, body: {"title": "Rambo", "genre": "War movie"} }
+ - do:
+ index: { refresh: true, index: test-index, id: "4", routing: 1, body: {"title": "Rambo II", "genre": "War movie"} }
+
+---
+"match query: uniqueTermsCount without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.uniqueTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 2 }
+
+---
+"match query: uniqueTermsCount with DFS":
+ - do:
+ search:
+ search_type: dfs_query_then_fetch
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.uniqueTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 2 }
+
+---
+"match query: matchedTermsCount without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.matchedTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"match query: matchedTermsCount with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.matchedTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"match query: docFreq min without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.docFreq().getMin()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 0 }
+
+---
+"match query: docFreq min with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.docFreq().getMin()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"match query: docFreq max without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.docFreq().getMax()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"match query: docFreq max with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.docFreq().getMax()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 2 }
+
+---
+"match query: totalTermFreq sum without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.totalTermFreq().getSum()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"match query: totalTermFreq sum with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.totalTermFreq().getSum()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 3 }
+ - match: { hits.hits.1._score: 3 }
+
+---
+"match query: termFreq sum without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.termFreq().getSum()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"match query: termFreq sum with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.termFreq().getSum()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"match query: termPositions avg without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.termPositions().getAverage()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1.5 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"match query: termPositions avg with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { match: { "title": "Star wars" } }
+ script:
+ source: "return _termStats.termPositions().getAverage()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1.5 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: uniqueTermsCount without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.uniqueTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: uniqueTermsCount with DFS":
+ - do:
+ search:
+ search_type: dfs_query_then_fetch
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.uniqueTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: matchedTermsCount without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.matchedTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: matchedTermsCount with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.matchedTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: docFreq min without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.docFreq().getMin()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: docFreq min with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.docFreq().getMin()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 2 }
+
+---
+"term query: docFreq max without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.docFreq().getMax()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: docFreq max with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.docFreq().getMax()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 2 }
+
+---
+"term query: totalTermFreq sum without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.totalTermFreq().getSum()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: totalTermFreq sum with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.totalTermFreq().getSum()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 2 }
+ - match: { hits.hits.1._score: 2 }
+
+---
+"term query: termFreq sum without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.termFreq().getSum()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: termFreq sum with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.termFreq().getSum()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 1 }
+ - match: { hits.hits.1._score: 1 }
+
+---
+"term query: termPositions avg without DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.termPositions().getAverage()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 0 }
+ - match: { hits.hits.1._score: 0 }
+
+---
+"term query: termPositions avg with DFS":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ search_type: dfs_query_then_fetch
+ index: test-index
+ body:
+ query:
+ script_score:
+ query: { term: { "genre.keyword": "Sci-fi" } }
+ script:
+ source: "return _termStats.termPositions().getAverage()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 0 }
+ - match: { hits.hits.1._score: 0 }
+
+---
+"Complex bool query: uniqueTermsCount":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query:
+ bool:
+ must:
+ match: { "title": "star wars" }
+ should:
+ term: { "genre.keyword": "Sci-fi" }
+ filter:
+ match: { "genre" : "sci"}
+ must_not:
+ term: { "genre.keyword": "War" }
+ script:
+ source: "return _termStats.uniqueTermsCount()"
+ - match: { hits.total: 2 }
+ - match: { hits.hits.0._score: 4 }
+ - match: { hits.hits.1._score: 4 }
+
+
+---
+"match_all query: uniqueTermsCount":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query:
+ match_all: {}
+ script:
+ source: "return _termStats.uniqueTermsCount()"
+ - match: { hits.total: 4 }
+ - match: { hits.hits.0._score: 0 }
+ - match: { hits.hits.1._score: 0 }
+ - match: { hits.hits.2._score: 0 }
+ - match: { hits.hits.3._score: 0 }
+
+---
+"match_all query: docFreq":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query:
+ match_all: {}
+ script:
+ source: "return _termStats.docFreq().getMax()"
+ - match: { hits.total: 4 }
+ - match: { hits.hits.0._score: 0 }
+ - match: { hits.hits.1._score: 0 }
+ - match: { hits.hits.2._score: 0 }
+ - match: { hits.hits.3._score: 0 }
+
+---
+"match_all query: totalTermFreq":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query:
+ match_all: {}
+ script:
+ source: "return _termStats.totalTermFreq().getSum()"
+ - match: { hits.total: 4 }
+ - match: { hits.hits.0._score: 0 }
+ - match: { hits.hits.1._score: 0 }
+ - match: { hits.hits.2._score: 0 }
+ - match: { hits.hits.3._score: 0 }
+
+---
+"match_all query: termFreq":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query:
+ match_all: {}
+ script:
+ source: "return _termStats.termFreq().getMax()"
+ - match: { hits.total: 4 }
+ - match: { hits.hits.0._score: 0 }
+ - match: { hits.hits.1._score: 0 }
+ - match: { hits.hits.2._score: 0 }
+ - match: { hits.hits.3._score: 0 }
+
+---
+"match_all query: termPositions":
+ - do:
+ search:
+ rest_total_hits_as_int: true
+ index: test-index
+ body:
+ query:
+ script_score:
+ query:
+ match_all: {}
+ script:
+ source: "return _termStats.termPositions().getSum()"
+ - match: { hits.total: 4 }
+ - match: { hits.hits.0._score: 0 }
+ - match: { hits.hits.1._score: 0 }
+ - match: { hits.hits.2._score: 0 }
+ - match: { hits.hits.3._score: 0 }
diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java
index ee60888d7a0a8..c59fc0f68c4d4 100644
--- a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java
+++ b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java
@@ -73,6 +73,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new MyScript(params1, lookup, ((DocValuesDocReader) docReader).getLeafReaderContext());
diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java
index 7ddaa4bc681fa..93837269f2090 100644
--- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java
+++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java
@@ -9,6 +9,7 @@
package org.elasticsearch.common.lucene.search.function;
import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.DocIdSetIterator;
@@ -29,10 +30,13 @@
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.ScoreScript.ExplanationHolder;
import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptTermStats;
import org.elasticsearch.search.lookup.SearchLookup;
import java.io.IOException;
+import java.util.HashSet;
import java.util.Objects;
+import java.util.Set;
/**
* A query that uses a script to compute documents' scores.
@@ -86,9 +90,18 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
return subQuery.createWeight(searcher, scoreMode, boost);
}
boolean needsScore = scriptBuilder.needs_score();
- ScoreMode subQueryScoreMode = needsScore ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
+ boolean needsTermStatistics = scriptBuilder.needs_termStats();
+
+ ScoreMode subQueryScoreMode = needsScore || needsTermStatistics ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, 1.0f);
+ // We collect the different terms used in the child query.
+ final Set terms = new HashSet<>();
+
+ if (needsTermStatistics) {
+ this.visit(QueryVisitor.termCollector(terms));
+ }
+
return new Weight(this) {
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
@@ -167,6 +180,9 @@ private ScoreScript makeScoreScript(LeafReaderContext context) throws IOExceptio
final ScoreScript scoreScript = scriptBuilder.newInstance(new DocValuesDocReader(lookup, context));
scoreScript._setIndexName(indexName);
scoreScript._setShard(shardId);
+ if (needsTermStatistics) {
+ scoreScript._setTermStats(new ScriptTermStats(searcher, context, scoreScript::_getDocId, terms));
+ }
return scoreScript;
}
diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java
index 503bd11fb434a..61d225c069c68 100644
--- a/server/src/main/java/org/elasticsearch/script/ScoreScript.java
+++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java
@@ -7,6 +7,7 @@
*/
package org.elasticsearch.script;
+import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorable;
import org.elasticsearch.common.logging.DeprecationCategory;
@@ -26,7 +27,6 @@
* A script used for adjusting the score on a per document basis.
*/
public abstract class ScoreScript extends DocBasedScript {
-
/** A helper to take in an explanation from a script and turn it into an {@link org.apache.lucene.search.Explanation} */
public static class ExplanationHolder {
private String description;
@@ -82,6 +82,8 @@ public Explanation get(double score, Explanation subQueryExplanation) {
private int shardId = -1;
private String indexName = null;
+ private ScriptTermStats termStats = null;
+
public ScoreScript(Map params, SearchLookup searchLookup, DocReader docReader) {
// searchLookup parameter is ignored but part of the ScriptFactory contract. It is part of that contract because it's required
// for expressions. Expressions should eventually be transitioned to using DocReader.
@@ -90,13 +92,13 @@ public ScoreScript(Map params, SearchLookup searchLookup, DocRea
if (docReader == null) {
assert params == null;
this.params = null;
- ;
this.docBase = 0;
} else {
params = new HashMap<>(params);
params.putAll(docReader.docAsMap());
this.params = new DynamicMap(params, PARAMS_FUNCTIONS);
- this.docBase = ((DocValuesDocReader) docReader).getLeafReaderContext().docBase;
+ LeafReaderContext leafReaderContext = ((DocValuesDocReader) docReader).getLeafReaderContext();
+ this.docBase = leafReaderContext.docBase;
}
}
@@ -189,14 +191,33 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}
+ /**
+ * Starting a name with underscore, so that the user cannot access this function directly through a script.
+ */
+ public void _setTermStats(ScriptTermStats termStats) {
+ this.termStats = termStats;
+ }
+
+ /**
+ * Accessed as _termStats in the painless script.
+ */
+ public ScriptTermStats get_termStats() {
+ assert termStats != null : "termStats is not available";
+ return termStats;
+ }
+
/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {
-
/**
* Return {@code true} if the script needs {@code _score} calculated, or {@code false} otherwise.
*/
boolean needs_score();
+ /**
+ * Return {@code true} if the script needs {@code _termStats} calculated, or {@code false} otherwise.
+ */
+ boolean needs_termStats();
+
ScoreScript newInstance(DocReader reader) throws IOException;
}
diff --git a/server/src/main/java/org/elasticsearch/script/ScriptFeatures.java b/server/src/main/java/org/elasticsearch/script/ScriptFeatures.java
index d4d78bf08844b..2522ecee83223 100644
--- a/server/src/main/java/org/elasticsearch/script/ScriptFeatures.java
+++ b/server/src/main/java/org/elasticsearch/script/ScriptFeatures.java
@@ -16,6 +16,6 @@
public final class ScriptFeatures implements FeatureSpecification {
@Override
public Set getFeatures() {
- return Set.of(VectorScoreScriptUtils.HAMMING_DISTANCE_FUNCTION);
+ return Set.of(VectorScoreScriptUtils.HAMMING_DISTANCE_FUNCTION, ScriptTermStats.TERM_STAT_FEATURE);
}
}
diff --git a/server/src/main/java/org/elasticsearch/script/ScriptTermStats.java b/server/src/main/java/org/elasticsearch/script/ScriptTermStats.java
new file mode 100644
index 0000000000000..0863e649487b9
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/ScriptTermStats.java
@@ -0,0 +1,234 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.script;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PostingsEnum;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.TermState;
+import org.apache.lucene.index.TermStates;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.IndexSearcher;
+import org.elasticsearch.common.util.CachedSupplier;
+import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.search.internal.ContextIndexSearcher;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Set;
+import java.util.function.IntSupplier;
+import java.util.function.Supplier;
+
+/**
+ * Access the term statistics of the children query of a script_score query.
+ */
+public class ScriptTermStats {
+
+ public static final NodeFeature TERM_STAT_FEATURE = new NodeFeature("script.term_stats");
+
+ private final IntSupplier docIdSupplier;
+ private final Term[] terms;
+ private final IndexSearcher searcher;
+ private final LeafReaderContext leafReaderContext;
+ private final StatsSummary statsSummary = new StatsSummary();
+ private final Supplier termContextsSupplier;
+ private final Supplier postingsSupplier;
+ private final Supplier docFreqSupplier;
+ private final Supplier totalTermFreqSupplier;
+
+ public ScriptTermStats(IndexSearcher searcher, LeafReaderContext leafReaderContext, IntSupplier docIdSupplier, Set terms) {
+ this.searcher = searcher;
+ this.leafReaderContext = leafReaderContext;
+ this.docIdSupplier = docIdSupplier;
+ this.terms = terms.toArray(new Term[0]);
+ this.termContextsSupplier = CachedSupplier.wrap(this::loadTermContexts);
+ this.postingsSupplier = CachedSupplier.wrap(this::loadPostings);
+ this.docFreqSupplier = CachedSupplier.wrap(this::loadDocFreq);
+ this.totalTermFreqSupplier = CachedSupplier.wrap(this::loadTotalTermFreq);
+ }
+
+ /**
+ * Number of unique terms in the query.
+ *
+ * @return the number of unique terms
+ */
+ public int uniqueTermsCount() {
+ return terms.length;
+ }
+
+ /**
+ * Number of terms that are matched im the query.
+ *
+ * @return the number of matched terms
+ */
+ public int matchedTermsCount() {
+ final int docId = docIdSupplier.getAsInt();
+ int matchedTerms = 0;
+
+ try {
+ for (PostingsEnum postingsEnum : postingsSupplier.get()) {
+ if (postingsEnum != null && postingsEnum.advance(docId) == docId && postingsEnum.freq() > 0) {
+ matchedTerms++;
+ }
+ }
+ return matchedTerms;
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Collect docFreq (number of documents a term occurs in) for the terms of the query and returns statistics for them.
+ *
+ * @return statistics on docFreq for the terms of the query.
+ */
+ public StatsSummary docFreq() {
+ return docFreqSupplier.get();
+ }
+
+ private StatsSummary loadDocFreq() {
+ StatsSummary docFreqStats = new StatsSummary();
+ TermStates[] termContexts = termContextsSupplier.get();
+
+ try {
+ for (int i = 0; i < termContexts.length; i++) {
+ if (searcher instanceof ContextIndexSearcher contextIndexSearcher) {
+ docFreqStats.accept(contextIndexSearcher.docFreq(terms[i], termContexts[i].docFreq()));
+ } else {
+ docFreqStats.accept(termContexts[i].docFreq());
+ }
+ }
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+
+ return docFreqStats;
+ }
+
+ /**
+ * Collect totalTermFreq (total number of occurrence of a term in the index) for the terms of the query and returns statistics for them.
+ *
+ * @return statistics on totalTermFreq for the terms of the query.
+ */
+ public StatsSummary totalTermFreq() {
+ return this.totalTermFreqSupplier.get();
+ }
+
+ private StatsSummary loadTotalTermFreq() {
+ StatsSummary totalTermFreqStats = new StatsSummary();
+ TermStates[] termContexts = termContextsSupplier.get();
+
+ try {
+ for (int i = 0; i < termContexts.length; i++) {
+ if (searcher instanceof ContextIndexSearcher contextIndexSearcher) {
+ totalTermFreqStats.accept(contextIndexSearcher.totalTermFreq(terms[i], termContexts[i].totalTermFreq()));
+ } else {
+ totalTermFreqStats.accept(termContexts[i].totalTermFreq());
+ }
+ }
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+
+ return totalTermFreqStats;
+ }
+
+ /**
+ * Collect totalFreq (number of occurrence of a term in the current doc for the terms of the query and returns statistics for them.
+ *
+ * @return statistics on totalTermFreq for the terms of the query in the current dac
+ */
+ public StatsSummary termFreq() {
+ statsSummary.reset();
+ final int docId = docIdSupplier.getAsInt();
+
+ try {
+ for (PostingsEnum postingsEnum : postingsSupplier.get()) {
+ if (postingsEnum == null || postingsEnum.advance(docId) != docId) {
+ statsSummary.accept(0);
+ } else {
+ statsSummary.accept(postingsEnum.freq());
+ }
+ }
+
+ return statsSummary;
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Collect termPositions (positions of a term in the current document) for the terms of the query and returns statistics for them.
+ *
+ * @return statistics on termPositions for the terms of the query in the current dac
+ */
+ public StatsSummary termPositions() {
+ try {
+ statsSummary.reset();
+ int docId = docIdSupplier.getAsInt();
+
+ for (PostingsEnum postingsEnum : postingsSupplier.get()) {
+ if (postingsEnum == null || postingsEnum.advance(docId) != docId) {
+ continue;
+ }
+ for (int i = 0; i < postingsEnum.freq(); i++) {
+ statsSummary.accept(postingsEnum.nextPosition() + 1);
+ }
+ }
+
+ return statsSummary;
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private TermStates[] loadTermContexts() {
+ try {
+ TermStates[] termContexts = new TermStates[terms.length];
+
+ for (int i = 0; i < terms.length; i++) {
+ termContexts[i] = TermStates.build(searcher, terms[i], true);
+ }
+
+ return termContexts;
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private PostingsEnum[] loadPostings() {
+ try {
+ PostingsEnum[] postings = new PostingsEnum[terms.length];
+ TermStates[] contexts = termContextsSupplier.get();
+
+ for (int i = 0; i < terms.length; i++) {
+ TermStates termStates = contexts[i];
+ if (termStates.docFreq() == 0) {
+ postings[i] = null;
+ continue;
+ }
+
+ TermState state = termStates.get(leafReaderContext);
+ if (state == null) {
+ postings[i] = null;
+ continue;
+ }
+
+ TermsEnum termsEnum = leafReaderContext.reader().terms(terms[i].field()).iterator();
+ termsEnum.seekExact(terms[i].bytes(), state);
+
+ postings[i] = termsEnum.postings(null, PostingsEnum.ALL);
+ }
+
+ return postings;
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/script/StatsSummary.java b/server/src/main/java/org/elasticsearch/script/StatsSummary.java
new file mode 100644
index 0000000000000..d59a7af7e1fa9
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/StatsSummary.java
@@ -0,0 +1,123 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.script;
+
+import org.elasticsearch.common.Strings;
+
+import java.util.Objects;
+import java.util.function.DoubleConsumer;
+
+/**
+ * The {@link StatsSummary} class accumulates statistical data for a sequence of double values.
+ *
+ *
This class provides statistics such as count, sum, minimum, maximum, and arithmetic mean
+ * of the recorded values.
+ */
+public class StatsSummary implements DoubleConsumer {
+
+ private long count = 0;
+ private double sum = 0d;
+ private Double min;
+ private Double max;
+
+ public StatsSummary() {}
+
+ StatsSummary(long count, double sum, double min, double max) {
+ this.count = count;
+ this.sum = sum;
+ this.min = min;
+ this.max = max;
+ }
+
+ @Override
+ public void accept(double value) {
+ count++;
+ sum += value;
+ min = min == null ? value : (value < min ? value : min);
+ max = max == null ? value : (value > max ? value : max);
+ }
+
+ /**
+ * Returns the min for recorded value.
+ */
+ public double getMin() {
+ return min == null ? 0.0 : min;
+ }
+
+ /**
+ * Returns the max for recorded values.
+ */
+ public double getMax() {
+ return max == null ? 0.0 : max;
+ }
+
+ /**
+ * Returns the arithmetic mean for recorded values.
+ */
+ public double getAverage() {
+ return count == 0.0 ? 0.0 : sum / count;
+ }
+
+ /**
+ * Returns the sum of all recorded values.
+ */
+ public double getSum() {
+ return sum;
+ }
+
+ /**
+ * Returns the number of recorded values.
+ */
+ public long getCount() {
+ return count;
+ }
+
+ /**
+ * Resets the accumulator, clearing all accumulated statistics.
+ * After calling this method, the accumulator will be in its initial state.
+ */
+ public void reset() {
+ count = 0;
+ sum = 0d;
+ min = null;
+ max = null;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(count, sum, min, max);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+
+ StatsSummary other = (StatsSummary) obj;
+
+ return Objects.equals(count, other.count)
+ && Objects.equals(sum, other.sum)
+ && Objects.equals(min, other.min)
+ && Objects.equals(max, other.max);
+ }
+
+ @Override
+ public String toString() {
+ return Strings.format(
+ "%s{count=%d, sum=%f, min=%f, average=%f, max=%f}",
+ this.getClass().getSimpleName(),
+ getCount(),
+ getSum(),
+ getMin(),
+ getAverage(),
+ getMax()
+ );
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java
index cba2cf761e6f3..b23988af92606 100644
--- a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java
+++ b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java
@@ -495,13 +495,11 @@ static void intersectScorerAndBitSet(Scorer scorer, BitSet acceptDocs, LeafColle
@Override
public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) throws IOException {
- if (aggregatedDfs == null) {
- // we are either executing the dfs phase or the search_type doesn't include the dfs phase.
- return super.termStatistics(term, docFreq, totalTermFreq);
- }
- TermStatistics termStatistics = aggregatedDfs.termStatistics().get(term);
+ TermStatistics termStatistics = termStatisticsFromDfs(term);
+
if (termStatistics == null) {
- // we don't have stats for this - this might be a must_not clauses etc. that doesn't allow extract terms on the query
+ // we don't have stats for this - dfs might be disabled pr this might be a must_not clauses etc.
+ // that doesn't allow extract terms on the query
return super.termStatistics(term, docFreq, totalTermFreq);
}
return termStatistics;
@@ -521,6 +519,33 @@ public CollectionStatistics collectionStatistics(String field) throws IOExceptio
return collectionStatistics;
}
+ public long docFreq(Term term, long docFreq) throws IOException {
+ TermStatistics termStatistics = termStatisticsFromDfs(term);
+
+ if (termStatistics == null) {
+ return docFreq;
+ }
+ return termStatistics.docFreq();
+ }
+
+ public long totalTermFreq(Term term, long totalTermFreq) throws IOException {
+ TermStatistics termStatistics = termStatisticsFromDfs(term);
+
+ if (termStatistics == null) {
+ return totalTermFreq;
+ }
+ return termStatistics.docFreq();
+ }
+
+ private TermStatistics termStatisticsFromDfs(Term term) {
+ if (aggregatedDfs == null) {
+ // we are either executing the dfs phase or the search_type doesn't include the dfs phase.
+ return null;
+ }
+
+ return aggregatedDfs.termStatistics().get(term);
+ }
+
public DirectoryReader getDirectoryReader() {
final IndexReader reader = getIndexReader();
assert reader instanceof DirectoryReader : "expected an instance of DirectoryReader, got " + reader.getClass();
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java
index 0cdc9568f1fac..89d5e300112b4 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java
@@ -142,6 +142,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
@@ -165,6 +170,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java
index 3728943cae418..0507a37636370 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java
@@ -233,6 +233,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) throws IOException {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java
index 9547b4f9cb9a3..6e12778f87cf9 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java
@@ -141,6 +141,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/GeoPointScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/GeoPointScriptFieldTypeTests.java
index 3289e46941a45..863f6a0819554 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/GeoPointScriptFieldTypeTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/GeoPointScriptFieldTypeTests.java
@@ -150,6 +150,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java
index 4726424ada5f2..4593b149e13db 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java
@@ -155,6 +155,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java
index 6912194625bb7..b5270b358ec40 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java
@@ -136,6 +136,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java
index 83b3dbe858471..c9ac1516b6f8e 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java
@@ -177,6 +177,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
diff --git a/server/src/test/java/org/elasticsearch/script/ScriptTermStatsTests.java b/server/src/test/java/org/elasticsearch/script/ScriptTermStatsTests.java
new file mode 100644
index 0000000000000..d748ad0f1569d
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/script/ScriptTermStatsTests.java
@@ -0,0 +1,358 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.script;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.TextField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.store.ByteBuffersDirectory;
+import org.apache.lucene.store.Directory;
+import org.elasticsearch.core.CheckedConsumer;
+import org.elasticsearch.test.ESTestCase;
+import org.hamcrest.Matcher;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class ScriptTermStatsTests extends ESTestCase {
+ public void testMatchedTermsCount() throws IOException {
+ // Returns number of matched terms for each doc.
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "bar")),
+ ScriptTermStats::matchedTermsCount,
+ Map.of("doc-1", equalTo(2), "doc-2", equalTo(2), "doc-3", equalTo(1))
+ );
+
+ // Partial match
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "baz")),
+ ScriptTermStats::matchedTermsCount,
+ Map.of("doc-1", equalTo(1), "doc-2", equalTo(1), "doc-3", equalTo(0))
+ );
+
+ // Always returns 0 when no term is provided.
+ assertAllDocs(
+ Set.of(),
+ ScriptTermStats::matchedTermsCount,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(0)))
+ );
+
+ // Always Returns 0 when none of the provided term has a match.
+ assertAllDocs(
+ randomTerms(),
+ ScriptTermStats::matchedTermsCount,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(0)))
+ );
+
+ // Always returns 0 when using a non-existing field
+ assertAllDocs(
+ Set.of(new Term("field-that-does-not-exists", "foo"), new Term("field-that-does-not-exists", "bar")),
+ ScriptTermStats::matchedTermsCount,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(0)))
+ );
+ }
+
+ public void testDocFreq() throws IOException {
+ // Single term
+ {
+ StatsSummary expected = new StatsSummary(1, 2, 2, 2);
+ assertAllDocs(
+ Set.of(new Term("field", "foo")),
+ ScriptTermStats::docFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // Multiple terms
+ {
+ StatsSummary expected = new StatsSummary(2, 5, 2, 3);
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "bar")),
+ ScriptTermStats::docFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // With missing terms
+ {
+ StatsSummary expected = new StatsSummary(2, 2, 0, 2);
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "baz")),
+ ScriptTermStats::docFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // When no term is provided.
+ {
+ StatsSummary expected = new StatsSummary();
+ assertAllDocs(
+ Set.of(),
+ ScriptTermStats::docFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // When using a non-existing field
+ {
+ StatsSummary expected = new StatsSummary(2, 0, 0, 0);
+ assertAllDocs(
+ Set.of(new Term("non-existing-field", "foo"), new Term("non-existing-field", "baz")),
+ ScriptTermStats::docFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+ }
+
+ public void testTotalTermFreq() throws IOException {
+ // Single term
+ {
+ StatsSummary expected = new StatsSummary(1, 3, 3, 3);
+ assertAllDocs(
+ Set.of(new Term("field", "foo")),
+ ScriptTermStats::totalTermFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // Multiple terms
+ {
+ StatsSummary expected = new StatsSummary(2, 6, 3, 3);
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "bar")),
+ ScriptTermStats::totalTermFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // With missing terms
+ {
+ StatsSummary expected = new StatsSummary(2, 3, 0, 3);
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "baz")),
+ ScriptTermStats::totalTermFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // When no term is provided.
+ {
+ StatsSummary expected = new StatsSummary();
+ assertAllDocs(
+ Set.of(),
+ ScriptTermStats::totalTermFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // When using a non-existing field
+ {
+ StatsSummary expected = new StatsSummary(2, 0, 0, 0);
+ assertAllDocs(
+ Set.of(new Term("non-existing-field", "foo"), new Term("non-existing-field", "baz")),
+ ScriptTermStats::totalTermFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+ }
+
+ public void testTermFreq() throws IOException {
+ // Single term
+ {
+
+ assertAllDocs(
+ Set.of(new Term("field", "foo")),
+ ScriptTermStats::termFreq,
+ Map.ofEntries(
+ Map.entry("doc-1", equalTo(new StatsSummary(1, 1, 1, 1))),
+ Map.entry("doc-2", equalTo(new StatsSummary(1, 2, 2, 2))),
+ Map.entry("doc-3", equalTo(new StatsSummary(1, 0, 0, 0)))
+ )
+ );
+ }
+
+ // Multiple terms
+ {
+ StatsSummary expected = new StatsSummary(2, 6, 3, 3);
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "bar")),
+ ScriptTermStats::termFreq,
+ Map.ofEntries(
+ Map.entry("doc-1", equalTo(new StatsSummary(2, 2, 1, 1))),
+ Map.entry("doc-2", equalTo(new StatsSummary(2, 3, 1, 2))),
+ Map.entry("doc-3", equalTo(new StatsSummary(2, 1, 0, 1)))
+ )
+ );
+ }
+
+ // With missing terms
+ {
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "baz")),
+ ScriptTermStats::termFreq,
+ Map.ofEntries(
+ Map.entry("doc-1", equalTo(new StatsSummary(2, 1, 0, 1))),
+ Map.entry("doc-2", equalTo(new StatsSummary(2, 2, 0, 2))),
+ Map.entry("doc-3", equalTo(new StatsSummary(2, 0, 0, 0)))
+ )
+ );
+ }
+
+ // When no term is provided.
+ {
+ StatsSummary expected = new StatsSummary();
+ assertAllDocs(
+ Set.of(),
+ ScriptTermStats::termFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // When using a non-existing field
+ {
+ StatsSummary expected = new StatsSummary(2, 0, 0, 0);
+ assertAllDocs(
+ Set.of(new Term("non-existing-field", "foo"), new Term("non-existing-field", "baz")),
+ ScriptTermStats::termFreq,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+ }
+
+ public void testTermPositions() throws IOException {
+ // Single term
+ {
+
+ assertAllDocs(
+ Set.of(new Term("field", "foo")),
+ ScriptTermStats::termPositions,
+ Map.ofEntries(
+ Map.entry("doc-1", equalTo(new StatsSummary(1, 1, 1, 1))),
+ Map.entry("doc-2", equalTo(new StatsSummary(2, 3, 1, 2))),
+ Map.entry("doc-3", equalTo(new StatsSummary()))
+ )
+ );
+ }
+
+ // Multiple terms
+ {
+ StatsSummary expected = new StatsSummary(2, 6, 3, 3);
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "bar")),
+ ScriptTermStats::termPositions,
+ Map.ofEntries(
+ Map.entry("doc-1", equalTo(new StatsSummary(2, 3, 1, 2))),
+ Map.entry("doc-2", equalTo(new StatsSummary(3, 6, 1, 3))),
+ Map.entry("doc-3", equalTo(new StatsSummary(1, 1, 1, 1)))
+ )
+ );
+ }
+
+ // With missing terms
+ {
+ assertAllDocs(
+ Set.of(new Term("field", "foo"), new Term("field", "baz")),
+ ScriptTermStats::termPositions,
+ Map.ofEntries(
+ Map.entry("doc-1", equalTo(new StatsSummary(1, 1, 1, 1))),
+ Map.entry("doc-2", equalTo(new StatsSummary(2, 3, 1, 2))),
+ Map.entry("doc-3", equalTo(new StatsSummary()))
+ )
+ );
+ }
+
+ // When no term is provided.
+ {
+ StatsSummary expected = new StatsSummary();
+ assertAllDocs(
+ Set.of(),
+ ScriptTermStats::termPositions,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+
+ // When using a non-existing field
+ {
+ StatsSummary expected = new StatsSummary();
+ assertAllDocs(
+ Set.of(new Term("non-existing-field", "foo"), new Term("non-existing-field", "bar")),
+ ScriptTermStats::termPositions,
+ Stream.of("doc-1", "doc-2", "doc-3").collect(Collectors.toMap(Function.identity(), k -> equalTo(expected)))
+ );
+ }
+ }
+
+ private void withIndexSearcher(CheckedConsumer consummer) throws IOException {
+ try (Directory dir = new ByteBuffersDirectory()) {
+ IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());
+
+ Document doc = new Document();
+ doc.add(new TextField("id", "doc-1", Field.Store.YES));
+ doc.add(new TextField("field", "foo bar", Field.Store.YES));
+ w.addDocument(doc);
+
+ doc = new Document();
+ doc.add(new TextField("id", "doc-2", Field.Store.YES));
+ doc.add(new TextField("field", "foo foo bar", Field.Store.YES));
+ w.addDocument(doc);
+
+ doc = new Document();
+ doc.add(new TextField("id", "doc-3", Field.Store.YES));
+ doc.add(new TextField("field", "bar", Field.Store.YES));
+ w.addDocument(doc);
+
+ try (IndexReader r = DirectoryReader.open(w)) {
+ w.close();
+ consummer.accept(newSearcher(r));
+ }
+ }
+ }
+
+ private void assertAllDocs(Set terms, Function function, Map> expectedValues)
+ throws IOException {
+ withIndexSearcher(searcher -> {
+ for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
+ IndexReader reader = leafReaderContext.reader();
+ DocIdSetIterator docIdSetIterator = DocIdSetIterator.all(reader.maxDoc());
+ ScriptTermStats termStats = new ScriptTermStats(searcher, leafReaderContext, docIdSetIterator::docID, terms);
+ while (docIdSetIterator.nextDoc() <= reader.maxDoc()) {
+ String docId = reader.document(docIdSetIterator.docID()).get("id");
+ if (expectedValues.containsKey(docId)) {
+ assertThat(function.apply(termStats), expectedValues.get(docId));
+ }
+ }
+ }
+ });
+ }
+
+ private Set randomTerms() {
+ Predicate isReservedTerm = term -> List.of("foo", "bar").contains(term.text());
+ Supplier termSupplier = () -> randomValueOtherThanMany(
+ isReservedTerm,
+ () -> new Term("field", randomAlphaOfLengthBetween(1, 5))
+ );
+ return randomSet(1, randomIntBetween(1, 10), termSupplier);
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/script/StatsSummaryTests.java b/server/src/test/java/org/elasticsearch/script/StatsSummaryTests.java
new file mode 100644
index 0000000000000..4c3d3ea4b7595
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/script/StatsSummaryTests.java
@@ -0,0 +1,83 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.script;
+
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
+
+public class StatsSummaryTests extends ESTestCase {
+ public void tesEmpty() {
+ StatsSummary accumulator = new StatsSummary();
+ assertThat(accumulator.getMin(), equalTo(0.0));
+ assertThat(accumulator.getMax(), equalTo(0.0));
+ assertThat(accumulator.getSum(), equalTo(0.0));
+ assertThat(accumulator.getAverage(), equalTo(0.0));
+ }
+
+ public void testGivenPositiveValues() {
+ StatsSummary accumulator = new StatsSummary();
+
+ for (int i = 1; i <= 10; i++) {
+ accumulator.accept(i);
+ }
+
+ assertThat(accumulator.getMin(), equalTo(1.0));
+ assertThat(accumulator.getMax(), equalTo(10.0));
+ assertThat(accumulator.getSum(), equalTo(55.0));
+ assertThat(accumulator.getAverage(), equalTo(5.5));
+ }
+
+ public void testGivenNegativeValues() {
+ StatsSummary accumulator = new StatsSummary();
+
+ for (int i = 1; i <= 10; i++) {
+ accumulator.accept(-1 * i);
+ }
+
+ assertThat(accumulator.getMin(), equalTo(-10.0));
+ assertThat(accumulator.getMax(), equalTo(-1.0));
+ assertThat(accumulator.getSum(), equalTo(-55.0));
+ assertThat(accumulator.getAverage(), equalTo(-5.5));
+ }
+
+ public void testReset() {
+ StatsSummary accumulator = new StatsSummary();
+ randomDoubles(randomIntBetween(1, 20)).forEach(accumulator);
+ assertThat(accumulator, not(equalTo(new StatsSummary())));
+
+ accumulator.reset();
+ assertThat(accumulator, equalTo(new StatsSummary()));
+ assertThat(accumulator.getMin(), equalTo(0.0));
+ assertThat(accumulator.getMax(), equalTo(0.0));
+ assertThat(accumulator.getSum(), equalTo(0.0));
+ assertThat(accumulator.getAverage(), equalTo(0.0));
+ }
+
+ public void testEqualsAndHashCode() {
+ StatsSummary stats1 = new StatsSummary();
+ StatsSummary stats2 = new StatsSummary();
+
+ // Empty accumulators are equals.
+ assertThat(stats1, equalTo(stats2));
+ assertThat(stats1.hashCode(), equalTo(stats2.hashCode()));
+
+ // Accumulators with same values are equals
+ randomDoubles(randomIntBetween(0, 20)).forEach(stats1.andThen(stats2));
+ assertThat(stats1, equalTo(stats2));
+ assertThat(stats1.hashCode(), equalTo(stats2.hashCode()));
+
+ // Accumulators with different values are not equals
+ randomDoubles(randomIntBetween(0, 20)).forEach(stats1);
+ randomDoubles(randomIntBetween(0, 20)).forEach(stats2);
+ assertThat(stats1, not(equalTo(stats2)));
+ assertThat(stats1.hashCode(), not(equalTo(stats2.hashCode())));
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/search/fetch/subphase/highlight/PlainHighlighterTests.java b/server/src/test/java/org/elasticsearch/search/fetch/subphase/highlight/PlainHighlighterTests.java
index a6c95efef9c46..209f67b1bb504 100644
--- a/server/src/test/java/org/elasticsearch/search/fetch/subphase/highlight/PlainHighlighterTests.java
+++ b/server/src/test/java/org/elasticsearch/search/fetch/subphase/highlight/PlainHighlighterTests.java
@@ -134,6 +134,11 @@ public boolean needs_score() {
return true;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader reader) throws IOException {
return new ScoreScript(params, lookup, reader) {
diff --git a/server/src/test/java/org/elasticsearch/search/query/ScriptScoreQueryTests.java b/server/src/test/java/org/elasticsearch/search/query/ScriptScoreQueryTests.java
index 6fff108cfb5ce..d6b1da9f76b42 100644
--- a/server/src/test/java/org/elasticsearch/search/query/ScriptScoreQueryTests.java
+++ b/server/src/test/java/org/elasticsearch/search/query/ScriptScoreQueryTests.java
@@ -15,9 +15,14 @@
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.elasticsearch.common.lucene.search.Queries;
@@ -26,13 +31,16 @@
import org.elasticsearch.script.DocReader;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptTermStats;
import org.elasticsearch.search.lookup.LeafSearchLookup;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.test.ESTestCase;
import org.junit.After;
import org.junit.Before;
+import org.mockito.ArgumentCaptor;
import java.io.IOException;
+import java.util.function.BiFunction;
import java.util.function.Function;
import static org.hamcrest.CoreMatchers.containsString;
@@ -40,6 +48,8 @@
import static org.hamcrest.collection.IsArrayWithSize.arrayWithSize;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ScriptScoreQueryTests extends ESTestCase {
@@ -74,22 +84,13 @@ public void closeAllTheReaders() throws IOException {
public void testExplain() throws IOException {
Script script = new Script("script using explain");
- ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> {
+ ScoreScript.LeafFactory factory = newFactory(script, true, false, explanation -> {
assertNotNull(explanation);
explanation.set("this explains the score");
return 1.0;
});
- ScriptScoreQuery query = new ScriptScoreQuery(
- Queries.newMatchAllQuery(),
- script,
- factory,
- lookup,
- null,
- "index",
- 0,
- IndexVersion.current()
- );
+ ScriptScoreQuery query = createScriptScoreQuery(Queries.newMatchAllQuery(), script, factory);
Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
Explanation explanation = weight.explain(leafReaderContext, 0);
assertNotNull(explanation);
@@ -99,18 +100,9 @@ public void testExplain() throws IOException {
public void testExplainDefault() throws IOException {
Script script = new Script("script without setting explanation");
- ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> 1.5);
+ ScoreScript.LeafFactory factory = newFactory(script, true, false, explanation -> 1.5);
- ScriptScoreQuery query = new ScriptScoreQuery(
- Queries.newMatchAllQuery(),
- script,
- factory,
- lookup,
- null,
- "index",
- 0,
- IndexVersion.current()
- );
+ ScriptScoreQuery query = createScriptScoreQuery(Queries.newMatchAllQuery(), script, factory);
Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
Explanation explanation = weight.explain(leafReaderContext, 0);
assertNotNull(explanation);
@@ -124,18 +116,9 @@ public void testExplainDefault() throws IOException {
public void testExplainDefaultNoScore() throws IOException {
Script script = new Script("script without setting explanation and no score");
- ScoreScript.LeafFactory factory = newFactory(script, false, explanation -> 2.0);
+ ScoreScript.LeafFactory factory = newFactory(script, false, false, explanation -> 2.0);
- ScriptScoreQuery query = new ScriptScoreQuery(
- Queries.newMatchAllQuery(),
- script,
- factory,
- lookup,
- null,
- "index",
- 0,
- IndexVersion.current()
- );
+ ScriptScoreQuery query = createScriptScoreQuery(Queries.newMatchAllQuery(), script, factory);
Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
Explanation explanation = weight.explain(leafReaderContext, 0);
assertNotNull(explanation);
@@ -148,43 +131,91 @@ public void testExplainDefaultNoScore() throws IOException {
public void testScriptScoreErrorOnNegativeScore() {
Script script = new Script("script that returns a negative score");
- ScoreScript.LeafFactory factory = newFactory(script, false, explanation -> -1000.0);
- ScriptScoreQuery query = new ScriptScoreQuery(
- Queries.newMatchAllQuery(),
- script,
- factory,
- lookup,
- null,
- "index",
- 0,
- IndexVersion.current()
- );
+ ScoreScript.LeafFactory factory = newFactory(script, false, false, explanation -> -1000.0);
+ ScriptScoreQuery query = createScriptScoreQuery(Queries.newMatchAllQuery(), script, factory);
+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> searcher.search(query, 1));
assertTrue(e.getMessage().contains("Must be a non-negative score!"));
}
+ public void testScriptTermStatsAvailable() throws IOException {
+ Script script = new Script("termStats script without setting explanation");
+ ScoreScript scoreScriptMock = mock(ScoreScript.class);
+ ScoreScript.LeafFactory factory = newFactory(false, true, (lookup, docReader) -> scoreScriptMock);
+
+ ScriptScoreQuery query = createScriptScoreQuery(
+ new BooleanQuery.Builder().add(new TermQuery(new Term("field", "text")), BooleanClause.Occur.MUST)
+ .add(new TermQuery(new Term("field", "missing")), BooleanClause.Occur.SHOULD)
+ .build(),
+ script,
+ factory
+ );
+
+ query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f).scorer(leafReaderContext);
+
+ ArgumentCaptor scriptTermStats = ArgumentCaptor.forClass(ScriptTermStats.class);
+ verify(scoreScriptMock)._setTermStats(scriptTermStats.capture());
+ assertThat(scriptTermStats.getValue().uniqueTermsCount(), equalTo(2));
+ }
+
+ public void testScriptTermStatsNotAvailable() throws IOException {
+ Script script = new Script("termStats script without setting explanation");
+ ScoreScript scoreScriptMock = mock(ScoreScript.class);
+ ScoreScript.LeafFactory factory = newFactory(false, false, (lookup, docReader) -> scoreScriptMock);
+
+ ScriptScoreQuery query = createScriptScoreQuery(
+ new BooleanQuery.Builder().add(new TermQuery(new Term("field", "text")), BooleanClause.Occur.MUST)
+ .add(new TermQuery(new Term("field", "missing")), BooleanClause.Occur.SHOULD)
+ .build(),
+ script,
+ factory
+ );
+
+ query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f).scorer(leafReaderContext);
+ verify(scoreScriptMock, never())._setTermStats(any());
+ }
+
+ private ScriptScoreQuery createScriptScoreQuery(Query subQuery, Script script, ScoreScript.LeafFactory factory) {
+ return new ScriptScoreQuery(subQuery, script, factory, lookup, null, "index", 0, IndexVersion.current());
+ }
+
private ScoreScript.LeafFactory newFactory(
Script script,
boolean needsScore,
+ boolean needsTermStats,
Function function
+ ) {
+ return newFactory(needsScore, needsTermStats, (lookup, docReader) -> new ScoreScript(script.getParams(), lookup, docReader) {
+ @Override
+ public double execute(ExplanationHolder explanation) {
+ return function.apply(explanation);
+ }
+ });
+ }
+
+ private ScoreScript.LeafFactory newFactory(
+ boolean needsScore,
+ boolean needsTermStats,
+ BiFunction scopreScriptProvider
) {
SearchLookup lookup = mock(SearchLookup.class);
LeafSearchLookup leafLookup = mock(LeafSearchLookup.class);
when(lookup.getLeafSearchLookup(any())).thenReturn(leafLookup);
+
return new ScoreScript.LeafFactory() {
@Override
public boolean needs_score() {
return needsScore;
}
+ @Override
+ public boolean needs_termStats() {
+ return needsTermStats;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
- return new ScoreScript(script.getParams(), lookup, docReader) {
- @Override
- public double execute(ExplanationHolder explanation) {
- return function.apply(explanation);
- }
- };
+ return scopreScriptProvider.apply(lookup, docReader);
}
};
}
diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java
index 2eb1811c12691..995481b31243e 100644
--- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java
+++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java
@@ -739,6 +739,11 @@ public boolean needs_score() {
return true;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) throws IOException {
Scorable[] scorerHolder = new Scorable[1];
diff --git a/x-pack/plugin/mapper-aggregate-metric/src/test/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateDoubleMetricFieldTypeTests.java b/x-pack/plugin/mapper-aggregate-metric/src/test/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateDoubleMetricFieldTypeTests.java
index 4b7b27bf2cec3..89c2799d8327d 100644
--- a/x-pack/plugin/mapper-aggregate-metric/src/test/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateDoubleMetricFieldTypeTests.java
+++ b/x-pack/plugin/mapper-aggregate-metric/src/test/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateDoubleMetricFieldTypeTests.java
@@ -136,6 +136,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchExecutionContext.lookup(), docReader) {
diff --git a/x-pack/plugin/spatial/src/test/java/org/elasticsearch/xpack/spatial/index/mapper/GeoShapeScriptFieldTypeTests.java b/x-pack/plugin/spatial/src/test/java/org/elasticsearch/xpack/spatial/index/mapper/GeoShapeScriptFieldTypeTests.java
index 592cb65800b71..5d6f68f8e06a9 100644
--- a/x-pack/plugin/spatial/src/test/java/org/elasticsearch/xpack/spatial/index/mapper/GeoShapeScriptFieldTypeTests.java
+++ b/x-pack/plugin/spatial/src/test/java/org/elasticsearch/xpack/spatial/index/mapper/GeoShapeScriptFieldTypeTests.java
@@ -174,6 +174,11 @@ public boolean needs_score() {
return false;
}
+ @Override
+ public boolean needs_termStats() {
+ return false;
+ }
+
@Override
public ScoreScript newInstance(DocReader docReader) {
return new ScoreScript(Map.of(), searchContext.lookup(), docReader) {
From d6b5c3d5cd4b78d2b6a79497f8ebc3202e8ec7b6 Mon Sep 17 00:00:00 2001
From: Athena Brown
Date: Tue, 27 Aug 2024 10:34:17 -0600
Subject: [PATCH 22/46] Fix DLS & FLS sometimes being enforced when it is
disabled (#111915)
This commit adjusts a few places where DLS & FLS are enforced
to make sure they respect the `xpack.security.dls_fls.enabled`
setting.
---
docs/changelog/111915.yaml | 6 ++++
.../xpack/security/Security.java | 13 ++++---
.../IndicesAliasesRequestInterceptor.java | 7 ++--
.../interceptor/ResizeRequestInterceptor.java | 11 ++++--
...IndicesAliasesRequestInterceptorTests.java | 35 +++++++++++++++----
.../ResizeRequestInterceptorTests.java | 30 ++++++++++++----
6 files changed, 80 insertions(+), 22 deletions(-)
create mode 100644 docs/changelog/111915.yaml
diff --git a/docs/changelog/111915.yaml b/docs/changelog/111915.yaml
new file mode 100644
index 0000000000000..f64c45b82d10c
--- /dev/null
+++ b/docs/changelog/111915.yaml
@@ -0,0 +1,6 @@
+pr: 111915
+summary: Fix DLS & FLS sometimes being enforced when it is disabled
+area: Authorization
+type: bug
+issues:
+ - 94709
diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java
index 11c688e9ee5eb..7e44d6d8b1c99 100644
--- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java
+++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java
@@ -584,6 +584,7 @@ public class Security extends Plugin
private Settings settings;
private final boolean enabled;
+ private final SetOnce dlsFlsEnabled = new SetOnce<>();
private final SecuritySystemIndices systemIndices;
private final ListenableFuture nodeStartedListenable;
@@ -1106,12 +1107,13 @@ Collection
*/
@Override public T visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) { return visitChildren(ctx); }
- /**
- * {@inheritDoc}
- *
- *
The default implementation returns the result of calling
- * {@link #visitChildren} on {@code ctx}.
- */
- @Override public T visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
@@ -593,6 +594,13 @@ public class EsqlBaseParserBaseVisitor extends AbstractParseTreeVisitor im
* {@link #visitChildren} on {@code ctx}.
*/
@Override public T visitLookupCommand(EsqlBaseParser.LookupCommandContext ctx) { return visitChildren(ctx); }
+ /**
+ * {@inheritDoc}
+ *
+ *
The default implementation returns the result of calling
+ * {@link #visitChildren} on {@code ctx}.
+ */
+ @Override public T visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
index 6ef7b2e07ce78..0c39b3ea83fa9 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
@@ -1,5 +1,13 @@
// ANTLR GENERATED CODE: DO NOT EDIT
package org.elasticsearch.xpack.esql.parser;
+
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
import org.antlr.v4.runtime.tree.ParseTreeListener;
/**
@@ -447,16 +455,6 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitStatsCommand(EsqlBaseParser.StatsCommandContext ctx);
- /**
- * Enter a parse tree produced by {@link EsqlBaseParser#inlinestatsCommand}.
- * @param ctx the parse tree
- */
- void enterInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx);
- /**
- * Exit a parse tree produced by {@link EsqlBaseParser#inlinestatsCommand}.
- * @param ctx the parse tree
- */
- void exitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#qualifiedName}.
* @param ctx the parse tree
@@ -905,6 +903,16 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitLookupCommand(EsqlBaseParser.LookupCommandContext ctx);
+ /**
+ * Enter a parse tree produced by {@link EsqlBaseParser#inlinestatsCommand}.
+ * @param ctx the parse tree
+ */
+ void enterInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx);
+ /**
+ * Exit a parse tree produced by {@link EsqlBaseParser#inlinestatsCommand}.
+ * @param ctx the parse tree
+ */
+ void exitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#matchCommand}.
* @param ctx the parse tree
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
index cbef7c55372a4..31c9371b9f806 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
@@ -1,5 +1,13 @@
// ANTLR GENERATED CODE: DO NOT EDIT
package org.elasticsearch.xpack.esql.parser;
+
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
import org.antlr.v4.runtime.tree.ParseTreeVisitor;
/**
@@ -270,12 +278,6 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor {
* @return the visitor result
*/
T visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx);
- /**
- * Visit a parse tree produced by {@link EsqlBaseParser#inlinestatsCommand}.
- * @param ctx the parse tree
- * @return the visitor result
- */
- T visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#qualifiedName}.
* @param ctx the parse tree
@@ -542,6 +544,12 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor {
* @return the visitor result
*/
T visitLookupCommand(EsqlBaseParser.LookupCommandContext ctx);
+ /**
+ * Visit a parse tree produced by {@link EsqlBaseParser#inlinestatsCommand}.
+ * @param ctx the parse tree
+ * @return the visitor result
+ */
+ T visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#matchCommand}.
* @param ctx the parse tree
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlConfig.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlConfig.java
new file mode 100644
index 0000000000000..8e5aa28d25c00
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlConfig.java
@@ -0,0 +1,28 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.parser;
+
+import org.elasticsearch.Build;
+
+class EsqlConfig {
+
+ // versioning information
+ boolean devVersion = Build.current().isSnapshot();
+
+ public boolean isDevVersion() {
+ return devVersion;
+ }
+
+ boolean isReleaseVersion() {
+ return isDevVersion() == false;
+ }
+
+ public void setDevVersion(boolean dev) {
+ this.devVersion = dev;
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java
index 593e9db9fc956..4be80af958e36 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java
@@ -7,23 +7,24 @@
package org.elasticsearch.xpack.esql.parser;
import org.antlr.v4.runtime.BaseErrorListener;
-import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.RecognitionException;
import org.antlr.v4.runtime.Recognizer;
import org.antlr.v4.runtime.Token;
-import org.antlr.v4.runtime.TokenFactory;
import org.antlr.v4.runtime.TokenSource;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
+import org.elasticsearch.xpack.esql.core.util.StringUtils;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import java.util.BitSet;
import java.util.function.BiFunction;
import java.util.function.Function;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
import static org.elasticsearch.xpack.esql.core.util.StringUtils.isInteger;
import static org.elasticsearch.xpack.esql.parser.ParserUtils.source;
@@ -32,6 +33,16 @@ public class EsqlParser {
private static final Logger log = LogManager.getLogger(EsqlParser.class);
+ private EsqlConfig config = new EsqlConfig();
+
+ public EsqlConfig config() {
+ return config;
+ }
+
+ public void setEsqlConfig(EsqlConfig config) {
+ this.config = config;
+ }
+
public LogicalPlan createStatement(String query) {
return createStatement(query, new QueryParams());
}
@@ -50,11 +61,14 @@ private T invokeParser(
BiFunction result
) {
try {
- EsqlBaseLexer lexer = new EsqlBaseLexer(new CaseChangingCharStream(CharStreams.fromString(query)));
+ // new CaseChangingCharStream()
+ EsqlBaseLexer lexer = new EsqlBaseLexer(CharStreams.fromString(query));
lexer.removeErrorListeners();
lexer.addErrorListener(ERROR_LISTENER);
+ lexer.setEsqlConfig(config);
+
TokenSource tokenSource = new ParametrizedTokenSource(lexer, params);
CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
EsqlBaseParser parser = new EsqlBaseParser(tokenStream);
@@ -66,6 +80,8 @@ private T invokeParser(
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
+ parser.setEsqlConfig(config);
+
ParserRuleContext tree = parseFunction.apply(parser);
if (log.isTraceEnabled()) {
@@ -93,6 +109,9 @@ public void exitFunctionExpression(EsqlBaseParser.FunctionExpressionContext ctx)
}
private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
+ // replace entries that start with DEV_
+ private final Pattern REPLACE_DEV = Pattern.compile(",*\\s*DEV_\\w+\\s*");
+
@Override
public void syntaxError(
Recognizer, ?> recognizer,
@@ -102,27 +121,31 @@ public void syntaxError(
String message,
RecognitionException e
) {
+ if (recognizer instanceof EsqlBaseParser parser && parser.isDevVersion() == false) {
+ Matcher m = REPLACE_DEV.matcher(message);
+ message = m.replaceAll(StringUtils.EMPTY);
+ }
+
throw new ParsingException(message, e, line, charPositionInLine);
}
};
/**
- * Finds all parameter tokens (?) and associates them with actual parameter values
+ * Finds all parameter tokens (?) and associates them with actual parameter values.
*
* Parameters are positional and we know where parameters occurred in the original stream in order to associate them
* with actual values.
*/
- private static class ParametrizedTokenSource implements TokenSource {
+ private static class ParametrizedTokenSource extends DelegatingTokenSource {
private static String message = "Inconsistent parameter declaration, "
+ "use one of positional, named or anonymous params but not a combination of ";
- private TokenSource delegate;
private QueryParams params;
private BitSet paramTypes = new BitSet(3);
private int param = 1;
ParametrizedTokenSource(TokenSource delegate, QueryParams params) {
- this.delegate = delegate;
+ super(delegate);
this.params = params;
}
@@ -148,36 +171,6 @@ public Token nextToken() {
return token;
}
- @Override
- public int getLine() {
- return delegate.getLine();
- }
-
- @Override
- public int getCharPositionInLine() {
- return delegate.getCharPositionInLine();
- }
-
- @Override
- public CharStream getInputStream() {
- return delegate.getInputStream();
- }
-
- @Override
- public String getSourceName() {
- return delegate.getSourceName();
- }
-
- @Override
- public void setTokenFactory(TokenFactory> factory) {
- delegate.setTokenFactory(factory);
- }
-
- @Override
- public TokenFactory> getTokenFactory() {
- return delegate.getTokenFactory();
- }
-
private void checkAnonymousParam(Token token) {
paramTypes.set(0);
if (paramTypes.cardinality() > 1) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
index c344fdc144e60..0352afdee4622 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
@@ -772,7 +772,7 @@ public Expression visitMatchBooleanExpression(EsqlBaseParser.MatchBooleanExpress
}
return new MatchQueryPredicate(
source(ctx),
- visitQualifiedName(ctx.qualifiedName()),
+ expression(ctx.valueExpression()),
visitString(ctx.queryString).fold().toString(),
null
);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LexerConfig.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LexerConfig.java
new file mode 100644
index 0000000000000..adcdba2f2eb4d
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LexerConfig.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.parser;
+
+import org.antlr.v4.runtime.CharStream;
+import org.antlr.v4.runtime.Lexer;
+
+/**
+ * Base class for hooking versioning information into the ANTLR parser.
+ */
+public abstract class LexerConfig extends Lexer {
+
+ // is null when running inside the IDEA plugin
+ EsqlConfig config;
+
+ public LexerConfig() {}
+
+ public LexerConfig(CharStream input) {
+ super(input);
+ }
+
+ boolean isDevVersion() {
+ return config == null || config.isDevVersion();
+ }
+
+ void setEsqlConfig(EsqlConfig config) {
+ this.config = config;
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ParserConfig.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ParserConfig.java
new file mode 100644
index 0000000000000..c6d4d3efa4a5a
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ParserConfig.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.parser;
+
+import org.antlr.v4.runtime.Parser;
+import org.antlr.v4.runtime.TokenStream;
+
+public abstract class ParserConfig extends Parser {
+
+ // is null when running inside the IDEA plugin
+ private EsqlConfig config;
+
+ public ParserConfig(TokenStream input) {
+ super(input);
+ }
+
+ boolean isDevVersion() {
+ return config == null || config.isDevVersion();
+ }
+
+ void setEsqlConfig(EsqlConfig config) {
+ this.config = config;
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/GrammarInDevelopmentParsingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/GrammarInDevelopmentParsingTests.java
new file mode 100644
index 0000000000000..18d8bc9fb0a75
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/GrammarInDevelopmentParsingTests.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.parser;
+
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.not;
+
+public class GrammarInDevelopmentParsingTests extends ESTestCase {
+
+ public void testDevelopmentInline() throws Exception {
+ parse("row a = 1 | inlinestats b = min(a) by c, d.e", "inlinestats");
+ }
+
+ public void testDevelopmentLookup() throws Exception {
+ parse("row a = 1 | lookup \"foo\" on j", "lookup");
+ }
+
+ public void testDevelopmentMetrics() throws Exception {
+ parse("metrics foo", "metrics");
+ }
+
+ public void testDevelopmentMatch() throws Exception {
+ parse("row a = 1 | match foo", "match");
+ }
+
+ void parse(String query, String errorMessage) {
+ ParsingException pe = expectThrows(ParsingException.class, () -> parser().createStatement(query));
+ assertThat(pe.getMessage(), containsString("mismatched input '" + errorMessage + "'"));
+ // check the parser eliminated the DEV_ tokens from the message
+ assertThat(pe.getMessage(), not(containsString("DEV_")));
+ }
+
+ private EsqlParser parser() {
+ EsqlParser parser = new EsqlParser();
+ EsqlConfig config = parser.config();
+ assumeTrue(" requires snapshot builds", config.devVersion);
+
+ // manually disable dev mode (make it production)
+ config.setDevVersion(false);
+ return parser;
+ }
+}
From 17e8b5fecf36de9b760a591243c5d2ea7a04804e Mon Sep 17 00:00:00 2001
From: weizijun
Date: Wed, 28 Aug 2024 05:31:23 +0800
Subject: [PATCH 25/46] [Inference API] Add Alibaba Cloud AI Search Model
support to Inference API (#111181)
Add Alibaba Cloud AI Search Model support to Inference API.
Supports the text_embedding, sparse_embedding and rerank tasks.
Requires an Alibaba Cloud Account with Alibaba Cloud Opensearch access
and an api key used to access Alibaba Cloud AI Search Model.
---
docs/changelog/111181.yaml | 5 +
.../org/elasticsearch/TransportVersions.java | 1 +
.../InferenceNamedWriteablesProvider.java | 61 ++++
.../xpack/inference/InferencePlugin.java | 2 +
.../AlibabaCloudSearchActionCreator.java | 53 +++
.../AlibabaCloudSearchActionVisitor.java | 24 ++
.../AlibabaCloudSearchEmbeddingsAction.java | 57 ++++
.../AlibabaCloudSearchRerankAction.java | 61 ++++
.../AlibabaCloudSearchSparseAction.java | 61 ++++
.../AlibabaCloudSearchAccount.java | 19 ++
.../AlibabaCloudSearchResponseHandler.java | 63 ++++
...baCloudSearchEmbeddingsRequestManager.java | 77 +++++
.../AlibabaCloudSearchRequestManager.java | 28 ++
...libabaCloudSearchRerankRequestManager.java | 77 +++++
...libabaCloudSearchSparseRequestManager.java | 77 +++++
.../AlibabaCloudSearchEmbeddingsRequest.java | 111 ++++++
...abaCloudSearchEmbeddingsRequestEntity.java | 66 ++++
.../AlibabaCloudSearchRequest.java | 22 ++
.../AlibabaCloudSearchRerankRequest.java | 113 +++++++
...AlibabaCloudSearchRerankRequestEntity.java | 42 +++
.../AlibabaCloudSearchSparseRequest.java | 111 ++++++
...AlibabaCloudSearchSparseRequestEntity.java | 47 +++
.../AlibabaCloudSearchUtils.java | 18 +
...baCloudSearchEmbeddingsResponseEntity.java | 109 ++++++
...AlibabaCloudSearchErrorResponseEntity.java | 69 ++++
...libabaCloudSearchRerankResponseEntity.java | 139 ++++++++
.../AlibabaCloudSearchResponseEntity.java | 78 +++++
...libabaCloudSearchSparseResponseEntity.java | 199 +++++++++++
.../AlibabaCloudSearchModel.java | 49 +++
...baCloudSearchRateLimitServiceSettings.java | 15 +
.../AlibabaCloudSearchService.java | 318 ++++++++++++++++++
.../AlibabaCloudSearchServiceSettings.java | 193 +++++++++++
.../AlibabaCloudSearchEmbeddingsModel.java | 104 ++++++
...aCloudSearchEmbeddingsServiceSettings.java | 152 +++++++++
...babaCloudSearchEmbeddingsTaskSettings.java | 173 ++++++++++
.../rerank/AlibabaCloudSearchRerankModel.java | 94 ++++++
...ibabaCloudSearchRerankServiceSettings.java | 97 ++++++
.../AlibabaCloudSearchRerankTaskSettings.java | 101 ++++++
.../sparse/AlibabaCloudSearchSparseModel.java | 98 ++++++
...ibabaCloudSearchSparseServiceSettings.java | 97 ++++++
.../AlibabaCloudSearchSparseTaskSettings.java | 186 ++++++++++
.../xpack/inference/InputTypeTests.java | 4 +
...oudSearchEmbeddingsRequestEntityTests.java | 57 ++++
...babaCloudSearchEmbeddingsRequestTests.java | 63 ++++
...baCloudSearchRerankRequestEntityTests.java | 34 ++
...baCloudSearchSparseRequestEntityTests.java | 49 +++
.../AlibabaCloudSearchSparseRequestTests.java | 63 ++++
...udSearchEmbeddingsResponseEntityTests.java | 69 ++++
...baCloudSearchErrorResponseEntityTests.java | 35 ++
...aCloudSearchRerankResponseEntityTests.java | 71 ++++
...aCloudSearchSparseResponseEntityTests.java | 85 +++++
...libabaCloudSearchServiceSettingsTests.java | 125 +++++++
.../AlibabaCloudSearchServiceTests.java | 172 ++++++++++
...libabaCloudSearchEmbeddingsModelTests.java | 71 ++++
...dSearchEmbeddingsServiceSettingsTests.java | 96 ++++++
...loudSearchEmbeddingsTaskSettingsTests.java | 73 ++++
.../AlibabaCloudSearchSparseModelTests.java | 71 ++++
...CloudSearchSparseServiceSettingsTests.java | 77 +++++
...abaCloudSearchSparseTaskSettingsTests.java | 74 ++++
59 files changed, 4756 insertions(+)
create mode 100644 docs/changelog/111181.yaml
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchEmbeddingsAction.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchRerankAction.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchSparseAction.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchAccount.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchResponseHandler.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRequestManager.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequest.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRequest.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequest.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchResponseEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntity.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchModel.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchRateLimitServiceSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankModel.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettings.java
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntityTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettingsTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettingsTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettingsTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java
diff --git a/docs/changelog/111181.yaml b/docs/changelog/111181.yaml
new file mode 100644
index 0000000000000..7f9f5937b7652
--- /dev/null
+++ b/docs/changelog/111181.yaml
@@ -0,0 +1,5 @@
+pr: 111181
+summary: "[Inference API] Add Alibaba Cloud AI Search Model support to Inference API"
+area: Machine Learning
+type: enhancement
+issues: [ ]
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index 41fa34bb5a4a3..c68a33c6df6c4 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -200,6 +200,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_ES_FIELD_CACHED_SERIALIZATION = def(8_730_00_0);
public static final TransportVersion ADD_MANAGE_ROLES_PRIVILEGE = def(8_731_00_0);
public static final TransportVersion REPOSITORIES_TELEMETRY = def(8_732_00_0);
+ public static final TransportVersion ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED = def(8_733_00_0);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index 489a81b642492..d4810ba930b44 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -25,6 +25,13 @@
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankServiceSettings;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseServiceSettings;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
@@ -117,6 +124,7 @@ public static List getNamedWriteables() {
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
addEisNamedWriteables(namedWriteables);
+ addAlibabaCloudSearchNamedWriteables(namedWriteables);
return namedWriteables;
}
@@ -482,6 +490,59 @@ private static void addAnthropicNamedWritables(List namedWriteables) {
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(
+ ServiceSettings.class,
+ AlibabaCloudSearchServiceSettings.NAME,
+ AlibabaCloudSearchServiceSettings::new
+ )
+ );
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(
+ ServiceSettings.class,
+ AlibabaCloudSearchEmbeddingsServiceSettings.NAME,
+ AlibabaCloudSearchEmbeddingsServiceSettings::new
+ )
+ );
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(
+ TaskSettings.class,
+ AlibabaCloudSearchEmbeddingsTaskSettings.NAME,
+ AlibabaCloudSearchEmbeddingsTaskSettings::new
+ )
+ );
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(
+ ServiceSettings.class,
+ AlibabaCloudSearchSparseServiceSettings.NAME,
+ AlibabaCloudSearchSparseServiceSettings::new
+ )
+ );
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(
+ TaskSettings.class,
+ AlibabaCloudSearchSparseTaskSettings.NAME,
+ AlibabaCloudSearchSparseTaskSettings::new
+ )
+ );
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(
+ ServiceSettings.class,
+ AlibabaCloudSearchRerankServiceSettings.NAME,
+ AlibabaCloudSearchRerankServiceSettings::new
+ )
+ );
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(
+ TaskSettings.class,
+ AlibabaCloudSearchRerankTaskSettings.NAME,
+ AlibabaCloudSearchRerankTaskSettings::new
+ )
+ );
+
+ }
+
private static void addEisNamedWriteables(List namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
index 9d85bbf751250..dff93a63d0647 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
@@ -74,6 +74,7 @@
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
@@ -237,6 +238,7 @@ public List getInferenceServiceFactories() {
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
+ context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java
new file mode 100644
index 0000000000000..218ca2ef39ed6
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
+
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
+
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the alibaba cloud search model type.
+ */
+public class AlibabaCloudSearchActionCreator implements AlibabaCloudSearchActionVisitor {
+ private final Sender sender;
+ private final ServiceComponents serviceComponents;
+
+ public AlibabaCloudSearchActionCreator(Sender sender, ServiceComponents serviceComponents) {
+ this.sender = Objects.requireNonNull(sender);
+ this.serviceComponents = Objects.requireNonNull(serviceComponents);
+ }
+
+ @Override
+ public ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map taskSettings, InputType inputType) {
+ var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, taskSettings, inputType);
+
+ return new AlibabaCloudSearchEmbeddingsAction(sender, overriddenModel, serviceComponents);
+ }
+
+ @Override
+ public ExecutableAction create(AlibabaCloudSearchSparseModel model, Map taskSettings, InputType inputType) {
+ var overriddenModel = AlibabaCloudSearchSparseModel.of(model, taskSettings, inputType);
+
+ return new AlibabaCloudSearchSparseAction(sender, overriddenModel, serviceComponents);
+ }
+
+ @Override
+ public ExecutableAction create(AlibabaCloudSearchRerankModel model, Map taskSettings) {
+ var overriddenModel = AlibabaCloudSearchRerankModel.of(model, taskSettings);
+
+ return new AlibabaCloudSearchRerankAction(sender, overriddenModel, serviceComponents);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java
new file mode 100644
index 0000000000000..69ae903c7b38f
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java
@@ -0,0 +1,24 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
+
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
+
+import java.util.Map;
+
+public interface AlibabaCloudSearchActionVisitor {
+ ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map taskSettings, InputType inputType);
+
+ ExecutableAction create(AlibabaCloudSearchSparseModel model, Map taskSettings, InputType inputType);
+
+ ExecutableAction create(AlibabaCloudSearchRerankModel model, Map taskSettings);
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchEmbeddingsAction.java
new file mode 100644
index 0000000000000..7a22bbf6b4bfd
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchEmbeddingsAction.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchEmbeddingsRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
+
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
+
+public class AlibabaCloudSearchEmbeddingsAction implements ExecutableAction {
+ private final AlibabaCloudSearchAccount account;
+ private final AlibabaCloudSearchEmbeddingsModel model;
+ private final String failedToSendRequestErrorMessage;
+ private final Sender sender;
+ private final AlibabaCloudSearchEmbeddingsRequestManager requestCreator;
+
+ public AlibabaCloudSearchEmbeddingsAction(Sender sender, AlibabaCloudSearchEmbeddingsModel model, ServiceComponents serviceComponents) {
+ this.model = Objects.requireNonNull(model);
+ this.sender = Objects.requireNonNull(sender);
+ this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
+ this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search text embeddings");
+ this.requestCreator = AlibabaCloudSearchEmbeddingsRequestManager.of(account, model, serviceComponents.threadPool());
+ }
+
+ @Override
+ public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) {
+ try {
+ ActionListener wrappedListener = wrapFailuresInElasticsearchException(
+ failedToSendRequestErrorMessage,
+ listener
+ );
+ sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
+ } catch (ElasticsearchException e) {
+ listener.onFailure(e);
+ } catch (Exception e) {
+ listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchRerankAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchRerankAction.java
new file mode 100644
index 0000000000000..88229ce63463b
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchRerankAction.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchRerankRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
+
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
+
+public class AlibabaCloudSearchRerankAction implements ExecutableAction {
+ private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchRerankAction.class);
+
+ private final AlibabaCloudSearchAccount account;
+ private final AlibabaCloudSearchRerankModel model;
+ private final String failedToSendRequestErrorMessage;
+ private final Sender sender;
+ private final AlibabaCloudSearchRerankRequestManager requestCreator;
+
+ public AlibabaCloudSearchRerankAction(Sender sender, AlibabaCloudSearchRerankModel model, ServiceComponents serviceComponents) {
+ this.model = Objects.requireNonNull(model);
+ this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
+ this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search rerank");
+ this.sender = Objects.requireNonNull(sender);
+ this.requestCreator = AlibabaCloudSearchRerankRequestManager.of(account, model, serviceComponents.threadPool());
+ }
+
+ @Override
+ public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) {
+ try {
+ ActionListener wrappedListener = wrapFailuresInElasticsearchException(
+ failedToSendRequestErrorMessage,
+ listener
+ );
+ sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
+ } catch (ElasticsearchException e) {
+ listener.onFailure(e);
+ } catch (Exception e) {
+ listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchSparseAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchSparseAction.java
new file mode 100644
index 0000000000000..2cd31ff83d200
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchSparseAction.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchSparseRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
+
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
+
+public class AlibabaCloudSearchSparseAction implements ExecutableAction {
+ private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchSparseAction.class);
+
+ private final AlibabaCloudSearchAccount account;
+ private final AlibabaCloudSearchSparseModel model;
+ private final String failedToSendRequestErrorMessage;
+ private final Sender sender;
+ private final AlibabaCloudSearchSparseRequestManager requestCreator;
+
+ public AlibabaCloudSearchSparseAction(Sender sender, AlibabaCloudSearchSparseModel model, ServiceComponents serviceComponents) {
+ this.model = Objects.requireNonNull(model);
+ this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
+ this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search sparse embeddings");
+ this.sender = Objects.requireNonNull(sender);
+ requestCreator = AlibabaCloudSearchSparseRequestManager.of(account, model, serviceComponents.threadPool());
+ }
+
+ @Override
+ public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) {
+ try {
+ ActionListener wrappedListener = wrapFailuresInElasticsearchException(
+ failedToSendRequestErrorMessage,
+ listener
+ );
+ sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
+ } catch (ElasticsearchException e) {
+ listener.onFailure(e);
+ } catch (Exception e) {
+ listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchAccount.java
new file mode 100644
index 0000000000000..6aabbe20cc355
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchAccount.java
@@ -0,0 +1,19 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.alibabacloudsearch;
+
+import org.elasticsearch.common.settings.SecureString;
+
+import java.util.Objects;
+
+public record AlibabaCloudSearchAccount(SecureString apiKey) {
+
+ public AlibabaCloudSearchAccount {
+ Objects.requireNonNull(apiKey);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchResponseHandler.java
new file mode 100644
index 0000000000000..05d51372d9cdc
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchResponseHandler.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.alibabacloudsearch;
+
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchErrorResponseEntity;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+
+import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
+
+/**
+ * Defines how to handle various errors returned from the AlibabaCloudSearch integration.
+ */
+public class AlibabaCloudSearchResponseHandler extends BaseResponseHandler {
+
+ public AlibabaCloudSearchResponseHandler(String requestType, ResponseParser parseFunction) {
+ super(requestType, parseFunction, AlibabaCloudSearchErrorResponseEntity::fromResponse);
+ }
+
+ @Override
+ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
+ throws RetryException {
+ checkForFailureStatusCode(request, result);
+ checkForEmptyBody(throttlerManager, logger, request, result);
+ }
+
+ /**
+ * Validates the status code throws an RetryException if not in the range [200, 300).
+ *
+ * @param request The http request
+ * @param result The http response and body
+ * @throws RetryException Throws if status code is {@code >= 300 or < 200 }
+ */
+ void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
+ int statusCode = result.response().getStatusLine().getStatusCode();
+ if (statusCode >= 200 && statusCode < 300) {
+ return;
+ }
+
+ // handle error codes
+ if (statusCode >= 500) {
+ throw new RetryException(false, buildError(SERVER_ERROR, request, result));
+ } else if (statusCode == 429) {
+ throw new RetryException(true, buildError(RATE_LIMIT, request, result));
+ } else if (statusCode == 401) {
+ throw new RetryException(false, buildError(AUTHENTICATION, request, result));
+ } else if (statusCode >= 300 && statusCode < 400) {
+ throw new RetryException(false, buildError(REDIRECTION, request, result));
+ } else {
+ throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java
new file mode 100644
index 0000000000000..55c699bf26e82
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.sender;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchEmbeddingsRequest;
+import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchEmbeddingsResponseEntity;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.function.Supplier;
+
+public class AlibabaCloudSearchEmbeddingsRequestManager extends AlibabaCloudSearchRequestManager {
+ private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchEmbeddingsRequestManager.class);
+
+ private static final ResponseHandler HANDLER = createEmbeddingsHandler();
+
+ private static ResponseHandler createEmbeddingsHandler() {
+ return new AlibabaCloudSearchResponseHandler(
+ "alibaba cloud search text embedding",
+ AlibabaCloudSearchEmbeddingsResponseEntity::fromResponse
+ );
+ }
+
+ public static AlibabaCloudSearchEmbeddingsRequestManager of(
+ AlibabaCloudSearchAccount account,
+ AlibabaCloudSearchEmbeddingsModel model,
+ ThreadPool threadPool
+ ) {
+ return new AlibabaCloudSearchEmbeddingsRequestManager(
+ Objects.requireNonNull(account),
+ Objects.requireNonNull(model),
+ Objects.requireNonNull(threadPool)
+ );
+ }
+
+ private final AlibabaCloudSearchEmbeddingsModel model;
+
+ private final AlibabaCloudSearchAccount account;
+
+ private AlibabaCloudSearchEmbeddingsRequestManager(
+ AlibabaCloudSearchAccount account,
+ AlibabaCloudSearchEmbeddingsModel model,
+ ThreadPool threadPool
+ ) {
+ super(threadPool, model);
+ this.account = Objects.requireNonNull(account);
+ this.model = Objects.requireNonNull(model);
+ }
+
+ @Override
+ public void execute(
+ InferenceInputs inferenceInputs,
+ RequestSender requestSender,
+ Supplier hasRequestCompletedFunction,
+ ActionListener listener
+ ) {
+ List input = DocumentsOnlyInput.of(inferenceInputs).getInputs();
+ AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, input, model);
+
+ execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRequestManager.java
new file mode 100644
index 0000000000000..c8ade15ac5057
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRequestManager.java
@@ -0,0 +1,28 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.sender;
+
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchModel;
+
+import java.util.Objects;
+
+abstract class AlibabaCloudSearchRequestManager extends BaseRequestManager {
+
+ protected AlibabaCloudSearchRequestManager(ThreadPool threadPool, AlibabaCloudSearchModel model) {
+ super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
+ }
+
+ record RateLimitGrouping(int apiKeyHash) {
+ public static RateLimitGrouping of(AlibabaCloudSearchModel model) {
+ Objects.requireNonNull(model);
+
+ return new RateLimitGrouping(model.rateLimitServiceSettings().hashCode());
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java
new file mode 100644
index 0000000000000..446db40aa5ae5
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.sender;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRerankRequest;
+import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchRerankResponseEntity;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
+
+import java.util.Objects;
+import java.util.function.Supplier;
+
+public class AlibabaCloudSearchRerankRequestManager extends AlibabaCloudSearchRequestManager {
+ private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchRerankRequestManager.class);
+ private static final ResponseHandler HANDLER = createRerankHandler();
+
+ private static ResponseHandler createRerankHandler() {
+ return new AlibabaCloudSearchResponseHandler("alibaba cloud search rerank", AlibabaCloudSearchRerankResponseEntity::fromResponse);
+ }
+
+ public static AlibabaCloudSearchRerankRequestManager of(
+ AlibabaCloudSearchAccount account,
+ AlibabaCloudSearchRerankModel model,
+ ThreadPool threadPool
+ ) {
+ return new AlibabaCloudSearchRerankRequestManager(
+ Objects.requireNonNull(account),
+ Objects.requireNonNull(model),
+ Objects.requireNonNull(threadPool)
+ );
+ }
+
+ private final AlibabaCloudSearchRerankModel model;
+
+ private final AlibabaCloudSearchAccount account;
+
+ private AlibabaCloudSearchRerankRequestManager(
+ AlibabaCloudSearchAccount account,
+ AlibabaCloudSearchRerankModel model,
+ ThreadPool threadPool
+ ) {
+ super(threadPool, model);
+ this.account = account;
+ this.model = model;
+ }
+
+ @Override
+ public void execute(
+ InferenceInputs inferenceInputs,
+ RequestSender requestSender,
+ Supplier hasRequestCompletedFunction,
+ ActionListener listener
+ ) {
+ var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
+ AlibabaCloudSearchRerankRequest request = new AlibabaCloudSearchRerankRequest(
+ account,
+ rerankInput.getQuery(),
+ rerankInput.getChunks(),
+ model
+ );
+
+ execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java
new file mode 100644
index 0000000000000..b0cc524bb4cbe
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.sender;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchSparseRequest;
+import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchSparseResponseEntity;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.function.Supplier;
+
+public class AlibabaCloudSearchSparseRequestManager extends AlibabaCloudSearchRequestManager {
+ private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchSparseRequestManager.class);
+
+ private static final ResponseHandler HANDLER = createEmbeddingsHandler();
+
+ private static ResponseHandler createEmbeddingsHandler() {
+ return new AlibabaCloudSearchResponseHandler(
+ "alibaba cloud search sparse embedding",
+ AlibabaCloudSearchSparseResponseEntity::fromResponse
+ );
+ }
+
+ public static AlibabaCloudSearchSparseRequestManager of(
+ AlibabaCloudSearchAccount account,
+ AlibabaCloudSearchSparseModel model,
+ ThreadPool threadPool
+ ) {
+ return new AlibabaCloudSearchSparseRequestManager(
+ Objects.requireNonNull(account),
+ Objects.requireNonNull(model),
+ Objects.requireNonNull(threadPool)
+ );
+ }
+
+ private final AlibabaCloudSearchSparseModel model;
+
+ private final AlibabaCloudSearchAccount account;
+
+ private AlibabaCloudSearchSparseRequestManager(
+ AlibabaCloudSearchAccount account,
+ AlibabaCloudSearchSparseModel model,
+ ThreadPool threadPool
+ ) {
+ super(threadPool, model);
+ this.account = Objects.requireNonNull(account);
+ this.model = Objects.requireNonNull(model);
+ }
+
+ @Override
+ public void execute(
+ InferenceInputs inferenceInputs,
+ RequestSender requestSender,
+ Supplier hasRequestCompletedFunction,
+ ActionListener listener
+ ) {
+ List input = DocumentsOnlyInput.of(inferenceInputs).getInputs();
+ AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, input, model);
+
+ execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequest.java
new file mode 100644
index 0000000000000..081854903405e
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequest.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.client.utils.URIBuilder;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
+
+public class AlibabaCloudSearchEmbeddingsRequest extends AlibabaCloudSearchRequest {
+
+ private final AlibabaCloudSearchAccount account;
+ private final List input;
+ private final URI uri;
+ private final AlibabaCloudSearchEmbeddingsTaskSettings taskSettings;
+ private final String model;
+ private final String host;
+ private final String workspaceName;
+ private final String httpSchema;
+ private final String inferenceEntityId;
+
+ public AlibabaCloudSearchEmbeddingsRequest(
+ AlibabaCloudSearchAccount account,
+ List input,
+ AlibabaCloudSearchEmbeddingsModel embeddingsModel
+ ) {
+ Objects.requireNonNull(embeddingsModel);
+
+ this.account = Objects.requireNonNull(account);
+ this.input = Objects.requireNonNull(input);
+ taskSettings = embeddingsModel.getTaskSettings();
+ model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
+ host = embeddingsModel.getServiceSettings().getCommonSettings().getHost();
+ workspaceName = embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName();
+ httpSchema = embeddingsModel.getServiceSettings().getCommonSettings().getHttpSchema() != null
+ ? embeddingsModel.getServiceSettings().getCommonSettings().getHttpSchema()
+ : "https";
+ uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri);
+ inferenceEntityId = embeddingsModel.getInferenceEntityId();
+ }
+
+ @Override
+ public HttpRequest createHttpRequest() {
+ HttpPost httpPost = new HttpPost(uri);
+
+ ByteArrayEntity byteEntity = new ByteArrayEntity(
+ Strings.toString(new AlibabaCloudSearchEmbeddingsRequestEntity(input, taskSettings)).getBytes(StandardCharsets.UTF_8)
+ );
+ httpPost.setEntity(byteEntity);
+
+ httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
+ httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
+
+ return new HttpRequest(httpPost, getInferenceEntityId());
+ }
+
+ @Override
+ public String getInferenceEntityId() {
+ return inferenceEntityId;
+ }
+
+ @Override
+ public URI getURI() {
+ return uri;
+ }
+
+ @Override
+ public Request truncate() {
+ return this;
+ }
+
+ @Override
+ public boolean[] getTruncationInfo() {
+ return null;
+ }
+
+ URI buildDefaultUri() throws URISyntaxException {
+ return new URIBuilder().setScheme(httpSchema)
+ .setHost(host)
+ .setPathSegments(
+ AlibabaCloudSearchUtils.VERSION_3,
+ AlibabaCloudSearchUtils.OPENAPI_PATH,
+ AlibabaCloudSearchUtils.WORKSPACE_PATH,
+ workspaceName,
+ AlibabaCloudSearchUtils.TEXT_EMBEDDING_PATH,
+ model
+ )
+ .build();
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntity.java
new file mode 100644
index 0000000000000..c2367aeff3070
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntity.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings.invalidInputTypeMessage;
+
+public record AlibabaCloudSearchEmbeddingsRequestEntity(List input, AlibabaCloudSearchEmbeddingsTaskSettings taskSettings)
+ implements
+ ToXContentObject {
+
+ private static final String SEARCH_DOCUMENT = "document";
+ private static final String SEARCH_QUERY = "query";
+
+ private static final String TEXTS_FIELD = "input";
+
+ static final String INPUT_TYPE_FIELD = "input_type";
+
+ public AlibabaCloudSearchEmbeddingsRequestEntity {
+ Objects.requireNonNull(input);
+ Objects.requireNonNull(taskSettings);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field(TEXTS_FIELD, input);
+
+ String inputType = covertToString(taskSettings.getInputType());
+ if (inputType != null) {
+ builder.field(INPUT_TYPE_FIELD, inputType);
+ }
+
+ builder.endObject();
+ return builder;
+ }
+
+ // default for testing
+ static String covertToString(InputType inputType) {
+ if (inputType == null) {
+ return null;
+ }
+
+ return switch (inputType) {
+ case INGEST -> SEARCH_DOCUMENT;
+ case SEARCH -> SEARCH_QUERY;
+ default -> {
+ assert false : invalidInputTypeMessage(inputType);
+ yield null;
+ }
+ };
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRequest.java
new file mode 100644
index 0000000000000..75fc12e1bad31
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRequest.java
@@ -0,0 +1,22 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+public abstract class AlibabaCloudSearchRequest implements Request {
+ private final long startTime;
+
+ public AlibabaCloudSearchRequest() {
+ this.startTime = System.currentTimeMillis();
+ }
+
+ public long getStartTime() {
+ return startTime;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java
new file mode 100644
index 0000000000000..878bcc6e6a0db
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.client.utils.URIBuilder;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
+
+public class AlibabaCloudSearchRerankRequest implements Request {
+ private final AlibabaCloudSearchAccount account;
+ private final String query;
+ private final List input;
+ private final URI uri;
+ private final AlibabaCloudSearchRerankTaskSettings taskSettings;
+ private final String model;
+ private final String host;
+ private final String workspaceName;
+ private final String httpSchema;
+ private final String inferenceEntityId;
+
+ public AlibabaCloudSearchRerankRequest(
+ AlibabaCloudSearchAccount account,
+ String query,
+ List input,
+ AlibabaCloudSearchRerankModel rerankModel
+ ) {
+ Objects.requireNonNull(rerankModel);
+
+ this.account = Objects.requireNonNull(account);
+ this.query = Objects.requireNonNull(query);
+ this.input = Objects.requireNonNull(input);
+ taskSettings = rerankModel.getTaskSettings();
+ model = rerankModel.getServiceSettings().getCommonSettings().modelId();
+ host = rerankModel.getServiceSettings().getCommonSettings().getHost();
+ workspaceName = rerankModel.getServiceSettings().getCommonSettings().getWorkspaceName();
+ httpSchema = rerankModel.getServiceSettings().getCommonSettings().getHttpSchema() != null
+ ? rerankModel.getServiceSettings().getCommonSettings().getHttpSchema()
+ : "https";
+ uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri);
+ inferenceEntityId = rerankModel.getInferenceEntityId();
+ }
+
+ @Override
+ public HttpRequest createHttpRequest() {
+ HttpPost httpPost = new HttpPost(uri);
+
+ ByteArrayEntity byteEntity = new ByteArrayEntity(
+ Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, taskSettings)).getBytes(StandardCharsets.UTF_8)
+ );
+ httpPost.setEntity(byteEntity);
+
+ httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
+ httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
+
+ return new HttpRequest(httpPost, getInferenceEntityId());
+ }
+
+ @Override
+ public String getInferenceEntityId() {
+ return inferenceEntityId;
+ }
+
+ @Override
+ public URI getURI() {
+ return uri;
+ }
+
+ @Override
+ public Request truncate() {
+ return this;
+ }
+
+ @Override
+ public boolean[] getTruncationInfo() {
+ return null;
+ }
+
+ URI buildDefaultUri() throws URISyntaxException {
+ return new URIBuilder().setScheme(httpSchema)
+ .setHost(host)
+ .setPathSegments(
+ AlibabaCloudSearchUtils.VERSION_3,
+ AlibabaCloudSearchUtils.OPENAPI_PATH,
+ AlibabaCloudSearchUtils.WORKSPACE_PATH,
+ workspaceName,
+ AlibabaCloudSearchUtils.RERANK_PATH,
+ model
+ )
+ .build();
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java
new file mode 100644
index 0000000000000..054e373e3e525
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+public record AlibabaCloudSearchRerankRequestEntity(String query, List input, AlibabaCloudSearchRerankTaskSettings taskSettings)
+ implements
+ ToXContentObject {
+
+ private static final String SEARCH_QUERY = "query";
+ private static final String TEXTS_FIELD = "docs";
+
+ public AlibabaCloudSearchRerankRequestEntity {
+ Objects.requireNonNull(query);
+ Objects.requireNonNull(input);
+ Objects.requireNonNull(taskSettings);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ {
+ builder.field(SEARCH_QUERY, query);
+ builder.field(TEXTS_FIELD, input);
+ }
+ builder.endObject();
+ return builder;
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequest.java
new file mode 100644
index 0000000000000..c7b4c314b07a7
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequest.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.client.utils.URIBuilder;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
+
+public class AlibabaCloudSearchSparseRequest extends AlibabaCloudSearchRequest {
+
+ private final AlibabaCloudSearchAccount account;
+ private final List input;
+ private final URI uri;
+ private final AlibabaCloudSearchSparseTaskSettings taskSettings;
+ private final String model;
+ private final String host;
+ private final String workspaceName;
+ private final String httpSchema;
+ private final String inferenceEntityId;
+
+ public AlibabaCloudSearchSparseRequest(
+ AlibabaCloudSearchAccount account,
+ List input,
+ AlibabaCloudSearchSparseModel sparseEmbeddingsModel
+ ) {
+ Objects.requireNonNull(sparseEmbeddingsModel);
+
+ this.account = Objects.requireNonNull(account);
+ this.input = Objects.requireNonNull(input);
+ taskSettings = sparseEmbeddingsModel.getTaskSettings();
+ model = sparseEmbeddingsModel.getServiceSettings().getCommonSettings().modelId();
+ host = sparseEmbeddingsModel.getServiceSettings().getCommonSettings().getHost();
+ workspaceName = sparseEmbeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName();
+ httpSchema = sparseEmbeddingsModel.getServiceSettings().getCommonSettings().getHttpSchema() != null
+ ? sparseEmbeddingsModel.getServiceSettings().getCommonSettings().getHttpSchema()
+ : "https";
+ uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri);
+ inferenceEntityId = sparseEmbeddingsModel.getInferenceEntityId();
+ }
+
+ @Override
+ public HttpRequest createHttpRequest() {
+ HttpPost httpPost = new HttpPost(uri);
+
+ ByteArrayEntity byteEntity = new ByteArrayEntity(
+ Strings.toString(new AlibabaCloudSearchSparseRequestEntity(input, taskSettings)).getBytes(StandardCharsets.UTF_8)
+ );
+ httpPost.setEntity(byteEntity);
+
+ httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
+ httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
+
+ return new HttpRequest(httpPost, getInferenceEntityId());
+ }
+
+ @Override
+ public String getInferenceEntityId() {
+ return inferenceEntityId;
+ }
+
+ @Override
+ public URI getURI() {
+ return uri;
+ }
+
+ @Override
+ public Request truncate() {
+ return this;
+ }
+
+ @Override
+ public boolean[] getTruncationInfo() {
+ return null;
+ }
+
+ URI buildDefaultUri() throws URISyntaxException {
+ return new URIBuilder().setScheme(httpSchema)
+ .setHost(host)
+ .setPathSegments(
+ AlibabaCloudSearchUtils.VERSION_3,
+ AlibabaCloudSearchUtils.OPENAPI_PATH,
+ AlibabaCloudSearchUtils.WORKSPACE_PATH,
+ workspaceName,
+ AlibabaCloudSearchUtils.SPARSE_EMBEDDING_PATH,
+ model
+ )
+ .build();
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntity.java
new file mode 100644
index 0000000000000..3aec226bfc277
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntity.java
@@ -0,0 +1,47 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+public record AlibabaCloudSearchSparseRequestEntity(List input, AlibabaCloudSearchSparseTaskSettings taskSettings)
+ implements
+ ToXContentObject {
+
+ private static final String TEXTS_FIELD = "input";
+
+ static final String INPUT_TYPE_FIELD = "input_type";
+
+ static final String RETURN_TOKEN_FIELD = "return_token";
+
+ public AlibabaCloudSearchSparseRequestEntity {
+ Objects.requireNonNull(input);
+ Objects.requireNonNull(taskSettings);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field(TEXTS_FIELD, input);
+ String inputType = AlibabaCloudSearchEmbeddingsRequestEntity.covertToString(taskSettings.getInputType());
+ if (inputType != null) {
+ builder.field(INPUT_TYPE_FIELD, inputType);
+ }
+ if (taskSettings.isReturnToken() != null) {
+ builder.field(RETURN_TOKEN_FIELD, taskSettings.isReturnToken());
+ }
+ builder.endObject();
+ return builder;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java
new file mode 100644
index 0000000000000..7d671471976f5
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+public class AlibabaCloudSearchUtils {
+ public static final String SERVICE_NAME = "alibabacloud-ai-search";
+ public static final String VERSION_3 = "v3";
+ public static final String OPENAPI_PATH = "openapi";
+ public static final String WORKSPACE_PATH = "workspaces";
+ public static final String TEXT_EMBEDDING_PATH = "text-embedding";
+ public static final String SPARSE_EMBEDDING_PATH = "text-sparse-embedding";
+ public static final String RERANK_PATH = "ranker";
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntity.java
new file mode 100644
index 0000000000000..33fa645b107bc
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntity.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch;
+
+import org.elasticsearch.common.xcontent.XContentParserUtils;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
+
+public class AlibabaCloudSearchEmbeddingsResponseEntity extends AlibabaCloudSearchResponseEntity {
+ private static final String FAILED_TO_FIND_FIELD_TEMPLATE =
+ "Failed to find required field [%s] in AlibabaCloud Search embeddings response";
+
+ /**
+ * Parses the AlibabaCloud Search embedding json response.
+ * For a request like:
+ *
+ *
+ *
+ * {
+ * "texts": ["hello this is my name", "I wish I was there!"]
+ * }
+ *
+ *
+ */
+public class AlibabaCloudSearchEmbeddingsTaskSettings implements TaskSettings {
+
+ public static final String NAME = "alibabacloud_search_embeddings_task_settings";
+ public static final AlibabaCloudSearchEmbeddingsTaskSettings EMPTY_SETTINGS = new AlibabaCloudSearchEmbeddingsTaskSettings(
+ (InputType) null
+ );
+ static final String INPUT_TYPE = "input_type";
+ static final EnumSet VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH);
+
+ public static AlibabaCloudSearchEmbeddingsTaskSettings fromMap(Map map) {
+ if (map == null || map.isEmpty()) {
+ return EMPTY_SETTINGS;
+ }
+
+ ValidationException validationException = new ValidationException();
+
+ InputType inputType = extractOptionalEnum(
+ map,
+ INPUT_TYPE,
+ ModelConfigurations.TASK_SETTINGS,
+ InputType::fromString,
+ VALID_REQUEST_VALUES,
+ validationException
+ );
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AlibabaCloudSearchEmbeddingsTaskSettings(inputType);
+ }
+
+ /**
+ * Creates a new {@link AlibabaCloudSearchEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters.
+ * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED.
+ * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null.
+ *
+ * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to
+ * originalSettings.
+ *
+ * @param originalSettings the settings stored as part of the inference entity configuration
+ * @param requestTaskSettings the settings passed in within the task_settings field of the request
+ * @param requestInputType the input type passed in the request parameters
+ * @return a constructed {@link AlibabaCloudSearchEmbeddingsTaskSettings}
+ */
+ public static AlibabaCloudSearchEmbeddingsTaskSettings of(
+ AlibabaCloudSearchEmbeddingsTaskSettings originalSettings,
+ AlibabaCloudSearchEmbeddingsTaskSettings requestTaskSettings,
+ InputType requestInputType
+ ) {
+ var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType);
+
+ return new AlibabaCloudSearchEmbeddingsTaskSettings(inputTypeToUse);
+ }
+
+ private static InputType getValidInputType(
+ AlibabaCloudSearchEmbeddingsTaskSettings originalSettings,
+ AlibabaCloudSearchEmbeddingsTaskSettings requestTaskSettings,
+ InputType requestInputType
+ ) {
+ InputType inputTypeToUse = originalSettings.inputType;
+
+ if (VALID_REQUEST_VALUES.contains(requestInputType)) {
+ inputTypeToUse = requestInputType;
+ } else if (requestTaskSettings.inputType != null) {
+ inputTypeToUse = requestTaskSettings.inputType;
+ }
+
+ return inputTypeToUse;
+ }
+
+ private final InputType inputType;
+
+ public AlibabaCloudSearchEmbeddingsTaskSettings(StreamInput in) throws IOException {
+ this(in.readOptionalEnum(InputType.class));
+ }
+
+ public AlibabaCloudSearchEmbeddingsTaskSettings(@Nullable InputType inputType) {
+ validateInputType(inputType);
+ this.inputType = inputType;
+ }
+
+ private static void validateInputType(InputType inputType) {
+ if (inputType == null) {
+ return;
+ }
+
+ assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ if (inputType != null) {
+ builder.field(INPUT_TYPE, inputType);
+ }
+
+ builder.endObject();
+ return builder;
+ }
+
+ public InputType getInputType() {
+ return inputType;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalEnum(inputType);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AlibabaCloudSearchEmbeddingsTaskSettings that = (AlibabaCloudSearchEmbeddingsTaskSettings) o;
+ return Objects.equals(inputType, that.inputType);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(inputType);
+ }
+
+ public static String invalidInputTypeMessage(InputType inputType) {
+ return Strings.format("received invalid input type value [%s]", inputType.toString());
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankModel.java
new file mode 100644
index 0000000000000..a9152b6edd4c5
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankModel.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchModel;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.util.Map;
+
+public class AlibabaCloudSearchRerankModel extends AlibabaCloudSearchModel {
+ public static AlibabaCloudSearchRerankModel of(AlibabaCloudSearchRerankModel model, Map taskSettings) {
+ var requestTaskSettings = AlibabaCloudSearchRerankTaskSettings.fromMap(taskSettings);
+ return new AlibabaCloudSearchRerankModel(
+ model,
+ AlibabaCloudSearchRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)
+ );
+ }
+
+ public AlibabaCloudSearchRerankModel(
+ String modelId,
+ TaskType taskType,
+ String service,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets,
+ ConfigurationParseContext context
+ ) {
+ this(
+ modelId,
+ taskType,
+ service,
+ AlibabaCloudSearchRerankServiceSettings.fromMap(serviceSettings, context),
+ AlibabaCloudSearchRerankTaskSettings.fromMap(taskSettings),
+ DefaultSecretSettings.fromMap(secrets)
+ );
+ }
+
+ // should only be used for testing
+ AlibabaCloudSearchRerankModel(
+ String modelId,
+ TaskType taskType,
+ String service,
+ AlibabaCloudSearchRerankServiceSettings serviceSettings,
+ AlibabaCloudSearchRerankTaskSettings taskSettings,
+ @Nullable DefaultSecretSettings secretSettings
+ ) {
+ super(
+ new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
+ new ModelSecrets(secretSettings),
+ serviceSettings.getCommonSettings()
+ );
+ }
+
+ private AlibabaCloudSearchRerankModel(AlibabaCloudSearchRerankModel model, AlibabaCloudSearchRerankTaskSettings taskSettings) {
+ super(model, taskSettings);
+ }
+
+ public AlibabaCloudSearchRerankModel(AlibabaCloudSearchRerankModel model, AlibabaCloudSearchRerankServiceSettings serviceSettings) {
+ super(model, serviceSettings);
+ }
+
+ @Override
+ public AlibabaCloudSearchRerankServiceSettings getServiceSettings() {
+ return (AlibabaCloudSearchRerankServiceSettings) super.getServiceSettings();
+ }
+
+ @Override
+ public AlibabaCloudSearchRerankTaskSettings getTaskSettings() {
+ return (AlibabaCloudSearchRerankTaskSettings) super.getTaskSettings();
+ }
+
+ @Override
+ public DefaultSecretSettings getSecretSettings() {
+ return (DefaultSecretSettings) super.getSecretSettings();
+ }
+
+ @Override
+ public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) {
+ return visitor.create(this, taskSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java
new file mode 100644
index 0000000000000..42c7238aefa7f
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java
@@ -0,0 +1,97 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+public class AlibabaCloudSearchRerankServiceSettings implements ServiceSettings {
+ public static final String NAME = "alibabacloud_search_rerank_service_settings";
+
+ public static AlibabaCloudSearchRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) {
+ ValidationException validationException = new ValidationException();
+ var commonServiceSettings = AlibabaCloudSearchServiceSettings.fromMap(map, context);
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AlibabaCloudSearchRerankServiceSettings(commonServiceSettings);
+ }
+
+ private final AlibabaCloudSearchServiceSettings commonSettings;
+
+ public AlibabaCloudSearchRerankServiceSettings(AlibabaCloudSearchServiceSettings commonSettings) {
+ this.commonSettings = commonSettings;
+ }
+
+ public AlibabaCloudSearchRerankServiceSettings(StreamInput in) throws IOException {
+ commonSettings = new AlibabaCloudSearchServiceSettings(in);
+ }
+
+ public AlibabaCloudSearchServiceSettings getCommonSettings() {
+ return commonSettings;
+ }
+
+ @Override
+ public String modelId() {
+ return commonSettings.modelId();
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ commonSettings.toXContentFragment(builder, params);
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public ToXContentObject getFilteredXContentObject() {
+ return this;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ commonSettings.writeTo(out);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AlibabaCloudSearchRerankServiceSettings that = (AlibabaCloudSearchRerankServiceSettings) o;
+ return Objects.equals(commonSettings, that.commonSettings);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(commonSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java
new file mode 100644
index 0000000000000..e9fb468eab7fb
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Defines the task settings for the AlibabaCloudSearch rerank service.
+ *
+ *
+ */
+public class AlibabaCloudSearchSparseTaskSettings implements TaskSettings {
+
+ public static final String NAME = "alibabacloud_search_sparse_embeddings_task_settings";
+ public static final AlibabaCloudSearchSparseTaskSettings EMPTY_SETTINGS = new AlibabaCloudSearchSparseTaskSettings(null, null);
+ static final String INPUT_TYPE = "input_type";
+ static final String RETURN_TOKEN = "return_token";
+ static final EnumSet VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH);
+
+ public static AlibabaCloudSearchSparseTaskSettings fromMap(Map map) {
+ if (map == null || map.isEmpty()) {
+ return EMPTY_SETTINGS;
+ }
+
+ ValidationException validationException = new ValidationException();
+
+ InputType inputType = extractOptionalEnum(
+ map,
+ INPUT_TYPE,
+ ModelConfigurations.TASK_SETTINGS,
+ InputType::fromString,
+ VALID_REQUEST_VALUES,
+ validationException
+ );
+
+ Boolean returnToken = extractOptionalBoolean(map, RETURN_TOKEN, validationException);
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AlibabaCloudSearchSparseTaskSettings(inputType, returnToken);
+ }
+
+ /**
+ * Creates a new {@link AlibabaCloudSearchSparseTaskSettings} by preferring non-null fields from the provided parameters.
+ * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED.
+ * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null.
+ *
+ * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to
+ * originalSettings.
+ *
+ * @param originalSettings the settings stored as part of the inference entity configuration
+ * @param requestTaskSettings the settings passed in within the task_settings field of the request
+ * @param requestInputType the input type passed in the request parameters
+ * @return a constructed {@link AlibabaCloudSearchSparseTaskSettings}
+ */
+ public static AlibabaCloudSearchSparseTaskSettings of(
+ AlibabaCloudSearchSparseTaskSettings originalSettings,
+ AlibabaCloudSearchSparseTaskSettings requestTaskSettings,
+ InputType requestInputType
+ ) {
+ var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType);
+ var returnToken = requestTaskSettings.isReturnToken() != null
+ ? requestTaskSettings.isReturnToken()
+ : originalSettings.isReturnToken();
+ return new AlibabaCloudSearchSparseTaskSettings(inputTypeToUse, returnToken);
+ }
+
+ private static InputType getValidInputType(
+ AlibabaCloudSearchSparseTaskSettings originalSettings,
+ AlibabaCloudSearchSparseTaskSettings requestTaskSettings,
+ InputType requestInputType
+ ) {
+ InputType inputTypeToUse = originalSettings.inputType;
+
+ if (VALID_REQUEST_VALUES.contains(requestInputType)) {
+ inputTypeToUse = requestInputType;
+ } else if (requestTaskSettings.inputType != null) {
+ inputTypeToUse = requestTaskSettings.inputType;
+ }
+
+ return inputTypeToUse;
+ }
+
+ private final InputType inputType;
+ private final Boolean returnToken;
+
+ public AlibabaCloudSearchSparseTaskSettings(StreamInput in) throws IOException {
+ this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean());
+ }
+
+ public AlibabaCloudSearchSparseTaskSettings(@Nullable InputType inputType, Boolean returnToken) {
+ validateInputType(inputType);
+ this.inputType = inputType;
+ this.returnToken = returnToken;
+ }
+
+ private static void validateInputType(InputType inputType) {
+ if (inputType == null) {
+ return;
+ }
+
+ assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ if (inputType != null) {
+ builder.field(INPUT_TYPE, inputType);
+ }
+ if (returnToken != null) {
+ builder.field(RETURN_TOKEN, returnToken);
+ }
+ builder.endObject();
+ return builder;
+ }
+
+ public InputType getInputType() {
+ return inputType;
+ }
+
+ public Boolean isReturnToken() {
+ return returnToken;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalEnum(inputType);
+ out.writeOptionalBoolean(returnToken);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AlibabaCloudSearchSparseTaskSettings that = (AlibabaCloudSearchSparseTaskSettings) o;
+ return Objects.equals(inputType, that.inputType) && Objects.equals(returnToken, that.returnToken);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(inputType, returnToken);
+ }
+
+ public static String invalidInputTypeMessage(InputType inputType) {
+ return Strings.format("received invalid input type value [%s]", inputType.toString());
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java
index d275d00373cbe..055b4581e067b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java
@@ -14,4 +14,8 @@ public class InputTypeTests extends ESTestCase {
public static InputType randomWithoutUnspecified() {
return randomFrom(InputType.INGEST, InputType.SEARCH, InputType.CLUSTERING, InputType.CLASSIFICATION);
}
+
+ public static InputType randomWithIngestAndSearch() {
+ return randomFrom(InputType.INGEST, InputType.SEARCH);
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntityTests.java
new file mode 100644
index 0000000000000..6aaab219c331d
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntityTests.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class AlibabaCloudSearchEmbeddingsRequestEntityTests extends ESTestCase {
+ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
+ var entity = new AlibabaCloudSearchEmbeddingsRequestEntity(
+ List.of("abc"),
+ new AlibabaCloudSearchEmbeddingsTaskSettings(InputType.INGEST)
+ );
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ MatcherAssert.assertThat(xContentResult, is("""
+ {"input":["abc"],"input_type":"document"}"""));
+ }
+
+ public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
+ var entity = new AlibabaCloudSearchEmbeddingsRequestEntity(List.of("abc"), AlibabaCloudSearchEmbeddingsTaskSettings.EMPTY_SETTINGS);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ MatcherAssert.assertThat(xContentResult, is("""
+ {"input":["abc"]}"""));
+ }
+
+ public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {
+ var thrownException = expectThrows(
+ AssertionError.class,
+ () -> AlibabaCloudSearchEmbeddingsRequestEntity.covertToString(InputType.UNSPECIFIED)
+ );
+ MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]"));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestTests.java
new file mode 100644
index 0000000000000..378401f589b19
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestTests.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchEmbeddingsRequestTests extends ESTestCase {
+ public void testCreateRequest() throws IOException {
+ var request = createRequest(
+ List.of("abc"),
+ AlibabaCloudSearchEmbeddingsModelTests.createModel(
+ "embedding_test",
+ TaskType.TEXT_EMBEDDING,
+ AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("embeddings_test", "host", "default"),
+ AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
+ getSecretSettingsMap("secret")
+ )
+ );
+
+ var httpRequest = request.createHttpRequest();
+ assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest.httpRequestBase();
+ MatcherAssert.assertThat(
+ httpPost.getURI().toString(),
+ is("https://host/v3/openapi/workspaces/default/text-embedding/embeddings_test")
+ );
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"))));
+ }
+
+ public static AlibabaCloudSearchEmbeddingsRequest createRequest(List input, AlibabaCloudSearchEmbeddingsModel model) {
+ var account = new AlibabaCloudSearchAccount(model.getSecretSettings().apiKey());
+ return new AlibabaCloudSearchEmbeddingsRequest(account, input, model);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java
new file mode 100644
index 0000000000000..8f981d64d36eb
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class AlibabaCloudSearchRerankRequestEntityTests extends ESTestCase {
+ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
+ var entity = new AlibabaCloudSearchRerankRequestEntity("query", List.of("abc"), new AlibabaCloudSearchRerankTaskSettings());
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ MatcherAssert.assertThat(xContentResult, is("""
+ {"query":"query","docs":["abc"]}"""));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntityTests.java
new file mode 100644
index 0000000000000..6ae209bc3c6f1
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntityTests.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class AlibabaCloudSearchSparseRequestEntityTests extends ESTestCase {
+ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
+ var entity = new AlibabaCloudSearchSparseRequestEntity(
+ List.of("abc"),
+ new AlibabaCloudSearchSparseTaskSettings(InputType.INGEST, true)
+ );
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ MatcherAssert.assertThat(xContentResult, is("""
+ {"input":["abc"],"input_type":"document","return_token":true}"""));
+ }
+
+ public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
+ var entity = new AlibabaCloudSearchSparseRequestEntity(List.of("abc"), AlibabaCloudSearchSparseTaskSettings.EMPTY_SETTINGS);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ MatcherAssert.assertThat(xContentResult, is("""
+ {"input":["abc"]}"""));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestTests.java
new file mode 100644
index 0000000000000..74fc225820641
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestTests.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModelTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettingsTests;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchSparseRequestTests extends ESTestCase {
+ public void testCreateRequest() throws IOException {
+ var request = createRequest(
+ List.of("abc"),
+ AlibabaCloudSearchSparseModelTests.createModel(
+ "embedding_test",
+ TaskType.TEXT_EMBEDDING,
+ AlibabaCloudSearchSparseServiceSettingsTests.getServiceSettingsMap("embeddings_test", "host", "default"),
+ AlibabaCloudSearchSparseTaskSettingsTests.getTaskSettingsMap(null, null),
+ getSecretSettingsMap("secret")
+ )
+ );
+
+ var httpRequest = request.createHttpRequest();
+ assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest.httpRequestBase();
+ MatcherAssert.assertThat(
+ httpPost.getURI().toString(),
+ is("https://host/v3/openapi/workspaces/default/text-sparse-embedding/embeddings_test")
+ );
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"))));
+ }
+
+ public static AlibabaCloudSearchSparseRequest createRequest(List input, AlibabaCloudSearchSparseModel model) {
+ var account = new AlibabaCloudSearchAccount(model.getSecretSettings().apiKey());
+ return new AlibabaCloudSearchSparseRequest(account, input, model);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
new file mode 100644
index 0000000000000..33fa6a2a542cb
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
@@ -0,0 +1,69 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AlibabaCloudSearchEmbeddingsResponseEntityTests extends ESTestCase {
+ public void testFromResponse_CreatesResultsForASingleItem() throws IOException, URISyntaxException {
+ String responseJson = """
+ {
+ "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4",
+ "latency": 38,
+ "usage": {
+ "token_count": 3072
+ },
+ "result": {
+ "embeddings": [
+ {
+ "index": 0,
+ "embedding": [
+ -0.02868066355586052,
+ 0.022033605724573135
+ ]
+ }
+ ]
+ }
+ }
+ """;
+
+ AlibabaCloudSearchRequest request = mock(AlibabaCloudSearchRequest.class);
+ URI uri = new URI("mock_uri");
+ when(request.getURI()).thenReturn(uri);
+
+ InferenceTextEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse(
+ request,
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(
+ parsedResults.embeddings(),
+ is(
+ List.of(
+ new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(
+ new float[] { -0.02868066355586052f, 0.022033605724573135f }
+ )
+ )
+ )
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntityTests.java
new file mode 100644
index 0000000000000..a03349c66b6d5
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntityTests.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.mockito.Mockito.mock;
+
+public class AlibabaCloudSearchErrorResponseEntityTests extends ESTestCase {
+ public void testFromResponse() {
+ String responseJson = """
+ {
+ "request_id": "651B3087-8A07-4BF3-B931-9C4E7B60F52D",
+ "latency": 0,
+ "code": "InvalidParameter",
+ "message": "JSON parse error: Cannot deserialize value of type `InputType` from String \\"xxx\\""
+ }
+ """;
+
+ AlibabaCloudSearchErrorResponseEntity errorMessage = AlibabaCloudSearchErrorResponseEntity.fromResponse(
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+ assertNotNull(errorMessage);
+ assertEquals("JSON parse error: Cannot deserialize value of type `InputType` from String \"xxx\"", errorMessage.getErrorMessage());
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntityTests.java
new file mode 100644
index 0000000000000..bebc8bb66f207
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntityTests.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class AlibabaCloudSearchRerankResponseEntityTests extends ESTestCase {
+
+ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
+ InferenceServiceResults parsedResults = AlibabaCloudSearchRerankResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8))
+ );
+
+ MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
+ List expected = responseLiteralDocs();
+ for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
+ assertThat(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), is(expected.get(i).index()));
+ }
+ }
+
+ private final String responseLiteral = """
+ {
+ "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+ "latency": 564.903929,
+ "usage": {
+ "doc_count": 2
+ },
+ "result": {
+ "scores":[
+ {
+ "index":1,
+ "score": 1.37
+ },
+ {
+ "index":0,
+ "score": -0.3
+ }
+ ]
+ }
+ }
+ """;
+
+ private ArrayList responseLiteralDocs() {
+ var list = new ArrayList();
+
+ list.add(new RankedDocsResults.RankedDoc(1, 1.37F, null));
+ list.add(new RankedDocsResults.RankedDoc(0, -0.3F, null));
+ return list;
+ };
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntityTests.java
new file mode 100644
index 0000000000000..a6d3a4b77d74f
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntityTests.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.search.WeightedToken;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AlibabaCloudSearchSparseResponseEntityTests extends ESTestCase {
+ public void testFromResponse_CreatesResultsForASingleItem() throws IOException, URISyntaxException {
+ String responseJson = """
+ {
+ "request_id": "DDC4306F-xxxx-xxxx-xxxx-92C5CEA756A0",
+ "latency": 25,
+ "usage": {
+ "token_count": 11
+ },
+ "result": {
+ "sparse_embeddings": [
+ {
+ "index": 0,
+ "embedding": [
+ {
+ "token_id": 6,
+ "weight": 0.1014404296875
+ },
+ {
+ "token_id": 163040,
+ "weight": 0.2841796875
+ },
+ {
+ "token_id": 354,
+ "weight": 0.1431884765625
+ }
+ ]
+ }
+ ]
+ }
+ }
+ """;
+
+ AlibabaCloudSearchRequest request = mock(AlibabaCloudSearchRequest.class);
+ URI uri = new URI("mock_uri");
+ when(request.getURI()).thenReturn(uri);
+
+ SparseEmbeddingResults parsedResults = AlibabaCloudSearchSparseResponseEntity.fromResponse(
+ request,
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(
+ parsedResults.embeddings(),
+ is(
+ List.of(
+ new SparseEmbeddingResults.Embedding(
+ List.of(
+ new WeightedToken("6", 0.1014404296875f),
+ new WeightedToken("163040", 0.2841796875f),
+ new WeightedToken("354", 0.1431884765625f)
+ ),
+ false
+ )
+ )
+ )
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettingsTests.java
new file mode 100644
index 0000000000000..d7965a38c845b
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettingsTests.java
@@ -0,0 +1,125 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchServiceSettingsTests extends AbstractWireSerializingTestCase {
+ /**
+ * The created settings can have a url set to null.
+ */
+ public static AlibabaCloudSearchServiceSettings createRandom() {
+ var model = randomAlphaOfLength(15);
+ String host = randomAlphaOfLength(15);
+ String workspaceName = randomAlphaOfLength(10);
+ String httpSchema = "https";
+ return new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, RateLimitSettingsTests.createRandom());
+ }
+
+ public void testFromMap() throws URISyntaxException {
+ var model = "model";
+ var host = "host";
+ var workspaceName = "default";
+ var httpSchema = "https";
+ var serviceSettings = AlibabaCloudSearchServiceSettings.fromMap(
+ new HashMap<>(
+ Map.of(
+ AlibabaCloudSearchServiceSettings.SERVICE_ID,
+ model,
+ AlibabaCloudSearchServiceSettings.HOST,
+ host,
+ AlibabaCloudSearchServiceSettings.WORKSPACE_NAME,
+ workspaceName,
+ AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME,
+ httpSchema
+ )
+ ),
+ null
+ );
+
+ MatcherAssert.assertThat(serviceSettings, is(new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, null)));
+ }
+
+ public void testFromMap_WithRateLimit() {
+ var model = "model";
+ var host = "host";
+ var workspaceName = "default";
+ var httpSchema = "https";
+ var serviceSettings = AlibabaCloudSearchServiceSettings.fromMap(
+ new HashMap<>(
+ Map.of(
+ AlibabaCloudSearchServiceSettings.SERVICE_ID,
+ model,
+ AlibabaCloudSearchServiceSettings.HOST,
+ host,
+ AlibabaCloudSearchServiceSettings.WORKSPACE_NAME,
+ workspaceName,
+ AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME,
+ httpSchema,
+ RateLimitSettings.FIELD_NAME,
+ new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
+ )
+ ),
+ null
+ );
+
+ MatcherAssert.assertThat(
+ serviceSettings,
+ is(new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, new RateLimitSettings(3)))
+ );
+ }
+
+ public void testXContent() throws IOException {
+ var entity = new AlibabaCloudSearchServiceSettings("model_id_name", "host_name", "workspace_name", null, null);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("""
+ {"service_id":"model_id_name","host":"host_name","workspace":"workspace_name","rate_limit":{"requests_per_minute":1000}}"""));
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return AlibabaCloudSearchServiceSettings::new;
+ }
+
+ @Override
+ protected AlibabaCloudSearchServiceSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected AlibabaCloudSearchServiceSettings mutateInstance(AlibabaCloudSearchServiceSettings instance) throws IOException {
+ return null;
+ }
+
+ public static Map getServiceSettingsMap(String serviceId, String host, String workspaceName) {
+ var map = new HashMap();
+ map.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, serviceId);
+ map.put(AlibabaCloudSearchServiceSettings.HOST, host);
+ map.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, workspaceName);
+ return map;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
new file mode 100644
index 0000000000000..cc70b61226fe3
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
@@ -0,0 +1,172 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests;
+import org.hamcrest.MatcherAssert;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.Mockito.mock;
+
+public class AlibabaCloudSearchServiceTests extends ESTestCase {
+ private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private ThreadPool threadPool;
+ private HttpClientManager clientManager;
+
+ @Before
+ public void init() throws Exception {
+ threadPool = createThreadPool(inferenceUtilityPool());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @After
+ public void shutdown() throws IOException {
+ clientManager.close();
+ terminate(threadPool);
+ }
+
+ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException {
+ try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
+ ActionListener modelVerificationListener = ActionListener.wrap(model -> {
+ assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class));
+
+ var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model;
+ assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id"));
+ assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host"));
+ assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default"));
+ assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ }, e -> fail("Model parsing should have succeeded " + e.getMessage()));
+
+ service.parseRequestConfig(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ getRequestConfigMap(
+ AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"),
+ AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
+ getSecretSettingsMap("secret")
+ ),
+ Set.of(),
+ modelVerificationListener
+ );
+ }
+ }
+
+ public void testCheckModelConfig() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool)) {
+ @Override
+ public void doInfer(
+ Model model,
+ List input,
+ Map taskSettings,
+ InputType inputType,
+ TimeValue timeout,
+ ActionListener listener
+ ) {
+ InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults(
+ List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.028680f, 0.022033f }))
+ );
+
+ listener.onResponse(results);
+ }
+ }) {
+ Map serviceSettingsMap = new HashMap<>();
+ serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id");
+ serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host");
+ serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default");
+ serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536);
+
+ Map taskSettingsMap = new HashMap<>();
+
+ Map secretSettingsMap = new HashMap<>();
+ secretSettingsMap.put("api_key", "secret");
+
+ var model = AlibabaCloudSearchEmbeddingsModelTests.createModel(
+ "service",
+ TaskType.TEXT_EMBEDDING,
+ serviceSettingsMap,
+ taskSettingsMap,
+ secretSettingsMap
+ );
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.checkModelConfig(model, listener);
+ var result = listener.actionGet(TIMEOUT);
+
+ Map expectedServiceSettingsMap = new HashMap<>();
+ expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id");
+ expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host");
+ expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default");
+ expectedServiceSettingsMap.put(ServiceFields.SIMILARITY, "DOT_PRODUCT");
+ expectedServiceSettingsMap.put(ServiceFields.DIMENSIONS, 2);
+
+ Map expectedTaskSettingsMap = new HashMap<>();
+
+ Map expectedSecretSettingsMap = new HashMap<>();
+ expectedSecretSettingsMap.put("api_key", "secret");
+
+ var expectedModel = AlibabaCloudSearchEmbeddingsModelTests.createModel(
+ "service",
+ TaskType.TEXT_EMBEDDING,
+ expectedServiceSettingsMap,
+ expectedTaskSettingsMap,
+ expectedSecretSettingsMap
+ );
+
+ MatcherAssert.assertThat(result, is(expectedModel));
+ }
+ }
+
+ private Map getRequestConfigMap(
+ Map serviceSettings,
+ Map taskSettings,
+ Map secretSettings
+ ) {
+ var builtServiceSettings = new HashMap<>();
+ builtServiceSettings.putAll(serviceSettings);
+ builtServiceSettings.putAll(secretSettings);
+
+ return new HashMap<>(
+ Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java
new file mode 100644
index 0000000000000..fca0ee11e5c78
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests;
+import org.hamcrest.MatcherAssert;
+
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchEmbeddingsModelTests extends ESTestCase {
+ public void testOverride() {
+ AlibabaCloudSearchEmbeddingsTaskSettings taskSettings = AlibabaCloudSearchEmbeddingsTaskSettingsTests.createRandom();
+ var model = createModel(
+ "service",
+ TaskType.TEXT_EMBEDDING,
+ AlibabaCloudSearchEmbeddingsServiceSettingsTests.createRandom(),
+ taskSettings,
+ DefaultSecretSettingsTests.createRandom()
+ );
+
+ var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, Map.of(), taskSettings.getInputType());
+ MatcherAssert.assertThat(overriddenModel, is(model));
+ }
+
+ public static AlibabaCloudSearchEmbeddingsModel createModel(
+ String modelId,
+ TaskType taskType,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets
+ ) {
+ return new AlibabaCloudSearchEmbeddingsModel(
+ modelId,
+ taskType,
+ AlibabaCloudSearchUtils.SERVICE_NAME,
+ serviceSettings,
+ taskSettings,
+ secrets,
+ null
+ );
+ }
+
+ public static AlibabaCloudSearchEmbeddingsModel createModel(
+ String modelId,
+ TaskType taskType,
+ AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings,
+ AlibabaCloudSearchEmbeddingsTaskSettings taskSettings,
+ @Nullable DefaultSecretSettings secretSettings
+ ) {
+ return new AlibabaCloudSearchEmbeddingsModel(
+ modelId,
+ taskType,
+ AlibabaCloudSearchUtils.SERVICE_NAME,
+ serviceSettings,
+ taskSettings,
+ secretSettings
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettingsTests.java
new file mode 100644
index 0000000000000..815e6d0311195
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettingsTests.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettingsTests;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<
+ AlibabaCloudSearchEmbeddingsServiceSettings> {
+ public static AlibabaCloudSearchEmbeddingsServiceSettings createRandom() {
+ var commonSettings = AlibabaCloudSearchServiceSettingsTests.createRandom();
+ var similarity = SimilarityMeasure.DOT_PRODUCT;
+ var dims = 1536;
+ var maxInputTokens = 512;
+ return new AlibabaCloudSearchEmbeddingsServiceSettings(commonSettings, similarity, dims, maxInputTokens);
+ }
+
+ public void testFromMap() {
+ var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
+ var dims = 1536;
+ var maxInputTokens = 512;
+ var model = "model";
+ var host = "host";
+ var workspaceName = "default";
+ var httpSchema = "https";
+ var serviceSettings = AlibabaCloudSearchEmbeddingsServiceSettings.fromMap(
+ new HashMap<>(
+ Map.of(
+ ServiceFields.SIMILARITY,
+ similarity,
+ ServiceFields.DIMENSIONS,
+ dims,
+ ServiceFields.MAX_INPUT_TOKENS,
+ maxInputTokens,
+ AlibabaCloudSearchServiceSettings.HOST,
+ host,
+ AlibabaCloudSearchServiceSettings.SERVICE_ID,
+ model,
+ AlibabaCloudSearchServiceSettings.WORKSPACE_NAME,
+ workspaceName,
+ AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME,
+ httpSchema
+ )
+ ),
+ null
+ );
+
+ MatcherAssert.assertThat(
+ serviceSettings,
+ is(
+ new AlibabaCloudSearchEmbeddingsServiceSettings(
+ new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, null),
+ SimilarityMeasure.DOT_PRODUCT,
+ dims,
+ maxInputTokens
+ )
+ )
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return AlibabaCloudSearchEmbeddingsServiceSettings::new;
+ }
+
+ @Override
+ protected AlibabaCloudSearchEmbeddingsServiceSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected AlibabaCloudSearchEmbeddingsServiceSettings mutateInstance(AlibabaCloudSearchEmbeddingsServiceSettings instance)
+ throws IOException {
+ return null;
+ }
+
+ public static Map getServiceSettingsMap(String serviceId, String host, String workspaceName) {
+ return AlibabaCloudSearchServiceSettingsTests.getServiceSettingsMap(serviceId, host, workspaceName);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java
new file mode 100644
index 0000000000000..244685d8e9833
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch;
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase<
+ AlibabaCloudSearchEmbeddingsTaskSettings> {
+ public static AlibabaCloudSearchEmbeddingsTaskSettings createRandom() {
+ var inputType = randomBoolean() ? randomWithIngestAndSearch() : null;
+
+ return new AlibabaCloudSearchEmbeddingsTaskSettings(inputType);
+ }
+
+ public void testFromMap() {
+ MatcherAssert.assertThat(
+ AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(
+ new HashMap<>(Map.of(AlibabaCloudSearchEmbeddingsTaskSettings.INPUT_TYPE, "ingest"))
+ ),
+ is(new AlibabaCloudSearchEmbeddingsTaskSettings(InputType.INGEST))
+ );
+ }
+
+ public void testFromMap_WhenInputTypeIsNull() {
+ InputType inputType = null;
+ MatcherAssert.assertThat(
+ AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())),
+ is(new AlibabaCloudSearchEmbeddingsTaskSettings(inputType))
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return AlibabaCloudSearchEmbeddingsTaskSettings::new;
+ }
+
+ @Override
+ protected AlibabaCloudSearchEmbeddingsTaskSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected AlibabaCloudSearchEmbeddingsTaskSettings mutateInstance(AlibabaCloudSearchEmbeddingsTaskSettings instance)
+ throws IOException {
+ return null;
+ }
+
+ public static Map getTaskSettingsMap(@Nullable InputType inputType) {
+ var map = new HashMap();
+
+ if (inputType != null) {
+ map.put(AlibabaCloudSearchEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString());
+ }
+
+ return map;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java
new file mode 100644
index 0000000000000..4e9179b66c36f
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests;
+import org.hamcrest.MatcherAssert;
+
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchSparseModelTests extends ESTestCase {
+ public void testOverride() {
+ AlibabaCloudSearchSparseTaskSettings taskSettings = AlibabaCloudSearchSparseTaskSettingsTests.createRandom();
+ var model = createModel(
+ "service",
+ TaskType.TEXT_EMBEDDING,
+ AlibabaCloudSearchSparseServiceSettingsTests.createRandom(),
+ taskSettings,
+ DefaultSecretSettingsTests.createRandom()
+ );
+
+ var overriddenModel = AlibabaCloudSearchSparseModel.of(model, Map.of(), taskSettings.getInputType());
+ MatcherAssert.assertThat(overriddenModel, is(model));
+ }
+
+ public static AlibabaCloudSearchSparseModel createModel(
+ String modelId,
+ TaskType taskType,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets
+ ) {
+ return new AlibabaCloudSearchSparseModel(
+ modelId,
+ taskType,
+ AlibabaCloudSearchUtils.SERVICE_NAME,
+ serviceSettings,
+ taskSettings,
+ secrets,
+ null
+ );
+ }
+
+ public static AlibabaCloudSearchSparseModel createModel(
+ String modelId,
+ TaskType taskType,
+ AlibabaCloudSearchSparseServiceSettings serviceSettings,
+ AlibabaCloudSearchSparseTaskSettings taskSettings,
+ @Nullable DefaultSecretSettings secretSettings
+ ) {
+ return new AlibabaCloudSearchSparseModel(
+ modelId,
+ taskType,
+ AlibabaCloudSearchUtils.SERVICE_NAME,
+ serviceSettings,
+ taskSettings,
+ secretSettings
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettingsTests.java
new file mode 100644
index 0000000000000..8dc635a52f06f
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettingsTests.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettingsTests;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchSparseServiceSettingsTests extends AbstractWireSerializingTestCase {
+ public static AlibabaCloudSearchSparseServiceSettings createRandom() {
+ var commonSettings = AlibabaCloudSearchServiceSettingsTests.createRandom();
+ return new AlibabaCloudSearchSparseServiceSettings(commonSettings);
+ }
+
+ public void testFromMap() {
+ var model = "model";
+ var host = "host";
+ var workspaceName = "default";
+ var httpSchema = "https";
+ var serviceSettings = AlibabaCloudSearchSparseServiceSettings.fromMap(
+ new HashMap<>(
+ Map.of(
+ AlibabaCloudSearchServiceSettings.HOST,
+ host,
+ AlibabaCloudSearchServiceSettings.SERVICE_ID,
+ model,
+ AlibabaCloudSearchServiceSettings.WORKSPACE_NAME,
+ workspaceName,
+ AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME,
+ httpSchema
+ )
+ ),
+ null
+ );
+
+ MatcherAssert.assertThat(
+ serviceSettings,
+ is(
+ new AlibabaCloudSearchSparseServiceSettings(
+ new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, null)
+ )
+ )
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return AlibabaCloudSearchSparseServiceSettings::new;
+ }
+
+ @Override
+ protected AlibabaCloudSearchSparseServiceSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected AlibabaCloudSearchSparseServiceSettings mutateInstance(AlibabaCloudSearchSparseServiceSettings instance) throws IOException {
+ return null;
+ }
+
+ public static Map getServiceSettingsMap(String serviceId, String host, String workspaceName) {
+ return AlibabaCloudSearchServiceSettingsTests.getServiceSettingsMap(serviceId, host, workspaceName);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java
new file mode 100644
index 0000000000000..b16d96f9a081b
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch;
+import static org.hamcrest.Matchers.is;
+
+public class AlibabaCloudSearchSparseTaskSettingsTests extends AbstractWireSerializingTestCase {
+ public static AlibabaCloudSearchSparseTaskSettings createRandom() {
+ var inputType = randomBoolean() ? randomWithIngestAndSearch() : null;
+ var returnToken = randomBoolean();
+
+ return new AlibabaCloudSearchSparseTaskSettings(inputType, returnToken);
+ }
+
+ public void testFromMap() {
+ MatcherAssert.assertThat(
+ AlibabaCloudSearchSparseTaskSettings.fromMap(new HashMap<>(Map.of(AlibabaCloudSearchSparseTaskSettings.INPUT_TYPE, "ingest"))),
+ is(new AlibabaCloudSearchSparseTaskSettings(InputType.INGEST, null))
+ );
+ }
+
+ public void testFromMap_WhenInputTypeIsNull() {
+ InputType inputType = null;
+ MatcherAssert.assertThat(
+ AlibabaCloudSearchSparseTaskSettings.fromMap(new HashMap<>(Map.of())),
+ is(new AlibabaCloudSearchSparseTaskSettings(inputType, null))
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return AlibabaCloudSearchSparseTaskSettings::new;
+ }
+
+ @Override
+ protected AlibabaCloudSearchSparseTaskSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected AlibabaCloudSearchSparseTaskSettings mutateInstance(AlibabaCloudSearchSparseTaskSettings instance) throws IOException {
+ return null;
+ }
+
+ public static Map getTaskSettingsMap(@Nullable InputType inputType, @Nullable Boolean returnToken) {
+ var map = new HashMap();
+
+ if (inputType != null) {
+ map.put(AlibabaCloudSearchSparseTaskSettings.INPUT_TYPE, inputType.toString());
+ }
+
+ if (returnToken != null) {
+ map.put(AlibabaCloudSearchSparseTaskSettings.RETURN_TOKEN, returnToken);
+ }
+
+ return map;
+ }
+}
From 38adbb07246527627330154e7f5877599c418d2a Mon Sep 17 00:00:00 2001
From: Oleksandr Kolomiiets
Date: Tue, 27 Aug 2024 14:55:00 -0700
Subject: [PATCH 26/46] Prevent synthetic field loaders accessing stored fields
from using stale data (#112173)
---
docs/changelog/112173.yaml | 7 +
.../extras/MatchOnlyTextFieldMapper.java | 2 +-
.../AnnotatedTextFieldMapper.java | 2 +-
.../indices.create/20_synthetic_source.yml | 53 +++-
.../search/fieldcaps/FieldCapabilitiesIT.java | 2 +-
.../mapper/CompositeSyntheticFieldLoader.java | 100 +++++---
.../mapper/IgnoreMalformedStoredValues.java | 12 +-
.../index/mapper/IpFieldMapper.java | 13 +-
.../index/mapper/KeywordFieldMapper.java | 42 ++--
.../index/mapper/ObjectMapper.java | 38 ++-
...dNumericDocValuesSyntheticFieldLoader.java | 1 -
...etDocValuesSyntheticFieldLoaderLayer.java} | 89 +------
.../index/mapper/SourceLoader.java | 13 +
.../mapper/StringStoredFieldFieldLoader.java | 48 ++--
.../index/mapper/TextFieldMapper.java | 2 +-
...ortedSetDocValuesSyntheticFieldLoader.java | 22 +-
.../CompositeSyntheticFieldLoaderTests.java | 226 ++++++++++++++++++
.../index/mapper/DocumentParserTests.java | 2 +-
.../mapper/HistogramFieldMapper.java | 2 +-
.../AggregateDoubleMetricFieldMapper.java | 8 +-
.../VersionStringFieldMapper.java | 7 +-
.../wildcard/mapper/WildcardFieldMapper.java | 53 ++--
22 files changed, 517 insertions(+), 227 deletions(-)
create mode 100644 docs/changelog/112173.yaml
rename server/src/main/java/org/elasticsearch/index/mapper/{SortedSetDocValuesSyntheticFieldLoader.java => SortedSetDocValuesSyntheticFieldLoaderLayer.java} (69%)
create mode 100644 server/src/test/java/org/elasticsearch/index/mapper/CompositeSyntheticFieldLoaderTests.java
diff --git a/docs/changelog/112173.yaml b/docs/changelog/112173.yaml
new file mode 100644
index 0000000000000..9a43b0d1bf1fa
--- /dev/null
+++ b/docs/changelog/112173.yaml
@@ -0,0 +1,7 @@
+pr: 112173
+summary: Prevent synthetic field loaders accessing stored fields from using stale
+ data
+area: Mapping
+type: bug
+issues:
+ - 112156
diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java
index 899cc42fea1e0..b3cd3586fca54 100644
--- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java
+++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java
@@ -447,7 +447,7 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() {
"field [" + fullPath() + "] of type [" + typeName() + "] doesn't support synthetic source because it declares copy_to"
);
}
- return new StringStoredFieldFieldLoader(fieldType().storedFieldNameForSyntheticSource(), leafName(), null) {
+ return new StringStoredFieldFieldLoader(fieldType().storedFieldNameForSyntheticSource(), leafName()) {
@Override
protected void write(XContentBuilder b, Object value) throws IOException {
b.value((String) value);
diff --git a/plugins/mapper-annotated-text/src/main/java/org/elasticsearch/index/mapper/annotatedtext/AnnotatedTextFieldMapper.java b/plugins/mapper-annotated-text/src/main/java/org/elasticsearch/index/mapper/annotatedtext/AnnotatedTextFieldMapper.java
index dac8e051f25f8..8d50a9f7e29a9 100644
--- a/plugins/mapper-annotated-text/src/main/java/org/elasticsearch/index/mapper/annotatedtext/AnnotatedTextFieldMapper.java
+++ b/plugins/mapper-annotated-text/src/main/java/org/elasticsearch/index/mapper/annotatedtext/AnnotatedTextFieldMapper.java
@@ -584,7 +584,7 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() {
);
}
if (fieldType.stored()) {
- return new StringStoredFieldFieldLoader(fullPath(), leafName(), null) {
+ return new StringStoredFieldFieldLoader(fullPath(), leafName()) {
@Override
protected void write(XContentBuilder b, Object value) throws IOException {
b.value((String) value);
diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml
index 1393d5454a9da..a696f3b2b3224 100644
--- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml
+++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml
@@ -1250,8 +1250,6 @@ empty nested object sorted as a first document:
- match: { hits.hits.1._source.name: B }
- match: { hits.hits.1._source.nested.a: "b" }
-
-
---
subobjects auto:
- requires:
@@ -1339,3 +1337,54 @@ subobjects auto:
- match: { hits.hits.3._source.id: 4 }
- match: { hits.hits.3._source.auto_obj.foo: 40 }
- match: { hits.hits.3._source.auto_obj.foo\.bar: 400 }
+
+---
+# 112156
+stored field under object with store_array_source:
+ - requires:
+ cluster_features: ["mapper.track_ignored_source"]
+ reason: requires tracking ignored source
+
+ - do:
+ indices.create:
+ index: test
+ body:
+ settings:
+ index:
+ sort.field: "name"
+ sort.order: "asc"
+ mappings:
+ _source:
+ mode: synthetic
+ properties:
+ name:
+ type: keyword
+ obj:
+ store_array_source: true
+ properties:
+ foo:
+ type: keyword
+ store: true
+
+ - do:
+ bulk:
+ index: test
+ refresh: true
+ body:
+ - '{ "create": { } }'
+ - '{ "name": "B", "obj": null }'
+ - '{ "create": { } }'
+ - '{ "name": "A", "obj": [ { "foo": "hello_from_the_other_side" } ] }'
+
+ - match: { errors: false }
+
+ - do:
+ search:
+ index: test
+ sort: name
+
+ - match: { hits.total.value: 2 }
+ - match: { hits.hits.0._source.name: A }
+ - match: { hits.hits.0._source.obj: [ { "foo": "hello_from_the_other_side" } ] }
+ - match: { hits.hits.1._source.name: B }
+ - match: { hits.hits.1._source.obj: null }
diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java
index 076158ee22037..0bce9ecb178d0 100644
--- a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java
+++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java
@@ -859,7 +859,7 @@ protected String contentType() {
@Override
public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() {
- return new StringStoredFieldFieldLoader(fullPath(), leafName(), null) {
+ return new StringStoredFieldFieldLoader(fullPath(), leafName()) {
@Override
protected void write(XContentBuilder b, Object value) throws IOException {
BytesRef ref = (BytesRef) value;
diff --git a/server/src/main/java/org/elasticsearch/index/mapper/CompositeSyntheticFieldLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/CompositeSyntheticFieldLoader.java
index efc3c7b507300..7bb1f99e81705 100644
--- a/server/src/main/java/org/elasticsearch/index/mapper/CompositeSyntheticFieldLoader.java
+++ b/server/src/main/java/org/elasticsearch/index/mapper/CompositeSyntheticFieldLoader.java
@@ -15,6 +15,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
@@ -28,29 +29,44 @@
* stored in a different field in case of ignore_malformed being enabled.
*/
public class CompositeSyntheticFieldLoader implements SourceLoader.SyntheticFieldLoader {
- private final String fieldName;
+ private final String leafFieldName;
private final String fullFieldName;
- private final SyntheticFieldLoaderLayer[] parts;
- private boolean hasValue;
+ private final Collection parts;
+ private boolean storedFieldLoadersHaveValues;
+ private boolean docValuesLoadersHaveValues;
- public CompositeSyntheticFieldLoader(String fieldName, String fullFieldName, SyntheticFieldLoaderLayer... parts) {
- this.fieldName = fieldName;
+ public CompositeSyntheticFieldLoader(String leafFieldName, String fullFieldName, Layer... parts) {
+ this(leafFieldName, fullFieldName, Arrays.asList(parts));
+ }
+
+ public CompositeSyntheticFieldLoader(String leafFieldName, String fullFieldName, Collection parts) {
+ this.leafFieldName = leafFieldName;
this.fullFieldName = fullFieldName;
this.parts = parts;
- this.hasValue = false;
+ this.storedFieldLoadersHaveValues = false;
+ this.docValuesLoadersHaveValues = false;
}
@Override
public Stream> storedFieldLoaders() {
- return Arrays.stream(parts).flatMap(SyntheticFieldLoaderLayer::storedFieldLoaders).map(e -> Map.entry(e.getKey(), values -> {
- hasValue = true;
- e.getValue().load(values);
+ return parts.stream().flatMap(Layer::storedFieldLoaders).map(e -> Map.entry(e.getKey(), new StoredFieldLoader() {
+ @Override
+ public void advanceToDoc(int docId) {
+ storedFieldLoadersHaveValues = false;
+ e.getValue().advanceToDoc(docId);
+ }
+
+ @Override
+ public void load(List