Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
val aft = new AFTSurvivalRegression()
.setCensorCol(censorCol)
.setFitIntercept(rFormula.hasIntercept)
.setFeaturesCol(rFormula.getFeaturesCol)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, aft))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
.setK(k)
.setMaxIter(maxIter)
.setTol(tol)
.setFeaturesCol(rFormula.getFeaturesCol)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, gm))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setMaxIter(maxIter)
.setWeightCol(weightCol)
.setRegParam(regParam)
.setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ private[r] object IsotonicRegressionWrapper
.setIsotonic(isotonic)
.setFeatureIndex(featureIndex)
.setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, isotonicRegression))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
.setK(k)
.setMaxIter(maxIter)
.setInitMode(initMode)
.setFeaturesCol(rFormula.getFeaturesCol)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, kMeans))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
val naiveBayes = new NaiveBayes()
.setSmoothing(smoothing)
.setModelType("bernoulli")
.setFeaturesCol(rFormula.getFeaturesCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
Expand Down
34 changes: 4 additions & 30 deletions mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,27 @@ package org.apache.spark.ml.r

import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset

object RWrapperUtils extends Logging {

/**
* DataFrame column check.
* When loading data, default columns "features" and "label" will be added. And these two names
* would conflict with RFormula default feature and label column names.
* When loading libsvm data, default columns "features" and "label" will be added.
* And "features" would conflict with RFormula default feature column names.
* Here is to change the column name to avoid "column already exists" error.
*
* @param rFormula RFormula instance
* @param data Input dataset
* @return Unit
*/
def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = {
if (data.schema.fieldNames.contains(rFormula.getLabelCol)) {
val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames)
logWarning(
s"data containing ${rFormula.getLabelCol} column, using new name $newLabelName instead")
rFormula.setLabelCol(newLabelName)
}

if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) {
val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames)
val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}"
logWarning(s"data containing ${rFormula.getFeaturesCol} column, " +
s"using new name $newFeaturesName instead")
rFormula.setFeaturesCol(newFeaturesName)
}
}

/**
* Convert conflicting name to be an unique name.
* Appending a sequence number, like originalName_output1
* and incrementing until it is not already there
*
* @param originalName Original name
* @param fieldNames Array of field names in existing schema
* @return String
*/
def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = {
var counter = 1
var newName = originalName + "_output"

while (fieldNames.contains(newName)) {
newName = originalName + "_output" + counter
counter += 1
}
newName
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,14 @@ class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
// after checking, model build is ok
RWrapperUtils.checkDataColumns(rFormula, data)

assert(rFormula.getLabelCol == "label_output")
assert(rFormula.getFeaturesCol == "features_output")
assert(rFormula.getLabelCol == "label")
assert(rFormula.getFeaturesCol.startsWith("features_"))

val model = rFormula.fit(data)
assert(model.isInstanceOf[RFormulaModel])

assert(model.getLabelCol == "label_output")
assert(model.getFeaturesCol == "features_output")
}

test("generate unique name by appending a sequence number") {
val originalName = "label"
val fieldNames = Array("label_output", "label_output1", "label_output2")
val newName = RWrapperUtils.convertToUniqueName(originalName, fieldNames)

assert(newName === "label_output3")
assert(model.getLabelCol == "label")
assert(model.getFeaturesCol.startsWith("features_"))
}

}