Skip to content

Commit f708edb

Browse files
committed
updated based on similar previous PR comments
1 parent aca6255 commit f708edb

File tree

5 files changed

+19
-30
lines changed

5 files changed

+19
-30
lines changed

mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
7171
def setLabelCol(value: String): this.type = set(labelCol, value)
7272

7373
/** @group setParam */
74-
@Since("2.2.0")
74+
@Since("3.0.0")
7575
def setWeightCol(value: String): this.type = set(weightCol, value)
7676

7777
setDefault(metricName -> "rmse")
@@ -88,7 +88,7 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
8888
.rdd
8989
.map { case Row(prediction: Double, label: Double, weight: Double) =>
9090
(prediction, label, weight) }
91-
val metrics = new RegressionMetrics(false, predictionAndLabelsWithWeights)
91+
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
9292
val metric = $(metricName) match {
9393
case "rmse" => metrics.rootMeanSquaredError
9494
case "mse" => metrics.meanSquaredError

mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,20 @@ import org.apache.spark.sql.DataFrame
3434
*/
3535
@Since("1.2.0")
3636
class RegressionMetrics @Since("2.0.0") (
37-
throughOrigin: Boolean, predAndObsWithOptWeight: RDD[_])
37+
predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
3838
extends Logging {
3939

4040
@Since("1.2.0")
41-
def this(predictionAndObservations: RDD[(Double, Double)]) =
42-
this(false, predictionAndObservations)
43-
44-
/**
45-
* Evaluator for regression.
46-
*
47-
* @param predictionAndObservations an RDD of (prediction, observation) pairs
48-
* @param throughOrigin True if the regression is through the origin. For example, in linear
49-
* regression, it will be true without fitting intercept.
50-
*/
51-
@Since("2.0.0")
52-
def this(predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) =
53-
this(throughOrigin, predictionAndObservations)
41+
def this(predictionAndObservations: RDD[_ <: Product]) =
42+
this(predictionAndObservations, false)
5443

5544
/**
5645
* An auxiliary constructor taking a DataFrame.
5746
* @param predictionAndObservations a DataFrame with two double columns:
5847
* prediction and observation
5948
*/
6049
private[mllib] def this(predictionAndObservations: DataFrame) =
61-
this(false, predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
50+
this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
6251

6352
/**
6453
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
5252
private var totalCnt: Long = 0
5353
private var totalWeightSum: Double = 0.0
5454
private var weightSquareSum: Double = 0.0
55-
private var currWeightSum: Array[Double] = _
55+
private var weightSum: Array[Double] = _
5656
private var nnz: Array[Long] = _
5757
private var currMax: Array[Double] = _
5858
private var currMin: Array[Double] = _
@@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
7878
currM2n = Array.ofDim[Double](n)
7979
currM2 = Array.ofDim[Double](n)
8080
currL1 = Array.ofDim[Double](n)
81-
currWeightSum = Array.ofDim[Double](n)
81+
weightSum = Array.ofDim[Double](n)
8282
nnz = Array.ofDim[Long](n)
8383
currMax = Array.fill[Double](n)(Double.MinValue)
8484
currMin = Array.fill[Double](n)(Double.MaxValue)
@@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
9191
val localCurrM2n = currM2n
9292
val localCurrM2 = currM2
9393
val localCurrL1 = currL1
94-
val localWeightSum = currWeightSum
94+
val localWeightSum = weightSum
9595
val localNumNonzeros = nnz
9696
val localCurrMax = currMax
9797
val localCurrMin = currMin
@@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
139139
weightSquareSum += other.weightSquareSum
140140
var i = 0
141141
while (i < n) {
142-
val thisNnz = currWeightSum(i)
143-
val otherNnz = other.currWeightSum(i)
142+
val thisNnz = weightSum(i)
143+
val otherNnz = other.weightSum(i)
144144
val totalNnz = thisNnz + otherNnz
145145
val totalCnnz = nnz(i) + other.nnz(i)
146146
if (totalNnz != 0.0) {
@@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
157157
currMax(i) = math.max(currMax(i), other.currMax(i))
158158
currMin(i) = math.min(currMin(i), other.currMin(i))
159159
}
160-
currWeightSum(i) = totalNnz
160+
weightSum(i) = totalNnz
161161
nnz(i) = totalCnnz
162162
i += 1
163163
}
@@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
170170
this.totalCnt = other.totalCnt
171171
this.totalWeightSum = other.totalWeightSum
172172
this.weightSquareSum = other.weightSquareSum
173-
this.currWeightSum = other.currWeightSum.clone()
173+
this.weightSum = other.weightSum.clone()
174174
this.nnz = other.nnz.clone()
175175
this.currMax = other.currMax.clone()
176176
this.currMin = other.currMin.clone()
@@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
189189
val realMean = Array.ofDim[Double](n)
190190
var i = 0
191191
while (i < n) {
192-
realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
192+
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
193193
i += 1
194194
}
195195
Vectors.dense(realMean)
@@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
214214
val len = currM2n.length
215215
while (i < len) {
216216
// We prevent variance from negative value caused by numerical error.
217-
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
218-
(totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
217+
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
218+
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
219219
i += 1
220220
}
221221
}

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ trait MultivariateStatisticalSummary {
4747
/**
4848
* Sum of weights.
4949
*/
50-
@Since("2.2.0")
50+
@Since("3.0.0")
5151
def weightSum: Double
5252

5353
/**

mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
137137
test("regression metrics with same (1.0) weight samples") {
138138
val predictionAndObservationWithWeight = sc.parallelize(
139139
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
140-
val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight)
140+
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
141141
assert(metrics.explainedVariance ~== 8.79687 absTol eps,
142142
"explained variance regression score mismatch")
143143
assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch")
@@ -174,7 +174,7 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
174174
test("regression metrics with weighted samples") {
175175
val predictionAndObservationWithWeight = sc.parallelize(
176176
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
177-
val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight)
177+
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
178178
assert(metrics.explainedVariance ~== 5.2425 absTol eps,
179179
"explained variance regression score mismatch")
180180
assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch")

0 commit comments

Comments
 (0)