55 */
66package org .elasticsearch .xpack .core .ml .inference .results ;
77
8+ import org .elasticsearch .Version ;
89import org .elasticsearch .common .ParseField ;
910import org .elasticsearch .common .io .stream .StreamInput ;
1011import org .elasticsearch .common .io .stream .StreamOutput ;
1617
1718import java .io .IOException ;
1819import java .util .Collections ;
19- import java .util .HashMap ;
2020import java .util .LinkedHashMap ;
21+ import java .util .List ;
2122import java .util .Map ;
2223import java .util .Objects ;
24+ import java .util .stream .Collectors ;
2325
2426import static org .elasticsearch .common .xcontent .ConstructingObjectParser .constructorArg ;
2527import static org .elasticsearch .common .xcontent .ConstructingObjectParser .optionalConstructorArg ;
2628
2729public class FeatureImportance implements Writeable , ToXContentObject {
2830
29- private final Map < String , Double > classImportance ;
31+ private final List < ClassImportance > classImportance ;
3032 private final double importance ;
3133 private final String featureName ;
3234 static final String IMPORTANCE = "importance" ;
3335 static final String FEATURE_NAME = "feature_name" ;
34- static final String CLASS_IMPORTANCE = "class_importance " ;
36+ static final String CLASSES = "classes " ;
3537
3638 public static FeatureImportance forRegression (String featureName , double importance ) {
3739 return new FeatureImportance (featureName , importance , null );
3840 }
3941
40- public static FeatureImportance forClassification (String featureName , Map <String , Double > classImportance ) {
41- return new FeatureImportance (featureName , classImportance .values ().stream ().mapToDouble (Math ::abs ).sum (), classImportance );
42+ public static FeatureImportance forClassification (String featureName , List <ClassImportance > classImportance ) {
43+ return new FeatureImportance (featureName ,
44+ classImportance .stream ().mapToDouble (ClassImportance ::getImportance ).map (Math ::abs ).sum (),
45+ classImportance );
4246 }
4347
4448 @ SuppressWarnings ("unchecked" )
4549 private static final ConstructingObjectParser <FeatureImportance , Void > PARSER =
4650 new ConstructingObjectParser <>("feature_importance" ,
47- a -> new FeatureImportance ((String ) a [0 ], (Double ) a [1 ], (Map < String , Double >) a [2 ])
51+ a -> new FeatureImportance ((String ) a [0 ], (Double ) a [1 ], (List < ClassImportance >) a [2 ])
4852 );
4953
5054 static {
5155 PARSER .declareString (constructorArg (), new ParseField (FeatureImportance .FEATURE_NAME ));
5256 PARSER .declareDouble (constructorArg (), new ParseField (FeatureImportance .IMPORTANCE ));
53- PARSER .declareObject (optionalConstructorArg (), (p , c ) -> p .map (HashMap ::new , XContentParser ::doubleValue ),
54- new ParseField (FeatureImportance .CLASS_IMPORTANCE ));
57+ PARSER .declareObjectArray (optionalConstructorArg (),
58+ (p , c ) -> ClassImportance .fromXContent (p ),
59+ new ParseField (FeatureImportance .CLASSES ));
5560 }
5661
5762 public static FeatureImportance fromXContent (XContentParser parser ) {
5863 return PARSER .apply (parser , null );
5964 }
6065
61- FeatureImportance (String featureName , double importance , Map < String , Double > classImportance ) {
66+ FeatureImportance (String featureName , double importance , List < ClassImportance > classImportance ) {
6267 this .featureName = Objects .requireNonNull (featureName );
6368 this .importance = importance ;
64- this .classImportance = classImportance == null ? null : Collections .unmodifiableMap (classImportance );
69+ this .classImportance = classImportance == null ? null : Collections .unmodifiableList (classImportance );
6570 }
6671
6772 public FeatureImportance (StreamInput in ) throws IOException {
6873 this .featureName = in .readString ();
6974 this .importance = in .readDouble ();
7075 if (in .readBoolean ()) {
71- this .classImportance = in .readMap (StreamInput ::readString , StreamInput ::readDouble );
76+ if (in .getVersion ().before (Version .V_7_10_0 )) {
77+ Map <String , Double > classImportance = in .readMap (StreamInput ::readString , StreamInput ::readDouble );
78+ this .classImportance = ClassImportance .fromMap (classImportance );
79+ } else {
80+ this .classImportance = in .readList (ClassImportance ::new );
81+ }
7282 } else {
7383 this .classImportance = null ;
7484 }
7585 }
7686
77- public Map < String , Double > getClassImportance () {
87+ public List < ClassImportance > getClassImportance () {
7888 return classImportance ;
7989 }
8090
@@ -92,7 +102,11 @@ public void writeTo(StreamOutput out) throws IOException {
92102 out .writeDouble (this .importance );
93103 out .writeBoolean (this .classImportance != null );
94104 if (this .classImportance != null ) {
95- out .writeMap (this .classImportance , StreamOutput ::writeString , StreamOutput ::writeDouble );
105+ if (out .getVersion ().before (Version .V_7_10_0 )) {
106+ out .writeMap (ClassImportance .toMap (this .classImportance ), StreamOutput ::writeString , StreamOutput ::writeDouble );
107+ } else {
108+ out .writeList (this .classImportance );
109+ }
96110 }
97111 }
98112
@@ -101,7 +115,7 @@ public Map<String, Object> toMap() {
101115 map .put (FEATURE_NAME , featureName );
102116 map .put (IMPORTANCE , importance );
103117 if (classImportance != null ) {
104- classImportance .forEach ( map :: put );
118+ map . put ( CLASSES , classImportance .stream (). map ( ClassImportance :: toMap ). collect ( Collectors . toList ()) );
105119 }
106120 return map ;
107121 }
@@ -112,11 +126,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
112126 builder .field (FEATURE_NAME , featureName );
113127 builder .field (IMPORTANCE , importance );
114128 if (classImportance != null && classImportance .isEmpty () == false ) {
115- builder .startObject (CLASS_IMPORTANCE );
116- for (Map .Entry <String , Double > entry : classImportance .entrySet ()) {
117- builder .field (entry .getKey (), entry .getValue ());
118- }
119- builder .endObject ();
129+ builder .field (CLASSES , classImportance );
120130 }
121131 builder .endObject ();
122132 return builder ;
@@ -136,4 +146,92 @@ public boolean equals(Object object) {
136146 public int hashCode () {
137147 return Objects .hash (featureName , importance , classImportance );
138148 }
149+
150+ public static class ClassImportance implements Writeable , ToXContentObject {
151+
152+ static final String CLASS_NAME = "class_name" ;
153+
154+ private static final ConstructingObjectParser <ClassImportance , Void > PARSER =
155+ new ConstructingObjectParser <>("feature_importance_class_importance" ,
156+ a -> new ClassImportance ((String ) a [0 ], (Double ) a [1 ])
157+ );
158+
159+ static {
160+ PARSER .declareString (constructorArg (), new ParseField (CLASS_NAME ));
161+ PARSER .declareDouble (constructorArg (), new ParseField (FeatureImportance .IMPORTANCE ));
162+ }
163+
164+ private static ClassImportance fromMapEntry (Map .Entry <String , Double > entry ) {
165+ return new ClassImportance (entry .getKey (), entry .getValue ());
166+ }
167+
168+ private static List <ClassImportance > fromMap (Map <String , Double > classImportanceMap ) {
169+ return classImportanceMap .entrySet ().stream ().map (ClassImportance ::fromMapEntry ).collect (Collectors .toList ());
170+ }
171+
172+ private static Map <String , Double > toMap (List <ClassImportance > importances ) {
173+ return importances .stream ().collect (Collectors .toMap (i -> i .className , i -> i .importance ));
174+ }
175+
176+ public static ClassImportance fromXContent (XContentParser parser ) {
177+ return PARSER .apply (parser , null );
178+ }
179+
180+ private final String className ;
181+ private final double importance ;
182+
183+ public ClassImportance (String className , double importance ) {
184+ this .className = className ;
185+ this .importance = importance ;
186+ }
187+
188+ public ClassImportance (StreamInput in ) throws IOException {
189+ this .className = in .readString ();
190+ this .importance = in .readDouble ();
191+ }
192+
193+ public String getClassName () {
194+ return className ;
195+ }
196+
197+ public double getImportance () {
198+ return importance ;
199+ }
200+
201+ public Map <String , Object > toMap () {
202+ Map <String , Object > map = new LinkedHashMap <>();
203+ map .put (CLASS_NAME , className );
204+ map .put (IMPORTANCE , importance );
205+ return map ;
206+ }
207+
208+ @ Override
209+ public void writeTo (StreamOutput out ) throws IOException {
210+ out .writeString (className );
211+ out .writeDouble (importance );
212+ }
213+
214+ @ Override
215+ public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
216+ builder .startObject ();
217+ builder .field (CLASS_NAME , className );
218+ builder .field (IMPORTANCE , importance );
219+ builder .endObject ();
220+ return builder ;
221+ }
222+
223+ @ Override
224+ public boolean equals (Object o ) {
225+ if (this == o ) return true ;
226+ if (o == null || getClass () != o .getClass ()) return false ;
227+ ClassImportance that = (ClassImportance ) o ;
228+ return Double .compare (that .importance , importance ) == 0 &&
229+ Objects .equals (className , that .className );
230+ }
231+
232+ @ Override
233+ public int hashCode () {
234+ return Objects .hash (className , importance );
235+ }
236+ }
139237}
0 commit comments