Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
val (rewritedFormula, censorCol) = formulaRewrite(formula)

val rFormula = new RFormula().setFormula(rewritedFormula)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

// get feature names from output schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
maxIter: Int,
tol: Double): GaussianMixtureWrapper = {

val rFormulaModel = new RFormula()
val rFormula = new RFormula()
.setFormula(formula)
.setFeaturesCol("features")
.fit(data)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

// get feature names from output schema
val schema = rFormulaModel.transform(data).schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ private[r] object GeneralizedLinearRegressionWrapper
weightCol: String): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ private[r] object IsotonicRegressionWrapper
featureIndex: Int,
weightCol: String): IsotonicRegressionWrapper = {

val rFormulaModel = new RFormula()
val rFormula = new RFormula()
.setFormula(formula)
.setFeaturesCol("features")
.fit(data)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

// get feature names from output schema
val schema = rFormulaModel.transform(data).schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
maxIter: Int,
initMode: String): KMeansWrapper = {

val rFormulaModel = new RFormula()
val rFormula = new RFormula()
.setFormula(formula)
.setFeaturesCol("features")
.fit(data)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

// get feature names from output schema
val schema = rFormulaModel.transform(data).schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.fit(data)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
val schema = rFormula.transform(data).schema
val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol))
val schema = rFormulaModel.transform(data).schema
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
.asInstanceOf[NominalAttribute]
val labels = labelAttr.values.get
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
// assemble and fit the pipeline
Expand All @@ -78,7 +79,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
.setOutputCol(PREDICTED_LABEL_COL)
.setLabels(labels)
val pipeline = new Pipeline()
.setStages(Array(rFormula, naiveBayes, idxToStr))
.setStages(Array(rFormulaModel, naiveBayes, idxToStr))
.fit(data)
new NaiveBayesWrapper(pipeline, labels, features)
}
Expand Down
71 changes: 71 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.r

import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.sql.Dataset

object RWrapperUtils extends Logging {

/**
* DataFrame column check.
* When loading data, default columns "features" and "label" will be added. And these two names
* would conflict with RFormula default feature and label column names.
* Here is to change the column name to avoid "column already exists" error.
*
* @param rFormula RFormula instance
* @param data Input dataset
* @return Unit
*/
def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = {
if (data.schema.fieldNames.contains(rFormula.getLabelCol)) {
val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think we end up checking for label_output twice, once in if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) and second time within convertToUniqueName? Perhaps we merge them?

Copy link
Contributor Author

@keypointt keypointt Sep 1, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)), it's checking label only

and in convertToUniqueName (), _output will be appended resulting in label_output: var newName = originalName + "_output", and then label_output is checked at while (fieldNames.contains(newName))

am I missing something? @felix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough. that makes sense, thanks

logWarning(
s"data containing ${rFormula.getLabelCol} column, using new name $newLabelName instead")
rFormula.setLabelCol(newLabelName)
}

if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) {
val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames)
logWarning(s"data containing ${rFormula.getFeaturesCol} column, " +
s"using new name $newFeaturesName instead")
rFormula.setFeaturesCol(newFeaturesName)
}
}

/**
* Convert conflicting name to be an unique name.
* Appending a sequence number, like originalName_output1
* and incrementing until it is not already there
*
* @param originalName Original name
* @param fieldNames Array of field names in existing schema
* @return String
*/
def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = {
var counter = 1
var newName = originalName + "_output"

while (fieldNames.contains(newName)) {
newName = originalName + "_output" + counter
counter += 1
}
newName
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
intercept[IllegalArgumentException] {
formula.fit(original)
}
intercept[IllegalArgumentException] {
formula.fit(original)
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is just a duplication of above


test("label column already exists") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.r

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.{RFormula, RFormulaModel}
import org.apache.spark.mllib.util.MLlibTestSparkContext

class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {

test("avoid libsvm data column name conflicting") {
val rFormula = new RFormula().setFormula("label ~ features")
val data = spark.read.format("libsvm").load("../data/mllib/sample_libsvm_data.txt")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I used "../data/", I'm not sure if there is a better way to do it, something like $current_directory/data/mllib/sample_libsvm_data.txt?

All I found is like this val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala#L36


// if not checking column name, then IllegalArgumentException
intercept[IllegalArgumentException] {
rFormula.fit(data)
}

// after checking, model build is ok
RWrapperUtils.checkDataColumns(rFormula, data)

assert(rFormula.getLabelCol == "label_output")
assert(rFormula.getFeaturesCol == "features_output")

val model = rFormula.fit(data)
assert(model.isInstanceOf[RFormulaModel])

assert(model.getLabelCol == "label_output")
assert(model.getFeaturesCol == "features_output")
}

test("generate unique name by appending a sequence number") {
val originalName = "label"
val fieldNames = Array("label_output", "label_output1", "label_output2")
val newName = RWrapperUtils.convertToUniqueName(originalName, fieldNames)

assert(newName === "label_output3")
}

}