Skip to content

Commit 7c3bfb9

Browse files
authored
[ML] updating feature_importance results mapping (#61104) (#61144)
This updates the feature_importance mapping change from elastic/ml-cpp#1387
1 parent f2f1552 commit 7c3bfb9

File tree

10 files changed

+174
-44
lines changed

10 files changed

+174
-44
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
314314
@Override
315315
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
316316
Map<String, Object> additionalProperties = new HashMap<>();
317-
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
317+
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.classificationFeatureImportanceMapping());
318318
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
319319
if ((dependentVariableMapping instanceof Map) == false) {
320320
return additionalProperties;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,46 @@
1818

1919
final class MapUtils {
2020

21-
private static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
22-
static {
23-
Map<String, Object> featureImportanceMappingProperties = new HashMap<>();
21+
private static Map<String, Object> createFeatureImportanceMapping(Map<String, Object> featureImportanceMappingProperties){
2422
featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
25-
featureImportanceMappingProperties.put("importance",
26-
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
2723
Map<String, Object> featureImportanceMapping = new HashMap<>();
2824
// TODO sorted indices don't support nested types
2925
//featureImportanceMapping.put("dynamic", true);
3026
//featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
3127
featureImportanceMapping.put("properties", featureImportanceMappingProperties);
32-
FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping);
28+
return featureImportanceMapping;
29+
}
30+
31+
private static final Map<String, Object> CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING;
32+
static {
33+
Map<String, Object> classImportancePropertiesMapping = new HashMap<>();
34+
// TODO sorted indices don't support nested types
35+
//classImportancePropertiesMapping.put("dynamic", true);
36+
//classImportancePropertiesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
37+
classImportancePropertiesMapping.put("class_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
38+
classImportancePropertiesMapping.put("importance",
39+
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
40+
Map<String, Object> featureImportancePropertiesMapping = new HashMap<>();
41+
featureImportancePropertiesMapping.put("classes", Collections.singletonMap("properties", classImportancePropertiesMapping));
42+
CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING =
43+
Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping));
44+
}
45+
46+
private static final Map<String, Object> REGRESSION_FEATURE_IMPORTANCE_MAPPING;
47+
static {
48+
Map<String, Object> featureImportancePropertiesMapping = new HashMap<>();
49+
featureImportancePropertiesMapping.put("importance",
50+
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
51+
REGRESSION_FEATURE_IMPORTANCE_MAPPING =
52+
Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping));
53+
}
54+
55+
static Map<String, Object> regressionFeatureImportanceMapping() {
56+
return REGRESSION_FEATURE_IMPORTANCE_MAPPING;
3357
}
3458

35-
static Map<String, Object> featureImportanceMapping() {
36-
return FEATURE_IMPORTANCE_MAPPING;
59+
static Map<String, Object> classificationFeatureImportanceMapping() {
60+
return CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING;
3761
}
3862

3963
private MapUtils() {}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
247247
@Override
248248
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
249249
Map<String, Object> additionalProperties = new HashMap<>();
250-
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
250+
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.regressionFeatureImportanceMapping());
251251
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
252252
// high (over 10M) values of dependent variable.
253253
additionalProperties.put(resultsFieldName + "." + predictionFieldName,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference.results;
77

8+
import org.elasticsearch.Version;
89
import org.elasticsearch.common.ParseField;
910
import org.elasticsearch.common.io.stream.StreamInput;
1011
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -16,65 +17,74 @@
1617

1718
import java.io.IOException;
1819
import java.util.Collections;
19-
import java.util.HashMap;
2020
import java.util.LinkedHashMap;
21+
import java.util.List;
2122
import java.util.Map;
2223
import java.util.Objects;
24+
import java.util.stream.Collectors;
2325

2426
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
2527
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
2628

2729
public class FeatureImportance implements Writeable, ToXContentObject {
2830

29-
private final Map<String, Double> classImportance;
31+
private final List<ClassImportance> classImportance;
3032
private final double importance;
3133
private final String featureName;
3234
static final String IMPORTANCE = "importance";
3335
static final String FEATURE_NAME = "feature_name";
34-
static final String CLASS_IMPORTANCE = "class_importance";
36+
static final String CLASSES = "classes";
3537

3638
public static FeatureImportance forRegression(String featureName, double importance) {
3739
return new FeatureImportance(featureName, importance, null);
3840
}
3941

40-
public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
41-
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
42+
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
43+
return new FeatureImportance(featureName,
44+
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
45+
classImportance);
4246
}
4347

4448
@SuppressWarnings("unchecked")
4549
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
4650
new ConstructingObjectParser<>("feature_importance",
47-
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
51+
a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
4852
);
4953

5054
static {
5155
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
5256
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
53-
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
54-
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
57+
PARSER.declareObjectArray(optionalConstructorArg(),
58+
(p, c) -> ClassImportance.fromXContent(p),
59+
new ParseField(FeatureImportance.CLASSES));
5560
}
5661

5762
public static FeatureImportance fromXContent(XContentParser parser) {
5863
return PARSER.apply(parser, null);
5964
}
6065

61-
FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
66+
FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
6267
this.featureName = Objects.requireNonNull(featureName);
6368
this.importance = importance;
64-
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
69+
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
6570
}
6671

6772
public FeatureImportance(StreamInput in) throws IOException {
6873
this.featureName = in.readString();
6974
this.importance = in.readDouble();
7075
if (in.readBoolean()) {
71-
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
76+
if (in.getVersion().before(Version.V_7_10_0)) {
77+
Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
78+
this.classImportance = ClassImportance.fromMap(classImportance);
79+
} else {
80+
this.classImportance = in.readList(ClassImportance::new);
81+
}
7282
} else {
7383
this.classImportance = null;
7484
}
7585
}
7686

77-
public Map<String, Double> getClassImportance() {
87+
public List<ClassImportance> getClassImportance() {
7888
return classImportance;
7989
}
8090

@@ -92,7 +102,11 @@ public void writeTo(StreamOutput out) throws IOException {
92102
out.writeDouble(this.importance);
93103
out.writeBoolean(this.classImportance != null);
94104
if (this.classImportance != null) {
95-
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
105+
if (out.getVersion().before(Version.V_7_10_0)) {
106+
out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
107+
} else {
108+
out.writeList(this.classImportance);
109+
}
96110
}
97111
}
98112

@@ -101,7 +115,7 @@ public Map<String, Object> toMap() {
101115
map.put(FEATURE_NAME, featureName);
102116
map.put(IMPORTANCE, importance);
103117
if (classImportance != null) {
104-
classImportance.forEach(map::put);
118+
map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
105119
}
106120
return map;
107121
}
@@ -112,11 +126,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
112126
builder.field(FEATURE_NAME, featureName);
113127
builder.field(IMPORTANCE, importance);
114128
if (classImportance != null && classImportance.isEmpty() == false) {
115-
builder.startObject(CLASS_IMPORTANCE);
116-
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
117-
builder.field(entry.getKey(), entry.getValue());
118-
}
119-
builder.endObject();
129+
builder.field(CLASSES, classImportance);
120130
}
121131
builder.endObject();
122132
return builder;
@@ -136,4 +146,92 @@ public boolean equals(Object object) {
136146
public int hashCode() {
137147
return Objects.hash(featureName, importance, classImportance);
138148
}
149+
150+
public static class ClassImportance implements Writeable, ToXContentObject {
151+
152+
static final String CLASS_NAME = "class_name";
153+
154+
private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
155+
new ConstructingObjectParser<>("feature_importance_class_importance",
156+
a -> new ClassImportance((String) a[0], (Double) a[1])
157+
);
158+
159+
static {
160+
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
161+
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
162+
}
163+
164+
private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
165+
return new ClassImportance(entry.getKey(), entry.getValue());
166+
}
167+
168+
private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
169+
return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
170+
}
171+
172+
private static Map<String, Double> toMap(List<ClassImportance> importances) {
173+
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
174+
}
175+
176+
public static ClassImportance fromXContent(XContentParser parser) {
177+
return PARSER.apply(parser, null);
178+
}
179+
180+
private final String className;
181+
private final double importance;
182+
183+
public ClassImportance(String className, double importance) {
184+
this.className = className;
185+
this.importance = importance;
186+
}
187+
188+
public ClassImportance(StreamInput in) throws IOException {
189+
this.className = in.readString();
190+
this.importance = in.readDouble();
191+
}
192+
193+
public String getClassName() {
194+
return className;
195+
}
196+
197+
public double getImportance() {
198+
return importance;
199+
}
200+
201+
public Map<String, Object> toMap() {
202+
Map<String, Object> map = new LinkedHashMap<>();
203+
map.put(CLASS_NAME, className);
204+
map.put(IMPORTANCE, importance);
205+
return map;
206+
}
207+
208+
@Override
209+
public void writeTo(StreamOutput out) throws IOException {
210+
out.writeString(className);
211+
out.writeDouble(importance);
212+
}
213+
214+
@Override
215+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
216+
builder.startObject();
217+
builder.field(CLASS_NAME, className);
218+
builder.field(IMPORTANCE, importance);
219+
builder.endObject();
220+
return builder;
221+
}
222+
223+
@Override
224+
public boolean equals(Object o) {
225+
if (this == o) return true;
226+
if (o == null || getClass() != o.getClass()) return false;
227+
ClassImportance that = (ClassImportance) o;
228+
return Double.compare(that.importance, importance) == 0 &&
229+
Objects.equals(className, that.className);
230+
}
231+
232+
@Override
233+
public int hashCode() {
234+
return Objects.hash(className, importance);
235+
}
236+
}
139237
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import java.util.Collections;
1616
import java.util.Comparator;
1717
import java.util.HashMap;
18-
import java.util.LinkedHashMap;
1918
import java.util.List;
2019
import java.util.Map;
2120
import java.util.stream.Collectors;
@@ -139,11 +138,13 @@ public static List<FeatureImportance> transformFeatureImportance(Map<String, dou
139138
if (v.length == 1) {
140139
importances.add(FeatureImportance.forRegression(k, v[0]));
141140
} else {
142-
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
141+
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
143142
// If the classificationLabels exist, their length must match leaf_value length
144143
assert classificationLabels == null || classificationLabels.size() == v.length;
145144
for (int i = 0; i < v.length; i++) {
146-
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
145+
classImportance.add(new FeatureImportance.ClassImportance(
146+
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
147+
v[i]));
147148
}
148149
importances.add(FeatureImportance.forClassification(k, classImportance));
149150
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,20 +261,20 @@ public void testFieldCardinalityLimitsIsNonEmpty() {
261261

262262
public void testGetExplicitlyMappedFields() {
263263
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
264-
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
264+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
265265
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
266-
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
266+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
267267
assertThat(
268268
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
269-
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
269+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
270270
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
271271
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
272272
"results");
273273
assertThat(explicitlyMappedFields,
274274
allOf(
275275
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
276276
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
277-
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
277+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()));
278278

279279
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
280280
new HashMap<String, Object>() {{
@@ -289,7 +289,7 @@ public void testGetExplicitlyMappedFields() {
289289
allOf(
290290
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
291291
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
292-
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
292+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()));
293293

294294
assertThat(
295295
new Classification("foo").getExplicitlyMappedFields(
@@ -298,7 +298,7 @@ public void testGetExplicitlyMappedFields() {
298298
put("path", "missing");
299299
}}),
300300
"results"),
301-
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
301+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
302302
}
303303

304304
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ public void testFieldCardinalityLimitsIsEmpty() {
206206
public void testGetExplicitlyMappedFields() {
207207
Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
208208
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
209-
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
209+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.regressionFeatureImportanceMapping()));
210210
}
211211

212212
public void testGetStateDocId() {

0 commit comments

Comments
 (0)