Skip to content

Commit 8a6ac78

Browse files
joegallojfreden
authored andcommitted
Refactor InferenceProcessorInfoExtractor to avoid ConfigurationUtils (elastic#115425)
1 parent 1dc33ee commit 8a6ac78

File tree

1 file changed

+21
-24
lines changed

1 file changed

+21
-24
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
import org.apache.lucene.util.Counter;
1111
import org.elasticsearch.cluster.ClusterState;
1212
import org.elasticsearch.cluster.metadata.Metadata;
13-
import org.elasticsearch.ingest.ConfigurationUtils;
1413
import org.elasticsearch.ingest.IngestMetadata;
15-
import org.elasticsearch.ingest.Pipeline;
1614
import org.elasticsearch.transport.Transports;
1715

1816
import java.util.HashMap;
@@ -24,6 +22,7 @@
2422
import java.util.function.Consumer;
2523

2624
import static org.elasticsearch.inference.InferenceResults.MODEL_ID_RESULTS_FIELD;
25+
import static org.elasticsearch.ingest.Pipeline.ON_FAILURE_KEY;
2726
import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY;
2827

2928
/**
@@ -53,16 +52,10 @@ public static int countInferenceProcessors(ClusterState state) {
5352
Counter counter = Counter.newCounter();
5453
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
5554
Map<String, Object> configMap = configuration.getConfigAsMap();
56-
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
55+
List<Map<String, Object>> processorConfigs = (List<Map<String, Object>>) configMap.get(PROCESSORS_KEY);
5756
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
5857
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
59-
addModelsAndPipelines(
60-
entry.getKey(),
61-
pipelineId,
62-
(Map<String, Object>) entry.getValue(),
63-
pam -> counter.addAndGet(1),
64-
0
65-
);
58+
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> counter.addAndGet(1), 0);
6659
}
6760
}
6861
});
@@ -73,7 +66,6 @@ public static int countInferenceProcessors(ClusterState state) {
7366
* @param ingestMetadata The ingestMetadata of current ClusterState
7467
* @return The set of model IDs referenced by inference processors
7568
*/
76-
@SuppressWarnings("unchecked")
7769
public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata ingestMetadata) {
7870
if (ingestMetadata == null) {
7971
return Set.of();
@@ -82,7 +74,7 @@ public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata inge
8274
Set<String> modelIds = new LinkedHashSet<>();
8375
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
8476
Map<String, Object> configMap = configuration.getConfigAsMap();
85-
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
77+
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
8678
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
8779
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
8880
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> modelIds.add(pam.modelIdOrAlias()), 0);
@@ -96,7 +88,6 @@ public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata inge
9688
* @param state Current cluster state
9789
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
9890
*/
99-
@SuppressWarnings("unchecked")
10091
public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state, Set<String> ids) {
10192
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
10293
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
@@ -110,7 +101,7 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
110101
}
111102
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
112103
Map<String, Object> configMap = configuration.getConfigAsMap();
113-
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
104+
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
114105
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
115106
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
116107
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> {
@@ -128,7 +119,6 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
128119
* @param state Current {@link ClusterState}
129120
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
130121
*/
131-
@SuppressWarnings("unchecked")
132122
public static Set<String> pipelineIdsForResource(ClusterState state, Set<String> ids) {
133123
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
134124
Set<String> pipelineIds = new HashSet<>();
@@ -142,7 +132,7 @@ public static Set<String> pipelineIdsForResource(ClusterState state, Set<String>
142132
}
143133
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
144134
Map<String, Object> configMap = configuration.getConfigAsMap();
145-
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
135+
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
146136
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
147137
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
148138
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> {
@@ -188,21 +178,16 @@ private static void addModelsAndPipelines(
188178
addModelsAndPipelines(
189179
innerProcessorWithName.getKey(),
190180
pipelineId,
191-
(Map<String, Object>) innerProcessorWithName.getValue(),
181+
innerProcessorWithName.getValue(),
192182
handler,
193183
level + 1
194184
);
195185
}
196186
}
197187
return;
198188
}
199-
if (processorDefinition instanceof Map<?, ?> definitionMap && definitionMap.containsKey(Pipeline.ON_FAILURE_KEY)) {
200-
List<Map<String, Object>> onFailureConfigs = ConfigurationUtils.readList(
201-
null,
202-
null,
203-
(Map<String, Object>) definitionMap,
204-
Pipeline.ON_FAILURE_KEY
205-
);
189+
if (processorDefinition instanceof Map<?, ?> definitionMap && definitionMap.containsKey(ON_FAILURE_KEY)) {
190+
List<Map<String, Object>> onFailureConfigs = readList(definitionMap, ON_FAILURE_KEY);
206191
onFailureConfigs.stream()
207192
.flatMap(map -> map.entrySet().stream())
208193
.forEach(entry -> addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), handler, level + 1));
@@ -211,4 +196,16 @@ private static void addModelsAndPipelines(
211196

212197
private record PipelineAndModel(String pipelineId, String modelIdOrAlias) {}
213198

199+
/**
200+
* A local alternative to ConfigurationUtils.readList(...) that reads list properties out of the processor configuration map,
201+
* but doesn't rely on mutating the configuration map.
202+
*/
203+
@SuppressWarnings("unchecked")
204+
private static List<Map<String, Object>> readList(Map<?, ?> processorConfig, String key) {
205+
Object val = processorConfig.get(key);
206+
if (val == null) {
207+
throw new IllegalArgumentException("Missing required property [" + key + "]");
208+
}
209+
return (List<Map<String, Object>>) val;
210+
}
214211
}

0 commit comments

Comments
 (0)