Skip to content

Commit

Permalink
Refactored FeatureMap$Entry$Type enum
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 26, 2022
1 parent 94ec549 commit 154f030
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 43 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def to_fmap_type(dtype):
return "int"
# Continuous floats
elif dtype == "float":
return "q"
# Binary indicators generated by pandas.get_dummies(X)
return "float"
# Binary indicators (ie. 0/1 values) generated by pandas.get_dummies(X)
elif dtype == "uint8":
return "i"
else:
Expand Down
87 changes: 46 additions & 41 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void addEntry(String name, String type){
public void addEntry(String name, Entry.Type type){
Entry entry;

if(type == Entry.Type.BINARY_INDICATOR){
if(type == Entry.Type.INDICATOR){
String value = null;

int equals = name.indexOf('=');
Expand All @@ -93,7 +93,7 @@ public void addEntry(String name, Entry.Type type){
name = name.substring(0, equals);
}

entry = new CategoricalEntry(name, value, type);
entry = new IndicatorEntry(name, value, type);
} else

{
Expand Down Expand Up @@ -176,22 +176,24 @@ private void setType(Type type){

static
public enum Type {
BINARY_INDICATOR,
FLOAT,
INDICATOR,
QUANTITIVE,
INTEGER,
FLOAT,
;

static
public Type fromString(String string){

switch(string){
case "i":
return Type.BINARY_INDICATOR;
return Type.INDICATOR;
case "q":
case "float":
return Type.FLOAT;
return Type.QUANTITIVE;
case "int":
return Type.INTEGER;
case "float":
return Type.FLOAT;
default:
throw new IllegalArgumentException(string);
}
Expand All @@ -200,43 +202,12 @@ public Type fromString(String string){
}

static
private class ContinuousEntry extends Entry {

public ContinuousEntry(String name, Type type){
super(name, type);
}

@Override
public Feature encodeFeature(PMMLEncoder encoder){
String name = getName();
Type type = getType();

DataField dataField = encoder.getDataField(name);
if(dataField == null){

switch(type){
case FLOAT:
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.FLOAT);
break;
case INTEGER:
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.INTEGER);
break;
default:
throw new IllegalArgumentException();
}
}

return new ContinuousFeature(encoder, dataField);
}
}

static
private class CategoricalEntry extends Entry {
private class IndicatorEntry extends Entry {

private String value = null;


public CategoricalEntry(String name, String value, Type type){
public IndicatorEntry(String name, String value, Type type){
super(name, type);

setValue(value);
Expand All @@ -252,7 +223,7 @@ public Feature encodeFeature(PMMLEncoder encoder){
if(dataField == null){

switch(type){
case BINARY_INDICATOR:
case INDICATOR:
if(value != null){
dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING);
} else
Expand Down Expand Up @@ -285,4 +256,38 @@ private void setValue(String value){
this.value = value;
}
}

static
private class ContinuousEntry extends Entry {

public ContinuousEntry(String name, Type type){
super(name, type);
}

@Override
public Feature encodeFeature(PMMLEncoder encoder){
String name = getName();
Type type = getType();

DataField dataField = encoder.getDataField(name);
if(dataField == null){

switch(type){
case QUANTITIVE:
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.FLOAT);
break;
case INTEGER:
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.INTEGER);
break;
case FLOAT:
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.FLOAT);
break;
default:
throw new IllegalArgumentException();
}
}

return new ContinuousFeature(encoder, dataField);
}
}
}

0 comments on commit 154f030

Please sign in to comment.