diff --git a/docs/changelog/125570.yaml b/docs/changelog/125570.yaml
new file mode 100644
index 0000000000000..ede177c666470
--- /dev/null
+++ b/docs/changelog/125570.yaml
@@ -0,0 +1,5 @@
+pr: 125570
+summary: ES|QL random sampling
+area: Machine Learning
+type: feature
+issues: []
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index a62e201f08544..054c47704caea 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -226,6 +226,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SYNONYMS_REFRESH_PARAM = def(9_060_0_00);
public static final TransportVersion DOC_FIELDS_AS_LIST = def(9_061_0_00);
public static final TransportVersion DENSE_VECTOR_OFF_HEAP_STATS = def(9_062_00_0);
+ public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER = def(9_063_0_00);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java
index 2183ce5646293..56b203700b362 100644
--- a/server/src/main/java/org/elasticsearch/search/SearchModule.java
+++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java
@@ -134,6 +134,7 @@
import org.elasticsearch.search.aggregations.bucket.sampler.UnmappedSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.InternalRandomSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplerAggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.DoubleTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongRareTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
@@ -1186,6 +1187,9 @@ private void registerQueryParsers(List plugins) {
registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
}));
+ registerQuery(
+ new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent)
+ );
registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);
}
diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java
index 89fe1a53a01cc..4834fdb87c12d 100644
--- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java
+++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java
@@ -44,14 +44,34 @@ public final class RandomSamplingQuery extends Query {
* can be generated
*/
public RandomSamplingQuery(double p, int seed, int hash) {
- if (p <= 0.0 || p >= 1.0) {
- throw new IllegalArgumentException("RandomSampling probability must be between 0.0 and 1.0, was [" + p + "]");
- }
+ checkProbabilityRange(p);
this.p = p;
this.seed = seed;
this.hash = hash;
}
+ /**
+ * Verifies that the probability is within the (0.0, 1.0) range.
+ * @throws IllegalArgumentException in case of an invalid probability.
+ */
+ public static void checkProbabilityRange(double p) throws IllegalArgumentException {
+ if (p <= 0.0 || p >= 1.0) {
+ throw new IllegalArgumentException("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]");
+ }
+ }
+
+ public double probability() {
+ return p;
+ }
+
+ public int seed() {
+ return seed;
+ }
+
+ public int hash() {
+ return hash;
+ }
+
@Override
public String toString(String field) {
return "RandomSamplingQuery{" + "p=" + p + ", seed=" + seed + ", hash=" + hash + '}';
@@ -98,13 +118,13 @@ public void visit(QueryVisitor visitor) {
/**
* A DocIDSetIter that skips a geometrically random number of documents
*/
- static class RandomSamplingIterator extends DocIdSetIterator {
+ public static class RandomSamplingIterator extends DocIdSetIterator {
private final int maxDoc;
private final double p;
private final FastGeometric distribution;
private int doc = -1;
- RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
+ public RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
this.maxDoc = maxDoc;
this.p = p;
this.distribution = new FastGeometric(rng, p);
diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQueryBuilder.java
new file mode 100644
index 0000000000000..c59a70753fad7
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQueryBuilder.java
@@ -0,0 +1,149 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.search.aggregations.bucket.sampler.random;
+
+import org.apache.lucene.search.Query;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.Randomness;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.index.query.AbstractQueryBuilder;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery.checkProbabilityRange;
+import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+public class RandomSamplingQueryBuilder extends AbstractQueryBuilder {
+
+ public static final String NAME = "random_sampling";
+ static final ParseField PROBABILITY = new ParseField("query");
+ static final ParseField SEED = new ParseField("seed");
+ static final ParseField HASH = new ParseField("hash");
+
+ private final double probability;
+ private int seed = Randomness.get().nextInt();
+ private int hash = 0;
+
+ public RandomSamplingQueryBuilder(double probability) {
+ checkProbabilityRange(probability);
+ this.probability = probability;
+ }
+
+ public RandomSamplingQueryBuilder seed(int seed) {
+ checkProbabilityRange(probability);
+ this.seed = seed;
+ return this;
+ }
+
+ public RandomSamplingQueryBuilder(StreamInput in) throws IOException {
+ super(in);
+ this.probability = in.readDouble();
+ this.seed = in.readInt();
+ this.hash = in.readInt();
+ }
+
+ public RandomSamplingQueryBuilder hash(Integer hash) {
+ this.hash = hash;
+ return this;
+ }
+
+ public double probability() {
+ return probability;
+ }
+
+ public int seed() {
+ return seed;
+ }
+
+ public int hash() {
+ return hash;
+ }
+
+ @Override
+ protected void doWriteTo(StreamOutput out) throws IOException {
+ out.writeDouble(probability);
+ out.writeInt(seed);
+ out.writeInt(hash);
+ }
+
+ @Override
+ protected void doXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject(NAME);
+ builder.field(PROBABILITY.getPreferredName(), probability);
+ builder.field(SEED.getPreferredName(), seed);
+ builder.field(HASH.getPreferredName(), hash);
+ builder.endObject();
+ }
+
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ NAME,
+ false,
+ args -> {
+ var randomSamplingQueryBuilder = new RandomSamplingQueryBuilder((double) args[0]);
+ if (args[1] != null) {
+ randomSamplingQueryBuilder.seed((int) args[1]);
+ }
+ if (args[2] != null) {
+ randomSamplingQueryBuilder.hash((int) args[2]);
+ }
+ return randomSamplingQueryBuilder;
+ }
+ );
+
+ static {
+ PARSER.declareDouble(constructorArg(), PROBABILITY);
+ PARSER.declareInt(optionalConstructorArg(), SEED);
+ PARSER.declareInt(optionalConstructorArg(), HASH);
+ }
+
+ public static RandomSamplingQueryBuilder fromXContent(XContentParser parser) throws IOException {
+ return PARSER.apply(parser, null);
+ }
+
+ @Override
+ protected Query doToQuery(SearchExecutionContext context) throws IOException {
+ return new RandomSamplingQuery(probability, seed, hash);
+ }
+
+ @Override
+ protected boolean doEquals(RandomSamplingQueryBuilder other) {
+ return probability == other.probability && seed == other.seed && hash == other.hash;
+ }
+
+ @Override
+ protected int doHashCode() {
+ return Objects.hash(probability, seed, hash);
+ }
+
+ /**
+ * Returns the name of the writeable object
+ */
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ /**
+ * The minimal version of the recipient this object can be sent to
+ */
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.RANDOM_SAMPLER_QUERY_BUILDER;
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java b/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java
index 9109cd6b89bed..1e638f8e7b30e 100644
--- a/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java
+++ b/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java
@@ -444,6 +444,7 @@ public CheckedBiConsumer getReque
"range",
"regexp",
"knn_score_doc",
+ "random_sampling",
"script",
"script_score",
"simple_query_string",
diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQueryBuilderTests.java
new file mode 100644
index 0000000000000..d64068042a57c
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQueryBuilderTests.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.search.aggregations.bucket.sampler.random;
+
+import org.apache.lucene.search.Query;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.test.AbstractQueryTestCase;
+import org.elasticsearch.xcontent.XContentParseException;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class RandomSamplingQueryBuilderTests extends AbstractQueryTestCase {
+
+ @Override
+ protected RandomSamplingQueryBuilder doCreateTestQueryBuilder() {
+ double probability = randomDoubleBetween(0.0, 1.0, false);
+ var builder = new RandomSamplingQueryBuilder(probability);
+ if (randomBoolean()) {
+ builder.seed(randomInt());
+ }
+ if (randomBoolean()) {
+ builder.hash(randomInt());
+ }
+ return builder;
+ }
+
+ @Override
+ protected void doAssertLuceneQuery(RandomSamplingQueryBuilder queryBuilder, Query query, SearchExecutionContext context)
+ throws IOException {
+ var rsQuery = asInstanceOf(RandomSamplingQuery.class, query);
+ assertThat(rsQuery.probability(), equalTo(queryBuilder.probability()));
+ assertThat(rsQuery.seed(), equalTo(queryBuilder.seed()));
+ assertThat(rsQuery.hash(), equalTo(queryBuilder.hash()));
+ }
+
+ @Override
+ protected boolean supportsBoost() {
+ return false;
+ }
+
+ @Override
+ protected boolean supportsQueryName() {
+ return false;
+ }
+
+ @Override
+ public void testUnknownField() {
+ var json = "{ \""
+ + RandomSamplingQueryBuilder.NAME
+ + "\" : {\"bogusField\" : \"someValue\", \""
+ + RandomSamplingQueryBuilder.PROBABILITY.getPreferredName()
+ + "\" : \""
+ + randomBoolean()
+ + "\", \""
+ + RandomSamplingQueryBuilder.SEED.getPreferredName()
+ + "\" : \""
+ + randomInt()
+ + "\", \""
+ + RandomSamplingQueryBuilder.HASH.getPreferredName()
+ + "\" : \""
+ + randomInt()
+ + "\" } }";
+ var e = expectThrows(XContentParseException.class, () -> parseQuery(json));
+ assertTrue(e.getMessage().contains("bogusField"));
+ }
+}
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java
index c922d0f928640..3260489983abd 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java
@@ -172,6 +172,10 @@ public static boolean isSupported(String name) {
return ATTRIBUTES_MAP.containsKey(name);
}
+ public static boolean isScoreAttribute(Expression a) {
+ return a instanceof MetadataAttribute ma && ma.name().equals(SCORE);
+ }
+
@Override
@SuppressWarnings("checkstyle:EqualsHashCode")// equals is implemented in parent. See innerEquals instead
public int hashCode() {
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Page.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Page.java
index 59ef3d778ca79..f08c66f4f9e6c 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Page.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Page.java
@@ -294,4 +294,21 @@ public Page projectBlocks(int[] blockMapping) {
}
}
}
+
+ public Page filter(int... positions) {
+ Block[] filteredBlocks = new Block[blocks.length];
+ boolean success = false;
+ try {
+ for (int i = 0; i < blocks.length; i++) {
+ filteredBlocks[i] = getBlock(i).filter(positions);
+ }
+ success = true;
+ } finally {
+ releaseBlocks();
+ if (success == false) {
+ Releasables.closeExpectNoException(filteredBlocks);
+ }
+ }
+ return new Page(filteredBlocks);
+ }
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ChangePointOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ChangePointOperator.java
index 2693c13a5383a..21efa314f1eed 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ChangePointOperator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ChangePointOperator.java
@@ -19,9 +19,9 @@
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangePointDetector;
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangeType;
+import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
-import java.util.LinkedList;
import java.util.List;
/**
@@ -68,8 +68,8 @@ public ChangePointOperator(DriverContext driverContext, int channel, String sour
this.sourceColumn = sourceColumn;
finished = false;
- inputPages = new LinkedList<>();
- outputPages = new LinkedList<>();
+ inputPages = new ArrayDeque<>();
+ outputPages = new ArrayDeque<>();
warnings = null;
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java
index 5b8d485c4da3a..d95f60f2191c8 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java
@@ -7,7 +7,6 @@
package org.elasticsearch.compute.operator;
-import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
@@ -69,20 +68,7 @@ protected Page process(Page page) {
}
positions = Arrays.copyOf(positions, rowCount);
- Block[] filteredBlocks = new Block[page.getBlockCount()];
- boolean success = false;
- try {
- for (int i = 0; i < page.getBlockCount(); i++) {
- filteredBlocks[i] = page.getBlock(i).filter(positions);
- }
- success = true;
- } finally {
- page.releaseBlocks();
- if (success == false) {
- Releasables.closeExpectNoException(filteredBlocks);
- }
- }
- return new Page(filteredBlocks);
+ return page.filter(positions);
}
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SampleOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SampleOperator.java
new file mode 100644
index 0000000000000..0a2158015c950
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SampleOperator.java
@@ -0,0 +1,228 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.operator;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.ArrayDeque;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.Objects;
+import java.util.SplittableRandom;
+
+public class SampleOperator implements Operator {
+
+ public record Factory(double probability, int seed) implements OperatorFactory {
+
+ @Override
+ public SampleOperator get(DriverContext driverContext) {
+ return new SampleOperator(probability, seed);
+ }
+
+ @Override
+ public String describe() {
+ return "SampleOperator[probability = " + probability + ", seed = " + seed + "]";
+ }
+ }
+
+ private final Deque outputPages;
+
+ /**
+ * At any time this iterator will point to be next document that still
+ * needs to be sampled. If this document is on the current page, it's
+ * added to the output and the iterator is advanced. It the document is
+ * not on the current page, the current page is finished and the index
+ * is used for the next page.
+ */
+ private final RandomSamplingQuery.RandomSamplingIterator randomSamplingIterator;
+ private boolean finished;
+
+ private int pagesProcessed = 0;
+ private int rowsReceived = 0;
+ private int rowsEmitted = 0;
+ private long collectNanos;
+ private long emitNanos;
+
+ private SampleOperator(double probability, int seed) {
+ finished = false;
+ outputPages = new ArrayDeque<>();
+ SplittableRandom random = new SplittableRandom(seed);
+ randomSamplingIterator = new RandomSamplingQuery.RandomSamplingIterator(Integer.MAX_VALUE, probability, random::nextInt);
+ // Initialize the iterator to the next document that needs to be sampled.
+ randomSamplingIterator.nextDoc();
+ }
+
+ /**
+ * whether the given operator can accept more input pages
+ */
+ @Override
+ public boolean needsInput() {
+ return finished == false;
+ }
+
+ /**
+ * adds an input page to the operator. only called when needsInput() == true and isFinished() == false
+ *
+ * @param page
+ * @throws UnsupportedOperationException if the operator is a {@link SourceOperator}
+ */
+ @Override
+ public void addInput(Page page) {
+ long startTime = System.nanoTime();
+ createOutputPage(page);
+ rowsReceived += page.getPositionCount();
+ page.releaseBlocks();
+ pagesProcessed++;
+ collectNanos += System.nanoTime() - startTime;
+ }
+
+ private void createOutputPage(Page page) {
+ final int[] sampledPositions = new int[page.getPositionCount()];
+ int sampledIdx = 0;
+ for (int i = randomSamplingIterator.docID(); i - rowsReceived < page.getPositionCount(); i = randomSamplingIterator.nextDoc()) {
+ sampledPositions[sampledIdx++] = i - rowsReceived;
+ }
+ if (sampledIdx > 0) {
+ outputPages.add(page.filter(Arrays.copyOf(sampledPositions, sampledIdx)));
+ }
+ }
+
+ /**
+ * notifies the operator that it won't receive any more input pages
+ */
+ @Override
+ public void finish() {
+ finished = true;
+ }
+
+ /**
+ * whether the operator has finished processing all input pages and made the corresponding output pages available
+ */
+ @Override
+ public boolean isFinished() {
+ return finished && outputPages.isEmpty();
+ }
+
+ @Override
+ public Page getOutput() {
+ final var emitStart = System.nanoTime();
+ Page page;
+ if (outputPages.isEmpty()) {
+ page = null;
+ } else {
+ page = outputPages.removeFirst();
+ rowsEmitted += page.getPositionCount();
+ }
+ emitNanos += System.nanoTime() - emitStart;
+ return page;
+ }
+
+ /**
+ * notifies the operator that it won't be used anymore (i.e. none of the other methods called),
+ * and its resources can be cleaned up
+ */
+ @Override
+ public void close() {
+ for (Page page : outputPages) {
+ page.releaseBlocks();
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "SampleOperator[sampled = " + rowsEmitted + "/" + rowsReceived + "]";
+ }
+
+ @Override
+ public Operator.Status status() {
+ return new Status(collectNanos, emitNanos, pagesProcessed, rowsReceived, rowsEmitted);
+ }
+
+ private record Status(long collectNanos, long emitNanos, int pagesProcessed, int rowsReceived, int rowsEmitted)
+ implements
+ Operator.Status {
+
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+ Operator.Status.class,
+ "sample",
+ Status::new
+ );
+
+ Status(StreamInput streamInput) throws IOException {
+ this(streamInput.readVLong(), streamInput.readVLong(), streamInput.readVInt(), streamInput.readVInt(), streamInput.readVInt());
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeVLong(collectNanos);
+ out.writeVLong(emitNanos);
+ out.writeVInt(pagesProcessed);
+ out.writeVInt(rowsReceived);
+ out.writeVInt(rowsEmitted);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field("collect_nanos", collectNanos);
+ if (builder.humanReadable()) {
+ builder.field("collect_time", TimeValue.timeValueNanos(collectNanos));
+ }
+ builder.field("emit_nanos", emitNanos);
+ if (builder.humanReadable()) {
+ builder.field("emit_time", TimeValue.timeValueNanos(emitNanos));
+ }
+ builder.field("pages_processed", pagesProcessed);
+ builder.field("rows_received", rowsReceived);
+ builder.field("rows_emitted", rowsEmitted);
+ return builder.endObject();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Status other = (Status) o;
+ return collectNanos == other.collectNanos
+ && emitNanos == other.emitNanos
+ && pagesProcessed == other.pagesProcessed
+ && rowsReceived == other.rowsReceived
+ && rowsEmitted == other.rowsEmitted;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(collectNanos, emitNanos, pagesProcessed, rowsReceived, rowsEmitted);
+ }
+
+ @Override
+ public String toString() {
+ return Strings.toString(this);
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ZERO;
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SampleOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SampleOperatorTests.java
new file mode 100644
index 0000000000000..cf2d73ab4768f
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SampleOperatorTests.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.operator;
+
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.test.OperatorTestCase;
+import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator;
+import org.hamcrest.Matcher;
+
+import java.util.List;
+import java.util.stream.LongStream;
+
+import static org.hamcrest.Matchers.both;
+import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.hamcrest.Matchers.matchesPattern;
+
+public class SampleOperatorTests extends OperatorTestCase {
+
+ @Override
+ protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
+ return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size));
+ }
+
+ @Override
+ protected void assertSimpleOutput(List input, List results) {
+ int inputCount = input.stream().mapToInt(Page::getPositionCount).sum();
+ int outputCount = results.stream().mapToInt(Page::getPositionCount).sum();
+ double meanExpectedOutputCount = 0.5 * inputCount;
+ double stdDevExpectedOutputCount = Math.sqrt(meanExpectedOutputCount);
+ assertThat((double) outputCount, closeTo(meanExpectedOutputCount, 10 * stdDevExpectedOutputCount));
+ }
+
+ @Override
+ protected SampleOperator.Factory simple() {
+ return new SampleOperator.Factory(0.5, randomInt());
+ }
+
+ @Override
+ protected Matcher expectedDescriptionOfSimple() {
+ return matchesPattern("SampleOperator\\[probability = 0.5, seed = -?\\d+]");
+ }
+
+ @Override
+ protected Matcher expectedToStringOfSimple() {
+ return equalTo("SampleOperator[sampled = 0/0]");
+ }
+
+ public void testAccuracy() {
+ BlockFactory blockFactory = driverContext().blockFactory();
+ int totalPositionCount = 0;
+
+ for (int iter = 0; iter < 10000; iter++) {
+ SampleOperator operator = simple().get(driverContext());
+ operator.addInput(new Page(blockFactory.newConstantNullBlock(20000)));
+ Page output = operator.getOutput();
+ // 10000 expected rows, stddev=sqrt(10000)=100, so this is 10 stddevs.
+ assertThat(output.getPositionCount(), both(greaterThan(9000)).and(lessThan(11000)));
+ totalPositionCount += output.getPositionCount();
+ output.releaseBlocks();
+ }
+
+ int averagePositionCount = totalPositionCount / 10000;
+ // Running 10000 times, so the stddev is divided by sqrt(10000)=100, so this 10 stddevs again.
+ assertThat(averagePositionCount, both(greaterThan(9990)).and(lessThan(10010)));
+ }
+}
diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestSampleIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestSampleIT.java
new file mode 100644
index 0000000000000..523d9c15ab128
--- /dev/null
+++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestSampleIT.java
@@ -0,0 +1,26 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.qa.single_node;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
+
+import org.elasticsearch.test.TestClustersThreadFilter;
+import org.elasticsearch.test.cluster.ElasticsearchCluster;
+import org.elasticsearch.xpack.esql.qa.rest.RestSampleTestCase;
+import org.junit.ClassRule;
+
+@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
+public class RestSampleIT extends RestSampleTestCase {
+ @ClassRule
+ public static ElasticsearchCluster cluster = Clusters.testCluster();
+
+ @Override
+ protected String getTestRestCluster() {
+ return cluster.getHttpAddresses();
+ }
+}
diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestSampleTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestSampleTestCase.java
new file mode 100644
index 0000000000000..2362c163d57d8
--- /dev/null
+++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestSampleTestCase.java
@@ -0,0 +1,148 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.qa.rest;
+
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.ResponseException;
+import org.elasticsearch.test.rest.ESRestTestCase;
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.hamcrest.Description;
+import org.hamcrest.TypeSafeMatcher;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.IntStream;
+
+import static org.hamcrest.Matchers.both;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+
+public class RestSampleTestCase extends ESRestTestCase {
+
+ @Before
+ public void skipWhenSampleDisabled() throws IOException {
+ assumeTrue(
+ "Requires SAMPLE capability",
+ EsqlSpecTestCase.hasCapabilities(adminClient(), List.of(EsqlCapabilities.Cap.SAMPLE.capabilityName()))
+ );
+ }
+
+ @Before
+ @After
+ public void assertRequestBreakerEmpty() throws Exception {
+ EsqlSpecTestCase.assertRequestBreakerEmpty();
+ }
+
+ /**
+ * Matcher for the results of sampling 50% of the elements 0,1,2,...,998,999.
+ * The results should consist of unique numbers in [0,999]. Furthermore, the
+ * size should on average be 500. Allowing for 10 stddev deviations, the size
+ * should be in [250,750].
+ */
+ private static final TypeSafeMatcher>> RESULT_MATCHER = new TypeSafeMatcher<>() {
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("a list with between 250 and 750 unique elements in [0,999]");
+ }
+
+ @Override
+ protected boolean matchesSafely(List> lists) {
+ if (lists.size() < 250 || lists.size() > 750) {
+ return false;
+ }
+ Set values = new HashSet<>();
+ for (List list : lists) {
+ if (list.size() != 1) {
+ return false;
+ }
+ Integer value = list.get(0);
+ if (value == null || value < 0 || value >= 1000) {
+ return false;
+ }
+ values.add(value);
+ }
+ return values.size() == lists.size();
+ }
+ };
+
+ /**
+ * This tests sampling in the Lucene query.
+ */
+ public void testSample_withFrom() throws IOException {
+ createTestIndex();
+ test("FROM sample-test-index | SAMPLE 0.5 | LIMIT 1000");
+ deleteTestIndex();
+ }
+
+ /**
+ * This tests sampling in the ES|QL operator.
+ */
+ public void testSample_withRow() throws IOException {
+ List numbers = IntStream.range(0, 999).boxed().toList();
+ test("ROW value = " + numbers + " | MV_EXPAND value | SAMPLE 0.5 | LIMIT 1000");
+ }
+
+ private void test(String query) throws IOException {
+ int iterationCount = 1000;
+ int totalResultSize = 0;
+ for (int iteration = 0; iteration < iterationCount; iteration++) {
+ Map result = runEsqlQuery(query);
+ assertResultMap(result, defaultOutputColumns(), RESULT_MATCHER);
+ totalResultSize += ((List>) result.get("values")).size();
+ }
+ // On average there's 500 elements in the results set.
+ // Allowing for 10 stddev deviations, it should be in [490,510].
+ assertThat(totalResultSize / iterationCount, both(greaterThan(490)).and(lessThan(510)));
+ }
+
+ private static List
*/
@Override public T visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) { return visitChildren(ctx); }
+ /**
+ * {@inheritDoc}
+ *
+ * The default implementation returns the result of calling
+ * {@link #visitChildren} on {@code ctx}.
+ */
+ @Override public T visitSampleCommand(EsqlBaseParser.SampleCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
index f94769d79bc08..c7068edc32c18 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
@@ -655,6 +655,16 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx);
+ /**
+ * Enter a parse tree produced by {@link EsqlBaseParser#sampleCommand}.
+ * @param ctx the parse tree
+ */
+ void enterSampleCommand(EsqlBaseParser.SampleCommandContext ctx);
+ /**
+ * Exit a parse tree produced by {@link EsqlBaseParser#sampleCommand}.
+ * @param ctx the parse tree
+ */
+ void exitSampleCommand(EsqlBaseParser.SampleCommandContext ctx);
/**
* Enter a parse tree produced by the {@code matchExpression}
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
index fab0c03af5e56..b25416a8cb35a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
@@ -400,6 +400,12 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor {
* @return the visitor result
*/
T visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx);
+ /**
+ * Visit a parse tree produced by {@link EsqlBaseParser#sampleCommand}.
+ * @param ctx the parse tree
+ * @return the visitor result
+ */
+ T visitSampleCommand(EsqlBaseParser.SampleCommandContext ctx);
/**
* Visit a parse tree produced by the {@code matchExpression}
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
index 51da9da8bbf9e..0f0b4a7f9140e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
@@ -66,6 +66,7 @@
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
+import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
@@ -768,4 +769,23 @@ public Literal inferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
ctx.parameter().getText()
);
}
+
+ public PlanFactory visitSampleCommand(EsqlBaseParser.SampleCommandContext ctx) {
+ var probability = visitDecimalValue(ctx.probability);
+ Literal seed;
+ if (ctx.seed != null) {
+ seed = visitIntegerValue(ctx.seed);
+ if (seed.dataType() != DataType.INTEGER) {
+ throw new ParsingException(
+ seed.source(),
+ "seed must be an integer, provided [{}] of type [{}]",
+ ctx.seed.getText(),
+ seed.dataType()
+ );
+ }
+ } else {
+ seed = null;
+ }
+ return plan -> new Sample(source(ctx), probability, seed, plan);
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
index d15b6aa2973aa..2fe9f5182ae00 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
@@ -21,6 +21,7 @@
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
+import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
@@ -47,6 +48,7 @@
import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.MvExpandExec;
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
+import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
@@ -87,6 +89,7 @@ public static List logical() {
OrderBy.ENTRY,
Project.ENTRY,
Rerank.ENTRY,
+ Sample.ENTRY,
TimeSeriesAggregate.ENTRY,
TopN.ENTRY
);
@@ -114,6 +117,7 @@ public static List physical() {
MvExpandExec.ENTRY,
ProjectExec.ENTRY,
RerankExec.ENTRY,
+ SampleExec.ENTRY,
ShowExec.ENTRY,
SubqueryExec.ENTRY,
TimeSeriesAggregateExec.ENTRY,
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Sample.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Sample.java
new file mode 100644
index 0000000000000..ea4e9396ecca8
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Sample.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.esql.plan.logical;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery;
+import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware;
+import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
+import org.elasticsearch.xpack.esql.common.Failures;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.expression.Foldables;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.common.Failure.fail;
+
+public class Sample extends UnaryPlan implements TelemetryAware, PostAnalysisVerificationAware {
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Sample", Sample::new);
+
+ private final Expression probability;
+ private final Expression seed;
+
+ public Sample(Source source, Expression probability, @Nullable Expression seed, LogicalPlan child) {
+ super(source, child);
+ this.probability = probability;
+ this.seed = seed;
+ }
+
+ private Sample(StreamInput in) throws IOException {
+ this(
+ Source.readFrom((PlanStreamInput) in),
+ in.readNamedWriteable(Expression.class), // probability
+ in.readOptionalNamedWriteable(Expression.class), // seed
+ in.readNamedWriteable(LogicalPlan.class) // child
+ );
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ source().writeTo(out);
+ out.writeNamedWriteable(probability);
+ out.writeOptionalNamedWriteable(seed);
+ out.writeNamedWriteable(child());
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ @Override
+ protected NodeInfo info() {
+ return NodeInfo.create(this, Sample::new, probability, seed, child());
+ }
+
+ @Override
+ public Sample replaceChild(LogicalPlan newChild) {
+ return new Sample(source(), probability, seed, newChild);
+ }
+
+ public Expression probability() {
+ return probability;
+ }
+
+ public Expression seed() {
+ return seed;
+ }
+
+ @Override
+ public boolean expressionsResolved() {
+ return probability.resolved() && (seed == null || seed.resolved());
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(probability, seed, child());
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+
+ var other = (Sample) obj;
+
+ return Objects.equals(probability, other.probability) && Objects.equals(seed, other.seed) && Objects.equals(child(), other.child());
+ }
+
+ @Override
+ public void postAnalysisVerification(Failures failures) {
+ try {
+ RandomSamplingQuery.checkProbabilityRange((double) Foldables.valueOf(FoldContext.small(), probability));
+ } catch (IllegalArgumentException e) {
+ failures.add(fail(probability, e.getMessage()));
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java
index 60e7eb535f444..2e74c7153f77e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java
@@ -308,6 +308,12 @@ public EsQueryExec withSorts(List sorts) {
: new EsQueryExec(source(), indexPattern, indexMode, indexNameWithModes, attrs, query, limit, sorts, estimatedRowSize);
}
+ public EsQueryExec withQuery(QueryBuilder query) {
+ return Objects.equals(this.query, query)
+ ? this
+ : new EsQueryExec(source(), indexPattern, indexMode, indexNameWithModes, attrs, query, limit, sorts, estimatedRowSize);
+ }
+
@Override
public int hashCode() {
return Objects.hash(indexPattern, indexMode, indexNameWithModes, attrs, query, limit, sorts);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/SampleExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/SampleExec.java
new file mode 100644
index 0000000000000..e110a1b60928a
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/SampleExec.java
@@ -0,0 +1,114 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.esql.capabilities.PostPhysicalOptimizationVerificationAware;
+import org.elasticsearch.xpack.esql.common.Failures;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.common.Failure.fail;
+
+public class SampleExec extends UnaryExec implements PostPhysicalOptimizationVerificationAware {
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+ PhysicalPlan.class,
+ "SampleExec",
+ SampleExec::new
+ );
+
+ private final Expression probability;
+ private final Expression seed;
+
+ public SampleExec(Source source, PhysicalPlan child, Expression probability, @Nullable Expression seed) {
+ super(source, child);
+ this.probability = probability;
+ this.seed = seed;
+ }
+
+ public SampleExec(StreamInput in) throws IOException {
+ this(
+ Source.readFrom((PlanStreamInput) in),
+ in.readNamedWriteable(PhysicalPlan.class), // child
+ in.readNamedWriteable(Expression.class), // probability
+ in.readOptionalNamedWriteable(Expression.class) // seed
+ );
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ source().writeTo(out);
+ out.writeNamedWriteable(child());
+ out.writeNamedWriteable(probability);
+ out.writeOptionalNamedWriteable(seed);
+ }
+
+ @Override
+ public UnaryExec replaceChild(PhysicalPlan newChild) {
+ return new SampleExec(source(), newChild, probability, seed);
+ }
+
+ @Override
+ protected NodeInfo extends PhysicalPlan> info() {
+ return NodeInfo.create(this, SampleExec::new, child(), probability, seed);
+ }
+
+ /**
+ * Returns the name of the writeable object
+ */
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ public Expression probability() {
+ return probability;
+ }
+
+ public Expression seed() {
+ return seed;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(child(), probability, seed);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+
+ var other = (SampleExec) obj;
+
+ return Objects.equals(child(), other.child()) && Objects.equals(probability, other.probability) && Objects.equals(seed, other.seed);
+ }
+
+ @Override
+ public void postPhysicalOptimizationVerification(Failures failures) {
+ // It's currently impossible in ES|QL to handle all data in deterministic order, therefore
+ // a fixed random seed in the sample operator doesn't work as intended and is disallowed.
+ // TODO: fix this.
+ if (seed != null) {
+ // TODO: what should the error message here be? This doesn't seem right.
+ failures.add(fail(seed, "Seed not supported when sampling can't be pushed down to Lucene"));
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
index f263c2f80e429..8ef7d43b28d4b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.planner;
import org.elasticsearch.cluster.ClusterName;
+import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
@@ -37,6 +38,7 @@
import org.elasticsearch.compute.operator.OutputOperator.OutputOperatorFactory;
import org.elasticsearch.compute.operator.RowInTableLookupOperator;
import org.elasticsearch.compute.operator.RrfScoreEvalOperator;
+import org.elasticsearch.compute.operator.SampleOperator;
import org.elasticsearch.compute.operator.ScoreOperator;
import org.elasticsearch.compute.operator.ShowOperator;
import org.elasticsearch.compute.operator.SinkOperator;
@@ -66,6 +68,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.expression.Foldables;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NameId;
@@ -107,6 +110,7 @@
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
+import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
@@ -253,6 +257,8 @@ private PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlannerContext c
return planRerank(rerank, context);
} else if (node instanceof ChangePointExec changePoint) {
return planChangePoint(changePoint, context);
+ } else if (node instanceof SampleExec Sample) {
+ return planSample(Sample, context);
}
// source nodes
else if (node instanceof EsQueryExec esQuery) {
@@ -800,6 +806,13 @@ private PhysicalOperation planChangePoint(ChangePointExec changePoint, LocalExec
);
}
+ private PhysicalOperation planSample(SampleExec rsx, LocalExecutionPlannerContext context) {
+ PhysicalOperation source = plan(rsx.child(), context);
+ var probability = (double) Foldables.valueOf(context.foldCtx(), rsx.probability());
+ var seed = rsx.seed() != null ? (int) Foldables.valueOf(context.foldCtx(), rsx.seed()) : Randomness.get().nextInt();
+ return source.with(new SampleOperator.Factory(probability, seed), source.layout);
+ }
+
/**
* Immutable physical operation.
*/
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/LocalMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/LocalMapper.java
index c8758bbe4bff7..4eba58edbe762 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/LocalMapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/LocalMapper.java
@@ -17,6 +17,7 @@
import org.elasticsearch.xpack.esql.plan.logical.LeafPlan;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
@@ -28,6 +29,7 @@
import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.LookupJoinExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
+import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import java.util.List;
@@ -83,6 +85,10 @@ private PhysicalPlan mapUnary(UnaryPlan unary) {
return new TopNExec(topN.source(), mappedChild, topN.order(), topN.limit(), null);
}
+ if (unary instanceof Sample sample) {
+ return new SampleExec(sample.source(), mappedChild, sample.probability(), sample.seed());
+ }
+
//
// Pipeline operators
//
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java
index ab137fe872795..1af000a7a36bd 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java
@@ -21,6 +21,7 @@
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
+import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
@@ -37,6 +38,7 @@
import org.elasticsearch.xpack.esql.plan.physical.LookupJoinExec;
import org.elasticsearch.xpack.esql.plan.physical.MergeExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
+import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
@@ -186,6 +188,12 @@ private PhysicalPlan mapUnary(UnaryPlan unary) {
);
}
+ // TODO: share code with local LocalMapper?
+ if (unary instanceof Sample sample) {
+ mappedChild = addExchangeForFragment(sample, mappedChild);
+ return new SampleExec(sample.source(), mappedChild, sample.probability(), sample.seed());
+ }
+
//
// Pipeline operators
//
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
index 3db455a0c8bb6..4851de1616844 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
@@ -12,6 +12,7 @@
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
@@ -53,7 +54,7 @@
/**
* Class for sharing code across Mappers.
*/
-class MapperUtils {
+public class MapperUtils {
private MapperUtils() {}
static PhysicalPlan mapLeaf(LeafPlan p) {
@@ -177,4 +178,13 @@ static AggregateExec aggExec(Aggregate aggregate, PhysicalPlan child, Aggregator
static PhysicalPlan unsupported(LogicalPlan p) {
throw new EsqlIllegalArgumentException("unsupported logical plan node [" + p.nodeName() + "]");
}
+
+ public static boolean hasScoreAttribute(List extends Attribute> attributes) {
+ for (Attribute attr : attributes) {
+ if (MetadataAttribute.isScoreAttribute(attr)) {
+ return true;
+ }
+ }
+ return false;
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/telemetry/FeatureMetric.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/telemetry/FeatureMetric.java
index 69682c0d0bb0f..ec4f5db99576c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/telemetry/FeatureMetric.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/telemetry/FeatureMetric.java
@@ -31,6 +31,7 @@
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
+import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
@@ -68,7 +69,8 @@ public enum FeatureMetric {
INSIST(Insist.class::isInstance),
FORK(Fork.class::isInstance),
RRF(RrfScoreEval.class::isInstance),
- COMPLETION(Completion.class::isInstance);
+ COMPLETION(Completion.class::isInstance),
+ SAMPLE(Sample.class::isInstance);
/**
* List here plans we want to exclude from telemetry
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
index db9047be3f065..704a0395e97b4 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
@@ -26,6 +26,8 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.function.Predicate;
+import java.util.function.Supplier;
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE;
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.MATCH_TYPE;
@@ -209,4 +211,13 @@ public static void loadEnrichPolicyResolution(EnrichResolution enrich, String po
public static IndexResolution tsdbIndexResolution() {
return loadMapping("tsdb-mapping.json", "test");
}
+
+ public static E randomValueOtherThanTest(Predicate exclude, Supplier supplier) {
+ while (true) {
+ E value = supplier.get();
+ if (exclude.test(value) == false) {
+ return value;
+ }
+ }
+ }
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index dab59a17805b1..5f9c14e2fda61 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -107,6 +107,7 @@
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultEnrichResolution;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping;
+import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.randomValueOtherThanTest;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.tsdbIndexResolution;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.hamcrest.Matchers.contains;
@@ -3403,6 +3404,19 @@ public void testRrfError() {
assertThat(e.getMessage(), containsString("Unknown column [_id]"));
}
+ public void testRandomSampleProbability() {
+ var e = expectThrows(VerificationException.class, () -> analyze("FROM test | SAMPLE 1."));
+ assertThat(e.getMessage(), containsString("RandomSampling probability must be strictly between 0.0 and 1.0, was [1.0]"));
+
+ e = expectThrows(VerificationException.class, () -> analyze("FROM test | SAMPLE .0"));
+ assertThat(e.getMessage(), containsString("RandomSampling probability must be strictly between 0.0 and 1.0, was [0.0]"));
+
+ double p = randomValueOtherThanTest(d -> 0 < d && d < 1, () -> randomDoubleBetween(0, Double.MAX_VALUE, false));
+ e = expectThrows(VerificationException.class, () -> analyze("FROM test | SAMPLE " + p));
+ assertThat(e.getMessage(), containsString("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]"));
+ }
+
+ // TODO There's too much boilerplate involved here! We need a better way of creating FieldCapabilitiesResponses from a mapping or index.
private static FieldCapabilitiesIndexResponse fieldCapabilitiesIndexResponse(
String indexName,
Map fields
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
index cbd3998f5a850..de56f0676b4a7 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
@@ -109,6 +109,7 @@
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
+import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@@ -121,6 +122,7 @@
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Row;
+import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
@@ -7804,4 +7806,153 @@ public void testPruneRedundantOrderBy() {
var mvExpand2 = as(mvExpand.child(), MvExpand.class);
as(mvExpand2.child(), Row.class);
}
+
+ /**
+ * Eval[[1[INTEGER] AS irrelevant1, 2[INTEGER] AS irrelevant2]]
+ * \_Limit[1000[INTEGER],false]
+ * \_Sample[0.015[DOUBLE],15[INTEGER]]
+ * \_EsRelation[test][_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..]
+ */
+ public void testSampleMerged() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ var query = """
+ FROM TEST
+ | SAMPLE .3 5
+ | EVAL irrelevant1 = 1
+ | SAMPLE .5 10
+ | EVAL irrelevant2 = 2
+ | SAMPLE .1
+ """;
+ var optimized = optimizedPlan(query);
+
+ var eval = as(optimized, Eval.class);
+ var limit = as(eval.child(), Limit.class);
+ var sample = as(limit.child(), Sample.class);
+ var source = as(sample.child(), EsRelation.class);
+
+ assertThat(sample.probability().fold(FoldContext.small()), equalTo(0.015));
+ assertThat(sample.seed().fold(FoldContext.small()), equalTo(5 ^ 10));
+ }
+
+ public void testSamplePushDown() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ for (var command : List.of(
+ "ENRICH languages_idx on first_name",
+ "EVAL x = 1",
+ // "INSIST emp_no", // TODO
+ "KEEP emp_no",
+ "DROP emp_no",
+ "RENAME emp_no AS x",
+ "GROK first_name \"%{WORD:bar}\"",
+ "DISSECT first_name \"%{z}\""
+ )) {
+ var query = "FROM TEST | " + command + " | SAMPLE .5";
+ var optimized = optimizedPlan(query);
+
+ var unary = as(optimized, UnaryPlan.class);
+ var limit = as(unary.child(), Limit.class);
+ var sample = as(limit.child(), Sample.class);
+ var source = as(sample.child(), EsRelation.class);
+
+ assertThat(sample.probability().fold(FoldContext.small()), equalTo(0.5));
+ assertNull(sample.seed());
+ }
+ }
+
+ public void testSamplePushDown_sort() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ var query = "FROM TEST | WHERE emp_no > 0 | SAMPLE 0.5 | LIMIT 100";
+ var optimized = optimizedPlan(query);
+
+ var limit = as(optimized, Limit.class);
+ var filter = as(limit.child(), Filter.class);
+ var sample = as(filter.child(), Sample.class);
+ var source = as(sample.child(), EsRelation.class);
+
+ assertThat(sample.probability().fold(FoldContext.small()), equalTo(0.5));
+ assertNull(sample.seed());
+ }
+
+ public void testSamplePushDown_where() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ var query = "FROM TEST | SORT emp_no | SAMPLE 0.5 | LIMIT 100";
+ var optimized = optimizedPlan(query);
+
+ var topN = as(optimized, TopN.class);
+ var sample = as(topN.child(), Sample.class);
+ var source = as(sample.child(), EsRelation.class);
+
+ assertThat(sample.probability().fold(FoldContext.small()), equalTo(0.5));
+ assertNull(sample.seed());
+ }
+
+ public void testSampleNoPushDown() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ for (var command : List.of("LIMIT 100", "MV_EXPAND languages", "STATS COUNT()")) {
+ var query = "FROM TEST | " + command + " | SAMPLE .5";
+ var optimized = optimizedPlan(query);
+
+ var limit = as(optimized, Limit.class);
+ var sample = as(limit.child(), Sample.class);
+ var unary = as(sample.child(), UnaryPlan.class);
+ var source = as(unary.child(), EsRelation.class);
+ }
+ }
+
+ /**
+ * Limit[1000[INTEGER],false]
+ * \_Sample[0.5[DOUBLE],null]
+ * \_Join[LEFT,[language_code{r}#4],[language_code{r}#4],[language_code{f}#17]]
+ * |_Eval[[emp_no{f}#6 AS language_code]]
+ * | \_EsRelation[test][_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..]
+ * \_EsRelation[languages_lookup][LOOKUP][language_code{f}#17, language_name{f}#18]
+ */
+ public void testSampleNoPushDownLookupJoin() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ var query = """
+ FROM TEST
+ | EVAL language_code = emp_no
+ | LOOKUP JOIN languages_lookup ON language_code
+ | SAMPLE .5
+ """;
+ var optimized = optimizedPlan(query);
+
+ var limit = as(optimized, Limit.class);
+ var sample = as(limit.child(), Sample.class);
+ var join = as(sample.child(), Join.class);
+ var eval = as(join.left(), Eval.class);
+ var source = as(eval.child(), EsRelation.class);
+ }
+
+ /**
+ * Limit[1000[INTEGER],false]
+ * \_Sample[0.5[DOUBLE],null]
+ * \_Limit[1000[INTEGER],false]
+ * \_ChangePoint[emp_no{f}#6,hire_date{f}#13,type{r}#4,pvalue{r}#5]
+ * \_TopN[[Order[hire_date{f}#13,ASC,ANY]],1001[INTEGER]]
+ * \_EsRelation[test][_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..]
+ */
+ public void testSampleNoPushDownChangePoint() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ var query = """
+ FROM TEST
+ | CHANGE_POINT emp_no ON hire_date
+ | SAMPLE .5 -55
+ """;
+ var optimized = optimizedPlan(query);
+
+ var limit = as(optimized, Limit.class);
+ var sample = as(limit.child(), Sample.class);
+ limit = as(sample.child(), Limit.class);
+ var changePoint = as(limit.child(), ChangePoint.class);
+ var topN = as(changePoint.child(), TopN.class);
+ var source = as(topN.child(), EsRelation.class);
+ }
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
index 64a281718ae8c..d8da50585bdfa 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
@@ -35,6 +35,7 @@
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.index.query.WildcardQueryBuilder;
+import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.GeoDistanceSortBuilder;
import org.elasticsearch.test.ESTestCase;
@@ -185,6 +186,7 @@
import static org.elasticsearch.xpack.esql.core.util.TestUtils.stripThrough;
import static org.elasticsearch.xpack.esql.parser.ExpressionBuilder.MAX_EXPRESSION_DEPTH;
import static org.elasticsearch.xpack.esql.parser.LogicalPlanBuilder.MAX_QUERY_DEPTH;
+import static org.elasticsearch.xpack.esql.planner.mapper.MapperUtils.hasScoreAttribute;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -7732,7 +7734,7 @@ public void testScore() {
EsRelation esRelation = as(filter.child(), EsRelation.class);
assertTrue(esRelation.optimized());
assertTrue(esRelation.resolved());
- assertTrue(esRelation.output().stream().anyMatch(a -> a.name().equals(MetadataAttribute.SCORE) && a instanceof MetadataAttribute));
+ assertTrue(hasScoreAttribute(esRelation.output()));
}
public void testScoreTopN() {
@@ -7754,7 +7756,7 @@ public void testScoreTopN() {
Order scoreOrer = order.getFirst();
assertEquals(Order.OrderDirection.DESC, scoreOrer.direction());
Expression child = scoreOrer.child();
- assertTrue(child instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE));
+ assertTrue(MetadataAttribute.isScoreAttribute(child));
Filter filter = as(topN.child(), Filter.class);
Match match = as(filter.condition(), Match.class);
@@ -7764,7 +7766,7 @@ public void testScoreTopN() {
EsRelation esRelation = as(filter.child(), EsRelation.class);
assertTrue(esRelation.optimized());
assertTrue(esRelation.resolved());
- assertTrue(esRelation.output().stream().anyMatch(a -> a.name().equals(MetadataAttribute.SCORE) && a instanceof MetadataAttribute));
+ assertTrue(hasScoreAttribute(esRelation.output()));
}
public void testReductionPlanForTopN() {
@@ -7822,6 +7824,54 @@ public void testEqualsPushdownToDelegateTooBig() {
as(limit2.child(), FilterExec.class);
}
+ /*
+ * LimitExec[1000[INTEGER]]
+ * \_ExchangeExec[[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gender{f}#4, hire_date{f}#9, job{f}#10, job.raw{f}#11, langua
+ * ges{f}#5, last_name{f}#6, long_noidx{f}#12, salary{f}#7],false]
+ * \_ProjectExec[[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gender{f}#4, hire_date{f}#9, job{f}#10, job.raw{f}#11, langua
+ * ges{f}#5, last_name{f}#6, long_noidx{f}#12, salary{f}#7]]
+ * \_FieldExtractExec[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gen..]<[],[]>
+ * \_EsQueryExec[test], indexMode[standard],
+ * query[{"bool":{"filter":[{"sampling":{"probability":0.1,"seed":234,"hash":0}}],"boost":1.0}}]
+ * [_doc{f}#24], limit[1000], sort[] estimatedRowSize[332]
+ */
+ public void testSamplePushDown() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ var plan = physicalPlan("""
+ FROM test
+ | SAMPLE +0.1 -234
+ """);
+ var optimized = optimizedPlan(plan);
+
+ var limit = as(optimized, LimitExec.class);
+ var exchange = as(limit.child(), ExchangeExec.class);
+ var project = as(exchange.child(), ProjectExec.class);
+ var fieldExtract = as(project.child(), FieldExtractExec.class);
+ var esQuery = as(fieldExtract.child(), EsQueryExec.class);
+
+ var boolQuery = as(esQuery.query(), BoolQueryBuilder.class);
+ var filter = boolQuery.filter();
+ var randomSampling = as(filter.get(0), RandomSamplingQueryBuilder.class);
+ assertThat(randomSampling.probability(), equalTo(0.1));
+ assertThat(randomSampling.seed(), equalTo(-234));
+ assertThat(randomSampling.hash(), equalTo(0));
+ }
+
+ public void testSample_seedNotSupportedInOperator() {
+ assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
+
+ optimizedPlan(physicalPlan("FROM test | SAMPLE 0.1"));
+ optimizedPlan(physicalPlan("FROM test | SAMPLE 0.1 42"));
+ optimizedPlan(physicalPlan("FROM test | MV_EXPAND first_name | SAMPLE 0.1"));
+
+ VerificationException e = expectThrows(
+ VerificationException.class,
+ () -> optimizedPlan(physicalPlan("FROM test | MV_EXPAND first_name | SAMPLE 0.1 42"))
+ );
+ assertThat(e.getMessage(), equalTo("Found 1 problem\nline 1:47: Seed not supported when sampling can't be pushed down to Lucene"));
+ }
+
@SuppressWarnings("SameParameterValue")
private static void assertFilterCondition(
Filter filter,
@@ -8005,7 +8055,7 @@ private PhysicalPlan physicalPlan(String query, TestDataSource dataSource, boole
var logical = logicalOptimizer.optimize(dataSource.analyzer.analyze(parser.createStatement(query)));
// System.out.println("Logical\n" + logical);
var physical = mapper.map(logical);
- // System.out.println(physical);
+ // System.out.println("Physical\n" + physical);
if (assertSerialization) {
assertSerialization(physical);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
index a445aeb8dd763..bc08f14867f56 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
@@ -3485,6 +3485,14 @@ public void testInvalidCompletion() {
expectError("FROM foo* | COMPLETION prompt AS targetField", "line 1:31: mismatched input 'AS' expecting {");
}
+ public void testSample() {
+ expectError("FROM test | SAMPLE .1 2 3", "line 1:25: extraneous input '3' expecting ");
+ expectError("FROM test | SAMPLE .1 \"2\"", "line 1:23: extraneous input '\"2\"' expecting ");
+ expectError("FROM test | SAMPLE 1", "line 1:20: mismatched input '1' expecting {DECIMAL_LITERAL, '+', '-'}");
+ expectError("FROM test | SAMPLE", "line 1:19: mismatched input '' expecting {DECIMAL_LITERAL, '+', '-'}");
+ expectError("FROM test | SAMPLE +.1 2147483648", "line 1:24: seed must be an integer, provided [2147483648] of type [LONG]");
+ }
+
static Alias alias(String name, Expression value) {
return new Alias(EMPTY, name, value);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/SampleSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/SampleSerializationTests.java
new file mode 100644
index 0000000000000..f24d738789b3e
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/SampleSerializationTests.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.logical;
+
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+
+import java.io.IOException;
+
+public class SampleSerializationTests extends AbstractLogicalPlanSerializationTests {
+ /**
+ * Creates a random test instance to use in the tests. This method will be
+ * called multiple times during test execution and should return a different
+ * random instance each time it is called.
+ */
+ @Override
+ protected Sample createTestInstance() {
+ return new Sample(randomSource(), randomProbability(), randomSeed(), randomChild(0));
+ }
+
+ public static Literal randomProbability() {
+ return new Literal(randomSource(), randomDoubleBetween(0, 1, false), DataType.DOUBLE);
+ }
+
+ public static Literal randomSeed() {
+ return randomBoolean() ? new Literal(randomSource(), randomInt(), DataType.INTEGER) : null;
+ }
+
+ /**
+ * Returns an instance which is mutated slightly so it should not be equal
+ * to the given instance.
+ *
+ * @param instance
+ */
+ @Override
+ protected Sample mutateInstance(Sample instance) throws IOException {
+ var probability = instance.probability();
+ var seed = instance.seed();
+ var child = instance.child();
+ int updateSelector = randomIntBetween(0, 2);
+ switch (updateSelector) {
+ case 0 -> probability = randomValueOtherThan(probability, SampleSerializationTests::randomProbability);
+ case 1 -> seed = randomValueOtherThan(seed, SampleSerializationTests::randomSeed);
+ case 2 -> child = randomValueOtherThan(child, () -> randomChild(0));
+ default -> throw new IllegalArgumentException("Invalid selector: " + updateSelector);
+ }
+ return new Sample(instance.source(), probability, seed, child);
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/SampleExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/SampleExecSerializationTests.java
new file mode 100644
index 0000000000000..12159d8afae7c
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/SampleExecSerializationTests.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical;
+
+import org.elasticsearch.xpack.esql.plan.logical.SampleSerializationTests;
+
+import java.io.IOException;
+
+import static org.elasticsearch.xpack.esql.plan.logical.SampleSerializationTests.randomProbability;
+import static org.elasticsearch.xpack.esql.plan.logical.SampleSerializationTests.randomSeed;
+
+public class SampleExecSerializationTests extends AbstractPhysicalPlanSerializationTests {
+ /**
+ * Creates a random test instance to use in the tests. This method will be
+ * called multiple times during test execution and should return a different
+ * random instance each time it is called.
+ */
+ @Override
+ protected SampleExec createTestInstance() {
+ return new SampleExec(randomSource(), randomChild(0), randomProbability(), randomSeed());
+ }
+
+ /**
+ * Returns an instance which is mutated slightly so it should not be equal
+ * to the given instance.
+ *
+ * @param instance
+ */
+ @Override
+ protected SampleExec mutateInstance(SampleExec instance) throws IOException {
+ var probability = instance.probability();
+ var seed = instance.seed();
+ var child = instance.child();
+ int updateSelector = randomIntBetween(0, 2);
+ switch (updateSelector) {
+ case 0 -> probability = randomValueOtherThan(probability, SampleSerializationTests::randomProbability);
+ case 1 -> seed = randomValueOtherThan(seed, SampleSerializationTests::randomSeed);
+ case 2 -> child = randomValueOtherThan(child, () -> randomChild(0));
+ default -> throw new IllegalArgumentException("Invalid selector: " + updateSelector);
+ }
+ return new SampleExec(instance.source(), child, probability, seed);
+ }
+}
diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml
index 40e93bafb8998..40c720584dbf1 100644
--- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml
+++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml
@@ -39,7 +39,7 @@ setup:
- do: {xpack.usage: {}}
- match: { esql.available: true }
- match: { esql.enabled: true }
- - length: { esql.features: 25 }
+ - length: { esql.features: 26 }
- set: {esql.features.dissect: dissect_counter}
- set: {esql.features.drop: drop_counter}
- set: {esql.features.eval: eval_counter}
@@ -65,6 +65,7 @@ setup:
- set: {esql.features.fork: fork_counter}
- set: {esql.features.rrf: rrf_counter}
- set: {esql.features.completion: completion_counter}
+ - set: {esql.features.sample: sample_counter}
- length: { esql.queries: 3 }
- set: {esql.queries.rest.total: rest_total_counter}
- set: {esql.queries.rest.failed: rest_failed_counter}
@@ -108,6 +109,7 @@ setup:
- match: {esql.features.fork: $fork_counter}
- match: {esql.features.rrf: $rrf_counter}
- match: {esql.features.completion: $completion_counter}
+ - match: {esql.features.sample: $sample_counter}
- gt: {esql.queries.rest.total: $rest_total_counter}
- match: {esql.queries.rest.failed: $rest_failed_counter}
- match: {esql.queries.kibana.total: $kibana_total_counter}