From 154f030670743c87334bbce1660f1dde77e83942 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Fri, 25 Mar 2022 09:24:25 +0200 Subject: [PATCH] Refactored FeatureMap$Entry$Type enum --- README.md | 4 +- .../java/org/jpmml/xgboost/FeatureMap.java | 87 ++++++++++--------- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 2d89f44..99c320d 100644 --- a/README.md +++ b/README.md @@ -71,8 +71,8 @@ def to_fmap_type(dtype): return "int" # Continuous floats elif dtype == "float": - return "q" - # Binary indicators generated by pandas.get_dummies(X) + return "float" + # Binary indicators (ie. 0/1 values) generated by pandas.get_dummies(X) elif dtype == "uint8": return "i" else: diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java index a22b9d9..35311e3 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java @@ -84,7 +84,7 @@ public void addEntry(String name, String type){ public void addEntry(String name, Entry.Type type){ Entry entry; - if(type == Entry.Type.BINARY_INDICATOR){ + if(type == Entry.Type.INDICATOR){ String value = null; int equals = name.indexOf('='); @@ -93,7 +93,7 @@ public void addEntry(String name, Entry.Type type){ name = name.substring(0, equals); } - entry = new CategoricalEntry(name, value, type); + entry = new IndicatorEntry(name, value, type); } else { @@ -176,9 +176,10 @@ private void setType(Type type){ static public enum Type { - BINARY_INDICATOR, - FLOAT, + INDICATOR, + QUANTITIVE, INTEGER, + FLOAT, ; static @@ -186,12 +187,13 @@ public Type fromString(String string){ switch(string){ case "i": - return Type.BINARY_INDICATOR; + return Type.INDICATOR; case "q": - case "float": - return Type.FLOAT; + return Type.QUANTITIVE; case "int": return Type.INTEGER; + case "float": + return Type.FLOAT; default: throw new IllegalArgumentException(string); } @@ -200,43 +202,12 @@ public Type fromString(String string){ } static - private class ContinuousEntry extends Entry { - - public ContinuousEntry(String name, Type type){ - super(name, type); - } - - @Override - public Feature encodeFeature(PMMLEncoder encoder){ - String name = getName(); - Type type = getType(); - - DataField dataField = encoder.getDataField(name); - if(dataField == null){ - - switch(type){ - case FLOAT: - dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.FLOAT); - break; - case INTEGER: - dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.INTEGER); - break; - default: - throw new IllegalArgumentException(); - } - } - - return new ContinuousFeature(encoder, dataField); - } - } - - static - private class CategoricalEntry extends Entry { + private class IndicatorEntry extends Entry { private String value = null; - public CategoricalEntry(String name, String value, Type type){ + public IndicatorEntry(String name, String value, Type type){ super(name, type); setValue(value); @@ -252,7 +223,7 @@ public Feature encodeFeature(PMMLEncoder encoder){ if(dataField == null){ switch(type){ - case BINARY_INDICATOR: + case INDICATOR: if(value != null){ dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING); } else @@ -285,4 +256,38 @@ private void setValue(String value){ this.value = value; } } + + static + private class ContinuousEntry extends Entry { + + public ContinuousEntry(String name, Type type){ + super(name, type); + } + + @Override + public Feature encodeFeature(PMMLEncoder encoder){ + String name = getName(); + Type type = getType(); + + DataField dataField = encoder.getDataField(name); + if(dataField == null){ + + switch(type){ + case QUANTITIVE: + dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.FLOAT); + break; + case INTEGER: + dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.INTEGER); + break; + case FLOAT: + dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.FLOAT); + break; + default: + throw new IllegalArgumentException(); + } + } + + return new ContinuousFeature(encoder, dataField); + } + } } \ No newline at end of file