Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
}

public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
List<FeatureImportance> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}

private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
List<FeatureImportance> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
Expand Down Expand Up @@ -118,7 +118,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class FeatureImportance implements Writeable {

private final Map<String, Double> classImportance;
private final double importance;
private final String featureName;
private static final String IMPORTANCE = "importance";
private static final String FEATURE_NAME = "feature_name";

public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
}

public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
}

private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
}

public FeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString();
this.importance = in.readDouble();
if (in.readBoolean()) {
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.classImportance = null;
}
}

public Map<String, Double> getClassImportance() {
return classImportance;
}

public double getImportance() {
return importance;
}

public String getFeatureName() {
return featureName;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.featureName);
out.writeDouble(this.importance);
out.writeBoolean(this.classImportance != null);
if (this.classImportance != null) {
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME, featureName);
map.put(IMPORTANCE, importance);
if (classImportance != null) {
classImportance.forEach(map::put);
}
return map;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
FeatureImportance that = (FeatureImportance) object;
return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance)
&& Objects.equals(classImportance, that.classImportance);
}

@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ public class RawInferenceResults implements InferenceResults {
public static final String NAME = "raw";

private final double[] value;
private final Map<String, Double> featureImportance;
private final Map<String, double[]> featureImportance;

public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
this.value = value;
this.featureImportance = featureImportance;
}
Expand All @@ -29,7 +29,7 @@ public double[] getValue() {
return value;
}

public Map<String, Double> getFeatureImportance() {
public Map<String, double[]> getFeatureImportance() {
return featureImportance;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public class RegressionInferenceResults extends SingleValueInferenceResults {

Expand All @@ -24,14 +25,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;

public RegressionInferenceResults(double value, InferenceConfig config) {
this(value, (RegressionConfig) config, Collections.emptyMap());
this(value, (RegressionConfig) config, Collections.emptyList());
}

public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}

private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
Expand Down Expand Up @@ -70,7 +71,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,46 @@
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.List;
import java.util.stream.Collectors;

public abstract class SingleValueInferenceResults implements InferenceResults {

private final double value;
private final Map<String, Double> featureImportance;
private final List<FeatureImportance> featureImportance;

static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
return unsortedFeatureImportances;
}
return unsortedFeatureImportances.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
.collect(Collectors.toList());
}

SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
this.featureImportance = in.readList(FeatureImportance::new);
} else {
this.featureImportance = Collections.emptyMap();
this.featureImportance = Collections.emptyList();
}
}

SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
}

public Double value() {
return value;
}

public Map<String, Double> getFeatureImportance() {
public List<FeatureImportance> getFeatureImportance() {
return featureImportance;
}

Expand All @@ -58,7 +59,7 @@ public String valueAsString() {
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
out.writeList(this.featureImportance);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -100,18 +102,46 @@ public static Double toDouble(Object value) {
return null;
}

public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, Double> featureImportances) {
public static Map<String, double[]> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, double[]> featureImportances) {
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
return featureImportances;
}

Map<String, Double> originalFeatureImportance = new HashMap<>();
Map<String, double[]> originalFeatureImportance = new HashMap<>();
featureImportances.forEach((feature, importance) -> {
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : sumDoubleArrays(importance, v1));
});

return originalFeatureImportance;
}

public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
@Nullable List<String> classificationLabels) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
featureImportance.forEach((k, v) -> {
// This indicates regression, or logistic regression
// If the length > 1, we assume multi-class classification.
if (v.length == 1) {
importances.add(FeatureImportance.forRegression(k, v[0]));
} else {
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
// If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) {
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
}
importances.add(FeatureImportance.forClassification(k, classImportance));
}
});
return importances;
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
assert sumTo != null && inc != null && sumTo.length == inc.length;
for (int i = 0; i < inc.length; i++) {
sumTo[i] += inc[i];
}
return sumTo;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* NOTE: Must be thread safe
* @param fields The fields inferring against
* @param featureDecoder A Map translating processed feature names to their original feature names
* @return A {@code Map<String, Double>} mapping each featureName to its importance
* @return A {@code Map<String, double[]>} mapping each featureName to its importance
*/
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);

default Version getMinimalCompatibilityVersion() {
return Version.V_7_6_0;
Expand Down
Loading