Skip to content

Commit b3c7fec

Browse files
committed
add require in OneVsRestModel
add setRawPredictionCol in OneVsRest create a local var numClass to resolve the issue
1 parent ebf4a6c commit b3c7fec

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ final class OneVsRestModel private[ml] (
138138
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
139139
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
140140

141+
require(models.nonEmpty, "OneVsRestModel requires at least one model for one class")
142+
141143
@Since("2.4.0")
142144
val numClasses: Int = models.length
143145

@@ -206,24 +208,25 @@ final class OneVsRestModel private[ml] (
206208
newDataset.unpersist()
207209
}
208210

209-
// output the RawPrediction as vector
210211
if (getRawPredictionCol != "") {
212+
val numClass = models.length
213+
214+
// output the RawPrediction as vector
211215
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
212-
val predArray = Array.fill[Double](numClasses)(0.0)
216+
val predArray = Array.fill[Double](numClass)(0.0)
213217
predictions.foreach { case (idx, value) => predArray(idx) = value }
214218
Vectors.dense(predArray)
215219
}
216220

217221
// output the index of the classifier with highest confidence as prediction
218-
val labelUDF = udf { (rawpredictions: Vector) => rawpredictions.argmax.toDouble }
222+
val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble }
219223

220224
// output confidence as raw prediction, label and label metadata as prediction
221225
aggregatedDataset
222226
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
223227
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
224228
.drop(accColName)
225-
}
226-
else {
229+
} else {
227230
// output the index of the classifier with highest confidence as prediction
228231
val labelUDF = udf { (predictions: Map[Int, Double]) =>
229232
predictions.maxBy(_._2)._1.toDouble
@@ -326,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
326329
@Since("1.5.0")
327330
def setPredictionCol(value: String): this.type = set(predictionCol, value)
328331

332+
/** @group setParam */
333+
@Since("2.4.0")
334+
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
335+
329336
/**
330337
* The implementation of parallel one vs. rest runs the classification for
331338
* each class in a separate threads.

0 commit comments

Comments
 (0)