diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index fb4c321a146f..b041dd42d6da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1050,8 +1050,11 @@ private[spark] object RandomForest extends Logging with Serializable { // Calculate the expected number of samples for finding splits val weightedNumSamples = samplesFractionForFindSplits(metadata) * metadata.weightedNumExamples + // scale tolerance by number of samples with constant factor + // Note: constant factor was tuned by running some tests where there were no zero + // feature values and validating we are never within tolerance + val tolerance = Utils.EPSILON * unweightedNumSamples * 100 // add expected zero value count and get complete statistics - val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) { partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples)) } else {