@@ -55,30 +55,40 @@ private[tree] object TreePoint {
5555 input : RDD [LabeledPoint ],
5656 bins : Array [Array [Bin ]],
5757 metadata : DecisionTreeMetadata ): RDD [TreePoint ] = {
58+ // Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
59+ val featureArity : Array [Int ] = new Array [Int ](metadata.numFeatures)
60+ val isUnordered : Array [Boolean ] = new Array [Boolean ](metadata.numFeatures)
61+ var featureIndex = 0
62+ while (featureIndex < metadata.numFeatures) {
63+ featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0 )
64+ isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
65+ featureIndex += 1
66+ }
5867 input.map { x =>
59- TreePoint .labeledPointToTreePoint(x, bins, metadata )
68+ TreePoint .labeledPointToTreePoint(x, bins, featureArity, isUnordered )
6069 }
6170 }
6271
6372 /**
6473 * Convert one LabeledPoint into its TreePoint representation.
6574 * @param bins Bins for features, of size (numFeatures, numBins).
66- * @param metadata DecisionTree training info, used for dataset metadata.
75+ * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
76+ * for categorical features.
77+ * @param isUnordered Array index by feature, with value true for unordered categorical features.
6778 */
6879 private def labeledPointToTreePoint (
6980 labeledPoint : LabeledPoint ,
7081 bins : Array [Array [Bin ]],
71- metadata : DecisionTreeMetadata ): TreePoint = {
82+ featureArity : Array [Int ],
83+ isUnordered : Array [Boolean ]): TreePoint = {
7284 val numFeatures = labeledPoint.features.size
7385 val arr = new Array [Int ](numFeatures)
7486 var featureIndex = 0
7587 while (featureIndex < numFeatures) {
76- val featureArity = metadata.featureArity.getOrElse(featureIndex, 0 )
77- arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity,
78- metadata.isUnordered(featureIndex), bins)
88+ arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
89+ isUnordered(featureIndex), bins)
7990 featureIndex += 1
8091 }
81-
8292 new TreePoint (labeledPoint.label, arr)
8393 }
8494
0 commit comments