diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala index e7f7a8e07d7f2..3e32f746e9cd9 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala @@ -55,7 +55,9 @@ class MultivariateGaussian @Since("2.0.0") ( * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - @transient private lazy val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants + @transient private lazy val tuple = calculateCovarianceConstants + @transient private lazy val rootSigmaInv = tuple._1 + @transient private lazy val u = tuple._2 /** * Returns density of this multivariate Gaussian at given point, x diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 94681ae9ef796..5459a0fab9135 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -405,18 +405,26 @@ class NaiveBayesModel private[ml] ( * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra * application of this condition (in predict function). */ - @transient private lazy val (thetaMinusNegTheta, piMinusThetaSum) = $(modelType) match { + @transient private lazy val thetaMinusNegTheta = $(modelType) match { + case Bernoulli => + theta.map(value => value - math.log1p(-math.exp(value))) + case _ => + // This should never happen. + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}. " + + "Variables thetaMinusNegTheta should only be precomputed in Bernoulli NB.") + } + + @transient private lazy val piMinusThetaSum = $(modelType) match { case Bernoulli => - val thetaMinusNegTheta = theta.map(value => value - math.log1p(-math.exp(value))) val negTheta = theta.map(value => math.log1p(-math.exp(value))) val ones = new DenseVector(Array.fill(theta.numCols)(1.0)) val piMinusThetaSum = pi.toDense.copy BLAS.gemv(1.0, negTheta, ones, 1.0, piMinusThetaSum) - (thetaMinusNegTheta, piMinusThetaSum) + piMinusThetaSum case _ => // This should never happen. throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}. " + - "Variables thetaMinusNegTheta and negThetaSum should only be precomputed in Bernoulli NB.") + "Variables piMinusThetaSum should only be precomputed in Bernoulli NB.") } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 9a746dcf35556..f34c22915ae15 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -60,7 +60,9 @@ class MultivariateGaussian @Since("1.3.0") ( * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - @transient private lazy val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants + @transient private lazy val tuple = calculateCovarianceConstants + @transient private lazy val rootSigmaInv = tuple._1 + @transient private lazy val u = tuple._2 /** * Returns density of this multivariate Gaussian at given point, x