From 4b477093737e9d9fae16c82836e421b5e0e7c63e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Sep 2017 03:49:16 +0000 Subject: [PATCH 1/5] Do withColumn on all input columns at once. --- .../org/apache/spark/ml/feature/Imputer.scala | 10 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 49 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 26 ++++++++++ 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 1f36eced3d08..4663f16b5f5d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -223,20 +223,18 @@ class ImputerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - var outputDF = dataset val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq - $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType val ic = col(inputCol) - outputDF = outputDF.withColumn(outputCol, - when(ic.isNull, surrogate) + when(ic.isNull, surrogate) .when(ic === $(missingValue), surrogate) .otherwise(ic) - .cast(inputType)) + .cast(inputType) } - outputDF.toDF() + dataset.withColumns($(outputCols), newCols).toDF() } override def transformSchema(schema: StructType): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ab0c4126bcbd..0fc6e2a56e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2101,6 +2101,55 @@ class Dataset[T] private[sql]( } } + /** + * Returns a new Dataset by adding columns or replacing the existing columns that has + * the same names. + */ + private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = { + require(colNames.size == cols.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of columns: ${cols.size}") + require(colNames.distinct.size == colNames.size, + s"It is disallowed to use duplicate column names: $colNames") + + val resolver = sparkSession.sessionState.analyzer.resolver + val output = queryExecution.analyzed.output + + val columnMap = colNames.zip(cols).toMap + + val replacedAndExistingColumns = output.map { field => + val dupColumn = columnMap.find { case (colName, col) => + resolver(field.name, colName) + } + if (dupColumn.isDefined) { + val colName = dupColumn.get._1 + val col = dupColumn.get._2 + col.as(colName) + } else { + Column(field) + } + } + + val newColumns = columnMap.filter { case (colName, col) => + !output.exists(f => resolver(f.name, colName)) + }.map { case (colName, col) => col.as(colName) } + + select(replacedAndExistingColumns ++ newColumns : _*) + } + + /** + * Returns a new Dataset by adding columns with metadata. + */ + private[spark] def withColumns( + colNames: Seq[String], + cols: Seq[Column], + metadata: Seq[Metadata]): DataFrame = { + val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) => + col.as(colName, metadata) + } + withColumns(colNames, newCols) + } + /** * Returns a new Dataset by adding a column with metadata. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 13341645e8ff..7f2c90ea3606 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -641,6 +641,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } + test("withColumns") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2")) + + val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns(Seq("newCol1", "newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err.getMessage.contains("It is disallowed to use duplicate column names")) + } + test("replace column using withColumn") { val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) @@ -649,6 +666,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(2) :: Row(3) :: Row(4) :: Nil) } + test("replace column using withColumns") { + val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y") + val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"), + Seq(df2("x") + 1, df2("y"), df2("y") + 1)) + checkAnswer( + df3.select("x", "newCol1", "newCol2"), + Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil) + } + test("drop column using drop") { val df = testData.drop("key") checkAnswer( From 2086900168bb1595de7e68efdebfecc9fb38314b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 20 Sep 2017 03:16:18 +0000 Subject: [PATCH 2/5] Sync withColumns and related test. --- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7f2c90ea3606..3a161efb5f2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -652,10 +652,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2")) val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns(Seq("newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert( + err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2")) + + val err2 = intercept[IllegalArgumentException] { testData.toDF().withColumns(Seq("newCol1", "newCol1"), Seq(col("key") + 1, col("key") + 2)) } - assert(err.getMessage.contains("It is disallowed to use duplicate column names")) + assert(err2.getMessage.contains("It is disallowed to use duplicate column names")) } test("replace column using withColumn") { From 07dec0f9ad946fa5f9858306af0b56d9343cdee6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 25 Sep 2017 03:13:20 +0000 Subject: [PATCH 3/5] Address case sensitivity in withColumns. --- .../scala/org/apache/spark/sql/Dataset.scala | 6 +++-- .../org/apache/spark/sql/DataFrameSuite.scala | 25 ++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0fc6e2a56e3d..520c8e5f81a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2109,8 +2109,10 @@ class Dataset[T] private[sql]( require(colNames.size == cols.size, s"The size of column names: ${colNames.size} isn't equal to " + s"the size of columns: ${cols.size}") - require(colNames.distinct.size == colNames.size, - s"It is disallowed to use duplicate column names: $colNames") + SchemaUtils.checkColumnNameDuplication( + colNames, + "in given column names", + sparkSession.sessionState.conf.caseSensitiveAnalysis) val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3a161efb5f2c..3c14cb7d6bc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -658,11 +658,30 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert( err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2")) - val err2 = intercept[IllegalArgumentException] { - testData.toDF().withColumns(Seq("newCol1", "newCol1"), + val err2 = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCOL1"), Seq(col("key") + 1, col("key") + 2)) } - assert(err2.getMessage.contains("It is disallowed to use duplicate column names")) + assert(err2.getMessage.contains("Found duplicate column(s)")) + } + + test("withColumns: case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1")) + + val err = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err.getMessage.contains("Found duplicate column(s)")) + } } test("replace column using withColumn") { From 21048a85cc736ff223bc5249489789edf53e1fc3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 30 Sep 2017 06:45:23 +0000 Subject: [PATCH 4/5] Address comment. --- .../scala/org/apache/spark/sql/Dataset.scala | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 520c8e5f81a0..393e205b967b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2083,23 +2083,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): DataFrame = { - val resolver = sparkSession.sessionState.analyzer.resolver - val output = queryExecution.analyzed.output - val shouldReplace = output.exists(f => resolver(f.name, colName)) - if (shouldReplace) { - val columns = output.map { field => - if (resolver(field.name, colName)) { - col.as(colName) - } else { - Column(field) - } - } - select(columns : _*) - } else { - select(Column("*"), col.as(colName)) - } - } + def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) /** * Returns a new Dataset by adding columns or replacing the existing columns that has From 1292ce01ccf24eb40638e748a9cb0b0dbabbb72c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 1 Oct 2017 05:37:24 +0000 Subject: [PATCH 5/5] Address comment. --- .../scala/org/apache/spark/sql/Dataset.scala | 25 +++---------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 393e205b967b..f2a76a506eb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2104,15 +2104,11 @@ class Dataset[T] private[sql]( val columnMap = colNames.zip(cols).toMap val replacedAndExistingColumns = output.map { field => - val dupColumn = columnMap.find { case (colName, col) => + columnMap.find { case (colName, _) => resolver(field.name, colName) - } - if (dupColumn.isDefined) { - val colName = dupColumn.get._1 - val col = dupColumn.get._2 - col.as(colName) - } else { - Column(field) + } match { + case Some((colName: String, col: Column)) => col.as(colName) + case _ => Column(field) } } @@ -2123,19 +2119,6 @@ class Dataset[T] private[sql]( select(replacedAndExistingColumns ++ newColumns : _*) } - /** - * Returns a new Dataset by adding columns with metadata. - */ - private[spark] def withColumns( - colNames: Seq[String], - cols: Seq[Column], - metadata: Seq[Metadata]): DataFrame = { - val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) => - col.as(colName, metadata) - } - withColumns(colNames, newCols) - } - /** * Returns a new Dataset by adding a column with metadata. */