From c3a0575829d20fcb3abaa36d72db297162855d8d Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sat, 26 Mar 2022 21:44:06 +0200 Subject: [PATCH] Added support for categorical tree splits See https://github.com/dmlc/xgboost/issues/6503 --- .../java/org/jpmml/xgboost/BinaryNode.java | 5 + .../java/org/jpmml/xgboost/FeatureMap.java | 110 +++++++++++++---- .../main/java/org/jpmml/xgboost/JSONNode.java | 28 +++++ .../main/java/org/jpmml/xgboost/Learner.java | 7 ++ .../src/main/java/org/jpmml/xgboost/Node.java | 3 + .../main/java/org/jpmml/xgboost/RegTree.java | 115 +++++++++++++++++- 6 files changed, 241 insertions(+), 27 deletions(-) diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/BinaryNode.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/BinaryNode.java index f930de4..f35d05f 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/BinaryNode.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/BinaryNode.java @@ -70,6 +70,11 @@ public int split_index(){ return (int)(this.sindex & ((1L << 31) - 1L)); } + @Override + public int split_type(){ + return 0; + } + @Override public int split_cond(){ return this.info; diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java index 35311e3..6220eca 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/FeatureMap.java @@ -18,6 +18,7 @@ */ package org.jpmml.xgboost; +import java.util.AbstractList; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -33,6 +34,7 @@ import org.dmg.pmml.OpType; import org.dmg.pmml.Value; import org.jpmml.converter.BinaryFeature; +import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.Feature; import org.jpmml.converter.PMMLEncoder; @@ -77,28 +79,7 @@ public List encodeFeatures(PMMLEncoder encoder){ } public void addEntry(String name, String type){ - addEntry(name, Entry.Type.fromString(type)); - - } - - public void addEntry(String name, Entry.Type type){ - Entry entry; - - if(type == Entry.Type.INDICATOR){ - String value = null; - - int equals = name.indexOf('='); - if(equals > -1){ - value = name.substring(equals + 1); - name = name.substring(0, equals); - } - - entry = new IndicatorEntry(name, value, type); - } else - - { - entry = new ContinuousEntry(name, type); - } + Entry entry = createEntry(name, Entry.Type.fromString(type)); addEntry(entry); } @@ -141,6 +122,31 @@ private void addValue(Value.Property property, String value){ values.add(value); } + static + private Entry createEntry(String name, Entry.Type type){ + + switch(type){ + case INDICATOR: + String value = null; + + int equals = name.indexOf('='); + if(equals > -1){ + value = name.substring(equals + 1); + name = name.substring(0, equals); + } + + return new IndicatorEntry(name, value, type); + case QUANTITIVE: + case INTEGER: + case FLOAT: + return new ContinuousEntry(name, type); + case CATEGORICAL: + return new CategoricalEntry(name, type); + default: + throw new IllegalArgumentException(); + } + } + abstract static public class Entry { @@ -180,6 +186,7 @@ public enum Type { QUANTITIVE, INTEGER, FLOAT, + CATEGORICAL, ; static @@ -194,6 +201,9 @@ public Type fromString(String string){ return Type.INTEGER; case "float": return Type.FLOAT; + case "c": + case "categorical": + return Type.CATEGORICAL; default: throw new IllegalArgumentException(string); } @@ -290,4 +300,60 @@ public Feature encodeFeature(PMMLEncoder encoder){ return new ContinuousFeature(encoder, dataField); } } + + static + private class CategoricalEntry extends Entry { + + public CategoricalEntry(String name, Type type){ + super(name, type); + } + + @Override + public Feature encodeFeature(PMMLEncoder encoder){ + String name = getName(); + Type type = getType(); + + DataField dataField = encoder.getDataField(name); + if(dataField == null){ + + switch(type){ + case CATEGORICAL: + dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING); + break; + default: + throw new IllegalArgumentException(); + } + } + + List values = new AbstractList(){ + + private int max = -1; + + + @Override + public boolean isEmpty(){ + return false; + } + + @Override + public int size(){ + + if(this.max < 0){ + throw new IllegalStateException(); + } + + return (this.max + 1); + } + + @Override + public Integer get(int i){ + this.max = Math.max(this.max, i); + + return i; + } + }; + + return new CategoricalFeature(encoder, dataField, values); + } + } } \ No newline at end of file diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/JSONNode.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/JSONNode.java index 6998ce1..97ba478 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/JSONNode.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/JSONNode.java @@ -18,6 +18,8 @@ */ package org.jpmml.xgboost; +import java.util.BitSet; + import com.google.gson.JsonObject; public class JSONNode extends Node implements JSONLoadable { @@ -32,8 +34,12 @@ public class JSONNode extends Node implements JSONLoadable { private int split_index; + private int split_type; + private float split_condition; + private BitSet split_categories; + public JSONNode(){ } @@ -45,7 +51,16 @@ public void loadJSON(JsonObject node){ this.right_child = node.getAsJsonPrimitive("right_child").getAsInt(); this.default_left = node.getAsJsonPrimitive("default_left").getAsBoolean(); this.split_index = node.getAsJsonPrimitive("split_index").getAsInt(); + this.split_type = node.getAsJsonPrimitive("split_type").getAsInt(); this.split_condition = node.getAsJsonPrimitive("split_condition").getAsFloat(); + + switch(this.split_type){ + case 0: + case 1: + break; + default: + throw new IllegalArgumentException(); + } } @Override @@ -68,6 +83,11 @@ public boolean default_left(){ return this.default_left; } + @Override + public int split_type(){ + return this.split_type; + } + @Override public int split_index(){ return this.split_index; @@ -82,4 +102,12 @@ public int split_cond(){ public float leaf_value(){ return this.split_condition; } + + public BitSet get_split_categories(){ + return this.split_categories; + } + + void set_split_categories(BitSet split_categories){ + this.split_categories = split_categories; + } } \ No newline at end of file diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java index db00c4f..bf5050e 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java @@ -39,6 +39,7 @@ import org.dmg.pmml.Visitor; import org.dmg.pmml.mining.MiningModel; import org.jpmml.converter.BinaryFeature; +import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.Feature; import org.jpmml.converter.Label; @@ -291,6 +292,12 @@ public Schema toXGBoostSchema(boolean numeric, Schema schema){ @Override public Feature apply(Feature feature){ + if(feature instanceof CategoricalFeature){ + CategoricalFeature categoricalFeature = (CategoricalFeature)feature; + + return categoricalFeature; + } else + if(feature instanceof BinaryFeature){ BinaryFeature binaryFeature = (BinaryFeature)feature; diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/Node.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/Node.java index cfc8e22..0283dac 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/Node.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/Node.java @@ -30,6 +30,9 @@ public class Node { abstract public int split_index(); + abstract + public int split_type(); + abstract public int split_cond(); diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/RegTree.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/RegTree.java index 1eb606c..fb3cc6e 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/RegTree.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/RegTree.java @@ -19,10 +19,12 @@ package org.jpmml.xgboost; import java.io.IOException; +import java.util.BitSet; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; +import com.google.common.primitives.Ints; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; import org.dmg.pmml.DataType; @@ -35,6 +37,7 @@ import org.dmg.pmml.tree.LeafNode; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.BinaryFeature; +import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.CategoryManager; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.Feature; @@ -100,25 +103,81 @@ public void loadJSON(JsonObject tree){ int[] split_type = JSONUtil.toIntArray(tree.getAsJsonArray("split_type")); float[] split_conditions = JSONUtil.toFloatArray(tree.getAsJsonArray("split_conditions")); + boolean has_cat = Ints.contains(split_type, 1); + this.nodes = new Node[this.num_nodes]; for(int i = 0; i < this.num_nodes; i++){ - - if(split_type[i] != 0){ - throw new IllegalArgumentException(); - } - JsonObject node = new JsonObject(); node.add("parent", new JsonPrimitive(parents[i])); node.add("left_child", new JsonPrimitive(left_children[i])); node.add("right_child", new JsonPrimitive(right_children[i])); node.add("default_left", new JsonPrimitive(default_left[i])); node.add("split_index", new JsonPrimitive(split_indices[i])); + node.add("split_type", new JsonPrimitive(split_type[i])); node.add("split_condition", new JsonPrimitive(split_conditions[i])); this.nodes[i] = new JSONNode(); ((JSONLoadable)this.nodes[i]).loadJSON(node); } + + if(has_cat){ + int[] categories_segments = JSONUtil.toIntArray(tree.getAsJsonArray("categories_segments")); + int[] categories_sizes = JSONUtil.toIntArray(tree.getAsJsonArray("categories_sizes")); + int[] categories_nodes = JSONUtil.toIntArray(tree.getAsJsonArray("categories_nodes")); + int[] categories = JSONUtil.toIntArray(tree.getAsJsonArray("categories")); + + int cnt = 0; + + int last_cat_node = categories_nodes[cnt]; + + for(int i = 0; i < this.num_nodes; i++){ + JSONNode node = (JSONNode)this.nodes[i]; + + if(i == last_cat_node){ + int j_begin = categories_segments[cnt]; + int j_end = j_begin + categories_sizes[cnt]; + + int max_cat = -1; + + for(int j = j_begin; j < j_end; j++){ + int category = categories[j]; + + max_cat = Math.max(max_cat, category); + } + + if(max_cat == -1){ + throw new IllegalArgumentException(); + } + + int n_cats = (max_cat + 1); + + BitSet cat_bits = new BitSet(n_cats); + + for(int j = j_begin; j < j_end; j++){ + int category = categories[j]; + + cat_bits.set(category, true); + } + + node.set_split_categories(cat_bits); + + cnt++; + + if(cnt == categories_nodes.length){ + last_cat_node = -1; + } else + + { + last_cat_node = categories_nodes[cnt]; + } + } else + + { + node.set_split_categories(null); + } + } + } } public Float getLeafValue(){ @@ -162,6 +221,52 @@ private org.dmg.pmml.tree.Node encodeNode(int index, Predicate predicate, boolea Predicate leftPredicate; Predicate rightPredicate; + if(feature instanceof CategoricalFeature){ + + if(node.split_type() != 1){ + throw new IllegalArgumentException(); + } + } else + + { + if(node.split_type() != 0){ + throw new IllegalArgumentException(); + } + } // End if + + if(feature instanceof CategoricalFeature){ + CategoricalFeature categoricalFeature = (CategoricalFeature)feature; + + Float splitValue = Float.intBitsToFloat(node.split_cond()); + if(!splitValue.isNaN()){ + throw new IllegalArgumentException(); + } + + BitSet split_categories = null; + + if(node instanceof JSONNode){ + JSONNode jsonNode = (JSONNode)node; + + split_categories = jsonNode.get_split_categories(); + } // End if + + if(split_categories == null){ + throw new IllegalArgumentException(); + } else + + // Assume one-hot-encoding + if(split_categories.cardinality() != 1){ + throw new IllegalArgumentException(); + } + + int catIndex = split_categories.nextSetBit(0); + + Object value = categoricalFeature.getValue(catIndex); + + leftPredicate = predicateManager.createSimplePredicate(categoricalFeature, SimplePredicate.Operator.NOT_EQUAL, value); + rightPredicate = predicateManager.createSimplePredicate(categoricalFeature, SimplePredicate.Operator.EQUAL, value); + } else + if(feature instanceof BinaryFeature){ BinaryFeature binaryFeature = (BinaryFeature)feature;