diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 7694773c816b2..067b9eeee49ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ @@ -340,8 +341,9 @@ class LogisticRegression @Since("1.2.0") ( val regParamL1 = $(elasticNetParam) * $(regParam) val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) + val bcFeaturesStd = instances.context.broadcast(featuresStd) val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), - $(standardization), featuresStd, featuresMean, regParamL2) + $(standardization), bcFeaturesStd, regParamL2) val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) @@ -436,6 +438,7 @@ class LogisticRegression @Since("1.2.0") ( rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } i += 1 } + bcFeaturesStd.destroy(blocking = false) if ($(fitIntercept)) { (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, @@ -932,11 +935,15 @@ class BinaryLogisticRegressionSummary private[classification] ( * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * + * @param bcCoefficients The broadcast coefficients corresponding to the features. + * @param bcFeaturesStd The broadcast standard deviation values of the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. */ private class LogisticAggregator( + val bcCoefficients: Broadcast[Vector], + val bcFeaturesStd: Broadcast[Array[Double]], private val numFeatures: Int, numClasses: Int, fitIntercept: Boolean) extends Serializable { @@ -952,14 +959,9 @@ private class LogisticAggregator( * of the objective function. * * @param instance The instance of data point to be added. - * @param coefficients The coefficients corresponding to the features. - * @param featuresStd The standard deviation values of the features. * @return This LogisticAggregator object. */ - def add( - instance: Instance, - coefficients: Vector, - featuresStd: Array[Double]): this.type = { + def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $numFeatures but got ${features.size}.") @@ -967,14 +969,16 @@ private class LogisticAggregator( if (weight == 0.0) return this - val coefficientsArray = coefficients match { + val coefficientsArray = bcCoefficients.value match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"coefficients only supports dense vector but got type ${coefficients.getClass}.") + "coefficients only supports dense vector" + + s"but got type ${bcCoefficients.value.getClass}.") } val localGradientSumArray = gradientSumArray + val featuresStd = bcFeaturesStd.value numClasses match { case 2 => // For Binary Logistic Regression. @@ -1071,24 +1075,23 @@ private class LogisticCostFun( numClasses: Int, fitIntercept: Boolean, standardization: Boolean, - featuresStd: Array[Double], - featuresMean: Array[Double], + bcFeaturesStd: Broadcast[Array[Double]], regParamL2: Double) extends DiffFunction[BDV[Double]] { + val featuresStd = bcFeaturesStd.value + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length val coeffs = Vectors.fromBreeze(coefficients) + val bcCoeffs = instances.context.broadcast(coeffs) val n = coeffs.size - val localFeaturesStd = featuresStd - val logisticAggregator = { - val seqOp = (c: LogisticAggregator, instance: Instance) => - c.add(instance, coeffs, localFeaturesStd) + val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) instances.treeAggregate( - new LogisticAggregator(numFeatures, numClasses, fitIntercept) + new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept) )(seqOp, combOp) } @@ -1128,6 +1131,7 @@ private class LogisticCostFun( } 0.5 * regParamL2 * sum } + bcCoeffs.destroy(blocking = false) (logisticAggregator.loss + regVal, new BDV(totalGradientArray)) }