diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index 15d71757a672..d511c1b5dda9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -204,14 +204,8 @@ class FMClassifier @Since("3.0.0") ( instr.logNumFeatures(numFeatures) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val data: RDD[(Double, OldVector)] = - dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => - require(label == 0 || label == 1, s"FMClassifier was given" + - s" dataset with invalid label $label. Labels must be in {0,1}; note that" + - s" FMClassifier currently only supports binary classification.") - (label, features) - } + val labeledPoint = extractLabeledPoints(dataset, numClasses) + val data: RDD[(Double, OldVector)] = labeledPoint.map(x => (x.label, x.features)) if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala index 0bf1836edbd4..0bdd0b4d9146 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala @@ -419,11 +419,8 @@ class FMRegressor @Since("3.0.0") ( instr.logNumFeatures(numFeatures) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val data: RDD[(Double, OldVector)] = - dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => - (label, features) - } + val labeledPoint = extractLabeledPoints(dataset) + val data: RDD[(Double, OldVector)] = labeledPoint.map(x => (x.label, x.features)) if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)