Skip to content

Refactor InferenceProcessorInfoExtractor to avoid ConfigurationUtils #115425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import org.apache.lucene.util.Counter;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.ingest.ConfigurationUtils;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.transport.Transports;

import java.util.HashMap;
Expand All @@ -24,6 +22,7 @@
import java.util.function.Consumer;

import static org.elasticsearch.inference.InferenceResults.MODEL_ID_RESULTS_FIELD;
import static org.elasticsearch.ingest.Pipeline.ON_FAILURE_KEY;
import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY;

/**
Expand Down Expand Up @@ -53,16 +52,10 @@ public static int countInferenceProcessors(ClusterState state) {
Counter counter = Counter.newCounter();
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
Map<String, Object> configMap = configuration.getConfigAsMap();
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
List<Map<String, Object>> processorConfigs = (List<Map<String, Object>>) configMap.get(PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
addModelsAndPipelines(
entry.getKey(),
pipelineId,
(Map<String, Object>) entry.getValue(),
pam -> counter.addAndGet(1),
0
);
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> counter.addAndGet(1), 0);
}
}
});
Expand All @@ -73,7 +66,6 @@ public static int countInferenceProcessors(ClusterState state) {
* @param ingestMetadata The ingestMetadata of current ClusterState
* @return The set of model IDs referenced by inference processors
*/
@SuppressWarnings("unchecked")
public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata ingestMetadata) {
if (ingestMetadata == null) {
return Set.of();
Expand All @@ -82,7 +74,7 @@ public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata inge
Set<String> modelIds = new LinkedHashSet<>();
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
Map<String, Object> configMap = configuration.getConfigAsMap();
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> modelIds.add(pam.modelIdOrAlias()), 0);
Expand All @@ -96,7 +88,6 @@ public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata inge
* @param state Current cluster state
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
*/
@SuppressWarnings("unchecked")
public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state, Set<String> ids) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
Expand All @@ -110,7 +101,7 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
}
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
Map<String, Object> configMap = configuration.getConfigAsMap();
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> {
Expand All @@ -128,7 +119,6 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
* @param state Current {@link ClusterState}
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
*/
@SuppressWarnings("unchecked")
public static Set<String> pipelineIdsForResource(ClusterState state, Set<String> ids) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
Set<String> pipelineIds = new HashSet<>();
Expand All @@ -142,7 +132,7 @@ public static Set<String> pipelineIdsForResource(ClusterState state, Set<String>
}
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
Map<String, Object> configMap = configuration.getConfigAsMap();
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> {
Expand Down Expand Up @@ -188,21 +178,16 @@ private static void addModelsAndPipelines(
addModelsAndPipelines(
innerProcessorWithName.getKey(),
pipelineId,
(Map<String, Object>) innerProcessorWithName.getValue(),
innerProcessorWithName.getValue(),
handler,
level + 1
);
}
}
return;
}
if (processorDefinition instanceof Map<?, ?> definitionMap && definitionMap.containsKey(Pipeline.ON_FAILURE_KEY)) {
List<Map<String, Object>> onFailureConfigs = ConfigurationUtils.readList(
null,
null,
(Map<String, Object>) definitionMap,
Pipeline.ON_FAILURE_KEY
);
if (processorDefinition instanceof Map<?, ?> definitionMap && definitionMap.containsKey(ON_FAILURE_KEY)) {
List<Map<String, Object>> onFailureConfigs = readList(definitionMap, ON_FAILURE_KEY);
onFailureConfigs.stream()
.flatMap(map -> map.entrySet().stream())
.forEach(entry -> addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), handler, level + 1));
Expand All @@ -211,4 +196,16 @@ private static void addModelsAndPipelines(

private record PipelineAndModel(String pipelineId, String modelIdOrAlias) {}

/**
* A local alternative to ConfigurationUtils.readList(...) that reads list properties out of the processor configuration map,
* but doesn't rely on mutating the configuration map.
*/
@SuppressWarnings("unchecked")
private static List<Map<String, Object>> readList(Map<?, ?> processorConfig, String key) {
Object val = processorConfig.get(key);
if (val == null) {
throw new IllegalArgumentException("Missing required property [" + key + "]");
}
return (List<Map<String, Object>>) val;
}
}