From cfed8844cbadbd760f73c2f906a1591806001a93 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Wed, 8 Jun 2016 16:39:28 -0700 Subject: [PATCH 01/13] [SPARK-15509] remove duplicate of intercept[IllegalArgumentException] --- .../test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index c12ab8fe9efe..0794a049d9cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -54,9 +54,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul intercept[IllegalArgumentException] { formula.fit(original) } - intercept[IllegalArgumentException] { - formula.fit(original) - } } test("label column already exists") { From 77886fe59463027f24c6ca909638731145b46ee2 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 9 Jun 2016 13:59:38 -0700 Subject: [PATCH 02/13] [SPARK-15509] no column exists error for naivebayes. expand to other wrappers --- .../apache/spark/ml/r/NaiveBayesWrapper.scala | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 28925c79da66..741d1442810e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -59,13 +59,24 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) - .fit(data) + + if (data.schema.fieldNames.contains("label")) { + data.withColumnRenamed("label", "label_input") + rFormula.setLabelCol("label_input") + } + + if (data.schema.fieldNames.contains("features")) { + data.withColumnRenamed("features", "features_input") + rFormula.setFeaturesCol("features_input") + } + + val model = rFormula.fit(data) // get labels and feature names from output schema - val schema = rFormula.transform(data).schema - val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol)) + val schema = model.transform(data).schema + val labelAttr = Attribute.fromStructField(schema(model.getLabelCol)) .asInstanceOf[NominalAttribute] val labels = labelAttr.values.get - val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + val featureAttrs = AttributeGroup.fromStructField(schema(model.getFeaturesCol)) .attributes.get val features = featureAttrs.map(_.name.get) // assemble and fit the pipeline @@ -78,7 +89,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { .setOutputCol(PREDICTED_LABEL_COL) .setLabels(labels) val pipeline = new Pipeline() - .setStages(Array(rFormula, naiveBayes, idxToStr)) + .setStages(Array(model, naiveBayes, idxToStr)) .fit(data) new NaiveBayesWrapper(pipeline, labels, features) } From e112ac0c0685f399f72e9ed60be00964ec4fcdc4 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 9 Jun 2016 14:04:56 -0700 Subject: [PATCH 03/13] [SPARK-15509] add a util function for all wrappers --- .../org/apache/spark/ml/r/RWrapperUtils.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala new file mode 100644 index 000000000000..3721844ec071 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +object RWrapperUtils { + +} From ef3702ee5beefad1ee51fe15cb01e1716aeda362 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 9 Jun 2016 15:27:37 -0700 Subject: [PATCH 04/13] [SPARK-15509] expand column check to other wrappers --- .../ml/r/AFTSurvivalRegressionWrapper.scala | 1 + .../GeneralizedLinearRegressionWrapper.scala | 1 + .../org/apache/spark/ml/r/KMeansWrapper.scala | 6 ++--- .../apache/spark/ml/r/NaiveBayesWrapper.scala | 20 +++++----------- .../org/apache/spark/ml/r/RWrapperUtils.scala | 23 +++++++++++++++++++ 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 5462f80d69ff..67d037ed6e02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -87,6 +87,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg val (rewritedFormula, censorCol) = formulaRewrite(formula) val rFormula = new RFormula().setFormula(rewritedFormula) + RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get feature names from output schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 9618a3423e9a..0de50fc94b8d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -71,6 +71,7 @@ private[r] object GeneralizedLinearRegressionWrapper maxit: Int): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula() .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema val schema = rFormulaModel.transform(data).schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index 4d4c303fc8c2..848e3f65a776 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -70,10 +70,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { maxIter: Int, initMode: String): KMeansWrapper = { - val rFormulaModel = new RFormula() + val rFormula = new RFormula() .setFormula(formula) - .setFeaturesCol("features") - .fit(data) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) // get feature names from output schema val schema = rFormulaModel.transform(data).schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 741d1442810e..dd25f02d37c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -60,23 +60,15 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val rFormula = new RFormula() .setFormula(formula) - if (data.schema.fieldNames.contains("label")) { - data.withColumnRenamed("label", "label_input") - rFormula.setLabelCol("label_input") - } - - if (data.schema.fieldNames.contains("features")) { - data.withColumnRenamed("features", "features_input") - rFormula.setFeaturesCol("features_input") - } + RWrapperUtils.checkDataColumns(rFormula, data) - val model = rFormula.fit(data) + val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema - val schema = model.transform(data).schema - val labelAttr = Attribute.fromStructField(schema(model.getLabelCol)) + val schema = rFormulaModel.transform(data).schema + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) .asInstanceOf[NominalAttribute] val labels = labelAttr.values.get - val featureAttrs = AttributeGroup.fromStructField(schema(model.getFeaturesCol)) + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) .attributes.get val features = featureAttrs.map(_.name.get) // assemble and fit the pipeline @@ -89,7 +81,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { .setOutputCol(PREDICTED_LABEL_COL) .setLabels(labels) val pipeline = new Pipeline() - .setStages(Array(model, naiveBayes, idxToStr)) + .setStages(Array(rFormulaModel, naiveBayes, idxToStr)) .fit(data) new NaiveBayesWrapper(pipeline, labels, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 3721844ec071..fcb34c1a05a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -17,6 +17,29 @@ package org.apache.spark.ml.r +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.sql.Dataset + object RWrapperUtils { + /** + * DataFrame column check. + * When loading data, default columns "features" and "label" will be added. And these two names + * would conflict with RFormula default feature and label column names. + * Here is to change the column name to avoid "column already exists" error. + * + * @param rFormula RFormula instance + * @param data input dataset + * @return Unit + */ + def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { + + if (data.schema.fieldNames.contains("label")) { + rFormula.setLabelCol("label_output") + } + + if (data.schema.fieldNames.contains("features")) { + rFormula.setFeaturesCol("features_output") + } + } } From aab3a12fe09cf3039708468a80837fa421739c69 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 9 Jun 2016 16:05:51 -0700 Subject: [PATCH 05/13] [SPARK-15509] add unit test --- .../org/apache/spark/ml/r/RWrapperUtils.scala | 5 +- .../spark/ml/r/RWrapperUtilsSuite.scala | 48 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index fcb34c1a05a7..1d606a13c5e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -17,10 +17,11 @@ package org.apache.spark.ml.r +import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.RFormula import org.apache.spark.sql.Dataset -object RWrapperUtils { +object RWrapperUtils extends Logging { /** * DataFrame column check. @@ -35,10 +36,12 @@ object RWrapperUtils { def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { if (data.schema.fieldNames.contains("label")) { + logWarning("data containing 'label' column, so change its name to avoid conflict") rFormula.setLabelCol("label_output") } if (data.schema.fieldNames.contains("features")) { + logWarning("data containing 'features' column, so change its name to avoid conflict") rFormula.setFeaturesCol("features_output") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala new file mode 100644 index 000000000000..9fa82343559c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.mllib.util.MLlibTestSparkContext + + +class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("avoid column name conflicting") { + val rFormula = new RFormula().setFormula("label ~ features") + val data = spark.read.format("libsvm") + .load("../../data/mllib/sample_libsvm_data.txt") + + // if not checking column name, then IllegalArgumentException + intercept[IllegalArgumentException] { + rFormula.fit(data) + } + + // after checking + RWrapperUtils.checkDataColumns(rFormula, data) + + assert(rFormula.getLabelCol == "label_output") + assert(rFormula.getFeaturesCol == "features_output") + + val model = rFormula.fit(data) + + assert(model.getLabelCol == "label_output") + assert(model.getFeaturesCol == "features_output") + } +} From f68ac34907f3a7d1d66e98572ada34d47df3eab9 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 9 Jun 2016 17:01:44 -0700 Subject: [PATCH 06/13] [SPARK-15509] some clean up --- .../scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala | 2 -- .../main/scala/org/apache/spark/ml/r/RWrapperUtils.scala | 1 - .../scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala | 9 ++++----- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index dd25f02d37c9..a9d151fb1da5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -59,9 +59,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) - RWrapperUtils.checkDataColumns(rFormula, data) - val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema val schema = rFormulaModel.transform(data).schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 1d606a13c5e6..6059fb0faf2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -34,7 +34,6 @@ object RWrapperUtils extends Logging { * @return Unit */ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { - if (data.schema.fieldNames.contains("label")) { logWarning("data containing 'label' column, so change its name to avoid conflict") rFormula.setLabelCol("label_output") diff --git a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala index 9fa82343559c..341de20deda4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala @@ -18,29 +18,28 @@ package org.apache.spark.ml.r import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.feature.{RFormula, RFormulaModel} import org.apache.spark.mllib.util.MLlibTestSparkContext - class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("avoid column name conflicting") { val rFormula = new RFormula().setFormula("label ~ features") - val data = spark.read.format("libsvm") - .load("../../data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("../../data/mllib/sample_libsvm_data.txt") // if not checking column name, then IllegalArgumentException intercept[IllegalArgumentException] { rFormula.fit(data) } - // after checking + // after checking, model build is ok RWrapperUtils.checkDataColumns(rFormula, data) assert(rFormula.getLabelCol == "label_output") assert(rFormula.getFeaturesCol == "features_output") val model = rFormula.fit(data) + assert(model.isInstanceOf[RFormulaModel]) assert(model.getLabelCol == "label_output") assert(model.getFeaturesCol == "features_output") From c8e30e9452031908fc829e527ab82a8e93598302 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 9 Jun 2016 17:45:53 -0700 Subject: [PATCH 07/13] [SPARK-15509] fix path --- .../test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala index 341de20deda4..6d601bcd5a1a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala @@ -25,7 +25,7 @@ class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("avoid column name conflicting") { val rFormula = new RFormula().setFormula("label ~ features") - val data = spark.read.format("libsvm").load("../../data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // if not checking column name, then IllegalArgumentException intercept[IllegalArgumentException] { From 43b2f8c5fb9e0d74579b948b1d52cad4faa76b66 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 9 Jun 2016 17:48:36 -0700 Subject: [PATCH 08/13] [SPARK-15509] fix path --- .../test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala index 6d601bcd5a1a..782ad26ed300 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala @@ -25,7 +25,7 @@ class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("avoid column name conflicting") { val rFormula = new RFormula().setFormula("label ~ features") - val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("../data/mllib/sample_libsvm_data.txt") // if not checking column name, then IllegalArgumentException intercept[IllegalArgumentException] { From 1bc150f8af93f0e5d35e40fd39e33176c974d8cf Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Sun, 28 Aug 2016 17:41:06 -0700 Subject: [PATCH 09/13] [SPARK-15509] scan through all r wrappers and add checking for formular --- .../org/apache/spark/ml/r/GaussianMixtureWrapper.scala | 5 +++-- .../apache/spark/ml/r/IsotonicRegressionWrapper.scala | 5 +++-- .../scala/org/apache/spark/ml/r/KMeansWrapper.scala | 1 + .../scala/org/apache/spark/ml/r/RWrapperUtils.scala | 10 +++++----- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala index 1e8b3bbab665..b654233a8936 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala @@ -68,10 +68,11 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp maxIter: Int, tol: Double): GaussianMixtureWrapper = { - val rFormulaModel = new RFormula() + val rFormula = new RFormula() .setFormula(formula) .setFeaturesCol("features") - .fit(data) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) // get feature names from output schema val schema = rFormulaModel.transform(data).schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala index 1ea80cb46ab7..6643a2c4dc5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala @@ -57,10 +57,11 @@ private[r] object IsotonicRegressionWrapper featureIndex: Int, weightCol: String): IsotonicRegressionWrapper = { - val rFormulaModel = new RFormula() + val rFormula = new RFormula() .setFormula(formula) .setFeaturesCol("features") - .fit(data) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) // get feature names from output schema val schema = rFormulaModel.transform(data).schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index 848e3f65a776..8616a8c01e5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -72,6 +72,7 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { val rFormula = new RFormula() .setFormula(formula) + .setFeaturesCol("features") RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 6059fb0faf2e..64ef11eeccb6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -30,18 +30,18 @@ object RWrapperUtils extends Logging { * Here is to change the column name to avoid "column already exists" error. * * @param rFormula RFormula instance - * @param data input dataset + * @param data Input dataset * @return Unit */ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { - if (data.schema.fieldNames.contains("label")) { + if (data.schema.fieldNames.contains(rFormula.getLabelCol)) { logWarning("data containing 'label' column, so change its name to avoid conflict") - rFormula.setLabelCol("label_output") + rFormula.setLabelCol(rFormula.getLabelCol + "_output") } - if (data.schema.fieldNames.contains("features")) { + if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { logWarning("data containing 'features' column, so change its name to avoid conflict") - rFormula.setFeaturesCol("features_output") + rFormula.setFeaturesCol(rFormula.getFeaturesCol + "_output") } } } From caa41833c5d6dcfac046e5a804b1498258121d94 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Wed, 31 Aug 2016 23:21:06 -0700 Subject: [PATCH 10/13] [SPARK-15509] keep searching in a sequential way until an unused column name has been found --- .../org/apache/spark/ml/r/RWrapperUtils.scala | 32 ++++++++++++++++--- .../spark/ml/r/RWrapperUtilsSuite.scala | 11 ++++++- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 64ef11eeccb6..a628c66c44c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -35,13 +35,37 @@ object RWrapperUtils extends Logging { */ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { if (data.schema.fieldNames.contains(rFormula.getLabelCol)) { - logWarning("data containing 'label' column, so change its name to avoid conflict") - rFormula.setLabelCol(rFormula.getLabelCol + "_output") + val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames) + logWarning( + s"data containing ${rFormula.getLabelCol} column, changing its name to $newLabelName") + rFormula.setLabelCol(newLabelName) } if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { - logWarning("data containing 'features' column, so change its name to avoid conflict") - rFormula.setFeaturesCol(rFormula.getFeaturesCol + "_output") + val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames) + logWarning( + s"data containing ${rFormula.getFeaturesCol} column, changing its name to $newFeaturesName") + rFormula.setFeaturesCol(newFeaturesName) } } + + /** + * Convert conflicting name to be an unique name. + * Appending a sequence number, like originalName_output1 + * and incrementing until it is not already there + * + * @param originalName Original name + * @param fieldNames Array of field names in existing schema + * @return String + */ + def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = { + var counter = 1 + var newName = originalName + "_output" + + while (fieldNames.contains(newName)) { + newName = originalName + "_output" + counter + counter += 1 + } + newName + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala index 782ad26ed300..ddc24cb3a648 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { - test("avoid column name conflicting") { + test("avoid libsvm data column name conflicting") { val rFormula = new RFormula().setFormula("label ~ features") val data = spark.read.format("libsvm").load("../data/mllib/sample_libsvm_data.txt") @@ -44,4 +44,13 @@ class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.getLabelCol == "label_output") assert(model.getFeaturesCol == "features_output") } + + test("generate unique name by appending a sequence number") { + val originalName = "label" + val fieldNames = Array("label_output", "label_output1", "label_output2") + val newName = RWrapperUtils.convertToUniqueName(originalName, fieldNames) + + assert(newName === "label_output3") + } + } From 1701252cf86a615874126215d956fd32d8eab0d0 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Wed, 31 Aug 2016 23:28:50 -0700 Subject: [PATCH 11/13] [SPARK-15509] fix style --- .../org/apache/spark/ml/r/RWrapperUtils.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index a628c66c44c5..bca9984cee16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -50,14 +50,14 @@ object RWrapperUtils extends Logging { } /** - * Convert conflicting name to be an unique name. - * Appending a sequence number, like originalName_output1 - * and incrementing until it is not already there - * - * @param originalName Original name - * @param fieldNames Array of field names in existing schema - * @return String - */ + * Convert conflicting name to be an unique name. + * Appending a sequence number, like originalName_output1 + * and incrementing until it is not already there + * + * @param originalName Original name + * @param fieldNames Array of field names in existing schema + * @return String + */ def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = { var counter = 1 var newName = originalName + "_output" From d9e3be5f45db5731e74613507047380d3f6f40f3 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 1 Sep 2016 16:28:19 -0700 Subject: [PATCH 12/13] [SPARK-15509] modify logging msg --- .../src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index bca9984cee16..f7dccc2ee12a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -37,14 +37,14 @@ object RWrapperUtils extends Logging { if (data.schema.fieldNames.contains(rFormula.getLabelCol)) { val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames) logWarning( - s"data containing ${rFormula.getLabelCol} column, changing its name to $newLabelName") + s"data containing ${rFormula.getLabelCol} column, using new name $newLabelName instead") rFormula.setLabelCol(newLabelName) } if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames) logWarning( - s"data containing ${rFormula.getFeaturesCol} column, changing its name to $newFeaturesName") + s"data containing ${rFormula.getFeaturesCol} column, using new name $newFeaturesName") rFormula.setFeaturesCol(newFeaturesName) } } From 8bb370ec0408ae10fe0c1c359b0f1b68066cbf87 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Thu, 1 Sep 2016 16:33:02 -0700 Subject: [PATCH 13/13] [SPARK-15509] add 'instead'... --- .../src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index f7dccc2ee12a..6a435992e3b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -43,8 +43,8 @@ object RWrapperUtils extends Logging { if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames) - logWarning( - s"data containing ${rFormula.getFeaturesCol} column, using new name $newFeaturesName") + logWarning(s"data containing ${rFormula.getFeaturesCol} column, " + + s"using new name $newFeaturesName instead") rFormula.setFeaturesCol(newFeaturesName) } }