@@ -138,6 +138,8 @@ final class OneVsRestModel private[ml] (
138138 @ Since (" 1.4.0" ) val models : Array [_ <: ClassificationModel [_, _]])
139139 extends Model [OneVsRestModel ] with OneVsRestParams with MLWritable {
140140
141+ require(models.nonEmpty, " OneVsRestModel requires at least one model for one class" )
142+
141143 @ Since (" 2.4.0" )
142144 val numClasses : Int = models.length
143145
@@ -206,24 +208,25 @@ final class OneVsRestModel private[ml] (
206208 newDataset.unpersist()
207209 }
208210
209- // output the RawPrediction as vector
210211 if (getRawPredictionCol != " " ) {
212+ val numClass = models.length
213+
214+ // output the RawPrediction as vector
211215 val rawPredictionUDF = udf { (predictions : Map [Int , Double ]) =>
212- val predArray = Array .fill[Double ](numClasses )(0.0 )
216+ val predArray = Array .fill[Double ](numClass )(0.0 )
213217 predictions.foreach { case (idx, value) => predArray(idx) = value }
214218 Vectors .dense(predArray)
215219 }
216220
217221 // output the index of the classifier with highest confidence as prediction
218- val labelUDF = udf { (rawpredictions : Vector ) => rawpredictions .argmax.toDouble }
222+ val labelUDF = udf { (rawPredictions : Vector ) => rawPredictions .argmax.toDouble }
219223
220224 // output confidence as raw prediction, label and label metadata as prediction
221225 aggregatedDataset
222226 .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
223227 .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
224228 .drop(accColName)
225- }
226- else {
229+ } else {
227230 // output the index of the classifier with highest confidence as prediction
228231 val labelUDF = udf { (predictions : Map [Int , Double ]) =>
229232 predictions.maxBy(_._2)._1.toDouble
@@ -326,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
326329 @ Since (" 1.5.0" )
327330 def setPredictionCol (value : String ): this .type = set(predictionCol, value)
328331
332+ /** @group setParam */
333+ @ Since (" 2.4.0" )
334+ def setRawPredictionCol (value : String ): this .type = set(rawPredictionCol, value)
335+
329336 /**
330337 * The implementation of parallel one vs. rest runs the classification for
331338 * each class in a separate threads.
0 commit comments