From 20b8ba2415f8c7c52f1d62c2fe83ad0e855567bc Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 9 Jun 2025 11:55:12 +0100 Subject: [PATCH 1/5] Increment inference stats counter for shard bulk inference calls This change updates the inference stats counter to include chunked inference calls performed by the shard bulk inference filter on all semantic text fields. It ensures that usage of inference on semantic text fields is properly recorded in the stats. --- .../xpack/inference/InferencePlugin.java | 12 ++- .../ShardBulkInferenceActionFilter.java | 22 ++++- .../ShardBulkInferenceActionFilterTests.java | 95 ++++++++++++++++--- 3 files changed, 108 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 915a4d3f7af9b..2709d9de19c5c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -344,22 +344,24 @@ public Collection createComponents(PluginServices services) { } inferenceServiceRegistry.set(serviceRegistry); + var meterRegistry = services.telemetryProvider().getMeterRegistry(); + var inferenceStats = InferenceStats.create(meterRegistry); + var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats); + var actionFilter = new ShardBulkInferenceActionFilter( services.clusterService(), serviceRegistry, modelRegistry.get(), getLicenseState(), - services.indexingPressure() + services.indexingPressure(), + inferenceStats ); shardBulkInferenceActionFilter.set(actionFilter); - var meterRegistry = services.telemetryProvider().getMeterRegistry(); - var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); - components.add(serviceRegistry); components.add(modelRegistry.get()); components.add(httpClientManager); - components.add(inferenceStats); + components.add(inferenceStatsBinding); // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, // if the rate limiting feature flags are enabled, otherwise provide noop implementation diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index a4ab8663e8664..66e16ccb81952 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -63,6 +63,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.io.IOException; import java.util.ArrayList; @@ -78,6 +79,8 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified @@ -112,6 +115,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private final ModelRegistry modelRegistry; private final XPackLicenseState licenseState; private final IndexingPressure indexingPressure; + private final InferenceStats inferenceStats; private volatile long batchSizeInBytes; public ShardBulkInferenceActionFilter( @@ -119,13 +123,15 @@ public ShardBulkInferenceActionFilter( InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, XPackLicenseState licenseState, - IndexingPressure indexingPressure + IndexingPressure indexingPressure, + InferenceStats inferenceStats ) { this.clusterService = clusterService; this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; this.licenseState = licenseState; this.indexingPressure = indexingPressure; + this.inferenceStats = inferenceStats; this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes(); clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize); } @@ -386,10 +392,12 @@ public void onFailure(Exception exc) { public void onResponse(List results) { try (onFinish) { var requestsIterator = requests.iterator(); + int success = 0; for (ChunkedInference result : results) { var request = requestsIterator.next(); var acc = inferenceResults.get(request.bulkItemIndex); if (result instanceof ChunkedInferenceError error) { + recordRequestCountMetrics(inferenceProvider.model, 1, error.exception()); acc.addFailure( new InferenceException( "Exception when running inference id [{}] on field [{}]", @@ -399,6 +407,7 @@ public void onResponse(List results) { ) ); } else { + success++; acc.addOrUpdateResponse( new FieldInferenceResponse( request.field(), @@ -412,12 +421,16 @@ public void onResponse(List results) { ); } } + if (success > 0) { + recordRequestCountMetrics(inferenceProvider.model, success, null); + } } } @Override public void onFailure(Exception exc) { try (onFinish) { + recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc); for (FieldInferenceRequest request : requests) { addInferenceResponseFailure( request.bulkItemIndex, @@ -444,6 +457,13 @@ public void onFailure(Exception exc) { ); } + private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) { + Map requestCountAttributes = new HashMap<>(); + requestCountAttributes.putAll(modelAttributes(model)); + requestCountAttributes.putAll(responseAttributes(throwable)); + inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes); + } + /** * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap} * for the specified {@code item}. diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index f592774b7a356..4af7b5aa7508a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -66,6 +66,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.After; import org.junit.Before; import org.mockito.stubbing.Answer; @@ -80,6 +81,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; @@ -103,9 +105,11 @@ import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.assertArg; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.longThat; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -127,7 +131,9 @@ public ShardBulkInferenceActionFilterTests(boolean useLegacyFormat) { @ParametersFactory public static Iterable parameters() throws Exception { - return List.of(new Object[] { true }, new Object[] { false }); + List lst = new ArrayList<>(); + lst.add(new Object[] { true }); + return lst; } @Before @@ -142,7 +148,15 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, true); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(), + NOOP_INDEXING_PRESSURE, + useLegacyFormat, + true, + inferenceStats + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -167,8 +181,16 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, false); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(), + NOOP_INDEXING_PRESSURE, + useLegacyFormat, + false, + inferenceStats + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -205,13 +227,15 @@ public void testLicenseInvalidForInference() throws InterruptedException { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { @@ -251,14 +275,15 @@ public void testInferenceNotFound() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true + true, + inferenceStats ); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); @@ -316,10 +341,29 @@ public void testItemFailures() throws Exception { request.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + + AtomicInteger success = new AtomicInteger(0); + AtomicInteger failed = new AtomicInteger(0); + verify(inferenceStats.requestCount(), atMost(3)).incrementBy(anyLong(), assertArg(attributes -> { + var statusCode = attributes.get("status_code"); + if (statusCode == null) { + failed.incrementAndGet(); + assertThat(attributes.get("error.type"), is("IllegalArgumentException")); + } else { + success.incrementAndGet(); + assertThat(statusCode, is(200)); + } + assertThat(attributes.get("task_type"), is(model.getTaskType().toString())); + assertThat(attributes.get("model_id"), is(model.getServiceSettings().modelId())); + assertThat(attributes.get("service"), is(model.getConfigurations().getService())); + })); + assertThat(success.get(), equalTo(1)); + assertThat(failed.get(), equalTo(2)); } @SuppressWarnings({ "unchecked", "rawtypes" }) public void testExplicitNull() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); @@ -329,7 +373,8 @@ public void testExplicitNull() throws Exception { Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -394,13 +439,15 @@ public void testExplicitNull() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testHandleEmptyInput() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -447,6 +494,7 @@ public void testHandleEmptyInput() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); Map inferenceModelMap = new HashMap<>(); int numModels = randomIntBetween(1, 3); for (int i = 0; i < numModels; i++) { @@ -471,7 +519,14 @@ public void testManyRandomDocs() throws Exception { modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, true); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + inferenceModelMap, + NOOP_INDEXING_PRESSURE, + useLegacyFormat, + true, + inferenceStats + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -503,6 +558,7 @@ public void testManyRandomDocs() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testIndexingPressure() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING); @@ -511,7 +567,8 @@ public void testIndexingPressure() throws Exception { Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel), indexingPressure, useLegacyFormat, - true + true, + inferenceStats ); XContentBuilder doc0Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "a test value"); @@ -619,6 +676,7 @@ public void testIndexingPressure() throws Exception { @SuppressWarnings("unchecked") public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build() ); @@ -628,7 +686,8 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true + true, + inferenceStats ); XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); @@ -702,6 +761,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build() ); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar"))); @@ -710,7 +770,8 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -813,12 +874,14 @@ public void testIndexingPressurePartialFailure() throws Exception { .build() ); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -893,7 +956,8 @@ private static ShardBulkInferenceActionFilter createFilter( Map modelMap, IndexingPressure indexingPressure, boolean useLegacyFormat, - boolean isLicenseValidForInference + boolean isLicenseValidForInference, + InferenceStats inferenceStats ) { ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { @@ -970,7 +1034,8 @@ private static ShardBulkInferenceActionFilter createFilter( inferenceServiceRegistry, modelRegistry, licenseState, - indexingPressure + indexingPressure, + inferenceStats ); } From b9e4540f469a286b2667b3c8312c3d43c297078e Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 9 Jun 2025 11:58:45 +0100 Subject: [PATCH 2/5] Update docs/changelog/129140.yaml --- docs/changelog/129140.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/129140.yaml diff --git a/docs/changelog/129140.yaml b/docs/changelog/129140.yaml new file mode 100644 index 0000000000000..e7ee59122c34f --- /dev/null +++ b/docs/changelog/129140.yaml @@ -0,0 +1,5 @@ +pr: 129140 +summary: Increment inference stats counter for shard bulk inference calls +area: Machine Learning +type: enhancement +issues: [] From d66afd00c3453b09fd4a44d0914ae30a465c55d8 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 9 Jun 2025 16:48:34 +0100 Subject: [PATCH 3/5] Add a source attribute in the inference count to distinguish semantic_text bulk indexing from other usage --- .../inference/action/filter/ShardBulkInferenceActionFilter.java | 1 + .../action/filter/ShardBulkInferenceActionFilterTests.java | 1 + 2 files changed, 2 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 66e16ccb81952..e3a51d188d109 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -461,6 +461,7 @@ private void recordRequestCountMetrics(Model model, int incrementBy, Throwable t Map requestCountAttributes = new HashMap<>(); requestCountAttributes.putAll(modelAttributes(model)); requestCountAttributes.putAll(responseAttributes(throwable)); + requestCountAttributes.put("source", "semantic_text_bulk"); inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 4af7b5aa7508a..c3fb04e5beff3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -356,6 +356,7 @@ public void testItemFailures() throws Exception { assertThat(attributes.get("task_type"), is(model.getTaskType().toString())); assertThat(attributes.get("model_id"), is(model.getServiceSettings().modelId())); assertThat(attributes.get("service"), is(model.getConfigurations().getService())); + assertThat(attributes.get("source"), is("semantic_text_bulk")) })); assertThat(success.get(), equalTo(1)); assertThat(failed.get(), equalTo(2)); From 24bdcdad02b13c233ca8e988ada680436df4ba7a Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 9 Jun 2025 16:49:28 +0100 Subject: [PATCH 4/5] fix typo --- .../action/filter/ShardBulkInferenceActionFilterTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index c3fb04e5beff3..0befd99f2af74 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -356,7 +356,7 @@ public void testItemFailures() throws Exception { assertThat(attributes.get("task_type"), is(model.getTaskType().toString())); assertThat(attributes.get("model_id"), is(model.getServiceSettings().modelId())); assertThat(attributes.get("service"), is(model.getConfigurations().getService())); - assertThat(attributes.get("source"), is("semantic_text_bulk")) + assertThat(attributes.get("source"), is("semantic_text_bulk")); })); assertThat(success.get(), equalTo(1)); assertThat(failed.get(), equalTo(2)); From 7bffff1e73f91cea6faa3e9e8642018a589bda6d Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 9 Jun 2025 21:26:30 +0100 Subject: [PATCH 5/5] apply review comment --- .../inference/action/filter/ShardBulkInferenceActionFilter.java | 2 +- .../action/filter/ShardBulkInferenceActionFilterTests.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index e3a51d188d109..082ece347208a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -461,7 +461,7 @@ private void recordRequestCountMetrics(Model model, int incrementBy, Throwable t Map requestCountAttributes = new HashMap<>(); requestCountAttributes.putAll(modelAttributes(model)); requestCountAttributes.putAll(responseAttributes(throwable)); - requestCountAttributes.put("source", "semantic_text_bulk"); + requestCountAttributes.put("inference_source", "semantic_text_bulk"); inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 0befd99f2af74..a7cb0234aee59 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -356,7 +356,7 @@ public void testItemFailures() throws Exception { assertThat(attributes.get("task_type"), is(model.getTaskType().toString())); assertThat(attributes.get("model_id"), is(model.getServiceSettings().modelId())); assertThat(attributes.get("service"), is(model.getConfigurations().getService())); - assertThat(attributes.get("source"), is("semantic_text_bulk")); + assertThat(attributes.get("inference_source"), is("semantic_text_bulk")); })); assertThat(success.get(), equalTo(1)); assertThat(failed.get(), equalTo(2));