Skip to content

Commit abc86b2

Browse files
committed
Fix tests, remove perf test in LocalTreeIntegrationSuite, use Scala transpose in LocalDecisionTreeUtils.
Changes made to fix tests: * Return correct impurity stats for splits that achieved a gain of 0 but didn't violate user-specified constraints on min info gain or min instances per node * Previously, ImpurityStats.impurity was set incorrectly in ImpurityStats.getInvalidImpurityStats(), requiring a correction in LearningNode.toNode. This commit fixes the issue by directly setting impurity = -1 in getInvalidSplits()
1 parent 9a7174e commit abc86b2

File tree

6 files changed

+19
-55
lines changed

6 files changed

+19
-55
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,8 @@ private[tree] class LearningNode(
278278
} else {
279279
assert(stats != null, "Unknown error during Decision Tree learning. Could not convert " +
280280
"LearningNode to Node")
281-
if (stats.valid) {
282-
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
283-
stats.impurityCalculator)
284-
} else {
285-
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
286-
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
287-
}
288-
281+
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
282+
stats.impurityCalculator)
289283
}
290284
}
291285

mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,20 @@ private[impl] object ImpurityUtils {
6060
val rightWeight = rightCount / totalCount.toDouble
6161

6262
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
63-
// if information gain doesn't satisfy minimum information gain,
63+
// If information gain doesn't satisfy minimum information gain,
6464
// then this split is invalid, return invalid information gain stats.
65-
// NOTE: We check gain < metadata.minInfoGain and gain <= 0 separately as this is what the
66-
// original tree training logic did.
67-
if (gain < metadata.minInfoGain || gain <= 0) {
65+
if (gain < metadata.minInfoGain) {
6866
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
6967
}
7068

69+
// If information gain is non-positive but doesn't violate the minimum info gain constraint,
70+
// return a stats object with correct values but valid = false to indicate that we should not
71+
// split.
72+
if (gain <= 0) {
73+
return new ImpurityStats(gain, impurity, parentImpurityCalculator, leftImpurityCalculator,
74+
rightImpurityCalculator, valid = false)
75+
}
76+
7177
new ImpurityStats(gain, impurity, parentImpurityCalculator,
7278
leftImpurityCalculator, rightImpurityCalculator)
7379
}

mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ private[ml] object LocalDecisionTreeUtils extends Logging {
5151
val numFeatures = rowStore(0).length
5252
require(numFeatures > 0, "Local decision tree training requires numFeatures > 0.")
5353
// Return the transpose of the rowStore matrix
54-
0.until(numFeatures).map { colIdx =>
55-
rowStore.map(row => row(colIdx))
56-
}.toArray
54+
rowStore.transpose
5755
}
5856

5957
}

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -638,16 +638,8 @@ private[spark] object RandomForest extends Logging {
638638

639639
// For each (feature, split), calculate the gain, and select the best (feature, split).
640640
val splitsAndImpurityInfo =
641-
validFeatureSplits.flatMap { case (featureIndexIdx, featureIndex) =>
642-
val (split, stats) = SplitUtils.chooseSplit(binAggregates,
643-
featureIndex, featureIndexIdx, splits)
644-
// Filter out invalid splits
645-
// TODO(smurching): Better to use map + filter or flatmap?
646-
if (stats.valid) {
647-
Seq((split, stats))
648-
} else {
649-
Seq.empty
650-
}
641+
validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
642+
SplitUtils.chooseSplit(binAggregates, featureIndex, featureIndexIdx, splits)
651643
}
652644

653645
val (bestSplit, bestSplitStats) =

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ class InformationGainStats(
7575
* @param impurityCalculator impurity statistics for current node
7676
* @param leftImpurityCalculator impurity statistics for left child node
7777
* @param rightImpurityCalculator impurity statistics for right child node
78-
* @param valid whether the current split satisfies minimum info gain or
79-
* minimum number of instances per node
78+
* @param valid whether the current split should be performed; true if split
79+
* satisfies minimum info gain, minimum number of instances per node, and
80+
* has positive info gain.
8081
*/
8182
private[spark] class ImpurityStats(
8283
val gain: Double,
@@ -112,7 +113,7 @@ private[spark] object ImpurityStats {
112113
* minimum number of instances per node.
113114
*/
114115
def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
115-
new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
116+
new ImpurityStats(Double.MinValue, impurity = -1,
116117
impurityCalculator, null, null, false)
117118
}
118119

mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeIntegrationSuite.scala

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -94,31 +94,4 @@ class LocalTreeIntegrationSuite extends SparkFunSuite with MLlibTestSparkContext
9494
testEquivalence(df, TreeTests.allParamSettings)
9595
}
9696

97-
// TODO(smurching): Probably remove this (since it depends on user env). Currently fails, partly
98-
// because collecting data for local training is slow but also because local training is
99-
// slightly slower than distributed training.
100-
// test("Local tree training is faster than distributed training on a medium-sized dataset") {
101-
// val sqlContext = spark.sqlContext
102-
// import sqlContext.implicits._
103-
// val df = LogisticRegressionDataGenerator.generateLogisticRDD(spark.sparkContext,
104-
// nexamples = 100000, nfeatures = 5, eps = 2.0, nparts = 1, probOne = 0.2)
105-
// .map(_.asML).toDF().cache()
106-
//
107-
// val timer = new TimeTracker()
108-
//
109-
// timer.start("local")
110-
// val localTree = setParams(new LocalDecisionTreeRegressor(), TreeTests.allParamSettings)
111-
// localTree.fit(df)
112-
// val localTrainTime = timer.stop("local")
113-
//
114-
// timer.start("distributed")
115-
// val distribTree = setParams(new DecisionTreeRegressor(), TreeTests.allParamSettings)
116-
// distribTree.fit(df)
117-
// val distribTrainTime = timer.stop("distributed")
118-
//
119-
// assert(localTrainTime < distribTrainTime, s"Local tree training time ($localTrainTime) " +
120-
// s"should be less than distributed tree training time ($distribTrainTime).")
121-
// }
122-
123-
12497
}

0 commit comments

Comments
 (0)