@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path
2828import org .apache .spark .SparkException
2929import org .apache .spark .annotation .Since
3030import org .apache .spark .internal .Logging
31- import org .apache .spark .ml .feature .{ Instance , InstanceBlock }
31+ import org .apache .spark .ml .feature .Instance
3232import org .apache .spark .ml .linalg ._
3333import org .apache .spark .ml .optim .aggregator .LogisticAggregator
3434import org .apache .spark .ml .optim .loss .{L2Regularization , RDDLossFunction }
@@ -50,8 +50,7 @@ import org.apache.spark.util.VersionUtils
5050 */
5151private [classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
5252 with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
53- with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth
54- with HasBlockSize {
53+ with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth {
5554
5655 import org .apache .spark .ml .classification .LogisticRegression .supportedFamilyNames
5756
@@ -431,15 +430,6 @@ class LogisticRegression @Since("1.2.0") (
431430 @ Since (" 2.2.0" )
432431 def setUpperBoundsOnIntercepts (value : Vector ): this .type = set(upperBoundsOnIntercepts, value)
433432
434- /**
435- * Set block size for stacking input data in matrices.
436- * Default is 1024.
437- *
438- * @group expertSetParam
439- */
440- @ Since (" 3.0.0" )
441- def setBlockSize (value : Int ): this .type = set(blockSize, value)
442-
443433 private def assertBoundConstrainedOptimizationParamsValid (
444434 numCoefficientSets : Int ,
445435 numFeatures : Int ): Unit = {
@@ -492,17 +482,24 @@ class LogisticRegression @Since("1.2.0") (
492482 this
493483 }
494484
495- override protected [spark] def train (
496- dataset : Dataset [_]): LogisticRegressionModel = instrumented { instr =>
485+ override protected [spark] def train (dataset : Dataset [_]): LogisticRegressionModel = {
486+ val handlePersistence = dataset.storageLevel == StorageLevel .NONE
487+ train(dataset, handlePersistence)
488+ }
489+
490+ protected [spark] def train (
491+ dataset : Dataset [_],
492+ handlePersistence : Boolean ): LogisticRegressionModel = instrumented { instr =>
493+ val instances = extractInstances(dataset)
494+
495+ if (handlePersistence) instances.persist(StorageLevel .MEMORY_AND_DISK )
496+
497497 instr.logPipelineStage(this )
498498 instr.logDataset(dataset)
499499 instr.logParams(this , labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol,
500500 probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol,
501501 fitIntercept)
502502
503- val sc = dataset.sparkSession.sparkContext
504- val instances = extractInstances(dataset)
505-
506503 val (summarizer, labelSummarizer) = instances.treeAggregate(
507504 (Summarizer .createSummarizerBuffer(" mean" , " std" , " count" ), new MultiClassSummarizer ))(
508505 seqOp = (c : (SummarizerBuffer , MultiClassSummarizer ), instance : Instance ) =>
@@ -585,9 +582,8 @@ class LogisticRegression @Since("1.2.0") (
585582 s " dangerous ground, so the algorithm may not converge. " )
586583 }
587584
588- val featuresMean = summarizer.mean.compressed
589- val featuresStd = summarizer.std.compressed
590- val bcFeaturesStd = sc.broadcast(featuresStd)
585+ val featuresMean = summarizer.mean.toArray
586+ val featuresStd = summarizer.std.toArray
591587
592588 if (! $(fitIntercept) && (0 until numFeatures).exists { i =>
593589 featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
@@ -599,7 +595,8 @@ class LogisticRegression @Since("1.2.0") (
599595 val regParamL1 = $(elasticNetParam) * $(regParam)
600596 val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
601597
602- val getAggregatorFunc = new LogisticAggregator (numFeatures, numClasses, $(fitIntercept),
598+ val bcFeaturesStd = instances.context.broadcast(featuresStd)
599+ val getAggregatorFunc = new LogisticAggregator (bcFeaturesStd, numClasses, $(fitIntercept),
603600 multinomial = isMultinomial)(_)
604601 val getFeaturesStd = (j : Int ) => if (j >= 0 && j < numCoefficientSets * numFeatures) {
605602 featuresStd(j / numCoefficientSets)
@@ -615,21 +612,7 @@ class LogisticRegression @Since("1.2.0") (
615612 None
616613 }
617614
618- val standardized = instances.map {
619- case Instance (label, weight, features) =>
620- val featuresStd = bcFeaturesStd.value
621- val array = Array .ofDim[Double ](numFeatures)
622- features.foreachNonZero { (i, v) =>
623- val std = featuresStd(i)
624- if (std != 0 ) array(i) = v / std
625- }
626- Instance (label, weight, Vectors .dense(array))
627- }
628- val blocks = InstanceBlock .blokify(standardized, $(blockSize))
629- .persist(StorageLevel .MEMORY_AND_DISK )
630- .setName(s " training dataset (blockSize= ${$(blockSize)}) " )
631-
632- val costFun = new RDDLossFunction (blocks, getAggregatorFunc, regularization,
615+ val costFun = new RDDLossFunction (instances, getAggregatorFunc, regularization,
633616 $(aggregationDepth))
634617
635618 val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets
@@ -823,7 +806,6 @@ class LogisticRegression @Since("1.2.0") (
823806 state = states.next()
824807 arrayBuilder += state.adjustedValue
825808 }
826- blocks.unpersist()
827809 bcFeaturesStd.destroy()
828810
829811 if (state == null ) {
@@ -893,6 +875,8 @@ class LogisticRegression @Since("1.2.0") (
893875 }
894876 }
895877
878+ if (handlePersistence) instances.unpersist()
879+
896880 val model = copyValues(new LogisticRegressionModel (uid, coefficientMatrix, interceptVector,
897881 numClasses, isMultinomial))
898882
0 commit comments