diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 32b0af72ba9b..1ed218aa58bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares { /** * Weighted population standard deviation of labels. */ - def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar) + def bStd: Double = { + // We prevent variance from negative value caused by numerical error. + val variance = math.max(bbSum / wSum - bBar * bBar, 0.0) + math.sqrt(variance) + } /** * Weighted mean of (label * features). @@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares { while (i < triK) { val l = j - 2 val aw = aSum(l) / wSum - std(l) = math.sqrt(aaValues(i) / wSum - aw * aw) + // We prevent variance from negative value caused by numerical error. + std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0)) i += j j += 1 } @@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares { while (i < triK) { val l = j - 2 val aw = aSum(l) / wSum - variance(l) = aaValues(i) / wSum - aw * aw + // We prevent variance from negative value caused by numerical error. + variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0) i += j j += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 7e408b9dbd13..cae41edb7aca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -436,8 +436,9 @@ private[ml] object SummaryBuilderImpl extends Logging { var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator + // We prevent variance from negative value caused by numerical error. + realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * + (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7dc0c459ec03..8121880cfb23 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator + // We prevent variance from negative value caused by numerical error. + realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * + (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index dfb733ff6e76..1ea851ef2d67 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(summarizer.count === 6) } + test("summarizer buffer zero variance test (SPARK-21818)") { + val summarizer1 = new SummarizerBuffer() + .add(Vectors.dense(3.0), 0.7) + val summarizer2 = new SummarizerBuffer() + .add(Vectors.dense(3.0), 0.4) + val summarizer3 = new SummarizerBuffer() + .add(Vectors.dense(3.0), 0.5) + val summarizer4 = new SummarizerBuffer() + .add(Vectors.dense(3.0), 0.4) + + val summarizer = summarizer1 + .merge(summarizer2) + .merge(summarizer3) + .merge(summarizer4) + + assert(summarizer.variance(0) >= 0.0) + } + test("summarizer buffer merging summarizer with empty summarizer") { // If one of two is non-empty, this should return the non-empty summarizer. // If both of them are empty, then just return the empty summarizer. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 797e84fcc737..c6466bc918dd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) } + + test ("test zero variance (SPARK-21818)") { + val summarizer1 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.7) + val summarizer2 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.4) + val summarizer3 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.5) + val summarizer4 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.4) + + val summarizer = summarizer1 + .merge(summarizer2) + .merge(summarizer3) + .merge(summarizer4) + + assert(summarizer.variance(0) >= 0.0) + } }