From bac208a154ea04853befcbe81ea9e0ce592c0727 Mon Sep 17 00:00:00 2001 From: Salvatore Campagna <93581129+salvatore-campagna@users.noreply.github.com> Date: Mon, 23 Sep 2024 22:16:08 +0200 Subject: [PATCH 01/14] Introduce an `ignore_above` index-level setting (#113121) (#113414) Here we introduce a new index-level setting, `ignore_above`, similar to what we have for `ignore_malformed`. The setting will apply to all `keyword`, `wildcard` and `flattened` fields. Each field mapping will still be allowed to override the index-level setting using a mapping-level `ignore_above` value. (cherry picked from commit 208a1fe5714c0e49549de7aaed7a9a847e7b4a15) --- .../mapping/params/ignore-above.asciidoc | 30 +++ .../search/530_ignore_above_stored_source.yml | 214 ++++++++++++++++++ .../540_ignore_above_synthetic_source.yml | 179 +++++++++++++++ .../test/search/550_ignore_above_invalid.yml | 63 ++++++ .../common/settings/IndexScopedSettings.java | 1 + .../elasticsearch/index/IndexSettings.java | 26 +++ .../index/mapper/KeywordFieldMapper.java | 53 +++-- .../index/mapper/MapperFeatures.java | 2 + .../flattened/FlattenedFieldMapper.java | 49 ++-- .../index/mapper/KeywordFieldTypeTests.java | 1 + .../index/mapper/MultiFieldsTests.java | 1 + .../20_ignore_above_stored_source.yml | 56 +++++ .../30_ignore_above_synthetic_source.yml | 58 +++++ .../wildcard/mapper/WildcardFieldMapper.java | 68 +++--- .../test/CoreTestTranslater.java | 24 +- 15 files changed, 762 insertions(+), 63 deletions(-) create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/530_ignore_above_stored_source.yml create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/540_ignore_above_synthetic_source.yml create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/550_ignore_above_invalid.yml create mode 100644 x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/wildcard/20_ignore_above_stored_source.yml create mode 100644 x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/wildcard/30_ignore_above_synthetic_source.yml diff --git a/docs/reference/mapping/params/ignore-above.asciidoc b/docs/reference/mapping/params/ignore-above.asciidoc index 7d04bc82dcbb..526f2d620596 100644 --- a/docs/reference/mapping/params/ignore-above.asciidoc +++ b/docs/reference/mapping/params/ignore-above.asciidoc @@ -57,3 +57,33 @@ NOTE: The value for `ignore_above` is the _character count_, but Lucene counts bytes. If you use UTF-8 text with many non-ASCII characters, you may want to set the limit to `32766 / 4 = 8191` since UTF-8 characters may occupy at most 4 bytes. + +[[index-mapping-ignore-above]] +=== `index.mapping.ignore_above` + +The `ignore_above` setting, typically used at the field level, can also be applied at the index level using +`index.mapping.ignore_above`. This setting lets you define a maximum string length for all applicable fields across +the index, including `keyword`, `wildcard`, and keyword values in `flattened` fields. Any values that exceed this +limit will be ignored during indexing and won’t be stored. + +This index-wide setting ensures a consistent approach to managing excessively long values. It works the same as the +field-level setting—if a string’s length goes over the specified limit, that string won’t be indexed or stored. +When dealing with arrays, each element is evaluated separately, and only the elements that exceed the limit are ignored. + +[source,console] +-------------------------------------------------- +PUT my-index-000001 +{ + "settings": { + "index.mapping.ignore_above": 256 + } +} +-------------------------------------------------- + +In this example, all applicable fields in `my-index-000001` will ignore any strings longer than 256 characters. + +TIP: You can override this index-wide setting for specific fields by specifying a custom `ignore_above` value in the +field mapping. + +NOTE: Just like the field-level `ignore_above`, this setting only affects indexing and storage. The original values +are still available in the `_source` field if `_source` is enabled, which is the default behavior in Elasticsearch. diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/530_ignore_above_stored_source.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/530_ignore_above_stored_source.yml new file mode 100644 index 000000000000..1730a49f743d --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/530_ignore_above_stored_source.yml @@ -0,0 +1,214 @@ +--- +ignore_above mapping level setting: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + properties: + keyword: + type: keyword + flattened: + type: flattened + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": "foo bar", "flattened": { "value": "the quick brown fox" } } + + - do: + search: + body: + fields: + - keyword + - flattened + query: + match_all: {} + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.keyword: "foo bar" } + - match: { hits.hits.0._source.flattened.value: "the quick brown fox" } + - match: { hits.hits.0.fields.keyword.0: "foo bar" } + - match: { hits.hits.0.fields.flattened: null } + +--- +ignore_above mapping level setting on arrays: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + properties: + keyword: + type: keyword + flattened: + type: flattened + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": ["foo bar", "the quick brown fox"], "flattened": { "value": ["the quick brown fox", "jumps over"] } } + + - do: + search: + body: + fields: + - keyword + - flattened + query: + match_all: {} + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.keyword: ["foo bar", "the quick brown fox"] } + - match: { hits.hits.0._source.flattened.value: ["the quick brown fox", "jumps over"] } + - match: { hits.hits.0.fields.keyword.0: "foo bar" } + - match: { hits.hits.0.fields.flattened.0.value: "jumps over" } + +--- +ignore_above mapping overrides setting: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + properties: + keyword: + type: keyword + ignore_above: 100 + flattened: + type: flattened + ignore_above: 100 + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": "foo bar baz foo bar baz", "flattened": { "value": "the quick brown fox" } } + + - do: + search: + body: + fields: + - keyword + - flattened + query: + match_all: { } + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.keyword: "foo bar baz foo bar baz" } + - match: { hits.hits.0._source.flattened.value: "the quick brown fox" } + - match: { hits.hits.0.fields.keyword.0: "foo bar baz foo bar baz" } + - match: { hits.hits.0.fields.flattened.0.value: "the quick brown fox" } + +--- +ignore_above mapping overrides setting on arrays: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + properties: + keyword: + type: keyword + ignore_above: 100 + flattened: + type: flattened + ignore_above: 100 + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": ["foo bar baz foo bar baz", "the quick brown fox jumps over"], "flattened": { "value": ["the quick brown fox", "jumps over the lazy dog"] } } + + - do: + search: + body: + fields: + - keyword + - flattened + query: + match_all: { } + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.keyword: ["foo bar baz foo bar baz", "the quick brown fox jumps over"] } + - match: { hits.hits.0._source.flattened.value: ["the quick brown fox", "jumps over the lazy dog"] } + - match: { hits.hits.0.fields.keyword: ["foo bar baz foo bar baz", "the quick brown fox jumps over"] } + - match: { hits.hits.0.fields.flattened.0.value: ["the quick brown fox", "jumps over the lazy dog"] } + +--- +date ignore_above index level setting: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + properties: + keyword: + type: keyword + date: + type: date + format: "yyyy-MM-dd'T'HH:mm:ss" + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": ["2023-09-17T15:30:00", "2023-09-17T15:31:00"], "date": ["2023-09-17T15:30:00", "2023-09-17T15:31:00"] } + + - do: + search: + body: + fields: + - keyword + - date + query: + match_all: {} + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.keyword: ["2023-09-17T15:30:00", "2023-09-17T15:31:00"] } + - match: { hits.hits.0._source.date: ["2023-09-17T15:30:00", "2023-09-17T15:31:00"] } + - match: { hits.hits.0.fields.keyword: null } + - match: { hits.hits.0.fields.date: ["2023-09-17T15:30:00","2023-09-17T15:31:00"] } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/540_ignore_above_synthetic_source.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/540_ignore_above_synthetic_source.yml new file mode 100644 index 000000000000..defdc8467bf8 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/540_ignore_above_synthetic_source.yml @@ -0,0 +1,179 @@ +--- +ignore_above mapping level setting: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + _source: + mode: synthetic + properties: + keyword: + type: keyword + flattened: + type: flattened + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": "foo bar", "flattened": { "value": "the quick brown fox" } } + + - do: + search: + body: + fields: + - keyword + - flattened + query: + match_all: {} + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.keyword: "foo bar" } + - match: { hits.hits.0._source.flattened.value: "the quick brown fox" } + - match: { hits.hits.0.fields.keyword.0: "foo bar" } + +--- +ignore_above mapping level setting on arrays: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + _source: + mode: synthetic + properties: + keyword: + type: keyword + flattened: + type: flattened + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": ["foo bar", "the quick brown fox"], "flattened": { "value": ["the quick brown fox", "jumps over"] } } + + - do: + search: + body: + fields: + - keyword + - flattened + query: + match_all: {} + + - length: { hits.hits: 1 } + #TODO: synthetic source field reconstruction bug (TBD: add link to the issue here) + #- match: { hits.hits.0._source.keyword: ["foo bar", "the quick brown fox"] } + - match: { hits.hits.0._source.flattened.value: ["the quick brown fox", "jumps over"] } + - match: { hits.hits.0.fields.keyword.0: "foo bar" } + - match: { hits.hits.0.fields.flattened.0.value: "jumps over" } + +--- +ignore_above mapping overrides setting: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + _source: + mode: synthetic + properties: + keyword: + type: keyword + ignore_above: 100 + flattened: + type: flattened + ignore_above: 100 + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": "foo bar baz foo bar baz", "flattened": { "value": "the quick brown fox" } } + + - do: + search: + body: + fields: + - keyword + - flattened + query: + match_all: { } + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.keyword: "foo bar baz foo bar baz" } + - match: { hits.hits.0._source.flattened.value: "the quick brown fox" } + - match: { hits.hits.0.fields.keyword.0: "foo bar baz foo bar baz" } + - match: { hits.hits.0.fields.flattened.0.value: "the quick brown fox" } + +--- +ignore_above mapping overrides setting on arrays: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + _source: + mode: synthetic + properties: + keyword: + type: keyword + ignore_above: 100 + flattened: + type: flattened + ignore_above: 100 + + - do: + index: + index: test + refresh: true + id: "1" + body: { "keyword": ["foo bar baz foo bar baz", "the quick brown fox jumps over"], "flattened": { "value": ["the quick brown fox", "jumps over the lazy dog"] } } + + - do: + search: + body: + fields: + - keyword + - flattened + query: + match_all: { } + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.keyword: ["foo bar baz foo bar baz", "the quick brown fox jumps over"] } + - match: { hits.hits.0._source.flattened.value: ["jumps over the lazy dog", "the quick brown fox"] } + - match: { hits.hits.0.fields.keyword: ["foo bar baz foo bar baz", "the quick brown fox jumps over"] } + - match: { hits.hits.0.fields.flattened.0.value: ["jumps over the lazy dog", "the quick brown fox"] } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/550_ignore_above_invalid.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/550_ignore_above_invalid.yml new file mode 100644 index 000000000000..3c29845871fe --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/550_ignore_above_invalid.yml @@ -0,0 +1,63 @@ +--- +ignore_above index setting negative value: + - do: + catch: bad_request + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: -1 + mappings: + properties: + keyword: + type: keyword + +--- +keyword ignore_above mapping setting negative value: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + catch: bad_request + indices.create: + index: test + body: + mappings: + properties: + keyword: + ignore_above: -2 + type: keyword + +--- +flattened ignore_above mapping setting negative value: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + catch: bad_request + indices.create: + index: test + body: + mappings: + properties: + flattened: + ignore_above: -2 + type: flattened + +--- +wildcard ignore_above mapping setting negative value: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + catch: bad_request + indices.create: + index: test + body: + mappings: + properties: + wildcard: + ignore_above: -2 + type: wildcard diff --git a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java index 778136cbf5d3..0258fdc77ead 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java @@ -151,6 +151,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexSettings.INDEX_SEARCH_IDLE_AFTER, IndexSettings.INDEX_SEARCH_THROTTLED, IndexFieldDataService.INDEX_FIELDDATA_CACHE_KEY, + IndexSettings.IGNORE_ABOVE_SETTING, FieldMapper.IGNORE_MALFORMED_SETTING, FieldMapper.COERCE_SETTING, Store.INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING, diff --git a/server/src/main/java/org/elasticsearch/index/IndexSettings.java b/server/src/main/java/org/elasticsearch/index/IndexSettings.java index 41523c6dc2c7..c97ba3953a58 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexSettings.java +++ b/server/src/main/java/org/elasticsearch/index/IndexSettings.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.translog.Translog; @@ -700,6 +701,31 @@ public Iterator> settings() { Property.IndexSettingDeprecatedInV7AndRemovedInV8 ); + /** + * The `index.mapping.ignore_above` setting defines the maximum length for the content of a field that will be indexed + * or stored. If the length of the field’s content exceeds this limit, the field value will be ignored during indexing. + * This setting is useful for `keyword`, `flattened`, and `wildcard` fields where very large values are undesirable. + * It allows users to manage the size of indexed data by skipping fields with excessively long content. As an index-level + * setting, it applies to all `keyword` and `wildcard` fields, as well as to keyword values within `flattened` fields. + * When it comes to arrays, the `ignore_above` setting applies individually to each element of the array. If any element's + * length exceeds the specified limit, only that element will be ignored during indexing, while the rest of the array will + * still be processed. This behavior is consistent with the field-level `ignore_above` setting. + * This setting can be overridden at the field level by specifying a custom `ignore_above` value in the field mapping. + *

+ * Example usage: + *

+     * "index.mapping.ignore_above": 256
+     * 
+ */ + public static final Setting IGNORE_ABOVE_SETTING = Setting.intSetting( + "index.mapping.ignore_above", + Integer.MAX_VALUE, + 0, + Property.IndexScope, + Property.ServerlessPublic + ); + public static final NodeFeature IGNORE_ABOVE_INDEX_LEVEL_SETTING = new NodeFeature("mapper.ignore_above_index_level_setting"); + private final Index index; private final IndexVersion version; private final Logger logger; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java index 2da8d3277373..46b1dbdce4c4 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java @@ -79,6 +79,7 @@ import static org.apache.lucene.index.IndexWriter.MAX_TERM_LENGTH; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.index.IndexSettings.IGNORE_ABOVE_SETTING; /** * A field mapper for keywords. This mapper accepts strings and indexes them as-is. @@ -110,8 +111,6 @@ public static class Defaults { Lucene.KEYWORD_ANALYZER, Lucene.KEYWORD_ANALYZER ); - - public static final int IGNORE_ABOVE = Integer.MAX_VALUE; } public static class KeywordField extends Field { @@ -158,12 +157,8 @@ public static final class Builder extends FieldMapper.DimensionBuilder { m -> toType(m).fieldType().eagerGlobalOrdinals(), false ); - private final Parameter ignoreAbove = Parameter.intParam( - "ignore_above", - true, - m -> toType(m).fieldType().ignoreAbove(), - Defaults.IGNORE_ABOVE - ); + private final Parameter ignoreAbove; + private final int ignoreAboveDefault; private final Parameter indexOptions = TextParams.keywordIndexOptions(m -> toType(m).indexOptions); private final Parameter hasNorms = TextParams.norms(false, m -> toType(m).fieldType.omitNorms() == false); @@ -193,7 +188,23 @@ public static final class Builder extends FieldMapper.DimensionBuilder { private final ScriptCompiler scriptCompiler; private final IndexVersion indexCreatedVersion; - public Builder(String name, IndexAnalyzers indexAnalyzers, ScriptCompiler scriptCompiler, IndexVersion indexCreatedVersion) { + public Builder(final String name, final MappingParserContext mappingParserContext) { + this( + name, + mappingParserContext.getIndexAnalyzers(), + mappingParserContext.scriptCompiler(), + IGNORE_ABOVE_SETTING.get(mappingParserContext.getSettings()), + mappingParserContext.getIndexSettings().getIndexVersionCreated() + ); + } + + Builder( + String name, + IndexAnalyzers indexAnalyzers, + ScriptCompiler scriptCompiler, + int ignoreAboveDefault, + IndexVersion indexCreatedVersion + ) { super(name); this.indexAnalyzers = indexAnalyzers; this.scriptCompiler = Objects.requireNonNull(scriptCompiler); @@ -220,10 +231,17 @@ public Builder(String name, IndexAnalyzers indexAnalyzers, ScriptCompiler script ); } }).precludesParameters(normalizer); + this.ignoreAboveDefault = ignoreAboveDefault; + this.ignoreAbove = Parameter.intParam("ignore_above", true, m -> toType(m).fieldType().ignoreAbove(), ignoreAboveDefault) + .addValidator(v -> { + if (v < 0) { + throw new IllegalArgumentException("[ignore_above] must be positive, got [" + v + "]"); + } + }); } public Builder(String name, IndexVersion indexCreatedVersion) { - this(name, null, ScriptCompiler.NONE, indexCreatedVersion); + this(name, null, ScriptCompiler.NONE, Integer.MAX_VALUE, indexCreatedVersion); } public Builder ignoreAbove(int ignoreAbove) { @@ -370,10 +388,7 @@ public KeywordFieldMapper build(MapperBuilderContext context) { private static final IndexVersion MINIMUM_COMPATIBILITY_VERSION = IndexVersion.fromId(5000099); - public static final TypeParser PARSER = new TypeParser( - (n, c) -> new Builder(n, c.getIndexAnalyzers(), c.scriptCompiler(), c.indexVersionCreated()), - MINIMUM_COMPATIBILITY_VERSION - ); + public static final TypeParser PARSER = new TypeParser(Builder::new, MINIMUM_COMPATIBILITY_VERSION); public static final class KeywordFieldType extends StringFieldType { @@ -865,6 +880,8 @@ public boolean hasNormalizer() { private final boolean isSyntheticSource; private final IndexAnalyzers indexAnalyzers; + private final int ignoreAboveDefault; + private final int ignoreAbove; private KeywordFieldMapper( String simpleName, @@ -887,6 +904,8 @@ private KeywordFieldMapper( this.scriptCompiler = builder.scriptCompiler; this.indexCreatedVersion = builder.indexCreatedVersion; this.isSyntheticSource = isSyntheticSource; + this.ignoreAboveDefault = builder.ignoreAboveDefault; + this.ignoreAbove = builder.ignoreAbove.getValue(); } @Override @@ -1004,7 +1023,9 @@ public Map indexAnalyzers() { @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(leafName(), indexAnalyzers, scriptCompiler, indexCreatedVersion).dimension(fieldType().isDimension()).init(this); + return new Builder(leafName(), indexAnalyzers, scriptCompiler, ignoreAboveDefault, indexCreatedVersion).dimension( + fieldType().isDimension() + ).init(this); } @Override @@ -1072,7 +1093,7 @@ protected BytesRef preserve(BytesRef value) { }); } - if (fieldType().ignoreAbove != Defaults.IGNORE_ABOVE) { + if (fieldType().ignoreAbove != ignoreAboveDefault) { layers.add(new CompositeSyntheticFieldLoader.StoredFieldLayer(originalName()) { @Override protected void writeValue(Object value, XContentBuilder b) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index d18c3283ef90..d2ca7a24a78f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -11,6 +11,7 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.flattened.FlattenedFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -41,6 +42,7 @@ public Set getFeatures() { SourceFieldMapper.SYNTHETIC_SOURCE_WITH_COPY_TO_AND_DOC_VALUES_FALSE_SUPPORT, SourceFieldMapper.SYNTHETIC_SOURCE_COPY_TO_FIX, FlattenedFieldMapper.IGNORE_ABOVE_SUPPORT, + IndexSettings.IGNORE_ABOVE_INDEX_LEVEL_SETTING, SourceFieldMapper.SYNTHETIC_SOURCE_COPY_TO_INSIDE_OBJECTS_FIX ); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/flattened/FlattenedFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/flattened/FlattenedFieldMapper.java index 867a4a7ec39e..9ea52752ec67 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/flattened/FlattenedFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/flattened/FlattenedFieldMapper.java @@ -82,6 +82,8 @@ import java.util.Set; import java.util.function.Function; +import static org.elasticsearch.index.IndexSettings.IGNORE_ABOVE_SETTING; + /** * A field mapper that accepts a JSON object and flattens it into a single field. This data type * can be a useful alternative to an 'object' mapping when the object has a large, unknown set @@ -123,6 +125,9 @@ private static Builder builder(Mapper in) { return ((FlattenedFieldMapper) in).builder; } + private final int ignoreAboveDefault; + private final int ignoreAbove; + public static class Builder extends FieldMapper.Builder { final Parameter depthLimit = Parameter.intParam( @@ -148,12 +153,8 @@ public static class Builder extends FieldMapper.Builder { m -> builder(m).eagerGlobalOrdinals.get(), false ); - private final Parameter ignoreAbove = Parameter.intParam( - "ignore_above", - true, - m -> builder(m).ignoreAbove.get(), - Integer.MAX_VALUE - ); + private final int ignoreAboveDefault; + private final Parameter ignoreAbove; private final Parameter indexOptions = TextParams.keywordIndexOptions(m -> builder(m).indexOptions.get()); private final Parameter similarity = TextParams.similarity(m -> builder(m).similarity.get()); @@ -176,7 +177,7 @@ public static class Builder extends FieldMapper.Builder { + "] are true" ); } - }).precludesParameters(ignoreAbove); + }); private final Parameter> meta = Parameter.metaParam(); @@ -184,8 +185,20 @@ public static FieldMapper.Parameter> dimensionsParam(Function builder(m).ignoreAbove.get(), ignoreAboveDefault) + .addValidator(v -> { + if (v < 0) { + throw new IllegalArgumentException("[ignore_above] must be positive, got [" + v + "]"); + } + }); + this.dimensions.precludesParameters(ignoreAbove); } @Override @@ -223,11 +236,11 @@ public FlattenedFieldMapper build(MapperBuilderContext context) { dimensions.get(), ignoreAbove.getValue() ); - return new FlattenedFieldMapper(leafName(), ft, builderParams(this, context), this); + return new FlattenedFieldMapper(leafName(), ft, builderParams(this, context), ignoreAboveDefault, this); } } - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n)); + public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n, IGNORE_ABOVE_SETTING.get(c.getSettings()))); /** * A field type that represents the values under a particular JSON key, used @@ -808,9 +821,17 @@ public void validateMatchedRoutingPath(final String routingPath) { private final FlattenedFieldParser fieldParser; private final Builder builder; - private FlattenedFieldMapper(String leafName, MappedFieldType mappedFieldType, BuilderParams builderParams, Builder builder) { + private FlattenedFieldMapper( + String leafName, + MappedFieldType mappedFieldType, + BuilderParams builderParams, + int ignoreAboveDefault, + Builder builder + ) { super(leafName, mappedFieldType, builderParams); + this.ignoreAboveDefault = ignoreAboveDefault; this.builder = builder; + this.ignoreAbove = builder.ignoreAbove.get(); this.fieldParser = new FlattenedFieldParser( mappedFieldType.name(), mappedFieldType.name() + KEYED_FIELD_SUFFIX, @@ -835,8 +856,8 @@ int depthLimit() { return builder.depthLimit.get(); } - int ignoreAbove() { - return builder.ignoreAbove.get(); + public int ignoreAbove() { + return ignoreAbove; } @Override @@ -876,7 +897,7 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(leafName()).init(this); + return new Builder(leafName(), ignoreAboveDefault).init(this); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/mapper/KeywordFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/KeywordFieldTypeTests.java index 7e5cc5045c10..b4c7ea0ed950 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/KeywordFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/KeywordFieldTypeTests.java @@ -243,6 +243,7 @@ public void testFetchSourceValue() throws IOException { "field", createIndexAnalyzers(), ScriptCompiler.NONE, + Integer.MAX_VALUE, IndexVersion.current() ).normalizer("lowercase").build(MapperBuilderContext.root(false, false)).fieldType(); assertEquals(List.of("value"), fetchSourceValue(normalizerMapper, "VALUE")); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldsTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldsTests.java index 06c312564830..fd024c5d23e2 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldsTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldsTests.java @@ -63,6 +63,7 @@ private KeywordFieldMapper.Builder getKeywordFieldMapperBuilder(boolean isStored "field", IndexAnalyzers.of(Map.of(), Map.of("normalizer", Lucene.STANDARD_ANALYZER), Map.of()), ScriptCompiler.NONE, + Integer.MAX_VALUE, IndexVersion.current() ); if (isStored) { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/wildcard/20_ignore_above_stored_source.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/wildcard/20_ignore_above_stored_source.yml new file mode 100644 index 000000000000..252bafbdbe15 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/wildcard/20_ignore_above_stored_source.yml @@ -0,0 +1,56 @@ +--- +wildcard field type ignore_above: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + properties: + a_wildcard: + type: wildcard + b_wildcard: + type: wildcard + ignore_above: 20 + c_wildcard: + type: wildcard + d_wildcard: + type: wildcard + ignore_above: 5 + + + + - do: + index: + index: test + refresh: true + id: "1" + body: { "a_wildcard": "foo bar", "b_wildcard": "the quick brown", "c_wildcard": ["foo", "bar", "jumps over the lazy dog"], "d_wildcard": ["foo", "bar", "the quick"]} + + - do: + search: + body: + fields: + - a_wildcard + - b_wildcard + - c_wildcard + - d_wildcard + query: + match_all: {} + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.a_wildcard: "foo bar" } + - match: { hits.hits.0._source.b_wildcard: "the quick brown" } + - match: { hits.hits.0._source.c_wildcard: ["foo", "bar", "jumps over the lazy dog"] } + - match: { hits.hits.0._source.d_wildcard: ["foo", "bar", "the quick"] } + - match: { hits.hits.0.fields.a_wildcard.0: "foo bar" } + - match: { hits.hits.0.fields.b_wildcard.0: "the quick brown" } + - match: { hits.hits.0.fields.c_wildcard: ["foo", "bar"] } + - match: { hits.hits.0.fields.d_wildcard: ["foo", "bar"] } + diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/wildcard/30_ignore_above_synthetic_source.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/wildcard/30_ignore_above_synthetic_source.yml new file mode 100644 index 000000000000..f5c9f3d92369 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/wildcard/30_ignore_above_synthetic_source.yml @@ -0,0 +1,58 @@ +--- +wildcard field type ignore_above: + - requires: + cluster_features: [ "mapper.ignore_above_index_level_setting" ] + reason: introduce ignore_above index level setting + - do: + indices.create: + index: test + body: + settings: + index: + mapping: + ignore_above: 10 + mappings: + _source: + mode: synthetic + properties: + a_wildcard: + type: wildcard + b_wildcard: + type: wildcard + ignore_above: 20 + c_wildcard: + type: wildcard + d_wildcard: + type: wildcard + ignore_above: 5 + + + + - do: + index: + index: test + refresh: true + id: "1" + body: { "a_wildcard": "foo bar", "b_wildcard": "the quick brown", "c_wildcard": ["foo", "bar", "jumps over the lazy dog"], "d_wildcard": ["foo", "bar", "the quick"]} + + - do: + search: + body: + fields: + - a_wildcard + - b_wildcard + - c_wildcard + - d_wildcard + query: + match_all: {} + + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.a_wildcard: "foo bar" } + - match: { hits.hits.0._source.b_wildcard: "the quick brown" } + - match: { hits.hits.0._source.c_wildcard: ["bar", "foo"] } + - match: { hits.hits.0._source.d_wildcard: ["bar", "foo", "the quick"] } + - match: { hits.hits.0.fields.a_wildcard.0: "foo bar" } + - match: { hits.hits.0.fields.b_wildcard.0: "the quick brown" } + - match: { hits.hits.0.fields.c_wildcard: ["bar", "foo"] } + - match: { hits.hits.0.fields.d_wildcard: ["bar", "foo"] } + diff --git a/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java b/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java index 8e4f56e29958..1e97e6437158 100644 --- a/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java +++ b/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java @@ -87,6 +87,8 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.index.IndexSettings.IGNORE_ABOVE_SETTING; + /** * A {@link FieldMapper} for indexing fields with ngrams for efficient wildcard matching */ @@ -191,7 +193,6 @@ public static class Defaults { Lucene.KEYWORD_ANALYZER, Lucene.KEYWORD_ANALYZER ); - public static final int IGNORE_ABOVE = Integer.MAX_VALUE; } private static WildcardFieldMapper toType(FieldMapper in) { @@ -200,21 +201,28 @@ private static WildcardFieldMapper toType(FieldMapper in) { public static class Builder extends FieldMapper.Builder { - final Parameter ignoreAbove = Parameter.intParam("ignore_above", true, m -> toType(m).ignoreAbove, Defaults.IGNORE_ABOVE) - .addValidator(v -> { - if (v < 0) { - throw new IllegalArgumentException("[ignore_above] must be positive, got [" + v + "]"); - } - }); + final Parameter ignoreAbove; final Parameter nullValue = Parameter.stringParam("null_value", false, m -> toType(m).nullValue, null).acceptsNull(); final Parameter> meta = Parameter.metaParam(); final IndexVersion indexVersionCreated; - public Builder(String name, IndexVersion indexVersionCreated) { + final int ignoreAboveDefault; + + public Builder(final String name, IndexVersion indexVersionCreated) { + this(name, Integer.MAX_VALUE, indexVersionCreated); + } + + private Builder(String name, int ignoreAboveDefault, IndexVersion indexVersionCreated) { super(name); this.indexVersionCreated = indexVersionCreated; + this.ignoreAboveDefault = ignoreAboveDefault; + this.ignoreAbove = Parameter.intParam("ignore_above", true, m -> toType(m).ignoreAbove, ignoreAboveDefault).addValidator(v -> { + if (v < 0) { + throw new IllegalArgumentException("[ignore_above] must be positive, got [" + v + "]"); + } + }); } @Override @@ -236,23 +244,18 @@ Builder nullValue(String nullValue) { public WildcardFieldMapper build(MapperBuilderContext context) { return new WildcardFieldMapper( leafName(), - new WildcardFieldType( - context.buildFullName(leafName()), - nullValue.get(), - ignoreAbove.get(), - indexVersionCreated, - meta.get() - ), - ignoreAbove.get(), + new WildcardFieldType(context.buildFullName(leafName()), indexVersionCreated, meta.get(), this), context.isSourceSynthetic(), builderParams(this, context), - nullValue.get(), - indexVersionCreated + indexVersionCreated, + this ); } } - public static TypeParser PARSER = new TypeParser((n, c) -> new Builder(n, c.indexVersionCreated())); + public static TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, IGNORE_ABOVE_SETTING.get(c.getSettings()), c.indexVersionCreated()) + ); public static final char TOKEN_START_OR_END_CHAR = 0; public static final String TOKEN_START_STRING = Character.toString(TOKEN_START_OR_END_CHAR); @@ -263,18 +266,18 @@ public static final class WildcardFieldType extends MappedFieldType { static Analyzer lowercaseNormalizer = new LowercaseNormalizer(); private final String nullValue; - private final int ignoreAbove; private final NamedAnalyzer analyzer; + private final int ignoreAbove; - private WildcardFieldType(String name, String nullValue, int ignoreAbove, IndexVersion version, Map meta) { + private WildcardFieldType(String name, IndexVersion version, Map meta, Builder builder) { super(name, true, false, true, Defaults.TEXT_SEARCH_INFO, meta); if (version.onOrAfter(IndexVersions.V_7_10_0)) { this.analyzer = WILDCARD_ANALYZER_7_10; } else { this.analyzer = WILDCARD_ANALYZER_7_9; } - this.nullValue = nullValue; - this.ignoreAbove = ignoreAbove; + this.nullValue = builder.nullValue.getValue(); + this.ignoreAbove = builder.ignoreAbove.getValue(); } @Override @@ -889,26 +892,27 @@ protected String parseSourceValue(Object value) { NGRAM_FIELD_TYPE = freezeAndDeduplicateFieldType(ft); assert NGRAM_FIELD_TYPE.indexOptions() == IndexOptions.DOCS; } - - private final int ignoreAbove; private final String nullValue; private final IndexVersion indexVersionCreated; + + private final int ignoreAbove; + private final int ignoreAboveDefault; private final boolean storeIgnored; private WildcardFieldMapper( String simpleName, WildcardFieldType mappedFieldType, - int ignoreAbove, boolean storeIgnored, BuilderParams builderParams, - String nullValue, - IndexVersion indexVersionCreated + IndexVersion indexVersionCreated, + Builder builder ) { super(simpleName, mappedFieldType, builderParams); - this.nullValue = nullValue; - this.ignoreAbove = ignoreAbove; + this.nullValue = builder.nullValue.getValue(); this.storeIgnored = storeIgnored; this.indexVersionCreated = indexVersionCreated; + this.ignoreAbove = builder.ignoreAbove.getValue(); + this.ignoreAboveDefault = builder.ignoreAboveDefault; } @Override @@ -983,14 +987,14 @@ protected String contentType() { @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(leafName(), indexVersionCreated).init(this); + return new Builder(leafName(), ignoreAboveDefault, indexVersionCreated).init(this); } @Override protected SyntheticSourceSupport syntheticSourceSupport() { var layers = new ArrayList(); layers.add(new WildcardSyntheticFieldLoader()); - if (ignoreAbove != Defaults.IGNORE_ABOVE) { + if (ignoreAbove != ignoreAboveDefault) { layers.add(new CompositeSyntheticFieldLoader.StoredFieldLayer(originalName()) { @Override protected void writeValue(Object value, XContentBuilder b) throws IOException { diff --git a/x-pack/qa/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/test/CoreTestTranslater.java b/x-pack/qa/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/test/CoreTestTranslater.java index 2bea4bb247d8..d34303ea803d 100644 --- a/x-pack/qa/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/test/CoreTestTranslater.java +++ b/x-pack/qa/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/test/CoreTestTranslater.java @@ -222,10 +222,32 @@ public boolean modifySections(List executables) { */ protected abstract boolean modifySearch(ApiCallSection search); + private static Object getSetting(final Object map, final String... keys) { + Map current = (Map) map; + for (final String key : keys) { + if (current != null) { + current = (Map) current.get(key); + } else { + return null; + } + } + return current; + } + private boolean modifyCreateIndex(ApiCallSection createIndex) { String index = createIndex.getParams().get("index"); for (Map body : createIndex.getBodies()) { - Object settings = body.get("settings"); + final Object settings = body.get("settings"); + final Object indexMapping = getSetting(settings, "index", "mapping"); + if (indexMapping instanceof Map m) { + final Object ignoreAbove = m.get("ignore_above"); + if (ignoreAbove instanceof Integer ignoreAboveValue) { + if (ignoreAboveValue >= 0) { + // Scripts don't support ignore_above so we skip those fields + continue; + } + } + } if (settings instanceof Map && ((Map) settings).containsKey("sort.field")) { /* * You can't sort the index on a runtime field From 62d3d538a4d82a5118d8889114077ffbf8a8672e Mon Sep 17 00:00:00 2001 From: Stanislav Malyshev Date: Mon, 23 Sep 2024 14:52:04 -0600 Subject: [PATCH 02/14] Test fix: ensure we don't accidentally generate two identical histograms (#113322) (#113415) * Test fix: looks like using one value is not random enough --- .../admin/cluster/stats/CCSTelemetrySnapshotTests.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/CCSTelemetrySnapshotTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/CCSTelemetrySnapshotTests.java index 0bca6e57dc47..e9188d9cb8f0 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/CCSTelemetrySnapshotTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/CCSTelemetrySnapshotTests.java @@ -33,7 +33,7 @@ public class CCSTelemetrySnapshotTests extends AbstractWireSerializingTestCase Date: Mon, 23 Sep 2024 13:53:37 -0700 Subject: [PATCH 03/14] Unmute logsdb data generation tests (#113306) (#113321) (cherry picked from commit 413b23a9ea16206e8cb97bc99f5ab6ac578229c7) # Conflicts: # muted-tests.yml Co-authored-by: Elastic Machine --- muted-tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 44cb6f631ddf..20863a6f6349 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -208,9 +208,6 @@ tests: - class: org.elasticsearch.packaging.test.WindowsServiceTests method: test33JavaChanged issue: https://github.com/elastic/elasticsearch/issues/113177 -- class: org.elasticsearch.datastreams.logsdb.qa.StandardVersusLogsIndexModeRandomDataChallengeRestIT - method: testMatchAllQuery - issue: https://github.com/elastic/elasticsearch/issues/113265 # Examples: # From 54ddc29fc7d086574b17901032b63516711b499e Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Mon, 23 Sep 2024 16:18:04 -0600 Subject: [PATCH 04/14] Default incremental bulk functionality to false (#113416) (#113417) This commit flips the incremental bulk setting to false. Additionally, it removes some test code which intermittently causes issues with security test cases. --- .../http/IncrementalBulkRestIT.java | 8 +++ .../action/bulk/IncrementalBulkService.java | 2 +- .../elasticsearch/test/ESIntegTestCase.java | 49 ++----------------- 3 files changed, 13 insertions(+), 46 deletions(-) diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IncrementalBulkRestIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IncrementalBulkRestIT.java index 2b24e53874e5..da0501169627 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IncrementalBulkRestIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IncrementalBulkRestIT.java @@ -29,6 +29,14 @@ @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, supportsDedicatedMasters = false, numDataNodes = 2, numClientNodes = 0) public class IncrementalBulkRestIT extends HttpSmokeTestCase { + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put(IncrementalBulkService.INCREMENTAL_BULK.getKey(), true) + .build(); + } + public void testBulkUriMatchingDoesNotMatchBulkCapabilitiesApi() throws IOException { Request request = new Request("GET", "/_capabilities?method=GET&path=%2F_bulk&capabilities=failure_store_status&pretty"); Response response = getRestClient().performRequest(request); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java b/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java index 7185c4d76265..fc264de35f51 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java @@ -36,7 +36,7 @@ public class IncrementalBulkService { public static final Setting INCREMENTAL_BULK = boolSetting( "rest.incremental_bulk", - true, + false, Setting.Property.NodeScope, Setting.Property.Dynamic ); diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index 684236b9af66..ab0a0bf626d5 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -26,7 +26,6 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; -import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainRequest; import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse; @@ -49,8 +48,6 @@ import org.elasticsearch.action.admin.indices.template.put.PutIndexTemplateRequestBuilder; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.bulk.IncrementalBulkService; -import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.ClearScrollResponse; import org.elasticsearch.action.search.SearchRequest; @@ -188,7 +185,6 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -1774,48 +1770,11 @@ public void indexRandom(boolean forceRefresh, boolean dummyDocuments, boolean ma logger.info("Index [{}] docs async: [{}] bulk: [{}] partitions [{}]", builders.size(), false, true, partition.size()); for (List segmented : partition) { BulkResponse actionGet; - if (randomBoolean()) { - BulkRequestBuilder bulkBuilder = client().prepareBulk(); - for (IndexRequestBuilder indexRequestBuilder : segmented) { - bulkBuilder.add(indexRequestBuilder); - } - actionGet = bulkBuilder.get(); - } else { - IncrementalBulkService bulkService = internalCluster().getInstance(IncrementalBulkService.class); - IncrementalBulkService.Handler handler = bulkService.newBulkRequest(); - - ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - segmented.forEach(b -> queue.add(b.request())); - - PlainActionFuture future = new PlainActionFuture<>(); - AtomicInteger runs = new AtomicInteger(0); - Runnable r = new Runnable() { - - @Override - public void run() { - int toRemove = Math.min(randomIntBetween(5, 10), queue.size()); - ArrayList> docs = new ArrayList<>(); - for (int i = 0; i < toRemove; i++) { - docs.add(queue.poll()); - } - - if (queue.isEmpty()) { - handler.lastItems(docs, () -> {}, future); - } else { - handler.addItems(docs, () -> {}, () -> { - // Every 10 runs dispatch to new thread to prevent stackoverflow - if (runs.incrementAndGet() % 10 == 0) { - new Thread(this).start(); - } else { - this.run(); - } - }); - } - } - }; - r.run(); - actionGet = future.actionGet(); + BulkRequestBuilder bulkBuilder = client().prepareBulk(); + for (IndexRequestBuilder indexRequestBuilder : segmented) { + bulkBuilder.add(indexRequestBuilder); } + actionGet = bulkBuilder.get(); assertThat(actionGet.hasFailures() ? actionGet.buildFailureMessage() : "", actionGet.hasFailures(), equalTo(false)); } } From d9188591a5c3fc4e3c408c567f6eb66c6823f970 Mon Sep 17 00:00:00 2001 From: Bogdan Pintea Date: Tue, 24 Sep 2024 00:53:49 +0200 Subject: [PATCH 05/14] ESQL: add tests checking on data availabiltiy (#113292) (#113422) This adds simple tests that check the shape of the available data to query as a first step in troubleshooting some non-reproducible failures. --- .../test/esql/26_aggs_bucket.yml | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/26_aggs_bucket.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/26_aggs_bucket.yml index 7d0989a6e188..ea7684fb69a0 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/26_aggs_bucket.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/26_aggs_bucket.yml @@ -30,6 +30,20 @@ - { "index": { "_index": "test_bucket" } } - { "ts": "2024-07-16T11:40:00Z" } + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM test_bucket | SORT ts' + - match: { columns.0.name: ts } + - match: { columns.0.type: date } + - length: { values: 4 } + - match: { values.0.0: "2024-07-16T08:10:00.000Z" } + - match: { values.1.0: "2024-07-16T09:20:00.000Z" } + - match: { values.2.0: "2024-07-16T10:30:00.000Z" } + - match: { values.3.0: "2024-07-16T11:40:00.000Z" } + - do: allowed_warnings_regex: - "No limit defined, adding default limit of \\[.*\\]" @@ -119,6 +133,40 @@ - { "index": { "_index": "test_bucket" } } - { "ts": "2024-09-16" } + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM test_bucket | STATS c = COUNT(*)' + - match: { columns.0.name: c } + - match: { columns.0.type: long } + - match: { values.0.0: 4 } + + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM test_bucket | SORT ts' + - match: { columns.0.name: ts } + - match: { columns.0.type: date } + - length: { values: 4 } + - match: { values.0.0: "2024-06-16T00:00:00.000Z" } + - match: { values.1.0: "2024-07-16T00:00:00.000Z" } + - match: { values.2.0: "2024-08-16T00:00:00.000Z" } + - match: { values.3.0: "2024-09-16T00:00:00.000Z" } + + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM test_bucket | STATS c = COUNT(*)' + - match: { columns.0.name: c } + - match: { columns.0.type: long } + - match: { values.0.0: 4 } + - do: allowed_warnings_regex: - "No limit defined, adding default limit of \\[.*\\]" From f7190599c239e885387de07b1b78aebda7ad938f Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Tue, 24 Sep 2024 08:08:52 +0200 Subject: [PATCH 06/14] Small performance improvement in h3 library (#113385) (#113429) Changing some FDIV's into FMUL's leads to performance improvements --- docs/changelog/113385.yaml | 5 +++ .../java/org/elasticsearch/h3/Constants.java | 7 +++- .../java/org/elasticsearch/h3/CoordIJK.java | 11 +++-- .../java/org/elasticsearch/h3/FastMath.java | 41 +++++++++++-------- .../main/java/org/elasticsearch/h3/Vec2d.java | 23 ++++++----- .../main/java/org/elasticsearch/h3/Vec3d.java | 2 +- 6 files changed, 57 insertions(+), 32 deletions(-) create mode 100644 docs/changelog/113385.yaml diff --git a/docs/changelog/113385.yaml b/docs/changelog/113385.yaml new file mode 100644 index 000000000000..9cee1ebcd4f6 --- /dev/null +++ b/docs/changelog/113385.yaml @@ -0,0 +1,5 @@ +pr: 113385 +summary: Small performance improvement in h3 library +area: Geo +type: enhancement +issues: [] diff --git a/libs/h3/src/main/java/org/elasticsearch/h3/Constants.java b/libs/h3/src/main/java/org/elasticsearch/h3/Constants.java index 5192fe836e73..570052700615 100644 --- a/libs/h3/src/main/java/org/elasticsearch/h3/Constants.java +++ b/libs/h3/src/main/java/org/elasticsearch/h3/Constants.java @@ -33,7 +33,7 @@ final class Constants { /** * 2.0 * PI */ - public static final double M_2PI = 6.28318530717958647692528676655900576839433; + public static final double M_2PI = 2.0 * Math.PI; /** * max H3 resolution; H3 version 1 has 16 resolutions, numbered 0 through 15 */ @@ -58,6 +58,11 @@ final class Constants { * square root of 7 */ public static final double M_SQRT7 = 2.6457513110645905905016157536392604257102; + + /** + * 1 / square root of 7 + */ + public static final double M_RSQRT7 = 1.0 / M_SQRT7; /** * scaling factor from hex2d resolution 0 unit length * (or distance between adjacent cell center points diff --git a/libs/h3/src/main/java/org/elasticsearch/h3/CoordIJK.java b/libs/h3/src/main/java/org/elasticsearch/h3/CoordIJK.java index e57f681fc2ea..8aae7583ef04 100644 --- a/libs/h3/src/main/java/org/elasticsearch/h3/CoordIJK.java +++ b/libs/h3/src/main/java/org/elasticsearch/h3/CoordIJK.java @@ -39,6 +39,9 @@ */ final class CoordIJK { + /** one seventh (1/7) **/ + private static final double M_ONESEVENTH = 1.0 / 7.0; + /** CoordIJK unit vectors corresponding to the 7 H3 digits. */ private static final int[][] UNIT_VECS = { @@ -281,8 +284,8 @@ public void neighbor(int digit) { public void upAp7r() { final int i = Math.subtractExact(this.i, this.k); final int j = Math.subtractExact(this.j, this.k); - this.i = (int) Math.round((Math.addExact(Math.multiplyExact(2, i), j)) / 7.0); - this.j = (int) Math.round((Math.subtractExact(Math.multiplyExact(3, j), i)) / 7.0); + this.i = (int) Math.round((Math.addExact(Math.multiplyExact(2, i), j)) * M_ONESEVENTH); + this.j = (int) Math.round((Math.subtractExact(Math.multiplyExact(3, j), i)) * M_ONESEVENTH); this.k = 0; ijkNormalize(); } @@ -295,8 +298,8 @@ public void upAp7r() { public void upAp7() { final int i = Math.subtractExact(this.i, this.k); final int j = Math.subtractExact(this.j, this.k); - this.i = (int) Math.round((Math.subtractExact(Math.multiplyExact(3, i), j)) / 7.0); - this.j = (int) Math.round((Math.addExact(Math.multiplyExact(2, j), i)) / 7.0); + this.i = (int) Math.round((Math.subtractExact(Math.multiplyExact(3, i), j)) * M_ONESEVENTH); + this.j = (int) Math.round((Math.addExact(Math.multiplyExact(2, j), i)) * M_ONESEVENTH); this.k = 0; ijkNormalize(); } diff --git a/libs/h3/src/main/java/org/elasticsearch/h3/FastMath.java b/libs/h3/src/main/java/org/elasticsearch/h3/FastMath.java index 61d767901ae0..760fa7553548 100644 --- a/libs/h3/src/main/java/org/elasticsearch/h3/FastMath.java +++ b/libs/h3/src/main/java/org/elasticsearch/h3/FastMath.java @@ -102,6 +102,15 @@ final class FastMath { private static final int MIN_DOUBLE_EXPONENT = -1074; private static final int MAX_DOUBLE_EXPONENT = 1023; + /** + * PI / 2.0 + */ + private static final double M_HALF_PI = Math.PI * 0.5; + /** + * PI / 4.0 + */ + private static final double M_QUARTER_PI = Math.PI * 0.25; + // -------------------------------------------------------------------------- // CONSTANTS FOR NORMALIZATIONS // -------------------------------------------------------------------------- @@ -335,7 +344,7 @@ public static double cos(double angle) { // Faster than using normalizeZeroTwoPi. angle = remainderTwoPi(angle); if (angle < 0.0) { - angle += 2 * Math.PI; + angle += Constants.M_2PI; } } // index: possibly outside tables range. @@ -366,7 +375,7 @@ public static double sin(double angle) { // Faster than using normalizeZeroTwoPi. angle = remainderTwoPi(angle); if (angle < 0.0) { - angle += 2 * Math.PI; + angle += Constants.M_2PI; } } int index = (int) (angle * SIN_COS_INDEXER + 0.5); @@ -387,9 +396,9 @@ public static double tan(double angle) { if (Math.abs(angle) > TAN_MAX_VALUE_FOR_INT_MODULO) { // Faster than using normalizeMinusHalfPiHalfPi. angle = remainderTwoPi(angle); - if (angle < -Math.PI / 2) { + if (angle < -M_HALF_PI) { angle += Math.PI; - } else if (angle > Math.PI / 2) { + } else if (angle > M_HALF_PI) { angle -= Math.PI; } } @@ -428,7 +437,7 @@ public static double tan(double angle) { * @return Value arccosine, in radians, in [0,PI]. */ public static double acos(double value) { - return Math.PI / 2 - FastMath.asin(value); + return M_HALF_PI - FastMath.asin(value); } /** @@ -468,7 +477,7 @@ public static double asin(double value) { return negateResult ? -result : result; } else { // value >= 1.0, or value is NaN if (value == 1.0) { - return negateResult ? -Math.PI / 2 : Math.PI / 2; + return negateResult ? -M_HALF_PI : M_HALF_PI; } else { return Double.NaN; } @@ -490,7 +499,7 @@ public static double atan(double value) { } if (value == 1.0) { // We want "exact" result for 1.0. - return negateResult ? -Math.PI / 4 : Math.PI / 4; + return negateResult ? -M_QUARTER_PI : M_QUARTER_PI; } else if (value <= ATAN_MAX_VALUE_FOR_TABS) { int index = (int) (value * ATAN_INDEXER + 0.5); double delta = value - index * ATAN_DELTA; @@ -511,7 +520,7 @@ public static double atan(double value) { if (Double.isNaN(value)) { return Double.NaN; } else { - return negateResult ? -Math.PI / 2 : Math.PI / 2; + return negateResult ? -M_HALF_PI : M_HALF_PI; } } } @@ -532,9 +541,9 @@ public static double atan2(double y, double x) { } if (x == Double.POSITIVE_INFINITY) { if (y == Double.POSITIVE_INFINITY) { - return Math.PI / 4; + return M_QUARTER_PI; } else if (y == Double.NEGATIVE_INFINITY) { - return -Math.PI / 4; + return -M_QUARTER_PI; } else if (y > 0.0) { return 0.0; } else if (y < 0.0) { @@ -551,9 +560,9 @@ public static double atan2(double y, double x) { } if (x == Double.NEGATIVE_INFINITY) { if (y == Double.POSITIVE_INFINITY) { - return 3 * Math.PI / 4; + return 3 * M_QUARTER_PI; } else if (y == Double.NEGATIVE_INFINITY) { - return -3 * Math.PI / 4; + return -3 * M_QUARTER_PI; } else if (y > 0.0) { return Math.PI; } else if (y < 0.0) { @@ -562,9 +571,9 @@ public static double atan2(double y, double x) { return Double.NaN; } } else if (y > 0.0) { - return Math.PI / 2 + FastMath.atan(-x / y); + return M_HALF_PI + FastMath.atan(-x / y); } else if (y < 0.0) { - return -Math.PI / 2 - FastMath.atan(x / y); + return -M_HALF_PI - FastMath.atan(x / y); } else { return Double.NaN; } @@ -577,9 +586,9 @@ public static double atan2(double y, double x) { } } if (y > 0.0) { - return Math.PI / 2; + return M_HALF_PI; } else if (y < 0.0) { - return -Math.PI / 2; + return -M_HALF_PI; } else { return Double.NaN; } diff --git a/libs/h3/src/main/java/org/elasticsearch/h3/Vec2d.java b/libs/h3/src/main/java/org/elasticsearch/h3/Vec2d.java index 12ce728a9996..b0c2627a5f39 100644 --- a/libs/h3/src/main/java/org/elasticsearch/h3/Vec2d.java +++ b/libs/h3/src/main/java/org/elasticsearch/h3/Vec2d.java @@ -29,8 +29,11 @@ */ final class Vec2d { - /** sin(60') */ - private static final double M_SIN60 = Constants.M_SQRT3_2; + /** 1/sin(60') **/ + private static final double M_RSIN60 = 1.0 / Constants.M_SQRT3_2; + + /** one third **/ + private static final double M_ONETHIRD = 1.0 / 3.0; private static final double VEC2D_RESOLUTION = 1e-7; @@ -133,14 +136,14 @@ static LatLng hex2dToGeo(double x, double y, int face, int res, boolean substrat // scale for current resolution length u for (int i = 0; i < res; i++) { - r /= Constants.M_SQRT7; + r *= Constants.M_RSQRT7; } // scale accordingly if this is a substrate grid if (substrate) { r /= 3.0; if (H3Index.isResolutionClassIII(res)) { - r /= Constants.M_SQRT7; + r *= Constants.M_RSQRT7; } } @@ -181,8 +184,8 @@ static CoordIJK hex2dToCoordIJK(double x, double y) { a2 = Math.abs(y); // first do a reverse conversion - x2 = a2 / M_SIN60; - x1 = a1 + x2 / 2.0; + x2 = a2 * M_RSIN60; + x1 = a1 + x2 * 0.5; // check if we have the center of a hex m1 = (int) x1; @@ -193,8 +196,8 @@ static CoordIJK hex2dToCoordIJK(double x, double y) { r2 = x2 - m2; if (r1 < 0.5) { - if (r1 < 1.0 / 3.0) { - if (r2 < (1.0 + r1) / 2.0) { + if (r1 < M_ONETHIRD) { + if (r2 < (1.0 + r1) * 0.5) { i = m1; j = m2; } else { @@ -215,7 +218,7 @@ static CoordIJK hex2dToCoordIJK(double x, double y) { } } } else { - if (r1 < 2.0 / 3.0) { + if (r1 < 2.0 * M_ONETHIRD) { if (r2 < (1.0 - r1)) { j = m2; } else { @@ -228,7 +231,7 @@ static CoordIJK hex2dToCoordIJK(double x, double y) { i = Math.incrementExact(m1); } } else { - if (r2 < (r1 / 2.0)) { + if (r2 < (r1 * 0.5)) { i = Math.incrementExact(m1); j = m2; } else { diff --git a/libs/h3/src/main/java/org/elasticsearch/h3/Vec3d.java b/libs/h3/src/main/java/org/elasticsearch/h3/Vec3d.java index c5c4f8975597..5973af4b51f6 100644 --- a/libs/h3/src/main/java/org/elasticsearch/h3/Vec3d.java +++ b/libs/h3/src/main/java/org/elasticsearch/h3/Vec3d.java @@ -96,7 +96,7 @@ static long geoToH3(int res, double lat, double lon) { } } // cos(r) = 1 - 2 * sin^2(r/2) = 1 - 2 * (sqd / 4) = 1 - sqd/2 - double r = FastMath.acos(1 - sqd / 2); + double r = FastMath.acos(1 - sqd * 0.5); if (r < Constants.EPSILON) { return FaceIJK.faceIjkToH3(res, face, new CoordIJK(0, 0, 0)); From d086e149fdffd0b9a154498a6530225bba58908d Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 24 Sep 2024 09:07:52 +0100 Subject: [PATCH 07/14] Add extra context to `TransportNodesAction` invocations (#113140) (#113201) Several `TransportNodesAction` implementations do some kind of top-level computation in addition to fanning out requests to individual nodes. Today they all have to do this once the node-level fanout is complete, but in most cases the top-level computation can happen in parallel with the fanout. This commit adds support for an additional `ActionContext` object, created when starting to process the request and exposed to `newResponseAsync()` at the end, to allow this parallelization. All implementations use `(Void) null` for this param, except for `TransportClusterStatsAction` which now parallelizes the computation of the cluster-state-based stats with the node-level fanout. --- .../stats/GeoIpStatsTransportAction.java | 2 +- .../TransportNodesCapabilitiesAction.java | 3 +- .../TransportNodesFeaturesAction.java | 3 +- .../TransportNodesHotThreadsAction.java | 3 +- .../node/info/TransportNodesInfoAction.java | 3 +- ...nsportNodesReloadSecureSettingsAction.java | 3 +- .../TransportPrevalidateShardPathAction.java | 3 +- .../node/stats/TransportNodesStatsAction.java | 4 +- .../node/usage/TransportNodesUsageAction.java | 3 +- .../status/TransportNodesSnapshotsStatus.java | 3 +- .../stats/TransportClusterStatsAction.java | 136 +++++++++++++----- .../TransportFindDanglingIndexAction.java | 3 +- .../TransportListDanglingIndicesAction.java | 3 +- .../support/nodes/TransportNodesAction.java | 18 ++- ...ransportNodesListGatewayStartedShards.java | 3 +- .../stats/HealthApiStatsTransportAction.java | 3 +- .../TransportNodesListShardStoreMetadata.java | 3 +- .../node/tasks/TaskManagerTestCase.java | 2 +- .../cluster/node/tasks/TestTaskPlugin.java | 2 +- .../nodes/TransportNodesActionTests.java | 68 ++++++++- .../action/TransportAnalyticsStatsAction.java | 3 +- .../NodesDataTiersUsageTransportAction.java | 3 +- .../TransportNodeDeprecationCheckAction.java | 3 +- .../TransportDeprecationCacheResetAction.java | 3 +- .../action/EnrichCoordinatorStatsAction.java | 2 +- .../eql/plugin/TransportEqlStatsAction.java | 3 +- .../esql/plugin/TransportEsqlStatsAction.java | 3 +- ...ransportGetInferenceDiagnosticsAction.java | 3 +- .../TransportTrainedModelCacheInfoAction.java | 3 +- ...rtClearRepositoriesStatsArchiveAction.java | 3 +- .../TransportRepositoriesStatsAction.java | 3 +- ...rtSearchableSnapshotCacheStoresAction.java | 3 +- ...rchableSnapshotsNodeCachesStatsAction.java | 3 +- .../TransportClearSecurityCacheAction.java | 3 +- .../TransportClearPrivilegesCacheAction.java | 3 +- .../realm/TransportClearRealmCacheAction.java | 3 +- .../role/TransportClearRolesCacheAction.java | 3 +- ...tServiceAccountNodesCredentialsAction.java | 3 +- .../action/SpatialStatsTransportAction.java | 3 +- .../sql/plugin/TransportSqlStatsAction.java | 3 +- .../TransportGetTransformNodeStatsAction.java | 3 +- .../actions/TransportWatcherStatsAction.java | 3 +- 42 files changed, 255 insertions(+), 81 deletions(-) diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/stats/GeoIpStatsTransportAction.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/stats/GeoIpStatsTransportAction.java index c1e9b04dda90..9ebf97ca4e9e 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/stats/GeoIpStatsTransportAction.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/stats/GeoIpStatsTransportAction.java @@ -30,7 +30,7 @@ import java.io.IOException; import java.util.List; -public class GeoIpStatsTransportAction extends TransportNodesAction { +public class GeoIpStatsTransportAction extends TransportNodesAction { private final DatabaseNodeService registry; private final GeoIpDownloaderTaskExecutor geoIpDownloaderTaskExecutor; diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/capabilities/TransportNodesCapabilitiesAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/capabilities/TransportNodesCapabilitiesAction.java index 1f772be2ed1e..8df34d882941 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/capabilities/TransportNodesCapabilitiesAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/capabilities/TransportNodesCapabilitiesAction.java @@ -38,7 +38,8 @@ public class TransportNodesCapabilitiesAction extends TransportNodesAction< NodesCapabilitiesRequest, NodesCapabilitiesResponse, TransportNodesCapabilitiesAction.NodeCapabilitiesRequest, - NodeCapability> { + NodeCapability, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:monitor/nodes/capabilities"); diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/features/TransportNodesFeaturesAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/features/TransportNodesFeaturesAction.java index c0cf86288fd3..e5e04c8490c8 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/features/TransportNodesFeaturesAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/features/TransportNodesFeaturesAction.java @@ -33,7 +33,8 @@ public class TransportNodesFeaturesAction extends TransportNodesAction< NodesFeaturesRequest, NodesFeaturesResponse, TransportNodesFeaturesAction.NodeFeaturesRequest, - NodeFeatures> { + NodeFeatures, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:monitor/nodes/features"); diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/hotthreads/TransportNodesHotThreadsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/hotthreads/TransportNodesHotThreadsAction.java index cf3b34877afa..f1e24258eb57 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/hotthreads/TransportNodesHotThreadsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/hotthreads/TransportNodesHotThreadsAction.java @@ -39,7 +39,8 @@ public class TransportNodesHotThreadsAction extends TransportNodesAction< NodesHotThreadsRequest, NodesHotThreadsResponse, TransportNodesHotThreadsAction.NodeRequest, - NodeHotThreads> { + NodeHotThreads, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:monitor/nodes/hot_threads"); diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/info/TransportNodesInfoAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/info/TransportNodesInfoAction.java index 9fc657feeb46..65bf76319759 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/info/TransportNodesInfoAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/info/TransportNodesInfoAction.java @@ -34,7 +34,8 @@ public class TransportNodesInfoAction extends TransportNodesAction< NodesInfoRequest, NodesInfoResponse, TransportNodesInfoAction.NodeInfoRequest, - NodeInfo> { + NodeInfo, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:monitor/nodes/info"); private final NodeService nodeService; diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/reload/TransportNodesReloadSecureSettingsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/reload/TransportNodesReloadSecureSettingsAction.java index 8f13e69a35a5..c84df0ddfe64 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/reload/TransportNodesReloadSecureSettingsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/reload/TransportNodesReloadSecureSettingsAction.java @@ -39,7 +39,8 @@ public class TransportNodesReloadSecureSettingsAction extends TransportNodesActi NodesReloadSecureSettingsRequest, NodesReloadSecureSettingsResponse, NodesReloadSecureSettingsRequest.NodeRequest, - NodesReloadSecureSettingsResponse.NodeResponse> { + NodesReloadSecureSettingsResponse.NodeResponse, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:admin/nodes/reload_secure_settings"); diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/shutdown/TransportPrevalidateShardPathAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/shutdown/TransportPrevalidateShardPathAction.java index d3f59292009f..8c49175c320f 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/shutdown/TransportPrevalidateShardPathAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/shutdown/TransportPrevalidateShardPathAction.java @@ -44,7 +44,8 @@ public class TransportPrevalidateShardPathAction extends TransportNodesAction< PrevalidateShardPathRequest, PrevalidateShardPathResponse, NodePrevalidateShardPathRequest, - NodePrevalidateShardPathResponse> { + NodePrevalidateShardPathResponse, + Void> { public static final String ACTION_NAME = "internal:admin/indices/prevalidate_shard_path"; public static final ActionType TYPE = new ActionType<>(ACTION_NAME); diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java index 0ac55291a797..379ebe80539b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java @@ -46,7 +46,8 @@ public class TransportNodesStatsAction extends TransportNodesAction< NodesStatsRequest, NodesStatsResponse, TransportNodesStatsAction.NodeStatsRequest, - NodeStats> { + NodeStats, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:monitor/nodes/stats"); @@ -83,6 +84,7 @@ protected NodesStatsResponse newResponse(NodesStatsRequest request, List responses, List failures, ActionListener listener diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/usage/TransportNodesUsageAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/usage/TransportNodesUsageAction.java index 967f619d31f4..a55c58568647 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/usage/TransportNodesUsageAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/usage/TransportNodesUsageAction.java @@ -34,7 +34,8 @@ public class TransportNodesUsageAction extends TransportNodesAction< NodesUsageRequest, NodesUsageResponse, TransportNodesUsageAction.NodeUsageRequest, - NodeUsage> { + NodeUsage, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:monitor/nodes/usage"); private final UsageService restUsageService; diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/TransportNodesSnapshotsStatus.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/TransportNodesSnapshotsStatus.java index 19b5894e2139..42b71e275bb1 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/TransportNodesSnapshotsStatus.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/TransportNodesSnapshotsStatus.java @@ -47,7 +47,8 @@ public class TransportNodesSnapshotsStatus extends TransportNodesAction< TransportNodesSnapshotsStatus.Request, TransportNodesSnapshotsStatus.NodesSnapshotStatus, TransportNodesSnapshotsStatus.NodeRequest, - TransportNodesSnapshotsStatus.NodeSnapshotStatus> { + TransportNodesSnapshotsStatus.NodeSnapshotStatus, + Void> { public static final String ACTION_NAME = TransportSnapshotsStatusAction.TYPE.name() + "[nodes]"; public static final ActionType TYPE = new ActionType<>(ACTION_NAME); diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java index 80f3e8c439d2..7e25fe45f633 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java @@ -12,6 +12,7 @@ import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; @@ -20,6 +21,8 @@ import org.elasticsearch.action.admin.indices.stats.CommonStatsFlags; import org.elasticsearch.action.admin.indices.stats.ShardStats; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.nodes.TransportNodesAction; import org.elasticsearch.cluster.ClusterSnapshotStats; import org.elasticsearch.cluster.ClusterState; @@ -31,7 +34,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.CancellableSingleObjectCache; -import org.elasticsearch.common.util.concurrent.ListenableFuture; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.UpdateForV9; import org.elasticsearch.index.IndexService; @@ -57,6 +59,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.Executor; import java.util.function.BiFunction; import java.util.function.BooleanSupplier; @@ -64,7 +67,8 @@ public class TransportClusterStatsAction extends TransportNodesAction< ClusterStatsRequest, ClusterStatsResponse, TransportClusterStatsAction.ClusterStatsNodeRequest, - ClusterStatsNodeResponse> { + ClusterStatsNodeResponse, + SubscribableListener> { public static final ActionType TYPE = new ActionType<>("cluster:monitor/stats"); private static final CommonStatsFlags SHARD_STATS_FLAGS = new CommonStatsFlags( @@ -84,6 +88,7 @@ public class TransportClusterStatsAction extends TransportNodesAction< private final SearchUsageHolder searchUsageHolder; private final CCSUsageTelemetry ccsUsageHolder; + private final Executor clusterStateStatsExecutor; private final MetadataStatsCache mappingStatsCache; private final MetadataStatsCache analysisStatsCache; @@ -111,14 +116,32 @@ public TransportClusterStatsAction( this.repositoriesService = repositoriesService; this.searchUsageHolder = usageService.getSearchUsageHolder(); this.ccsUsageHolder = usageService.getCcsUsageHolder(); + this.clusterStateStatsExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT); this.mappingStatsCache = new MetadataStatsCache<>(threadPool.getThreadContext(), MappingStats::of); this.analysisStatsCache = new MetadataStatsCache<>(threadPool.getThreadContext(), AnalysisStats::of); } + @Override + protected SubscribableListener createActionContext(Task task, ClusterStatsRequest request) { + assert task instanceof CancellableTask; + final var cancellableTask = (CancellableTask) task; + final var additionalStatsListener = new SubscribableListener(); + AdditionalStats.compute( + cancellableTask, + clusterStateStatsExecutor, + clusterService, + mappingStatsCache, + analysisStatsCache, + additionalStatsListener + ); + return additionalStatsListener; + } + @Override protected void newResponseAsync( final Task task, final ClusterStatsRequest request, + final SubscribableListener additionalStatsListener, final List responses, final List failures, final ActionListener listener @@ -128,41 +151,19 @@ protected void newResponseAsync( + "the cluster state that are too slow for a transport thread" ); assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.MANAGEMENT); - assert task instanceof CancellableTask; - final CancellableTask cancellableTask = (CancellableTask) task; - final ClusterState state = clusterService.state(); - final Metadata metadata = state.metadata(); - final ClusterSnapshotStats clusterSnapshotStats = ClusterSnapshotStats.of( - state, - clusterService.threadPool().absoluteTimeInMillis() - ); - - final ListenableFuture mappingStatsStep = new ListenableFuture<>(); - final ListenableFuture analysisStatsStep = new ListenableFuture<>(); - mappingStatsCache.get(metadata, cancellableTask::isCancelled, mappingStatsStep); - analysisStatsCache.get(metadata, cancellableTask::isCancelled, analysisStatsStep); - mappingStatsStep.addListener( - listener.delegateFailureAndWrap( - (l, mappingStats) -> analysisStatsStep.addListener( - l.delegateFailureAndWrap( - (ll, analysisStats) -> ActionListener.completeWith( - ll, - () -> new ClusterStatsResponse( - System.currentTimeMillis(), - metadata.clusterUUID(), - clusterService.getClusterName(), - responses, - failures, - mappingStats, - analysisStats, - VersionStats.of(metadata, responses), - clusterSnapshotStats - ) - ) - ) - ) + additionalStatsListener.andThenApply( + additionalStats -> new ClusterStatsResponse( + System.currentTimeMillis(), + additionalStats.clusterUUID(), + clusterService.getClusterName(), + responses, + failures, + additionalStats.mappingStats(), + additionalStats.analysisStats(), + VersionStats.of(clusterService.state().metadata(), responses), + additionalStats.clusterSnapshotStats() ) - ); + ).addListener(listener); } @Override @@ -316,4 +317,67 @@ protected boolean isFresh(Long currentKey, Long newKey) { return newKey <= currentKey; } } + + public static final class AdditionalStats { + + private String clusterUUID; + private MappingStats mappingStats; + private AnalysisStats analysisStats; + private ClusterSnapshotStats clusterSnapshotStats; + + static void compute( + CancellableTask task, + Executor executor, + ClusterService clusterService, + MetadataStatsCache mappingStatsCache, + MetadataStatsCache analysisStatsCache, + ActionListener listener + ) { + executor.execute(ActionRunnable.wrap(listener, l -> { + task.ensureNotCancelled(); + final var result = new AdditionalStats(); + result.compute( + clusterService.state(), + mappingStatsCache, + analysisStatsCache, + task::isCancelled, + clusterService.threadPool().absoluteTimeInMillis(), + l.map(ignored -> result) + ); + })); + } + + private void compute( + ClusterState clusterState, + MetadataStatsCache mappingStatsCache, + MetadataStatsCache analysisStatsCache, + BooleanSupplier isCancelledSupplier, + long absoluteTimeInMillis, + ActionListener listener + ) { + try (var listeners = new RefCountingListener(listener)) { + final var metadata = clusterState.metadata(); + clusterUUID = metadata.clusterUUID(); + mappingStatsCache.get(metadata, isCancelledSupplier, listeners.acquire(s -> mappingStats = s)); + analysisStatsCache.get(metadata, isCancelledSupplier, listeners.acquire(s -> analysisStats = s)); + clusterSnapshotStats = ClusterSnapshotStats.of(clusterState, absoluteTimeInMillis); + } + } + + String clusterUUID() { + return clusterUUID; + } + + MappingStats mappingStats() { + return mappingStats; + } + + AnalysisStats analysisStats() { + return analysisStats; + } + + ClusterSnapshotStats clusterSnapshotStats() { + return clusterSnapshotStats; + } + } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/dangling/find/TransportFindDanglingIndexAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/dangling/find/TransportFindDanglingIndexAction.java index 30d6cdf932fe..a181b059b82e 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/dangling/find/TransportFindDanglingIndexAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/dangling/find/TransportFindDanglingIndexAction.java @@ -34,7 +34,8 @@ public class TransportFindDanglingIndexAction extends TransportNodesAction< FindDanglingIndexRequest, FindDanglingIndexResponse, NodeFindDanglingIndexRequest, - NodeFindDanglingIndexResponse> { + NodeFindDanglingIndexResponse, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:admin/indices/dangling/find"); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/dangling/list/TransportListDanglingIndicesAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/dangling/list/TransportListDanglingIndicesAction.java index 70b7ff370afd..3410e617e3ed 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/dangling/list/TransportListDanglingIndicesAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/dangling/list/TransportListDanglingIndicesAction.java @@ -36,7 +36,8 @@ public class TransportListDanglingIndicesAction extends TransportNodesAction< ListDanglingIndicesRequest, ListDanglingIndicesResponse, NodeListDanglingIndicesRequest, - NodeListDanglingIndicesResponse> { + NodeListDanglingIndicesResponse, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:admin/indices/dangling/list"); diff --git a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java index 2eed3b6263c8..89b7ec01c040 100644 --- a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.UpdateForV9; import org.elasticsearch.tasks.CancellableTask; @@ -53,7 +54,8 @@ public abstract class TransportNodesAction< NodesRequest extends BaseNodesRequest, NodesResponse extends BaseNodesResponse, NodeRequest extends TransportRequest, - NodeResponse extends BaseNodeResponse> extends TransportAction { + NodeResponse extends BaseNodeResponse, + ActionContext> extends TransportAction { private static final Logger logger = LogManager.getLogger(TransportNodesAction.class); @@ -99,6 +101,7 @@ protected void doExecute(Task task, NodesRequest request, ActionListener, Exception>>() { + final ActionContext actionContext = createActionContext(task, request); final ArrayList responses = new ArrayList<>(concreteNodes.length); final ArrayList exceptions = new ArrayList<>(0); @@ -166,7 +169,7 @@ protected CheckedConsumer, Exception> onCompletion // ref releases all happen-before here so no need to be synchronized return l -> { try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) { - newResponseAsync(task, request, responses, exceptions, l); + newResponseAsync(task, request, actionContext, responses, exceptions, l); } }; } @@ -187,6 +190,16 @@ private Writeable.Reader nodeResponseReader(DiscoveryNode discover return in -> TransportNodesAction.this.newNodeResponse(in, discoveryNode); } + /** + * Create an (optional) {@link ActionContext}: called when starting to execute this action, and the result passed to + * {@link #newResponseAsync} on completion. NB runs on the transport worker thread, must not do anything expensive without dispatching + * to a different executor. + */ + @Nullable + protected ActionContext createActionContext(Task task, NodesRequest request) { + return null; + } + /** * Create a new {@link NodesResponse}. This method is executed on {@link #finalExecutor}. * @@ -211,6 +224,7 @@ private Writeable.Reader nodeResponseReader(DiscoveryNode discover protected void newResponseAsync( Task task, NodesRequest request, + ActionContext actionContext, List responses, List failures, ActionListener listener diff --git a/server/src/main/java/org/elasticsearch/gateway/TransportNodesListGatewayStartedShards.java b/server/src/main/java/org/elasticsearch/gateway/TransportNodesListGatewayStartedShards.java index d77635af8f45..b7ddb9226ddb 100644 --- a/server/src/main/java/org/elasticsearch/gateway/TransportNodesListGatewayStartedShards.java +++ b/server/src/main/java/org/elasticsearch/gateway/TransportNodesListGatewayStartedShards.java @@ -58,7 +58,8 @@ public class TransportNodesListGatewayStartedShards extends TransportNodesAction TransportNodesListGatewayStartedShards.Request, TransportNodesListGatewayStartedShards.NodesGatewayStartedShards, TransportNodesListGatewayStartedShards.NodeRequest, - TransportNodesListGatewayStartedShards.NodeGatewayStartedShards> { + TransportNodesListGatewayStartedShards.NodeGatewayStartedShards, + Void> { private static final Logger logger = LogManager.getLogger(TransportNodesListGatewayStartedShards.class); diff --git a/server/src/main/java/org/elasticsearch/health/stats/HealthApiStatsTransportAction.java b/server/src/main/java/org/elasticsearch/health/stats/HealthApiStatsTransportAction.java index a0325b4c467e..4c2e809f48be 100644 --- a/server/src/main/java/org/elasticsearch/health/stats/HealthApiStatsTransportAction.java +++ b/server/src/main/java/org/elasticsearch/health/stats/HealthApiStatsTransportAction.java @@ -29,7 +29,8 @@ public class HealthApiStatsTransportAction extends TransportNodesAction< HealthApiStatsAction.Request, HealthApiStatsAction.Response, HealthApiStatsAction.Request.Node, - HealthApiStatsAction.Response.Node> { + HealthApiStatsAction.Response.Node, + Void> { private final HealthApiStats healthApiStats; diff --git a/server/src/main/java/org/elasticsearch/indices/store/TransportNodesListShardStoreMetadata.java b/server/src/main/java/org/elasticsearch/indices/store/TransportNodesListShardStoreMetadata.java index 10d6c32585a9..dc6b14cec3ca 100644 --- a/server/src/main/java/org/elasticsearch/indices/store/TransportNodesListShardStoreMetadata.java +++ b/server/src/main/java/org/elasticsearch/indices/store/TransportNodesListShardStoreMetadata.java @@ -60,7 +60,8 @@ public class TransportNodesListShardStoreMetadata extends TransportNodesAction< TransportNodesListShardStoreMetadata.Request, TransportNodesListShardStoreMetadata.NodesStoreFilesMetadata, TransportNodesListShardStoreMetadata.NodeRequest, - TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> { + TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata, + Void> { private static final Logger logger = LogManager.getLogger(TransportNodesListShardStoreMetadata.class); diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index 25c7ac1d39d0..a61360dab7e3 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -138,7 +138,7 @@ public int failureCount() { * Simulates node-based task that can be used to block node tasks so they are guaranteed to be registered by task manager */ abstract class AbstractTestNodesAction, NodeRequest extends TransportRequest> - extends TransportNodesAction { + extends TransportNodesAction { AbstractTestNodesAction( String actionName, diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TestTaskPlugin.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TestTaskPlugin.java index 4c0ac871d2e3..059197443795 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TestTaskPlugin.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TestTaskPlugin.java @@ -243,7 +243,7 @@ public Task createTask(long id, String type, String action, TaskId parentTaskId, } } - public static class TransportTestTaskAction extends TransportNodesAction { + public static class TransportTestTaskAction extends TransportNodesAction { @Inject public TransportTestTaskAction(ThreadPool threadPool, ClusterService clusterService, TransportService transportService) { diff --git a/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java b/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java index ed347643f0e7..4a3b060c3e1c 100644 --- a/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java @@ -78,7 +78,8 @@ public class TransportNodesActionTests extends ESTestCase { private TransportService transportService; public void testRequestIsSentToEachNode() { - TransportNodesAction action = getTestTransportNodesAction(); + TransportNodesAction action = + getTestTransportNodesAction(); TestNodesRequest request = new TestNodesRequest(); action.execute(null, request, new PlainActionFuture<>()); Map> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear(); @@ -89,7 +90,8 @@ public void testRequestIsSentToEachNode() { } public void testNodesSelectors() { - TransportNodesAction action = getTestTransportNodesAction(); + TransportNodesAction action = + getTestTransportNodesAction(); int numSelectors = randomIntBetween(1, 5); Set nodeSelectors = new HashSet<>(); for (int i = 0; i < numSelectors; i++) { @@ -109,7 +111,7 @@ public void testNodesSelectors() { } public void testCustomResolving() { - TransportNodesAction action = + TransportNodesAction action = getDataNodesOnlyTransportNodesAction(transportService); TestNodesRequest request = new TestNodesRequest(randomBoolean() ? null : generateRandomStringArray(10, 5, false, true)); action.execute(null, request, new PlainActionFuture<>()); @@ -257,6 +259,63 @@ public void testResponsesReleasedOnCancellation() { assertTrue(cancellableTask.isCancelled()); // keep task alive } + public void testActionContextReleasedOnCancellation() { + final var reachabilityChecker = new ReachabilityChecker(); + final TransportNodesAction action = + new TransportNodesAction<>( + "indices:admin/test", + clusterService, + transportService, + new ActionFilters(Collections.emptySet()), + TestNodeRequest::new, + THREAD_POOL.executor(ThreadPool.Names.GENERIC) + ) { + @Override + protected TestNodesResponse newResponse( + TestNodesRequest request, + List testNodeResponses, + List failures + ) { + return fail(null, "should not be called"); + } + + @Override + protected TestNodeRequest newNodeRequest(TestNodesRequest request) { + return new TestNodeRequest(); + } + + @Override + protected TestNodeResponse newNodeResponse(StreamInput in, DiscoveryNode node) throws IOException { + return new TestNodeResponse(in); + } + + @Override + protected TestNodeResponse nodeOperation(TestNodeRequest request, Task task) { + return new TestNodeResponse(); + } + + @Override + protected Object createActionContext(Task task, TestNodesRequest request) { + return reachabilityChecker.register(new Object()); + } + }; + + final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()); + final PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(cancellableTask, new TestNodesRequest(), listener); + + reachabilityChecker.checkReachable(); + TaskCancelHelper.cancel(cancellableTask, "simulated"); + reachabilityChecker.ensureUnreachable(); + + for (CapturingTransport.CapturedRequest capturedRequest : transport.getCapturedRequestsAndClear()) { + transport.handleLocalError(capturedRequest.requestId(), new ElasticsearchException("simulated")); + } + + expectThrows(TaskCancelledException.class, () -> listener.actionGet(10, TimeUnit.SECONDS)); + assertTrue(cancellableTask.isCancelled()); // keep task alive + } + @BeforeClass public static void startThreadPool() { THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName()); @@ -341,7 +400,8 @@ private static class TestTransportNodesAction extends TransportNodesAction< TestNodesRequest, TestNodesResponse, TestNodeRequest, - TestNodeResponse> { + TestNodeResponse, + Void> { TestTransportNodesAction( ClusterService clusterService, diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/action/TransportAnalyticsStatsAction.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/action/TransportAnalyticsStatsAction.java index d20ef5abe238..830ab3528dcc 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/action/TransportAnalyticsStatsAction.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/action/TransportAnalyticsStatsAction.java @@ -26,7 +26,8 @@ public class TransportAnalyticsStatsAction extends TransportNodesAction< AnalyticsStatsAction.Request, AnalyticsStatsAction.Response, AnalyticsStatsAction.NodeRequest, - AnalyticsStatsAction.NodeResponse> { + AnalyticsStatsAction.NodeResponse, + Void> { private final AnalyticsUsage usage; @Inject diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/datatiers/NodesDataTiersUsageTransportAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/datatiers/NodesDataTiersUsageTransportAction.java index eb35ba651df2..6a544f637772 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/datatiers/NodesDataTiersUsageTransportAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/datatiers/NodesDataTiersUsageTransportAction.java @@ -54,7 +54,8 @@ public class NodesDataTiersUsageTransportAction extends TransportNodesAction< NodesDataTiersUsageTransportAction.NodesRequest, NodesDataTiersUsageTransportAction.NodesResponse, NodesDataTiersUsageTransportAction.NodeRequest, - NodeDataTiersUsage> { + NodeDataTiersUsage, + Void> { public static final ActionType TYPE = new ActionType<>("cluster:monitor/nodes/data_tier_usage"); public static final NodeFeature LOCALLY_PRECALCULATED_STATS_FEATURE = new NodeFeature("usage.data_tiers.precalculate_stats"); diff --git a/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/TransportNodeDeprecationCheckAction.java b/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/TransportNodeDeprecationCheckAction.java index 92b16b6a3430..745f5e7ae895 100644 --- a/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/TransportNodeDeprecationCheckAction.java +++ b/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/TransportNodeDeprecationCheckAction.java @@ -43,7 +43,8 @@ public class TransportNodeDeprecationCheckAction extends TransportNodesAction< NodesDeprecationCheckRequest, NodesDeprecationCheckResponse, NodesDeprecationCheckAction.NodeRequest, - NodesDeprecationCheckAction.NodeResponse> { + NodesDeprecationCheckAction.NodeResponse, + Void> { private final Settings settings; private final XPackLicenseState licenseState; diff --git a/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/logging/TransportDeprecationCacheResetAction.java b/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/logging/TransportDeprecationCacheResetAction.java index 01d9089a153f..1a82574752fe 100644 --- a/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/logging/TransportDeprecationCacheResetAction.java +++ b/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/logging/TransportDeprecationCacheResetAction.java @@ -28,7 +28,8 @@ public class TransportDeprecationCacheResetAction extends TransportNodesAction< DeprecationCacheResetAction.Request, DeprecationCacheResetAction.Response, DeprecationCacheResetAction.NodeRequest, - DeprecationCacheResetAction.NodeResponse> { + DeprecationCacheResetAction.NodeResponse, + Void> { private static final Logger logger = LogManager.getLogger(TransportDeprecationCacheResetAction.class); diff --git a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/EnrichCoordinatorStatsAction.java b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/EnrichCoordinatorStatsAction.java index 1213c439c628..808acee58df3 100644 --- a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/EnrichCoordinatorStatsAction.java +++ b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/EnrichCoordinatorStatsAction.java @@ -111,7 +111,7 @@ public void writeTo(StreamOutput out) throws IOException { } } - public static class TransportAction extends TransportNodesAction { + public static class TransportAction extends TransportNodesAction { private final EnrichCache enrichCache; private final EnrichCoordinatorProxyAction.Coordinator coordinator; diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlStatsAction.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlStatsAction.java index 76f3d05cdb9d..5f1fcbbe6659 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlStatsAction.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlStatsAction.java @@ -28,7 +28,8 @@ public class TransportEqlStatsAction extends TransportNodesAction< EqlStatsRequest, EqlStatsResponse, EqlStatsRequest.NodeStatsRequest, - EqlStatsResponse.NodeStatsResponse> { + EqlStatsResponse.NodeStatsResponse, + Void> { // the plan executor holds the metrics private final PlanExecutor planExecutor; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlStatsAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlStatsAction.java index 7ed027436bbc..985dcf118ac5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlStatsAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlStatsAction.java @@ -31,7 +31,8 @@ public class TransportEsqlStatsAction extends TransportNodesAction< EsqlStatsRequest, EsqlStatsResponse, EsqlStatsRequest.NodeStatsRequest, - EsqlStatsResponse.NodeStatsResponse> { + EsqlStatsResponse.NodeStatsResponse, + Void> { static final NodeFeature ESQL_STATS_FEATURE = new NodeFeature("esql.stats_node"); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java index 88689035fbd8..cdd322cfe74f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceDiagnosticsAction.java @@ -28,7 +28,8 @@ public class TransportGetInferenceDiagnosticsAction extends TransportNodesAction GetInferenceDiagnosticsAction.Request, GetInferenceDiagnosticsAction.Response, GetInferenceDiagnosticsAction.NodeRequest, - GetInferenceDiagnosticsAction.NodeResponse> { + GetInferenceDiagnosticsAction.NodeResponse, + Void> { private final HttpClientManager httpClientManager; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportTrainedModelCacheInfoAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportTrainedModelCacheInfoAction.java index 5b76b46f66c6..af7e1869420b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportTrainedModelCacheInfoAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportTrainedModelCacheInfoAction.java @@ -35,7 +35,8 @@ public class TransportTrainedModelCacheInfoAction extends TransportNodesAction< TrainedModelCacheInfoAction.Request, TrainedModelCacheInfoAction.Response, TransportTrainedModelCacheInfoAction.NodeModelCacheInfoRequest, - CacheInfo> { + CacheInfo, + Void> { private final ModelLoadingService modelLoadingService; diff --git a/x-pack/plugin/repositories-metering-api/src/main/java/org/elasticsearch/xpack/repositories/metering/action/TransportClearRepositoriesStatsArchiveAction.java b/x-pack/plugin/repositories-metering-api/src/main/java/org/elasticsearch/xpack/repositories/metering/action/TransportClearRepositoriesStatsArchiveAction.java index a7ffc096f6ff..f138449559d2 100644 --- a/x-pack/plugin/repositories-metering-api/src/main/java/org/elasticsearch/xpack/repositories/metering/action/TransportClearRepositoriesStatsArchiveAction.java +++ b/x-pack/plugin/repositories-metering-api/src/main/java/org/elasticsearch/xpack/repositories/metering/action/TransportClearRepositoriesStatsArchiveAction.java @@ -29,7 +29,8 @@ public final class TransportClearRepositoriesStatsArchiveAction extends Transpor ClearRepositoriesMeteringArchiveRequest, RepositoriesMeteringResponse, TransportClearRepositoriesStatsArchiveAction.ClearRepositoriesStatsArchiveNodeRequest, - RepositoriesNodeMeteringResponse> { + RepositoriesNodeMeteringResponse, + Void> { private final RepositoriesService repositoriesService; diff --git a/x-pack/plugin/repositories-metering-api/src/main/java/org/elasticsearch/xpack/repositories/metering/action/TransportRepositoriesStatsAction.java b/x-pack/plugin/repositories-metering-api/src/main/java/org/elasticsearch/xpack/repositories/metering/action/TransportRepositoriesStatsAction.java index cb7d27481448..76ad89a9dfea 100644 --- a/x-pack/plugin/repositories-metering-api/src/main/java/org/elasticsearch/xpack/repositories/metering/action/TransportRepositoriesStatsAction.java +++ b/x-pack/plugin/repositories-metering-api/src/main/java/org/elasticsearch/xpack/repositories/metering/action/TransportRepositoriesStatsAction.java @@ -27,7 +27,8 @@ public final class TransportRepositoriesStatsAction extends TransportNodesAction RepositoriesMeteringRequest, RepositoriesMeteringResponse, TransportRepositoriesStatsAction.RepositoriesNodeStatsRequest, - RepositoriesNodeMeteringResponse> { + RepositoriesNodeMeteringResponse, + Void> { private final RepositoriesService repositoriesService; diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotCacheStoresAction.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotCacheStoresAction.java index 446f0f433fe3..67cb5cddd988 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotCacheStoresAction.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotCacheStoresAction.java @@ -39,7 +39,8 @@ public class TransportSearchableSnapshotCacheStoresAction extends TransportNodes TransportSearchableSnapshotCacheStoresAction.Request, TransportSearchableSnapshotCacheStoresAction.NodesCacheFilesMetadata, TransportSearchableSnapshotCacheStoresAction.NodeRequest, - TransportSearchableSnapshotCacheStoresAction.NodeCacheFilesMetadata> { + TransportSearchableSnapshotCacheStoresAction.NodeCacheFilesMetadata, + Void> { public static final String ACTION_NAME = "internal:admin/xpack/searchable_snapshots/cache/store"; diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotsNodeCachesStatsAction.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotsNodeCachesStatsAction.java index c28948b4101e..b414ff6daf71 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotsNodeCachesStatsAction.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotsNodeCachesStatsAction.java @@ -47,7 +47,8 @@ public class TransportSearchableSnapshotsNodeCachesStatsAction extends Transport TransportSearchableSnapshotsNodeCachesStatsAction.NodesRequest, TransportSearchableSnapshotsNodeCachesStatsAction.NodesCachesStatsResponse, TransportSearchableSnapshotsNodeCachesStatsAction.NodeRequest, - TransportSearchableSnapshotsNodeCachesStatsAction.NodeCachesStatsResponse> { + TransportSearchableSnapshotsNodeCachesStatsAction.NodeCachesStatsResponse, + Void> { public static final String ACTION_NAME = "cluster:admin/xpack/searchable_snapshots/cache/stats"; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportClearSecurityCacheAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportClearSecurityCacheAction.java index 56965274c6fa..ac06cf5f1eb6 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportClearSecurityCacheAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportClearSecurityCacheAction.java @@ -33,7 +33,8 @@ public class TransportClearSecurityCacheAction extends TransportNodesAction< ClearSecurityCacheRequest, ClearSecurityCacheResponse, ClearSecurityCacheRequest.Node, - ClearSecurityCacheResponse.Node> { + ClearSecurityCacheResponse.Node, + Void> { private final CacheInvalidatorRegistry cacheInvalidatorRegistry; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/privilege/TransportClearPrivilegesCacheAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/privilege/TransportClearPrivilegesCacheAction.java index 852144d1c277..14868dda9f04 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/privilege/TransportClearPrivilegesCacheAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/privilege/TransportClearPrivilegesCacheAction.java @@ -30,7 +30,8 @@ public class TransportClearPrivilegesCacheAction extends TransportNodesAction< ClearPrivilegesCacheRequest, ClearPrivilegesCacheResponse, ClearPrivilegesCacheRequest.Node, - ClearPrivilegesCacheResponse.Node> { + ClearPrivilegesCacheResponse.Node, + Void> { private final CompositeRolesStore rolesStore; private final CacheInvalidatorRegistry cacheInvalidatorRegistry; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/realm/TransportClearRealmCacheAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/realm/TransportClearRealmCacheAction.java index 4d574c6b6c0a..23c4e312e2f3 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/realm/TransportClearRealmCacheAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/realm/TransportClearRealmCacheAction.java @@ -32,7 +32,8 @@ public class TransportClearRealmCacheAction extends TransportNodesAction< ClearRealmCacheRequest, ClearRealmCacheResponse, ClearRealmCacheRequest.Node, - ClearRealmCacheResponse.Node> { + ClearRealmCacheResponse.Node, + Void> { private final Realms realms; private final AuthenticationService authenticationService; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/role/TransportClearRolesCacheAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/role/TransportClearRolesCacheAction.java index 82e62187f7f4..412b0d0b70da 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/role/TransportClearRolesCacheAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/role/TransportClearRolesCacheAction.java @@ -28,7 +28,8 @@ public class TransportClearRolesCacheAction extends TransportNodesAction< ClearRolesCacheRequest, ClearRolesCacheResponse, ClearRolesCacheRequest.Node, - ClearRolesCacheResponse.Node> { + ClearRolesCacheResponse.Node, + Void> { private final CompositeRolesStore rolesStore; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountNodesCredentialsAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountNodesCredentialsAction.java index 228c606dd1e3..82ec40718959 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountNodesCredentialsAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/service/TransportGetServiceAccountNodesCredentialsAction.java @@ -35,7 +35,8 @@ public class TransportGetServiceAccountNodesCredentialsAction extends TransportN GetServiceAccountCredentialsNodesRequest, GetServiceAccountCredentialsNodesResponse, GetServiceAccountCredentialsNodesRequest.Node, - GetServiceAccountCredentialsNodesResponse.Node> { + GetServiceAccountCredentialsNodesResponse.Node, + Void> { private final FileServiceAccountTokenStore fileServiceAccountTokenStore; diff --git a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/action/SpatialStatsTransportAction.java b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/action/SpatialStatsTransportAction.java index f36ee616996e..526b2c85c84f 100644 --- a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/action/SpatialStatsTransportAction.java +++ b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/action/SpatialStatsTransportAction.java @@ -26,7 +26,8 @@ public class SpatialStatsTransportAction extends TransportNodesAction< SpatialStatsAction.Request, SpatialStatsAction.Response, SpatialStatsAction.NodeRequest, - SpatialStatsAction.NodeResponse> { + SpatialStatsAction.NodeResponse, + Void> { private final SpatialUsage usage; @Inject diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlStatsAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlStatsAction.java index c334c5779050..337abf47ca0e 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlStatsAction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlStatsAction.java @@ -28,7 +28,8 @@ public class TransportSqlStatsAction extends TransportNodesAction< SqlStatsRequest, SqlStatsResponse, SqlStatsRequest.NodeStatsRequest, - SqlStatsResponse.NodeStatsResponse> { + SqlStatsResponse.NodeStatsResponse, + Void> { // the plan executor holds the metrics private final PlanExecutor planExecutor; diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportGetTransformNodeStatsAction.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportGetTransformNodeStatsAction.java index 83e7f55df04b..3fd97ee49e1d 100644 --- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportGetTransformNodeStatsAction.java +++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportGetTransformNodeStatsAction.java @@ -35,7 +35,8 @@ public class TransportGetTransformNodeStatsAction extends TransportNodesAction< NodesStatsRequest, NodesStatsResponse, NodeStatsRequest, - NodeStatsResponse> { + NodeStatsResponse, + Void> { private final TransportService transportService; private final TransformScheduler scheduler; diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/transport/actions/TransportWatcherStatsAction.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/transport/actions/TransportWatcherStatsAction.java index 220415cf9d09..0c79bba50722 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/transport/actions/TransportWatcherStatsAction.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/transport/actions/TransportWatcherStatsAction.java @@ -36,7 +36,8 @@ public class TransportWatcherStatsAction extends TransportNodesAction< WatcherStatsRequest, WatcherStatsResponse, WatcherStatsRequest.Node, - WatcherStatsResponse.Node> { + WatcherStatsResponse.Node, + Void> { private final ExecutionService executionService; private final TriggerService triggerService; From 68c2efc7f68cfe19f37453e089b622cb23b402c0 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 24 Sep 2024 09:09:43 +0100 Subject: [PATCH 08/14] Detect long-running outbound tasks on network threads (#113250) (#113375) Extends the mechanism introduced in #109204 to cover slow-running outbound tasks too. Closes #108710 Closes ES-8625 --- .../netty4/Netty4HttpServerTransport.java | 12 +-- .../transport/netty4/Netty4Transport.java | 5 +- .../netty4/Netty4WriteThrottlingHandler.java | 81 ++++++++++++++----- .../Netty4WriteThrottlingHandlerTests.java | 66 +++++++++------ .../common/network/ThreadWatchdog.java | 10 +++ .../common/network/ThreadWatchdogTests.java | 38 ++++++--- 6 files changed, 148 insertions(+), 64 deletions(-) diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index 5ed3d8139295..b971a52b7afb 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -332,8 +332,12 @@ protected void initChannel(Channel ch) throws Exception { if (tlsConfig.isTLSEnabled()) { ch.pipeline().addLast("ssl", new SslHandler(tlsConfig.createServerSSLEngine())); } + final var threadWatchdogActivityTracker = transport.threadWatchdog.getActivityTrackerForCurrentThread(); ch.pipeline() - .addLast("chunked_writer", new Netty4WriteThrottlingHandler(transport.getThreadPool().getThreadContext())) + .addLast( + "chunked_writer", + new Netty4WriteThrottlingHandler(transport.getThreadPool().getThreadContext(), threadWatchdogActivityTracker) + ) .addLast("byte_buf_sizer", NettyByteBufSizer.INSTANCE); if (transport.readTimeoutMillis > 0) { ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)); @@ -409,11 +413,7 @@ protected Result beginEncode(HttpResponse httpResponse, String acceptEncoding) t ch.pipeline() .addLast( "pipelining", - new Netty4HttpPipeliningHandler( - transport.pipeliningMaxEvents, - transport, - transport.threadWatchdog.getActivityTrackerForCurrentThread() - ) + new Netty4HttpPipeliningHandler(transport.pipeliningMaxEvents, transport, threadWatchdogActivityTracker) ); transport.serverAcceptedChannel(nettyHttpChannel); } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index b99c76e7b061..d8b02a0e9a0d 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -359,7 +359,10 @@ private void setupPipeline(Channel ch, boolean isRemoteClusterServerChannel) { if (NetworkTraceFlag.TRACE_ENABLED) { pipeline.addLast("logging", ESLoggingHandler.INSTANCE); } - pipeline.addLast("chunked_writer", new Netty4WriteThrottlingHandler(getThreadPool().getThreadContext())); + pipeline.addLast( + "chunked_writer", + new Netty4WriteThrottlingHandler(getThreadPool().getThreadContext(), threadWatchdog.getActivityTrackerForCurrentThread()) + ); pipeline.addLast( "dispatcher", new Netty4MessageInboundHandler( diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4WriteThrottlingHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4WriteThrottlingHandler.java index 15011957040a..738da83817cb 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4WriteThrottlingHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4WriteThrottlingHandler.java @@ -23,6 +23,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefIterator; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.network.ThreadWatchdog; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.transport.Transports; @@ -42,31 +43,44 @@ public final class Netty4WriteThrottlingHandler extends ChannelDuplexHandler { private final Queue queuedWrites = new LinkedList<>(); private final ThreadContext threadContext; + private final ThreadWatchdog.ActivityTracker threadWatchdogActivityTracker; private WriteOperation currentWrite; - public Netty4WriteThrottlingHandler(ThreadContext threadContext) { + public Netty4WriteThrottlingHandler(ThreadContext threadContext, ThreadWatchdog.ActivityTracker threadWatchdogActivityTracker) { this.threadContext = threadContext; + this.threadWatchdogActivityTracker = threadWatchdogActivityTracker; } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws IOException { - if (msg instanceof BytesReference reference) { - if (reference.hasArray()) { - writeSingleByteBuf(ctx, Unpooled.wrappedBuffer(reference.array(), reference.arrayOffset(), reference.length()), promise); - } else { - BytesRefIterator iter = reference.iterator(); - final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); - BytesRef next; - while ((next = iter.next()) != null) { - final ChannelPromise chunkPromise = ctx.newPromise(); - combiner.add((Future) chunkPromise); - writeSingleByteBuf(ctx, Unpooled.wrappedBuffer(next.bytes, next.offset, next.length), chunkPromise); + final boolean startedActivity = threadWatchdogActivityTracker.maybeStartActivity(); + try { + if (msg instanceof BytesReference reference) { + if (reference.hasArray()) { + writeSingleByteBuf( + ctx, + Unpooled.wrappedBuffer(reference.array(), reference.arrayOffset(), reference.length()), + promise + ); + } else { + BytesRefIterator iter = reference.iterator(); + final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); + BytesRef next; + while ((next = iter.next()) != null) { + final ChannelPromise chunkPromise = ctx.newPromise(); + combiner.add((Future) chunkPromise); + writeSingleByteBuf(ctx, Unpooled.wrappedBuffer(next.bytes, next.offset, next.length), chunkPromise); + } + combiner.finish(promise); } - combiner.finish(promise); + } else { + assert msg instanceof ByteBuf; + writeSingleByteBuf(ctx, (ByteBuf) msg, promise); + } + } finally { + if (startedActivity) { + threadWatchdogActivityTracker.stopActivity(); } - } else { - assert msg instanceof ByteBuf; - writeSingleByteBuf(ctx, (ByteBuf) msg, promise); } } @@ -116,22 +130,45 @@ private void queueWrite(ByteBuf buf, ChannelPromise promise) { @Override public void channelWritabilityChanged(ChannelHandlerContext ctx) { - if (ctx.channel().isWritable()) { - doFlush(ctx); + final boolean startedActivity = threadWatchdogActivityTracker.maybeStartActivity(); + try { + if (ctx.channel().isWritable()) { + doFlush(ctx); + } + ctx.fireChannelWritabilityChanged(); + } finally { + if (startedActivity) { + threadWatchdogActivityTracker.stopActivity(); + } } - ctx.fireChannelWritabilityChanged(); } @Override public void flush(ChannelHandlerContext ctx) { - if (doFlush(ctx) == false) { - ctx.flush(); + final boolean startedActivity = threadWatchdogActivityTracker.maybeStartActivity(); + try { + if (doFlush(ctx) == false) { + ctx.flush(); + } + } finally { + if (startedActivity) { + threadWatchdogActivityTracker.stopActivity(); + } } } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { - doFlush(ctx); + final boolean startedActivity = threadWatchdogActivityTracker.maybeStartActivity(); + try { + doFlush(ctx); + } finally { + if (startedActivity) { + threadWatchdogActivityTracker.stopActivity(); + } + } + + // super.channelInactive() can trigger reads which are tracked separately (and are not re-entrant) so no activity tracking here super.channelInactive(ctx); } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4WriteThrottlingHandlerTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4WriteThrottlingHandlerTests.java index cf1fcbe88ea9..d87889c6a241 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4WriteThrottlingHandlerTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4WriteThrottlingHandlerTests.java @@ -18,43 +18,52 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.common.network.ThreadWatchdog; +import org.elasticsearch.common.network.ThreadWatchdogHelper; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.Transports; import org.junit.After; import org.junit.Before; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.ExecutionException; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.oneOf; +import static org.hamcrest.Matchers.startsWith; public class Netty4WriteThrottlingHandlerTests extends ESTestCase { - private SharedGroupFactory.SharedGroup transportGroup; + private ThreadWatchdog threadWatchdog = new ThreadWatchdog(); @Before - public void createGroup() { - final SharedGroupFactory sharedGroupFactory = new SharedGroupFactory(Settings.EMPTY); - transportGroup = sharedGroupFactory.getTransportGroup(); + public void setFakeThreadName() { + // These tests interact with EmbeddedChannel instances directly on the test thread, so we rename it temporarily to satisfy checks + // that we're running on a transport thread + Thread.currentThread().setName(Transports.TEST_MOCK_TRANSPORT_THREAD_PREFIX + Thread.currentThread().getName()); } @After - public void stopGroup() { - transportGroup.shutdown(); + public void resetThreadName() { + final var threadName = Thread.currentThread().getName(); + assertThat(threadName, startsWith(Transports.TEST_MOCK_TRANSPORT_THREAD_PREFIX)); + Thread.currentThread().setName(threadName.substring(Transports.TEST_MOCK_TRANSPORT_THREAD_PREFIX.length())); } - public void testThrottlesLargeMessage() throws ExecutionException, InterruptedException { + public void testThrottlesLargeMessage() { final List seen = new CopyOnWriteArrayList<>(); final CapturingHandler capturingHandler = new CapturingHandler(seen); final EmbeddedChannel embeddedChannel = new EmbeddedChannel( capturingHandler, - new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY)) + new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY), threadWatchdog.getActivityTrackerForCurrentThread()) ); // we assume that the channel outbound buffer is smaller than Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE final int writeableBytes = Math.toIntExact(embeddedChannel.bytesBeforeUnwritable()); @@ -66,11 +75,11 @@ public void testThrottlesLargeMessage() throws ExecutionException, InterruptedEx ); final Object message = wrapAsNettyOrEsBuffer(messageBytes); final ChannelPromise promise = embeddedChannel.newPromise(); - transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(message, promise)).get(); + embeddedChannel.write(message, promise); assertThat(seen, hasSize(1)); assertSliceEquals(seen.get(0), message, 0, Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE); assertFalse(promise.isDone()); - transportGroup.getLowLevelGroup().submit(embeddedChannel::flush).get(); + embeddedChannel.flush(); assertTrue(promise.isDone()); assertThat(seen, hasSize(fullSizeChunks + (extraChunkSize == 0 ? 0 : 1))); assertTrue(capturingHandler.didWriteAfterThrottled); @@ -84,12 +93,12 @@ public void testThrottlesLargeMessage() throws ExecutionException, InterruptedEx } } - public void testThrottleLargeCompositeMessage() throws ExecutionException, InterruptedException { + public void testThrottleLargeCompositeMessage() { final List seen = new CopyOnWriteArrayList<>(); final CapturingHandler capturingHandler = new CapturingHandler(seen); final EmbeddedChannel embeddedChannel = new EmbeddedChannel( capturingHandler, - new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY)) + new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY), threadWatchdog.getActivityTrackerForCurrentThread()) ); // we assume that the channel outbound buffer is smaller than Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE final int writeableBytes = Math.toIntExact(embeddedChannel.bytesBeforeUnwritable()); @@ -105,51 +114,51 @@ public void testThrottleLargeCompositeMessage() throws ExecutionException, Inter new BytesArray(messageBytes, splitOffset, messageBytes.length - splitOffset) ); final ChannelPromise promise = embeddedChannel.newPromise(); - transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(message, promise)).get(); + embeddedChannel.write(message, promise); assertThat(seen, hasSize(oneOf(1, 2))); assertSliceEquals(seen.get(0), message, 0, seen.get(0).readableBytes()); assertFalse(promise.isDone()); - transportGroup.getLowLevelGroup().submit(embeddedChannel::flush).get(); + embeddedChannel.flush(); assertTrue(promise.isDone()); assertThat(seen, hasSize(oneOf(fullSizeChunks, fullSizeChunks + 1))); assertTrue(capturingHandler.didWriteAfterThrottled); assertBufferEquals(Unpooled.compositeBuffer().addComponents(true, seen), message); } - public void testPassesSmallMessageDirectly() throws ExecutionException, InterruptedException { + public void testPassesSmallMessageDirectly() { final List seen = new CopyOnWriteArrayList<>(); final CapturingHandler capturingHandler = new CapturingHandler(seen); final EmbeddedChannel embeddedChannel = new EmbeddedChannel( capturingHandler, - new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY)) + new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY), threadWatchdog.getActivityTrackerForCurrentThread()) ); final int writeableBytes = Math.toIntExact(embeddedChannel.bytesBeforeUnwritable()); assertThat(writeableBytes, lessThan(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE)); final byte[] messageBytes = randomByteArrayOfLength(randomIntBetween(0, Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE)); final Object message = wrapAsNettyOrEsBuffer(messageBytes); final ChannelPromise promise = embeddedChannel.newPromise(); - transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(message, promise)).get(); + embeddedChannel.write(message, promise); assertThat(seen, hasSize(1)); // first message should be passed through straight away assertBufferEquals(seen.get(0), message); assertFalse(promise.isDone()); - transportGroup.getLowLevelGroup().submit(embeddedChannel::flush).get(); + embeddedChannel.flush(); assertTrue(promise.isDone()); assertThat(seen, hasSize(1)); assertFalse(capturingHandler.didWriteAfterThrottled); } - public void testThrottlesOnUnwritable() throws ExecutionException, InterruptedException { + public void testThrottlesOnUnwritable() { final List seen = new CopyOnWriteArrayList<>(); final EmbeddedChannel embeddedChannel = new EmbeddedChannel( new CapturingHandler(seen), - new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY)) + new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY), threadWatchdog.getActivityTrackerForCurrentThread()) ); final int writeableBytes = Math.toIntExact(embeddedChannel.bytesBeforeUnwritable()); assertThat(writeableBytes, lessThan(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE)); final byte[] messageBytes = randomByteArrayOfLength(writeableBytes + randomIntBetween(0, 10)); final Object message = wrapAsNettyOrEsBuffer(messageBytes); final ChannelPromise promise = embeddedChannel.newPromise(); - transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(message, promise)).get(); + embeddedChannel.write(message, promise); assertThat(seen, hasSize(1)); // first message should be passed through straight away assertBufferEquals(seen.get(0), message); assertFalse(promise.isDone()); @@ -157,11 +166,11 @@ public void testThrottlesOnUnwritable() throws ExecutionException, InterruptedEx randomByteArrayOfLength(randomIntBetween(0, Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE)) ); final ChannelPromise promiseForQueued = embeddedChannel.newPromise(); - transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(messageToQueue, promiseForQueued)).get(); + embeddedChannel.write(messageToQueue, promiseForQueued); assertThat(seen, hasSize(1)); assertFalse(promiseForQueued.isDone()); assertFalse(promise.isDone()); - transportGroup.getLowLevelGroup().submit(embeddedChannel::flush).get(); + embeddedChannel.flush(); assertTrue(promise.isDone()); assertTrue(promiseForQueued.isDone()); } @@ -191,7 +200,7 @@ private static Object wrapAsNettyOrEsBuffer(byte[] messageBytes) { return new BytesArray(messageBytes); } - private static class CapturingHandler extends ChannelOutboundHandlerAdapter { + private class CapturingHandler extends ChannelOutboundHandlerAdapter { private final List seen; private boolean wasThrottled = false; @@ -204,6 +213,13 @@ private static class CapturingHandler extends ChannelOutboundHandlerAdapter { @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assertThat( + ThreadWatchdogHelper.getStuckThreadNames(threadWatchdog), + // writes are re-entrant so we might already be considered stuck due to an earlier check + anyOf(emptyIterable(), hasItem(Thread.currentThread().getName())) + ); + assertThat(ThreadWatchdogHelper.getStuckThreadNames(threadWatchdog), hasItem(Thread.currentThread().getName())); + assertTrue("should only write to writeable channel", ctx.channel().isWritable()); assertThat(msg, instanceOf(ByteBuf.class)); final ByteBuf buf = (ByteBuf) msg; diff --git a/server/src/main/java/org/elasticsearch/common/network/ThreadWatchdog.java b/server/src/main/java/org/elasticsearch/common/network/ThreadWatchdog.java index 687a8f5940bd..5432e7cfa267 100644 --- a/server/src/main/java/org/elasticsearch/common/network/ThreadWatchdog.java +++ b/server/src/main/java/org/elasticsearch/common/network/ThreadWatchdog.java @@ -131,6 +131,16 @@ public void startActivity() { assert isIdle(prevValue) : "thread [" + trackedThread.getName() + "] was already active"; } + public boolean maybeStartActivity() { + assert trackedThread == Thread.currentThread() : trackedThread.getName() + " vs " + Thread.currentThread().getName(); + if (isIdle(get())) { + getAndIncrement(); + return true; + } else { + return false; + } + } + public void stopActivity() { assert trackedThread == Thread.currentThread() : trackedThread.getName() + " vs " + Thread.currentThread().getName(); final var prevValue = getAndIncrement(); diff --git a/server/src/test/java/org/elasticsearch/common/network/ThreadWatchdogTests.java b/server/src/test/java/org/elasticsearch/common/network/ThreadWatchdogTests.java index 06cfddf6c973..f8506a007bb1 100644 --- a/server/src/test/java/org/elasticsearch/common/network/ThreadWatchdogTests.java +++ b/server/src/test/java/org/elasticsearch/common/network/ThreadWatchdogTests.java @@ -49,22 +49,22 @@ public void testSimpleActivityTracking() throws InterruptedException { // step 1: thread is idle safeAwait(barrier); - activityTracker.startActivity(); + startActivity(activityTracker); safeAwait(barrier); // step 2: thread is active safeAwait(barrier); for (int i = between(1, 10); i > 0; i--) { - activityTracker.stopActivity(); - activityTracker.startActivity(); + stopActivity(activityTracker); + startActivity(activityTracker); } safeAwait(barrier); // step 3: thread still active, but made progress safeAwait(barrier); - activityTracker.stopActivity(); + stopActivity(activityTracker); safeAwait(barrier); // step 4: thread is idle again @@ -117,11 +117,11 @@ public void testMultipleBlockedThreads() throws InterruptedException { threads[i] = new Thread(() -> { safeAwait(barrier); final var activityTracker = watchdog.getActivityTrackerForCurrentThread(); - activityTracker.startActivity(); + startActivity(activityTracker); safeAwait(barrier); // wait for main test thread safeAwait(barrier); - activityTracker.stopActivity(); + stopActivity(activityTracker); }, threadNames.get(i)); threads[i].start(); } @@ -158,14 +158,14 @@ public void testConcurrency() throws Exception { threads[i] = new Thread(() -> { final var activityTracker = watchdog.getActivityTrackerForCurrentThread(); while (keepGoing.get()) { - activityTracker.startActivity(); + startActivity(activityTracker); try { safeAcquire(semaphore); Thread.yield(); semaphore.release(); Thread.yield(); } finally { - activityTracker.stopActivity(); + stopActivity(activityTracker); warmUpLatch.countDown(); } } @@ -233,7 +233,7 @@ public void testLoggingAndScheduling() { ); } - activityTracker.startActivity(); + startActivity(activityTracker); assertAdvanceTime(deterministicTaskQueue, checkIntervalMillis); MockLog.assertThatLogger( deterministicTaskQueue::runAllRunnableTasks, @@ -262,7 +262,7 @@ public void testLoggingAndScheduling() { ) ); assertAdvanceTime(deterministicTaskQueue, Math.max(quietTimeMillis, checkIntervalMillis)); - activityTracker.stopActivity(); + stopActivity(activityTracker); MockLog.assertThatLogger( deterministicTaskQueue::runAllRunnableTasks, ThreadWatchdog.class, @@ -303,4 +303,22 @@ private static void assertAdvanceTime(DeterministicTaskQueue deterministicTaskQu deterministicTaskQueue.advanceTime(); assertEquals(expectedMillis, deterministicTaskQueue.getCurrentTimeMillis() - currentTimeMillis); } + + private static void startActivity(ThreadWatchdog.ActivityTracker activityTracker) { + if (randomBoolean()) { + activityTracker.startActivity(); + } else { + assertTrue(activityTracker.maybeStartActivity()); + } + if (randomBoolean()) { + assertFalse(activityTracker.maybeStartActivity()); + } + } + + private static void stopActivity(ThreadWatchdog.ActivityTracker activityTracker) { + if (randomBoolean()) { + assertFalse(activityTracker.maybeStartActivity()); + } + activityTracker.stopActivity(); + } } From caf94ca88a9ef2bb60639df5ab9344f17a03dcb7 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 24 Sep 2024 09:16:01 +0100 Subject: [PATCH 09/14] Make `AddIndexBlockClusterStateUpdateRequest` a record (#113349) (#113389) No need to extend `IndicesClusterStateUpdateRequest`, this thing can be completely immutable. --- ...ddIndexBlockClusterStateUpdateRequest.java | 38 ++++++++----------- .../TransportAddIndexBlockAction.java | 21 ++++++---- .../metadata/MetadataIndexStateService.java | 12 +++--- 3 files changed, 34 insertions(+), 37 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/readonly/AddIndexBlockClusterStateUpdateRequest.java b/server/src/main/java/org/elasticsearch/action/admin/indices/readonly/AddIndexBlockClusterStateUpdateRequest.java index beaf561bfee5..50bd3b37b4cb 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/readonly/AddIndexBlockClusterStateUpdateRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/readonly/AddIndexBlockClusterStateUpdateRequest.java @@ -8,32 +8,26 @@ */ package org.elasticsearch.action.admin.indices.readonly; -import org.elasticsearch.cluster.ack.IndicesClusterStateUpdateRequest; import org.elasticsearch.cluster.metadata.IndexMetadata.APIBlock; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.Index; + +import java.util.Objects; /** * Cluster state update request that allows to add a block to one or more indices */ -public class AddIndexBlockClusterStateUpdateRequest extends IndicesClusterStateUpdateRequest { - - private final APIBlock block; - private long taskId; - - public AddIndexBlockClusterStateUpdateRequest(final APIBlock block, final long taskId) { - this.block = block; - this.taskId = taskId; - } - - public long taskId() { - return taskId; - } - - public APIBlock getBlock() { - return block; - } - - public AddIndexBlockClusterStateUpdateRequest taskId(final long taskId) { - this.taskId = taskId; - return this; +public record AddIndexBlockClusterStateUpdateRequest( + TimeValue masterNodeTimeout, + TimeValue ackTimeout, + APIBlock block, + long taskId, + Index[] indices +) { + public AddIndexBlockClusterStateUpdateRequest { + Objects.requireNonNull(masterNodeTimeout); + Objects.requireNonNull(ackTimeout); + Objects.requireNonNull(block); + Objects.requireNonNull(indices); } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/readonly/TransportAddIndexBlockAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/readonly/TransportAddIndexBlockAction.java index 2b8f832b8aaf..867cd80fb68d 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/readonly/TransportAddIndexBlockAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/readonly/TransportAddIndexBlockAction.java @@ -102,13 +102,18 @@ protected void masterOperation( return; } - final AddIndexBlockClusterStateUpdateRequest addBlockRequest = new AddIndexBlockClusterStateUpdateRequest( - request.getBlock(), - task.getId() - ).ackTimeout(request.ackTimeout()).masterNodeTimeout(request.masterNodeTimeout()).indices(concreteIndices); - indexStateService.addIndexBlock(addBlockRequest, listener.delegateResponse((delegatedListener, t) -> { - logger.debug(() -> "failed to mark indices as readonly [" + Arrays.toString(concreteIndices) + "]", t); - delegatedListener.onFailure(t); - })); + indexStateService.addIndexBlock( + new AddIndexBlockClusterStateUpdateRequest( + request.masterNodeTimeout(), + request.ackTimeout(), + request.getBlock(), + task.getId(), + concreteIndices + ), + listener.delegateResponse((delegatedListener, t) -> { + logger.debug(() -> "failed to mark indices as readonly [" + Arrays.toString(concreteIndices) + "]", t); + delegatedListener.onFailure(t); + }) + ); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java index 00e7d2b05f2a..0c33878b0122 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java @@ -470,7 +470,7 @@ public void addIndexBlock(AddIndexBlockClusterStateUpdateRequest request, Action } addBlocksQueue.submitTask( - "add-index-block-[" + request.getBlock().name + "]-" + Arrays.toString(concreteIndices), + "add-index-block-[" + request.block().name + "]-" + Arrays.toString(concreteIndices), new AddBlocksTask(request, listener), request.masterNodeTimeout() ); @@ -480,7 +480,7 @@ private class AddBlocksExecutor extends SimpleBatchedExecutor> executeTask(AddBlocksTask task, ClusterState clusterState) { - return addIndexBlock(task.request.indices(), clusterState, task.request.getBlock()); + return addIndexBlock(task.request.indices(), clusterState, task.request.block()); } @Override @@ -497,7 +497,7 @@ public void taskSucceeded(AddBlocksTask task, Map blockedIn .delegateFailure( (delegate2, verifyResults) -> finalizeBlocksQueue.submitTask( "finalize-index-block-[" - + task.request.getBlock().name + + task.request.block().name + "]-[" + blockedIndices.keySet().stream().map(Index::getName).collect(Collectors.joining(", ")) + "]", @@ -529,7 +529,7 @@ public Tuple> executeTask(FinalizeBlocksTask clusterState, task.blockedIndices, task.verifyResults, - task.request.getBlock() + task.request.block() ); assert finalizeResult.v2().size() == task.verifyResults.size(); return finalizeResult; @@ -797,9 +797,7 @@ private void sendVerifyShardBlockRequest( block, parentTaskId ); - if (request.ackTimeout() != null) { - shardRequest.timeout(request.ackTimeout()); - } + shardRequest.timeout(request.ackTimeout()); client.executeLocally(TransportVerifyShardIndexBlockAction.TYPE, shardRequest, listener); } } From 8c81222b6699cbf7f0a790f1ba6b53ce1151269e Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Tue, 24 Sep 2024 11:04:08 +0100 Subject: [PATCH 10/14] Change default locale of date processors to ENGLISH (#112796) (#113438) It is English in the docs, so this fixes the code to match the docs. Note that this really impacts Elasticsearch when run on JDK 23 with the CLDR locale database, as in the COMPAT database pre-23, root and en are essentially the same. --- .../java/org/elasticsearch/ingest/common/DateProcessor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/DateProcessor.java b/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/DateProcessor.java index bfdf87f417b6..22db5a330fb4 100644 --- a/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/DateProcessor.java +++ b/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/DateProcessor.java @@ -98,7 +98,7 @@ private static ZoneId newDateTimeZone(String timezone) { } private static Locale newLocale(String locale) { - return locale == null ? Locale.ROOT : LocaleUtils.parse(locale); + return locale == null ? Locale.ENGLISH : LocaleUtils.parse(locale); } @Override From d553b8bef99c26d072f6ae1f50beaeec2629f0bf Mon Sep 17 00:00:00 2001 From: Andrei Dan Date: Tue, 24 Sep 2024 13:08:31 +0300 Subject: [PATCH 11/14] Implement `parseBytesRef` for TimeSeriesRoutingHashFieldType (#113373) (#113439) This implements the `parseBytesRef` method for the `_ts_routing_hash` field so we can parse the values generated by the companion `format` method. We parse the values when fetching them from the source when the field is used as a `sort` paired with `search_after`. Before this change a sort by and search_after `_ts_routing_hash` would yield an `UnsupportedOperationException` (cherry picked from commit 4e5e87037074e7b4a6ccd6b729da477f99aabeae) Signed-off-by: Andrei Dan --- docs/changelog/113373.yaml | 6 +++ .../test/tsdb/25_id_generation.yml | 47 +++++++++++++++++++ .../index/mapper/MapperFeatures.java | 3 +- .../TimeSeriesRoutingHashFieldMapper.java | 9 ++++ .../search/DocValueFormatTests.java | 13 +++++ 5 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 docs/changelog/113373.yaml diff --git a/docs/changelog/113373.yaml b/docs/changelog/113373.yaml new file mode 100644 index 000000000000..cbb3829e0342 --- /dev/null +++ b/docs/changelog/113373.yaml @@ -0,0 +1,6 @@ +pr: 113373 +summary: Implement `parseBytesRef` for `TimeSeriesRoutingHashFieldType` +area: TSDB +type: bug +issues: + - 112399 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/tsdb/25_id_generation.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/tsdb/25_id_generation.yml index 973832cf3ca7..4faa0424adb4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/tsdb/25_id_generation.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/tsdb/25_id_generation.yml @@ -65,6 +65,9 @@ setup: --- generates a consistent id: + - requires: + cluster_features: "tsdb.ts_routing_hash_doc_value_parse_byte_ref" + reason: _tsid routing hash doc value parsing has been fixed - do: bulk: refresh: true @@ -152,6 +155,50 @@ generates a consistent id: - match: { hits.hits.8._source.@timestamp: 2021-04-28T18:52:04.467Z } - match: { hits.hits.8._source.k8s.pod.uid: 947e4ced-1786-4e53-9e0c-5c447e959507 } + - do: + search: + index: id_generation_test + body: + query: + match_all: {} + sort: ["@timestamp", "_ts_routing_hash"] + _source: true + search_after: [ "2021-04-28T18:50:03.142Z", "cn4exQ" ] + docvalue_fields: [_ts_routing_hash] + + - match: {hits.total.value: 9} + + - match: { hits.hits.0._id: cZZNs7B9sSWsyrL5AAABeRnRGTM } + - match: { hits.hits.0._source.@timestamp: 2021-04-28T18:50:04.467Z } + - match: { hits.hits.0._source.k8s.pod.uid: 947e4ced-1786-4e53-9e0c-5c447e959507 } + + - match: { hits.hits.1._id: cn4excfoxSs_KdA5AAABeRnRYiY } + - match: { hits.hits.1._source.@timestamp: 2021-04-28T18:50:23.142Z } + - match: { hits.hits.1._source.k8s.pod.uid: df3145b3-0563-4d3b-a0f7-897eb2876ea9 } + + - match: { hits.hits.2._id: cZZNs7B9sSWsyrL5AAABeRnRZ1M } + - match: { hits.hits.2._source.@timestamp: 2021-04-28T18:50:24.467Z } + - match: { hits.hits.2._source.k8s.pod.uid: 947e4ced-1786-4e53-9e0c-5c447e959507 } + + - match: { hits.hits.3._id: cZZNs7B9sSWsyrL5AAABeRnRtXM } + - match: { hits.hits.3._source.@timestamp: 2021-04-28T18:50:44.467Z } + - match: { hits.hits.3._source.k8s.pod.uid: 947e4ced-1786-4e53-9e0c-5c447e959507 } + + - match: { hits.hits.4._id: cn4excfoxSs_KdA5AAABeRnR11Y } + - match: { hits.hits.4._source.@timestamp: 2021-04-28T18:50:53.142Z } + - match: { hits.hits.4._source.k8s.pod.uid: df3145b3-0563-4d3b-a0f7-897eb2876ea9 } + + - match: { hits.hits.5._id: cn4excfoxSs_KdA5AAABeRnR_mY } + - match: { hits.hits.5._source.@timestamp: 2021-04-28T18:51:03.142Z } + - match: { hits.hits.5._source.k8s.pod.uid: df3145b3-0563-4d3b-a0f7-897eb2876ea9 } + + - match: { hits.hits.6._id: cZZNs7B9sSWsyrL5AAABeRnSA5M } + - match: { hits.hits.6._source.@timestamp: 2021-04-28T18:51:04.467Z } + - match: { hits.hits.6._source.k8s.pod.uid: 947e4ced-1786-4e53-9e0c-5c447e959507 } + + - match: { hits.hits.7._id: cZZNs7B9sSWsyrL5AAABeRnS7fM } + - match: { hits.hits.7._source.@timestamp: 2021-04-28T18:52:04.467Z } + - match: { hits.hits.7._source.k8s.pod.uid: 947e4ced-1786-4e53-9e0c-5c447e959507 } --- index a new document on top of an old one: - do: diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index d2ca7a24a78f..ac7d10abc712 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -43,7 +43,8 @@ public Set getFeatures() { SourceFieldMapper.SYNTHETIC_SOURCE_COPY_TO_FIX, FlattenedFieldMapper.IGNORE_ABOVE_SUPPORT, IndexSettings.IGNORE_ABOVE_INDEX_LEVEL_SETTING, - SourceFieldMapper.SYNTHETIC_SOURCE_COPY_TO_INSIDE_OBJECTS_FIX + SourceFieldMapper.SYNTHETIC_SOURCE_COPY_TO_INSIDE_OBJECTS_FIX, + TimeSeriesRoutingHashFieldMapper.TS_ROUTING_HASH_FIELD_PARSES_BYTES_REF ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/TimeSeriesRoutingHashFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/TimeSeriesRoutingHashFieldMapper.java index 3c4a0ae4e51f..60f792068300 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/TimeSeriesRoutingHashFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/TimeSeriesRoutingHashFieldMapper.java @@ -14,6 +14,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.ByteUtils; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.fielddata.FieldData; @@ -45,6 +46,7 @@ public class TimeSeriesRoutingHashFieldMapper extends MetadataFieldMapper { public static final TimeSeriesRoutingHashFieldMapper INSTANCE = new TimeSeriesRoutingHashFieldMapper(); public static final TypeParser PARSER = new FixedTypeParser(c -> c.getIndexSettings().getMode().timeSeriesRoutingHashFieldMapper()); + static final NodeFeature TS_ROUTING_HASH_FIELD_PARSES_BYTES_REF = new NodeFeature("tsdb.ts_routing_hash_doc_value_parse_byte_ref"); static final class TimeSeriesRoutingHashFieldType extends MappedFieldType { @@ -64,6 +66,13 @@ public Object format(BytesRef value) { return Uid.decodeId(value.bytes, value.offset, value.length); } + @Override + public BytesRef parseBytesRef(Object value) { + if (value instanceof BytesRef valueAsBytesRef) { + return valueAsBytesRef; + } + return Uid.encodeId(value.toString()); + } }; private TimeSeriesRoutingHashFieldType() { diff --git a/server/src/test/java/org/elasticsearch/search/DocValueFormatTests.java b/server/src/test/java/org/elasticsearch/search/DocValueFormatTests.java index 0a830b598817..6b42dbbb39c9 100644 --- a/server/src/test/java/org/elasticsearch/search/DocValueFormatTests.java +++ b/server/src/test/java/org/elasticsearch/search/DocValueFormatTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.index.mapper.DateFieldMapper.Resolution; import org.elasticsearch.index.mapper.TimeSeriesIdFieldMapper.TimeSeriesIdBuilder; +import org.elasticsearch.index.mapper.TimeSeriesRoutingHashFieldMapper; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -33,6 +34,8 @@ import static org.elasticsearch.search.aggregations.bucket.geogrid.GeoTileUtils.longEncode; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; public class DocValueFormatTests extends ESTestCase { @@ -388,4 +391,14 @@ public void testParseTsid() throws IOException { Object tsidBase64 = Base64.getUrlEncoder().withoutPadding().encodeToString(expectedBytes); assertEquals(tsidFormat, tsidBase64); } + + public void testFormatAndParseTsRoutingHash() throws IOException { + BytesRef tsRoutingHashInput = new BytesRef("cn4exQ"); + DocValueFormat docValueFormat = TimeSeriesRoutingHashFieldMapper.INSTANCE.fieldType().docValueFormat(null, ZoneOffset.UTC); + Object formattedValue = docValueFormat.format(tsRoutingHashInput); + // the format method takes BytesRef as input and outputs a String + assertThat(formattedValue, instanceOf(String.class)); + // the parse method will output the BytesRef input + assertThat(docValueFormat.parseBytesRef(formattedValue), is(tsRoutingHashInput)); + } } From db00f0c106cdcaa7c15f0454678d07d37ecbab22 Mon Sep 17 00:00:00 2001 From: Pooya Salehi Date: Tue, 24 Sep 2024 12:12:43 +0200 Subject: [PATCH 12/14] Remove test logging from PrevalidateShardPathIT#testCheckShards (#113434) (#113440) Relates https://github.com/elastic/elasticsearch/pull/113107 Closes https://github.com/elastic/elasticsearch/issues/111134 --- .../elasticsearch/cluster/PrevalidateShardPathIT.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java index 062f4adb2712..87943dedc708 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java @@ -21,7 +21,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.junit.annotations.TestIssueLogging; import java.util.HashSet; import java.util.Set; @@ -41,15 +40,6 @@ @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0) public class PrevalidateShardPathIT extends ESIntegTestCase { - @TestIssueLogging( - value = "org.elasticsearch.cluster.service.MasterService:DEBUG," - + "org.elasticsearch.indices.store.IndicesStore:TRACE," - + "org.elasticsearch.indices.cluster.IndicesClusterStateService:DEBUG," - + "org.elasticsearch.indices.IndicesService:TRACE," - + "org.elasticsearch.index.IndexService:TRACE," - + "org.elasticsearch.env.NodeEnvironment:TRACE", - issueUrl = "https://github.com/elastic/elasticsearch/issues/111134" - ) public void testCheckShards() throws Exception { internalCluster().startMasterOnlyNode(); String node1 = internalCluster().startDataOnlyNode(); From 9a21ca63d7f2f307739fd24885a931ae8b75aea5 Mon Sep 17 00:00:00 2001 From: Salvatore Campagna <93581129+salvatore-campagna@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:47:09 +0200 Subject: [PATCH 13/14] LogsDB data migration integration testing (#112710) (#113448) Here we test reindexing logsdb indices, creating and restoring snapshots. Note that logsdb uses synthetic source and restoring source only snapshots fails due to missing _source. (cherry picked from commit f7880ae85f0be9f9a8c89c5415ce406c293a09db) --- .../repository-source-only.asciidoc | 2 +- .../datastreams/LogsDataStreamRestIT.java | 293 +++++++++++++++++- 2 files changed, 280 insertions(+), 15 deletions(-) diff --git a/docs/reference/snapshot-restore/repository-source-only.asciidoc b/docs/reference/snapshot-restore/repository-source-only.asciidoc index 07ddedd19793..04e53c42aff9 100644 --- a/docs/reference/snapshot-restore/repository-source-only.asciidoc +++ b/docs/reference/snapshot-restore/repository-source-only.asciidoc @@ -18,7 +18,7 @@ stream or index. ================================================== Source-only snapshots are only supported if the `_source` field is enabled and no source-filtering is applied. -When you restore a source-only snapshot: +As a result, indices adopting synthetic source cannot be restored. When you restore a source-only snapshot: * The restored index is read-only and can only serve `match_all` search or scroll requests to enable reindexing. diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/LogsDataStreamRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/LogsDataStreamRestIT.java index f62fa83b4e11..f95815d1daff 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/LogsDataStreamRestIT.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/LogsDataStreamRestIT.java @@ -9,16 +9,23 @@ package org.elasticsearch.datastreams; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.RestClient; import org.elasticsearch.common.network.InetAddresses; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.common.time.FormatNames; +import org.elasticsearch.index.IndexMode; +import org.elasticsearch.repositories.fs.FsRepository; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.hamcrest.Matchers; import org.junit.Before; import org.junit.ClassRule; @@ -41,6 +48,7 @@ public class LogsDataStreamRestIT extends ESRestTestCase { public static ElasticsearchCluster cluster = ElasticsearchCluster.local() .distribution(DistributionType.DEFAULT) .setting("xpack.security.enabled", "false") + .setting("xpack.license.self_generated.type", "trial") .build(); @Override @@ -102,7 +110,7 @@ private static void waitForLogs(RestClient client) throws Exception { } }"""; - private static final String STANDARD_TEMPLATE = """ + private static final String LOGS_STANDARD_INDEX_MODE = """ { "index_patterns": [ "logs-*-*" ], "data_stream": {}, @@ -135,6 +143,39 @@ private static void waitForLogs(RestClient client) throws Exception { } }"""; + private static final String STANDARD_TEMPLATE = """ + { + "index_patterns": [ "standard-*-*" ], + "data_stream": {}, + "priority": 201, + "template": { + "settings": { + "index": { + "mode": "standard" + } + }, + "mappings": { + "properties": { + "@timestamp" : { + "type": "date" + }, + "host.name": { + "type": "keyword" + }, + "pid": { + "type": "long" + }, + "method": { + "type": "keyword" + }, + "ip_address": { + "type": "ip" + } + } + } + } + }"""; + private static final String TIME_SERIES_TEMPLATE = """ { "index_patterns": [ "logs-*-*" ], @@ -203,7 +244,7 @@ public void testLogsIndexing() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 0); + assertDataStreamBackingIndexMode("logsdb", 0, DATA_STREAM_NAME); rolloverDataStream(client, DATA_STREAM_NAME); indexDocument( client, @@ -218,7 +259,7 @@ public void testLogsIndexing() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 1); + assertDataStreamBackingIndexMode("logsdb", 1, DATA_STREAM_NAME); } public void testLogsStandardIndexModeSwitch() throws IOException { @@ -237,9 +278,9 @@ public void testLogsStandardIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 0); + assertDataStreamBackingIndexMode("logsdb", 0, DATA_STREAM_NAME); - putTemplate(client, "custom-template", STANDARD_TEMPLATE); + putTemplate(client, "custom-template", LOGS_STANDARD_INDEX_MODE); rolloverDataStream(client, DATA_STREAM_NAME); indexDocument( client, @@ -254,7 +295,7 @@ public void testLogsStandardIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("standard", 1); + assertDataStreamBackingIndexMode("standard", 1, DATA_STREAM_NAME); putTemplate(client, "custom-template", LOGS_TEMPLATE); rolloverDataStream(client, DATA_STREAM_NAME); @@ -271,7 +312,7 @@ public void testLogsStandardIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 2); + assertDataStreamBackingIndexMode("logsdb", 2, DATA_STREAM_NAME); } public void testLogsTimeSeriesIndexModeSwitch() throws IOException { @@ -290,7 +331,7 @@ public void testLogsTimeSeriesIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 0); + assertDataStreamBackingIndexMode("logsdb", 0, DATA_STREAM_NAME); putTemplate(client, "custom-template", TIME_SERIES_TEMPLATE); rolloverDataStream(client, DATA_STREAM_NAME); @@ -307,7 +348,7 @@ public void testLogsTimeSeriesIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("time_series", 1); + assertDataStreamBackingIndexMode("time_series", 1, DATA_STREAM_NAME); putTemplate(client, "custom-template", LOGS_TEMPLATE); rolloverDataStream(client, DATA_STREAM_NAME); @@ -324,11 +365,193 @@ public void testLogsTimeSeriesIndexModeSwitch() throws IOException { randomLongBetween(1_000_000L, 2_000_000L) ) ); - assertDataStreamBackingIndexMode("logsdb", 2); + assertDataStreamBackingIndexMode("logsdb", 2, DATA_STREAM_NAME); + } + + public void testLogsDBToStandardReindex() throws IOException { + // LogsDB data stream + putTemplate(client, "logs-template", LOGS_TEMPLATE); + createDataStream(client, "logs-apache-kafka"); + + // Standard data stream + putTemplate(client, "standard-template", STANDARD_TEMPLATE); + createDataStream(client, "standard-apache-kafka"); + + // Index some documents in the LogsDB index + for (int i = 0; i < 10; i++) { + indexDocument( + client, + "logs-apache-kafka", + document( + Instant.now().plusSeconds(10), + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomFrom("PUT", "POST", "GET"), + randomAlphaOfLength(64), + randomIp(randomBoolean()), + randomLongBetween(1_000_000L, 2_000_000L) + ) + ); + } + assertDataStreamBackingIndexMode("logsdb", 0, "logs-apache-kafka"); + assertDocCount(client, "logs-apache-kafka", 10); + + // Reindex a LogsDB data stream into a standard data stream + final Request reindexRequest = new Request("POST", "/_reindex?refresh=true"); + reindexRequest.setJsonEntity(""" + { + "source": { + "index": "logs-apache-kafka" + }, + "dest": { + "index": "standard-apache-kafka", + "op_type": "create" + } + } + """); + assertOK(client.performRequest(reindexRequest)); + assertDataStreamBackingIndexMode("standard", 0, "standard-apache-kafka"); + assertDocCount(client, "standard-apache-kafka", 10); + } + + public void testStandardToLogsDBReindex() throws IOException { + // LogsDB data stream + putTemplate(client, "logs-template", LOGS_TEMPLATE); + createDataStream(client, "logs-apache-kafka"); + + // Standard data stream + putTemplate(client, "standard-template", STANDARD_TEMPLATE); + createDataStream(client, "standard-apache-kafka"); + + // Index some documents in a standard index + for (int i = 0; i < 10; i++) { + indexDocument( + client, + "standard-apache-kafka", + document( + Instant.now().plusSeconds(10), + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomFrom("PUT", "POST", "GET"), + randomAlphaOfLength(64), + randomIp(randomBoolean()), + randomLongBetween(1_000_000L, 2_000_000L) + ) + ); + } + assertDataStreamBackingIndexMode("standard", 0, "standard-apache-kafka"); + assertDocCount(client, "standard-apache-kafka", 10); + + // Reindex a standard data stream into a LogsDB data stream + final Request reindexRequest = new Request("POST", "/_reindex?refresh=true"); + reindexRequest.setJsonEntity(""" + { + "source": { + "index": "standard-apache-kafka" + }, + "dest": { + "index": "logs-apache-kafka", + "op_type": "create" + } + } + """); + assertOK(client.performRequest(reindexRequest)); + assertDataStreamBackingIndexMode("logsdb", 0, "logs-apache-kafka"); + assertDocCount(client, "logs-apache-kafka", 10); + } + + public void testLogsDBSnapshotCreateRestoreMount() throws IOException { + final String repository = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + registerRepository(repository, FsRepository.TYPE, Settings.builder().put("location", randomAlphaOfLength(6))); + + final String index = randomAlphaOfLength(12).toLowerCase(Locale.ROOT); + createIndex(client, index, Settings.builder().put("index.mode", IndexMode.LOGSDB.getName()).build()); + + for (int i = 0; i < 10; i++) { + indexDocument( + client, + index, + document( + Instant.now().plusSeconds(10), + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomFrom("PUT", "POST", "GET"), + randomAlphaOfLength(64), + randomIp(randomBoolean()), + randomLongBetween(1_000_000L, 2_000_000L) + ) + ); + } + + final String snapshot = randomAlphaOfLength(8).toLowerCase(Locale.ROOT); + deleteSnapshot(repository, snapshot, true); + createSnapshot(client, repository, snapshot, true, index); + wipeDataStreams(); + wipeAllIndices(); + restoreSnapshot(client, repository, snapshot, true, index); + + final String restoreIndex = randomAlphaOfLength(7).toLowerCase(Locale.ROOT); + final Request mountRequest = new Request("POST", "/_snapshot/" + repository + '/' + snapshot + "/_mount"); + mountRequest.addParameter("wait_for_completion", "true"); + mountRequest.setJsonEntity("{\"index\": \"" + index + "\",\"renamed_index\": \"" + restoreIndex + "\"}"); + + assertOK(client.performRequest(mountRequest)); + assertDocCount(client, restoreIndex, 10); + assertThat(getSettings(client, restoreIndex).get("index.mode"), Matchers.equalTo(IndexMode.LOGSDB.getName())); + } + + // NOTE: this test will fail on snapshot creation after fixing + // https://github.com/elastic/elasticsearch/issues/112735 + public void testLogsDBSourceOnlySnapshotCreation() throws IOException { + final String repository = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + registerRepository(repository, FsRepository.TYPE, Settings.builder().put("location", randomAlphaOfLength(6))); + // A source-only repository delegates storage to another repository + final String sourceOnlyRepository = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + registerRepository( + sourceOnlyRepository, + "source", + Settings.builder().put("delegate_type", FsRepository.TYPE).put("location", repository) + ); + + final String index = randomAlphaOfLength(12).toLowerCase(Locale.ROOT); + createIndex(client, index, Settings.builder().put("index.mode", IndexMode.LOGSDB.getName()).build()); + + for (int i = 0; i < 10; i++) { + indexDocument( + client, + index, + document( + Instant.now().plusSeconds(10), + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomFrom("PUT", "POST", "GET"), + randomAlphaOfLength(64), + randomIp(randomBoolean()), + randomLongBetween(1_000_000L, 2_000_000L) + ) + ); + } + + final String snapshot = randomAlphaOfLength(8).toLowerCase(Locale.ROOT); + deleteSnapshot(sourceOnlyRepository, snapshot, true); + createSnapshot(client, sourceOnlyRepository, snapshot, true, index); + wipeDataStreams(); + wipeAllIndices(); + // Can't snapshot _source only on an index that has incomplete source ie. has _source disabled or filters the source + final ResponseException responseException = expectThrows( + ResponseException.class, + () -> restoreSnapshot(client, sourceOnlyRepository, snapshot, true, index) + ); + assertThat(responseException.getMessage(), Matchers.containsString("wasn't fully snapshotted")); + } + + private static void registerRepository(final String repository, final String type, final Settings.Builder settings) throws IOException { + registerRepository(repository, type, false, settings.build()); } - private void assertDataStreamBackingIndexMode(final String indexMode, int backingIndex) throws IOException { - assertThat(getSettings(client, getWriteBackingIndex(client, DATA_STREAM_NAME, backingIndex)).get("index.mode"), is(indexMode)); + private void assertDataStreamBackingIndexMode(final String indexMode, int backingIndex, final String dataStreamName) + throws IOException { + assertThat(getSettings(client, getWriteBackingIndex(client, dataStreamName, backingIndex)).get("index.mode"), is(indexMode)); } private String document( @@ -364,8 +587,8 @@ private static void putTemplate(final RestClient client, final String templateNa assertOK(client.performRequest(request)); } - private static void indexDocument(final RestClient client, String dataStreamName, String doc) throws IOException { - final Request request = new Request("POST", "/" + dataStreamName + "/_doc?refresh=true"); + private static void indexDocument(final RestClient client, String indexOrtDataStream, String doc) throws IOException { + final Request request = new Request("POST", "/" + indexOrtDataStream + "/_doc?refresh=true"); request.setJsonEntity(doc); final Response response = client.performRequest(request); assertOK(response); @@ -393,4 +616,46 @@ private static Map getSettings(final RestClient client, final St final Request request = new Request("GET", "/" + indexName + "/_settings?flat_settings"); return ((Map>) entityAsMap(client.performRequest(request)).get(indexName)).get("settings"); } + + private static void createSnapshot( + RestClient restClient, + String repository, + String snapshot, + boolean waitForCompletion, + final String... indices + ) throws IOException { + final Request request = new Request(HttpPut.METHOD_NAME, "_snapshot/" + repository + '/' + snapshot); + request.addParameter("wait_for_completion", Boolean.toString(waitForCompletion)); + request.setJsonEntity(""" + "indices": $indices + """.replace("$indices", String.join(", ", indices))); + + final Response response = restClient.performRequest(request); + assertThat( + "Failed to create snapshot [" + snapshot + "] in repository [" + repository + "]: " + response, + response.getStatusLine().getStatusCode(), + equalTo(RestStatus.OK.getStatus()) + ); + } + + private static void restoreSnapshot( + final RestClient client, + final String repository, + String snapshot, + boolean waitForCompletion, + final String... indices + ) throws IOException { + final Request request = new Request(HttpPost.METHOD_NAME, "_snapshot/" + repository + '/' + snapshot + "/_restore"); + request.addParameter("wait_for_completion", Boolean.toString(waitForCompletion)); + request.setJsonEntity(""" + "indices": $indices + """.replace("$indices", String.join(", ", indices))); + + final Response response = client.performRequest(request); + assertThat( + "Failed to restore snapshot [" + snapshot + "] from repository [" + repository + "]: " + response, + response.getStatusLine().getStatusCode(), + equalTo(RestStatus.OK.getStatus()) + ); + } } From cb42fd45de3f61166f315c70e53d8519963209d2 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 24 Sep 2024 08:14:58 -0400 Subject: [PATCH 14/14] [ML] Stream Inference API (#113158) (#113423) Create `POST _inference///_stream` and `POST _inference//_stream` API. REST Streaming API will reuse InferenceAction. For now, all services and task types will return an HTTP 405 status code and error message. Co-authored-by: Elastic Machine --- docs/changelog/113158.yaml | 5 + .../inference/InferenceService.java | 17 ++ .../inference/action/InferenceAction.java | 18 +- .../action/InferenceActionRequestTests.java | 63 ++++-- .../AsyncInferenceResponseConsumer.java | 68 ++++++ .../inference/InferenceBaseRestTest.java | 61 +++++- .../xpack/inference/InferenceCrudIT.java | 58 +++++ .../mock/TestInferenceServicePlugin.java | 5 + ...stStreamingCompletionServiceExtension.java | 204 ++++++++++++++++++ ...search.inference.InferenceServiceExtension | 1 + .../xpack/inference/InferencePlugin.java | 2 + .../action/TransportInferenceAction.java | 65 ++++-- .../queries/SemanticQueryBuilder.java | 3 +- ...ankFeaturePhaseRankCoordinatorContext.java | 3 +- .../inference/rest/BaseInferenceAction.java | 55 +++++ .../xpack/inference/rest/Paths.java | 7 + .../inference/rest/RestInferenceAction.java | 35 +-- .../rest/RestStreamInferenceAction.java | 43 ++++ .../TextSimilarityRankTests.java | 3 +- .../TextSimilarityTestPlugin.java | 3 +- .../rest/BaseInferenceActionTests.java | 107 +++++++++ .../rest/RestInferenceActionTests.java | 40 +--- .../rest/RestStreamInferenceActionTests.java | 50 +++++ .../TransportCoordinatedInferenceAction.java | 3 +- 24 files changed, 798 insertions(+), 121 deletions(-) create mode 100644 docs/changelog/113158.yaml create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/AsyncInferenceResponseConsumer.java create mode 100644 x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java diff --git a/docs/changelog/113158.yaml b/docs/changelog/113158.yaml new file mode 100644 index 000000000000..d097ea11b3a2 --- /dev/null +++ b/docs/changelog/113158.yaml @@ -0,0 +1,5 @@ +pr: 113158 +summary: Adds a new Inference API for streaming responses back to the user. +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index a37fb3dd7567..9e9a4cf89037 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -188,4 +188,21 @@ default boolean isInClusterService() { * @return {@link TransportVersion} specifying the version */ TransportVersion getMinimalSupportedVersion(); + + /** + * The set of tasks where this service provider supports using the streaming API. + * @return set of supported task types. Defaults to empty. + */ + default Set supportedStreamingTasks() { + return Set.of(); + } + + /** + * Checks the task type against the set of supported streaming tasks returned by {@link #supportedStreamingTasks()}. + * @param taskType the task that supports streaming + * @return true if the taskType is supported + */ + default boolean canStream(TaskType taskType) { + return supportedStreamingTasks().contains(taskType); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index d898f961651f..a19edd5a0816 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -92,6 +92,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType, private final Map taskSettings; private final InputType inputType; private final TimeValue inferenceTimeout; + private final boolean stream; public Request( TaskType taskType, @@ -100,7 +101,8 @@ public Request( List input, Map taskSettings, InputType inputType, - TimeValue inferenceTimeout + TimeValue inferenceTimeout, + boolean stream ) { this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; @@ -109,6 +111,7 @@ public Request( this.taskSettings = taskSettings; this.inputType = inputType; this.inferenceTimeout = inferenceTimeout; + this.stream = stream; } public Request(StreamInput in) throws IOException { @@ -134,6 +137,9 @@ public Request(StreamInput in) throws IOException { this.query = null; this.inferenceTimeout = DEFAULT_TIMEOUT; } + + // streaming is not supported yet for transport traffic + this.stream = false; } public TaskType getTaskType() { @@ -165,7 +171,7 @@ public TimeValue getInferenceTimeout() { } public boolean isStreaming() { - return false; + return stream; } @Override @@ -261,6 +267,7 @@ public static class Builder { private Map taskSettings = Map.of(); private String query; private TimeValue timeout = DEFAULT_TIMEOUT; + private boolean stream = false; private Builder() {} @@ -303,8 +310,13 @@ private Builder setInferenceTimeout(String inferenceTimeout) { return setInferenceTimeout(TimeValue.parseTimeValue(inferenceTimeout, TIMEOUT.getPreferredName())); } + public Builder setStream(boolean stream) { + this.stream = stream; + return this; + } + public Request build() { - return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout); + return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index f41e117e75b9..a9ca5e6da872 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -46,7 +46,8 @@ protected InferenceAction.Request createTestInstance() { randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), - TimeValue.timeValueMillis(randomLongBetween(1, 2048)) + TimeValue.timeValueMillis(randomLongBetween(1, 2048)), + false ); } @@ -80,7 +81,8 @@ public void testValidation_TextEmbedding() { List.of("input"), null, null, - null + null, + false ); ActionRequestValidationException e = request.validate(); assertNull(e); @@ -94,7 +96,8 @@ public void testValidation_Rerank() { List.of("input"), null, null, - null + null, + false ); ActionRequestValidationException e = request.validate(); assertNull(e); @@ -108,7 +111,8 @@ public void testValidation_TextEmbedding_Null() { null, null, null, - null + null, + false ); ActionRequestValidationException inputNullError = inputNullRequest.validate(); assertNotNull(inputNullError); @@ -123,7 +127,8 @@ public void testValidation_TextEmbedding_Empty() { List.of(), null, null, - null + null, + false ); ActionRequestValidationException inputEmptyError = inputEmptyRequest.validate(); assertNotNull(inputEmptyError); @@ -138,7 +143,8 @@ public void testValidation_Rerank_Null() { List.of("input"), null, null, - null + null, + false ); ActionRequestValidationException queryNullError = queryNullRequest.validate(); assertNotNull(queryNullError); @@ -153,7 +159,8 @@ public void testValidation_Rerank_Empty() { List.of("input"), null, null, - null + null, + false ); ActionRequestValidationException queryEmptyError = queryEmptyRequest.validate(); assertNotNull(queryEmptyError); @@ -185,7 +192,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); } case 1 -> new InferenceAction.Request( @@ -195,7 +203,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); case 2 -> { var changedInputs = new ArrayList(instance.getInput()); @@ -207,7 +216,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc changedInputs, instance.getTaskSettings(), instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); } case 3 -> { @@ -225,7 +235,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), taskSettings, instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); } case 4 -> { @@ -237,7 +248,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), nextInputType, - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); } case 5 -> new InferenceAction.Request( @@ -247,7 +259,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - instance.getInferenceTimeout() + instance.getInferenceTimeout(), + false ); case 6 -> { var newDuration = Duration.of( @@ -262,7 +275,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()) + TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()), + false ); } default -> throw new UnsupportedOperationException(); @@ -279,7 +293,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput().subList(0, 1), instance.getTaskSettings(), InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } else if (version.before(TransportVersions.V_8_13_0)) { return new InferenceAction.Request( @@ -289,7 +304,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } else if (version.before(TransportVersions.V_8_13_0) && (instance.getInputType() == InputType.UNSPECIFIED @@ -302,7 +318,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput(), instance.getTaskSettings(), InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } else if (version.before(TransportVersions.V_8_13_0) && (instance.getInputType() == InputType.CLUSTERING || instance.getInputType() == InputType.CLASSIFICATION)) { @@ -313,7 +330,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } else if (version.before(TransportVersions.V_8_14_0)) { return new InferenceAction.Request( @@ -323,7 +341,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInput(), instance.getTaskSettings(), instance.getInputType(), - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } @@ -339,7 +358,8 @@ public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOExceptio List.of(), Map.of(), InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ), TransportVersions.V_8_13_0 ); @@ -353,7 +373,8 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn List.of(), Map.of(), InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); InferenceAction.Request deserializedInstance = copyWriteable( diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/AsyncInferenceResponseConsumer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/AsyncInferenceResponseConsumer.java new file mode 100644 index 000000000000..eb5f3c75bab6 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/AsyncInferenceResponseConsumer.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpResponse; +import org.apache.http.entity.ContentType; +import org.apache.http.nio.ContentDecoder; +import org.apache.http.nio.IOControl; +import org.apache.http.nio.protocol.AbstractAsyncResponseConsumer; +import org.apache.http.nio.util.SimpleInputBuffer; +import org.apache.http.protocol.HttpContext; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.concurrent.atomic.AtomicReference; + +class AsyncInferenceResponseConsumer extends AbstractAsyncResponseConsumer { + private final AtomicReference httpResponse = new AtomicReference<>(); + private final Deque collector = new ArrayDeque<>(); + private final ServerSentEventParser sseParser = new ServerSentEventParser(); + private final SimpleInputBuffer inputBuffer = new SimpleInputBuffer(4096); + + @Override + protected void onResponseReceived(HttpResponse httpResponse) { + this.httpResponse.set(httpResponse); + } + + @Override + protected void onContentReceived(ContentDecoder contentDecoder, IOControl ioControl) throws IOException { + inputBuffer.consumeContent(contentDecoder); + } + + @Override + protected void onEntityEnclosed(HttpEntity httpEntity, ContentType contentType) { + httpResponse.updateAndGet(response -> { + response.setEntity(httpEntity); + return response; + }); + } + + @Override + protected HttpResponse buildResult(HttpContext httpContext) { + var allBytes = new byte[inputBuffer.length()]; + try { + inputBuffer.read(allBytes); + sseParser.parse(allBytes).forEach(collector::offer); + } catch (IOException e) { + failed(e); + } + return httpResponse.get(); + } + + @Override + protected void releaseResources() {} + + Deque events() { + return collector; + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index f30f2e8fe201..c19cd916055d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -9,7 +9,9 @@ import org.apache.http.util.EntityUtils; import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; @@ -19,11 +21,15 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.junit.ClassRule; import java.io.IOException; +import java.util.Deque; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -72,6 +78,23 @@ static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) { """, taskType); } + static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) { + var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; + return Strings.format(""" + { + %s + "service": "streaming_completion_test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + "temperature": 3 + } + } + """, taskType); + } + static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) { var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; return Strings.format(""" @@ -252,6 +275,32 @@ protected Map inferOnMockService(String modelId, List in return inferOnMockServiceInternal(endpoint, input); } + protected Deque streamInferOnMockService(String modelId, TaskType taskType, List input) throws Exception { + var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId); + return callAsync(endpoint, input); + } + + private Deque callAsync(String endpoint, List input) throws Exception { + var responseConsumer = new AsyncInferenceResponseConsumer(); + var request = new Request("POST", endpoint); + request.setJsonEntity(jsonBody(input)); + request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); + var latch = new CountDownLatch(1); + client().performRequestAsync(request, new ResponseListener() { + @Override + public void onSuccess(Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exception) { + latch.countDown(); + } + }); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + return responseConsumer.events(); + } + protected Map inferOnMockService(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return inferOnMockServiceInternal(endpoint, input); @@ -259,7 +308,13 @@ protected Map inferOnMockService(String modelId, TaskType taskTy private Map inferOnMockServiceInternal(String endpoint, List input) throws IOException { var request = new Request("POST", endpoint); + request.setJsonEntity(jsonBody(input)); + var response = client().performRequest(request); + assertOkOrCreated(response); + return entityAsMap(response); + } + private String jsonBody(List input) { var bodyBuilder = new StringBuilder("{\"input\": ["); for (var in : input) { bodyBuilder.append('"').append(in).append('"').append(','); @@ -267,11 +322,7 @@ private Map inferOnMockServiceInternal(String endpoint, List { + switch (event.name()) { + case EVENT -> assertThat(event.value(), equalToIgnoringCase("error")); + case DATA -> assertThat( + event.value(), + containsString( + "Streaming is not allowed for service [streaming_completion_test_service] and task [sparse_embedding]" + ) + ); + } + }); + } finally { + deleteModel(modelId); + } + } + + public void testSupportedStream() throws Exception { + String modelId = "streaming"; + putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + var singleModel = getModel(modelId); + assertEquals(modelId, singleModel.get("inference_id")); + assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); + + var input = IntStream.range(0, randomInt(10)).mapToObj(i -> randomAlphaOfLength(10)).toList(); + + try { + var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input); + + var expectedResponses = Stream.concat( + input.stream().map(String::toUpperCase).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"), + Stream.of("[DONE]") + ).iterator(); + assertThat(events.size(), equalTo((input.size() + 1) * 2)); + events.forEach(event -> { + switch (event.name()) { + case EVENT -> assertThat(event.value(), equalToIgnoringCase("message")); + case DATA -> assertThat(event.value(), equalTo(expectedResponses.next())); + } + }); + } finally { + deleteModel(modelId); + } + } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java index 752472b90374..eef0da909f52 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java @@ -44,6 +44,11 @@ public List getNamedWriteables() { ServiceSettings.class, TestRerankingServiceExtension.TestServiceSettings.NAME, TestRerankingServiceExtension.TestServiceSettings::new + ), + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + TestStreamingCompletionServiceExtension.TestServiceSettings.NAME, + TestStreamingCompletionServiceExtension.TestServiceSettings::new ) ); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java new file mode 100644 index 000000000000..3d72b1f2729b --- /dev/null +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -0,0 +1,204 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mock; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Flow; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; + +public class TestStreamingCompletionServiceExtension implements InferenceServiceExtension { + @Override + public List getInferenceServiceFactories() { + return List.of(TestInferenceService::new); + } + + public static class TestInferenceService extends AbstractTestInferenceService { + private static final String NAME = "streaming_completion_test_service"; + private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION); + + public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} + + @Override + public String name() { + return NAME; + } + + @Override + protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { + return TestServiceSettings.fromMap(serviceSettingsMap); + } + + @Override + @SuppressWarnings("unchecked") + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + var serviceSettings = TestSparseInferenceServiceExtension.TestServiceSettings.fromMap(serviceSettingsMap); + var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings)); + } + + @Override + public void infer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + switch (model.getConfigurations().getTaskType()) { + case COMPLETION -> listener.onResponse(makeResults(input)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + + private StreamingChatCompletionResults makeResults(List input) { + var responseIter = input.stream().map(String::toUpperCase).iterator(); + return new StreamingChatCompletionResults(subscriber -> { + subscriber.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + if (responseIter.hasNext()) { + subscriber.onNext(completionChunk(responseIter.next())); + } else { + subscriber.onComplete(); + } + } + + @Override + public void cancel() {} + }); + }); + } + + private ChunkedToXContent completionChunk(String delta) { + return params -> Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.startArray(COMPLETION), + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field("delta", delta), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.endObject() + ); + } + + @Override + public void chunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + + @Override + public Set supportedStreamingTasks() { + return supportedStreamingTasks; + } + } + + public record TestServiceSettings(String modelId) implements ServiceSettings { + public static final String NAME = "streaming_completion_test_service_settings"; + + public TestServiceSettings(StreamInput in) throws IOException { + this(in.readString()); + } + + public static TestServiceSettings fromMap(Map map) { + var modelId = map.remove("model").toString(); + + if (modelId == null) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("missing model id"); + throw validationException; + } + + return new TestServiceSettings(modelId); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId()); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field("model", modelId()).endObject(); + } + } +} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension b/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension index 690168b538fb..c996a33d1e91 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension @@ -1,3 +1,4 @@ org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension +org.elasticsearch.xpack.inference.mock.TestStreamingCompletionServiceExtension diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 1cec996400a9..a6972ddc214f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -73,6 +73,7 @@ import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; +import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService; @@ -167,6 +168,7 @@ public List getRestHandlers( ) { return List.of( new RestInferenceAction(), + new RestStreamInferenceAction(), new RestGetInferenceModelAction(), new RestPutInferenceModelAction(), new RestDeleteInferenceEndpointAction(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index bfdfca166ef3..803e8f1e0761 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -26,10 +27,17 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.core.Strings.format; + public class TransportInferenceAction extends HandledTransportAction { private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; + private static final Set> supportsStreaming = Set.of(); + private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; private final InferenceStats inferenceStats; @@ -101,15 +109,40 @@ private void inferOnService( InferenceService service, ActionListener listener ) { - service.infer( - model, - request.getQuery(), - request.getInput(), - request.getTaskSettings(), - request.getInputType(), - request.getInferenceTimeout(), - createListener(request, listener) - ); + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + service.infer( + model, + request.getQuery(), + request.getInput(), + request.getTaskSettings(), + request.getInputType(), + request.getInferenceTimeout(), + createListener(request, listener) + ); + } else { + listener.onFailure(unsupportedStreamingTaskException(request, service)); + } + } + + private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { + var supportedTasks = service.supportedStreamingTasks(); + if (supportedTasks.isEmpty()) { + return new ElasticsearchStatusException( + format("Streaming is not allowed for service [%s].", service.name()), + RestStatus.METHOD_NOT_ALLOWED + ); + } else { + var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); + return new ElasticsearchStatusException( + format( + "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", + service.name(), + request.getTaskType(), + validTasks + ), + RestStatus.METHOD_NOT_ALLOWED + ); + } } private ActionListener createListener( @@ -118,17 +151,9 @@ private ActionListener createListener( ) { if (request.isStreaming()) { return listener.delegateFailureAndWrap((l, inferenceResults) -> { - if (inferenceResults.isStreaming()) { - var taskProcessor = streamingTaskManager.create( - STREAMING_INFERENCE_TASK_TYPE, - STREAMING_TASK_ACTION - ); - inferenceResults.publisher().subscribe(taskProcessor); - l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor)); - } else { - // if we asked for streaming but the provider doesn't support it, for now we're going to get back the single response - l.onResponse(new InferenceAction.Response(inferenceResults)); - } + var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); + inferenceResults.publisher().subscribe(taskProcessor); + l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor)); }); } return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults))); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 8f1e28d0d8ee..7f21f94d3327 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -204,7 +204,8 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu List.of(query), Map.of(), InputType.SEARCH, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, + false ); queryRewriteContext.registerAsyncAction( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index cad11cbdc9d5..0ff48bfd493b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -144,7 +144,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { docFeatures, Map.of(), InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java new file mode 100644 index 000000000000..e72e68052f64 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.io.IOException; + +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; +import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; + +abstract class BaseInferenceAction extends BaseRestHandler { + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String inferenceEntityId; + TaskType taskType; + if (restRequest.hasParam(INFERENCE_ID)) { + inferenceEntityId = restRequest.param(INFERENCE_ID); + taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + } else { + inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); + taskType = TaskType.ANY; + } + + InferenceAction.Request.Builder requestBuilder; + try (var parser = restRequest.contentParser()) { + requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser); + } + + var inferTimeout = restRequest.paramAsTime( + InferenceAction.Request.TIMEOUT.getPreferredName(), + InferenceAction.Request.DEFAULT_TIMEOUT + ); + requestBuilder.setInferenceTimeout(inferTimeout); + var request = prepareInferenceRequest(requestBuilder); + return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); + } + + protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { + return builder.build(); + } + + protected abstract ActionListener listener(RestChannel channel); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index e33931f3d2f8..9f64b58e48b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -15,6 +15,13 @@ public final class Paths { static final String TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/{" + INFERENCE_ID + "}"; static final String INFERENCE_DIAGNOSTICS_PATH = "_inference/.diagnostics"; + static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_stream"; + static final String STREAM_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + + TASK_TYPE_OR_INFERENCE_ID + + "}/{" + + INFERENCE_ID + + "}/_stream"; + private Paths() { } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index f5c30d0a94c5..0fbc2f8214cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -7,26 +7,21 @@ package org.elasticsearch.xpack.inference.rest; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestChunkedToXContentListener; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.POST; -import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH; import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH; -import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; @ServerlessScope(Scope.PUBLIC) -public class RestInferenceAction extends BaseRestHandler { +public class RestInferenceAction extends BaseInferenceAction { @Override public String getName() { return "inference_action"; @@ -38,27 +33,7 @@ public List routes() { } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String inferenceEntityId; - TaskType taskType; - if (restRequest.hasParam(INFERENCE_ID)) { - inferenceEntityId = restRequest.param(INFERENCE_ID); - taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); - } else { - inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); - taskType = TaskType.ANY; - } - - InferenceAction.Request.Builder requestBuilder; - try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser); - } - - var inferTimeout = restRequest.paramAsTime( - InferenceAction.Request.TIMEOUT.getPreferredName(), - InferenceAction.Request.DEFAULT_TIMEOUT - ); - requestBuilder.setInferenceTimeout(inferTimeout); - return channel -> client.execute(InferenceAction.INSTANCE, requestBuilder.build(), new RestChunkedToXContentListener<>(channel)); + protected ActionListener listener(RestChannel channel) { + return new RestChunkedToXContentListener<>(channel); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java new file mode 100644 index 000000000000..875c288da52b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.Scope; +import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH; +import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_TASK_TYPE_INFERENCE_ID_PATH; + +@ServerlessScope(Scope.PUBLIC) +public class RestStreamInferenceAction extends BaseInferenceAction { + @Override + public String getName() { + return "stream_inference_action"; + } + + @Override + public List routes() { + return List.of(new Route(POST, STREAM_INFERENCE_ID_PATH), new Route(POST, STREAM_TASK_TYPE_INFERENCE_ID_PATH)); + } + + @Override + protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { + return builder.setStream(true).build(); + } + + @Override + protected ActionListener listener(RestChannel channel) { + return new ServerSentEventsRestActionListener(channel); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index a26dc50097cf..a042fca44fdb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -92,7 +92,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { docFeatures, Map.of("inferenceResultCount", inferenceResultCount), InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 6d0c15d5c0bf..120527f48954 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -312,7 +312,8 @@ protected InferenceAction.Request generateRequest(List docFeatures) { docFeatures, Map.of("throwing", true), InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT + InferenceAction.Request.DEFAULT_TIMEOUT, + false ); } }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java new file mode 100644 index 000000000000..05a8d52be5df --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestChunkedToXContentListener; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; +import org.junit.Before; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class BaseInferenceActionTests extends RestActionTestCase { + + @Before + public void setUpAction() { + controller().registerHandler(new BaseInferenceAction() { + @Override + protected ActionListener listener(RestChannel channel) { + return new RestChunkedToXContentListener<>(channel); + } + + @Override + public String getName() { + return "base_inference_action"; + } + + @Override + public List routes() { + return List.of(new Route(POST, route("{task_type_or_id}"))); + } + }); + } + + private static String route(String param) { + return "_route/" + param; + } + + public void testUsesDefaultTimeout() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + + var request = (InferenceAction.Request) actionRequest; + assertThat(request.getInferenceTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT)); + + executeCalled.set(true); + return createResponse(); + })); + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(route("test")) + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } + + public void testUses3SecondTimeoutFromParams() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + + var request = (InferenceAction.Request) actionRequest; + assertThat(request.getInferenceTimeout(), is(TimeValue.timeValueSeconds(3))); + + executeCalled.set(true); + return createResponse(); + })); + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(route("test")) + .withParams(new HashMap<>(Map.of("timeout", "3s"))) + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } + + static InferenceAction.Response createResponse() { + return new InferenceAction.Response( + new InferenceTextEmbeddingByteResults( + List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1 })) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java index 48e5d54a6273..1b0df1b4a20d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java @@ -9,19 +9,14 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.junit.Before; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - +import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -33,13 +28,13 @@ public void setUpAction() { controller().registerHandler(new RestInferenceAction()); } - public void testUsesDefaultTimeout() { + public void testStreamIsFalse() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); var request = (InferenceAction.Request) actionRequest; - assertThat(request.getInferenceTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT)); + assertThat(request.isStreaming(), is(false)); executeCalled.set(true); return createResponse(); @@ -52,33 +47,4 @@ public void testUsesDefaultTimeout() { dispatchRequest(inferenceRequest); assertThat(executeCalled.get(), equalTo(true)); } - - public void testUses3SecondTimeoutFromParams() { - SetOnce executeCalled = new SetOnce<>(); - verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); - - var request = (InferenceAction.Request) actionRequest; - assertThat(request.getInferenceTimeout(), is(TimeValue.timeValueSeconds(3))); - - executeCalled.set(true); - return createResponse(); - })); - - RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withPath("_inference/test") - .withParams(new HashMap<>(Map.of("timeout", "3s"))) - .withContent(new BytesArray("{}"), XContentType.JSON) - .build(); - dispatchRequest(inferenceRequest); - assertThat(executeCalled.get(), equalTo(true)); - } - - private static InferenceAction.Response createResponse() { - return new InferenceAction.Response( - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1 })) - ) - ); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java new file mode 100644 index 000000000000..b999e2c9b72f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.junit.Before; + +import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class RestStreamInferenceActionTests extends RestActionTestCase { + + @Before + public void setUpAction() { + controller().registerHandler(new RestStreamInferenceAction()); + } + + public void testStreamIsTrue() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + + var request = (InferenceAction.Request) actionRequest; + assertThat(request.isStreaming(), is(true)); + + executeCalled.set(true); + return createResponse(); + })); + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath("_inference/test/_stream") + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index fd13e3de4e6c..ab5a9d43fd6d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -126,7 +126,8 @@ private void doInferenceServiceModel(CoordinatedInferenceAction.Request request, request.getInputs(), request.getTaskSettings(), inputType, - request.getInferenceTimeout() + request.getInferenceTimeout(), + false ), listener.delegateFailureAndWrap((l, r) -> l.onResponse(translateInferenceServiceResponse(r.getResults()))) );