From ef9e38db13376e6da23bfb3899c8a33c360e52bb Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 18 Nov 2019 22:51:20 -0800 Subject: [PATCH 1/8] initial checkin --- .../spark/sql/DataFrameNaFunctions.scala | 86 +++++++++++++------ .../spark/sql/DataFrameNaFunctionsSuite.scala | 17 ++++ 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 6dd21f114c90..1b2058a45ca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -130,20 +130,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.2.0 */ - def fill(value: Long): DataFrame = fill(value, df.columns) + def fill(value: Long): DataFrame = fillValue(value) /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): DataFrame = fill(value, df.columns) + def fill(value: Double): DataFrame = fillValue(value) /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): DataFrame = fill(value, df.columns) + def fill(value: String): DataFrame = fillValue(value) /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. @@ -199,7 +199,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.3.0 */ - def fill(value: Boolean): DataFrame = fill(value, df.columns) + def fill(value: Boolean): DataFrame = fillValue(value) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified @@ -433,15 +433,28 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. + * It selects a column based on its name. */ private def fillCol[T](col: StructField, replacement: T): Column = { val quotedColName = "`" + col.name + "`" - val colValue = col.dataType match { + fillCol(col.dataType, col.name, df.col(quotedColName), replacement) + } + + /** + * Returns a [[Column]] expression that replaces null value in `expr` with `replacement`. + * It uses the given `expr` as a column. + */ + private def fillCol[T]( + dataType: DataType, + colName: String, + expr: Column, + replacement: T): Column = { + val colValue = dataType match { case DoubleType | FloatType => - nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types - case _ => df.col(quotedColName) + nanvl(expr, lit(null)) // nanvl only supports these types + case _ => expr } - coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) + coalesce(colValue, lit(replacement).cast(dataType)).as(colName) } /** @@ -468,37 +481,62 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { s"Unsupported value type ${v.getClass.getName} ($v).") } + private def getTargetType[T](value: T): AbstractDataType = { + value match { + case _: Double | _: Long => NumericType + case _: String => StringType + case _: Boolean => BooleanType + case _ => + throw new IllegalArgumentException( + s"Unsupported value type for fill(): ${value.getClass.getName} ($value).") + } + } + + private def typeMatches(targetType: AbstractDataType, sourceType: DataType): Boolean = { + (targetType, sourceType) match { + case (NumericType, dt) => dt.isInstanceOf[NumericType] + case (StringType, dt) => dt == StringType + case (BooleanType, dt) => dt == BooleanType + case _ => + throw new IllegalArgumentException(s"$targetType is not matched for fill().") + } + } + /** * Returns a new `DataFrame` that replaces null or NaN values in specified * numeric, string columns. If a specified column is not a numeric, string * or boolean column it is ignored. */ private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { - // the fill[T] which T is Long/Double, + // the fill[T] which T is Long/Double, // should apply on all the NumericType Column, for example: // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b") // input.na.fill(3.1) // the result is (3,164.3), not (null, 164.3) - val targetType = value match { - case _: Double | _: Long => NumericType - case _: String => StringType - case _: Boolean => BooleanType - case _ => throw new IllegalArgumentException( - s"Unsupported value type ${value.getClass.getName} ($value).") - } + val targetType = getTargetType(value) val columnEquals = df.sparkSession.sessionState.analyzer.resolver val filledColumns = df.schema.fields.filter { f => - val typeMatches = (targetType, f.dataType) match { - case (NumericType, dt) => dt.isInstanceOf[NumericType] - case (StringType, dt) => dt == StringType - case (BooleanType, dt) => dt == BooleanType - case _ => - throw new IllegalArgumentException(s"$targetType is not matched at fillValue") - } // Only fill if the column is part of the cols list. - typeMatches && cols.exists(col => columnEquals(f.name, col)) + typeMatches(targetType, f.dataType) && cols.exists(col => columnEquals(f.name, col)) } + df.withColumns(filledColumns.map(_.name), filledColumns.map(fillCol[T](_, value))) } + + /** + * Returns a new `DataFrame` that replaces null or NaN values for all the supported columns. + * Note that this handles the `DataFrame` with duplicate column names (e.g., self-joined). + */ + private def fillValue[T](value: T): DataFrame = { + val targetType = getTargetType(value) + val projections = df.queryExecution.analyzed.output.map { attr => + if (typeMatches(targetType, attr.dataType)) { + fillCol(attr.dataType, attr.name, Column(attr), value) + } else { + Column(attr) + } + } + df.select(projections : _*) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 75642a0bd932..8aa0558e5d31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -340,4 +340,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { )).na.drop("name" :: Nil).select("name"), Row("Alice") :: Row("David") :: Nil) } + + test("SPARK-29890: duplicate names are allowed for fill() if column names are not specified.") { + val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2") + val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2") + val df = left.join(right, Seq("col1")) + + // If column names are specified, the following fails due to unambiguity. + val exception = intercept[AnalysisException] { + df.na.fill("hello", Seq("col2")) + } + assert(exception.getMessage.contains("Reference 'col2' is ambiguous")) + + // If column names are not specified, fill() is applied to all the eligible columns. + checkAnswer( + df.na.fill("hello"), + Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil) + } } From 623099e4c691a2f3622f770352becccfbbdedccc Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 18 Nov 2019 23:08:46 -0800 Subject: [PATCH 2/8] update comments --- .../scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 8aa0558e5d31..b97c22d8ae32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -346,7 +346,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2") val df = left.join(right, Seq("col1")) - // If column names are specified, the following fails due to unambiguity. + // If column names are specified, the following fails due to ambiguity. val exception = intercept[AnalysisException] { df.na.fill("hello", Seq("col2")) } From 5642d9efacc61f2c742b863e3de2ede774591295 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Tue, 19 Nov 2019 12:24:37 -0800 Subject: [PATCH 3/8] Address PR comments --- .../spark/sql/DataFrameNaFunctions.scala | 89 +++++++------------ 1 file changed, 34 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 1b2058a45ca4..9c084d6e4ca3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -130,20 +130,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.2.0 */ - def fill(value: Long): DataFrame = fillValue(value) + def fill(value: Long): DataFrame = fill(value, Seq.empty) /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): DataFrame = fillValue(value) + def fill(value: Double): DataFrame = fill(value, Seq.empty) /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): DataFrame = fillValue(value) + def fill(value: String): DataFrame = fill(value, Seq.empty) /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. @@ -167,7 +167,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.2.0 */ - def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols) + def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified @@ -175,7 +175,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols) + def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) /** @@ -192,14 +192,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols) + def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) /** * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. * * @since 2.3.0 */ - def fill(value: Boolean): DataFrame = fillValue(value) + def fill(value: Boolean): DataFrame = fill(value, Seq.empty) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified @@ -207,7 +207,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.3.0 */ - def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols) + def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) /** * Returns a new `DataFrame` that replaces null values in specified boolean columns. @@ -444,17 +444,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `expr` with `replacement`. * It uses the given `expr` as a column. */ - private def fillCol[T]( - dataType: DataType, - colName: String, - expr: Column, - replacement: T): Column = { + private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = { val colValue = dataType match { case DoubleType | FloatType => nanvl(expr, lit(null)) // nanvl only supports these types case _ => expr } - coalesce(colValue, lit(replacement).cast(dataType)).as(colName) + coalesce(colValue, lit(replacement).cast(dataType)).as(name) } /** @@ -481,60 +477,43 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { s"Unsupported value type ${v.getClass.getName} ($v).") } - private def getTargetType[T](value: T): AbstractDataType = { - value match { - case _: Double | _: Long => NumericType - case _: String => StringType - case _: Boolean => BooleanType - case _ => - throw new IllegalArgumentException( - s"Unsupported value type for fill(): ${value.getClass.getName} ($value).") - } - } - - private def typeMatches(targetType: AbstractDataType, sourceType: DataType): Boolean = { - (targetType, sourceType) match { - case (NumericType, dt) => dt.isInstanceOf[NumericType] - case (StringType, dt) => dt == StringType - case (BooleanType, dt) => dt == BooleanType - case _ => - throw new IllegalArgumentException(s"$targetType is not matched for fill().") - } + private def toAttributes(cols: Seq[String]): Seq[Attribute] = { + cols.map(df.col(_).named.toAttribute) } /** * Returns a new `DataFrame` that replaces null or NaN values in specified * numeric, string columns. If a specified column is not a numeric, string - * or boolean column it is ignored. + * or boolean column it is ignored. If `cols` is empty, fill() is applied to + * all the eligible columns. */ - private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { - // the fill[T] which T is Long/Double, + private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = { + // the fill[T] which T is Long/Double, // should apply on all the NumericType Column, for example: // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b") // input.na.fill(3.1) // the result is (3,164.3), not (null, 164.3) - val targetType = getTargetType(value) - - val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val filledColumns = df.schema.fields.filter { f => - // Only fill if the column is part of the cols list. - typeMatches(targetType, f.dataType) && cols.exists(col => columnEquals(f.name, col)) + val targetType = value match { + case _: Double | _: Long => NumericType + case _: String => StringType + case _: Boolean => BooleanType + case _ => throw new IllegalArgumentException( + s"Unsupported value type ${value.getClass.getName} ($value).") } - df.withColumns(filledColumns.map(_.name), filledColumns.map(fillCol[T](_, value))) - } - - /** - * Returns a new `DataFrame` that replaces null or NaN values for all the supported columns. - * Note that this handles the `DataFrame` with duplicate column names (e.g., self-joined). - */ - private def fillValue[T](value: T): DataFrame = { - val targetType = getTargetType(value) - val projections = df.queryExecution.analyzed.output.map { attr => - if (typeMatches(targetType, attr.dataType)) { - fillCol(attr.dataType, attr.name, Column(attr), value) + val projections = df.queryExecution.analyzed.output.map { col => + val typeMatches = (targetType, col.dataType) match { + case (NumericType, dt) => dt.isInstanceOf[NumericType] + case (StringType, dt) => dt == StringType + case (BooleanType, dt) => dt == BooleanType + case _ => + throw new IllegalArgumentException(s"$targetType is not matched at fillValue") + } + // Only fill if the column is part of the cols list. + if (typeMatches && (cols.isEmpty || cols.exists(_.semanticEquals(col)))) { + fillCol(col.dataType, col.name, Column(col), value) } else { - Column(attr) + Column(col) } } df.select(projections : _*) From b0f5f5c2667a897d6726b7d69f54f2da5fc88127 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Tue, 19 Nov 2019 12:40:44 -0800 Subject: [PATCH 4/8] Refine --- .../org/apache/spark/sql/DataFrameNaFunctions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 9c084d6e4ca3..edb77e0dc22a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -130,20 +130,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.2.0 */ - def fill(value: Long): DataFrame = fill(value, Seq.empty) + def fill(value: Long): DataFrame = fillValue(value, df.queryExecution.analyzed.output) /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): DataFrame = fill(value, Seq.empty) + def fill(value: Double): DataFrame = fillValue(value, df.queryExecution.analyzed.output) /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): DataFrame = fill(value, Seq.empty) + def fill(value: String): DataFrame = fillValue(value, df.queryExecution.analyzed.output) /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. @@ -199,7 +199,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.3.0 */ - def fill(value: Boolean): DataFrame = fill(value, Seq.empty) + def fill(value: Boolean): DataFrame = fillValue(value, df.queryExecution.analyzed.output) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified @@ -510,7 +510,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { throw new IllegalArgumentException(s"$targetType is not matched at fillValue") } // Only fill if the column is part of the cols list. - if (typeMatches && (cols.isEmpty || cols.exists(_.semanticEquals(col)))) { + if (typeMatches && cols.exists(_.semanticEquals(col))) { fillCol(col.dataType, col.name, Column(col), value) } else { Column(col) From 204bb10b713007f2e8c88333b744776176d29657 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Tue, 19 Nov 2019 12:45:13 -0800 Subject: [PATCH 5/8] refactor code --- .../apache/spark/sql/DataFrameNaFunctions.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index edb77e0dc22a..2757eaa35c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -130,20 +130,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.2.0 */ - def fill(value: Long): DataFrame = fillValue(value, df.queryExecution.analyzed.output) + def fill(value: Long): DataFrame = fillValue(value, outputAttributes) /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): DataFrame = fillValue(value, df.queryExecution.analyzed.output) + def fill(value: Double): DataFrame = fillValue(value, outputAttributes) /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): DataFrame = fillValue(value, df.queryExecution.analyzed.output) + def fill(value: String): DataFrame = fillValue(value, outputAttributes) /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. @@ -199,7 +199,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.3.0 */ - def fill(value: Boolean): DataFrame = fillValue(value, df.queryExecution.analyzed.output) + def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified @@ -349,7 +349,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); - * }}} + * }}}outputAttributes * * @param cols list of columns to apply the value replacement. If `col` is "*", * replacement is applied on all string, numeric or boolean columns. @@ -481,6 +481,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { cols.map(df.col(_).named.toAttribute) } + private def outputAttributes: Seq[Attribute] = { + df.queryExecution.analyzed.output + } + /** * Returns a new `DataFrame` that replaces null or NaN values in specified * numeric, string columns. If a specified column is not a numeric, string @@ -501,7 +505,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { s"Unsupported value type ${value.getClass.getName} ($value).") } - val projections = df.queryExecution.analyzed.output.map { col => + val projections = outputAttributes.map { col => val typeMatches = (targetType, col.dataType) match { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType From 3efcf1391e3cf295d1888818847e1e0dfdc4fe44 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 20 Nov 2019 11:13:07 -0800 Subject: [PATCH 6/8] address PR comments --- .../apache/spark/sql/DataFrameNaFunctions.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 2757eaa35c9f..0759593d3461 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -349,7 +349,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); - * }}}outputAttributes + * }}} * * @param cols list of columns to apply the value replacement. If `col` is "*", * replacement is applied on all string, numeric or boolean columns. @@ -478,7 +478,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } private def toAttributes(cols: Seq[String]): Seq[Attribute] = { - cols.map(df.col(_).named.toAttribute) + def resolve(colName: String) : Attribute = { + df.col(colName).named.toAttribute match { + case a: Attribute => a + case _ => throw new IllegalArgumentException(s"'$colName' is not a top level column.") + } + } + cols.map(resolve) } private def outputAttributes: Seq[Attribute] = { @@ -486,10 +492,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } /** - * Returns a new `DataFrame` that replaces null or NaN values in specified - * numeric, string columns. If a specified column is not a numeric, string - * or boolean column it is ignored. If `cols` is empty, fill() is applied to - * all the eligible columns. + * Returns a new `DataFrame` that replaces null or NaN values in the specified + * columns. If a specified column is not a numeric, string or boolean column, + * it is ignored. */ private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = { // the fill[T] which T is Long/Double, From 702897c9792f100b08d7c9b4b7de7c455ad278ee Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 21 Nov 2019 22:26:39 -0800 Subject: [PATCH 7/8] Address comments --- .../spark/sql/DataFrameNaFunctions.scala | 6 ++--- .../spark/sql/DataFrameNaFunctionsSuite.scala | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 0759593d3461..9cb4221dd823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -478,13 +478,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } private def toAttributes(cols: Seq[String]): Seq[Attribute] = { - def resolve(colName: String) : Attribute = { - df.col(colName).named.toAttribute match { + cols.flatMap { colName => + df.col(colName).expr.collect { case a: Attribute => a - case _ => throw new IllegalArgumentException(s"'$colName' is not a top level column.") } } - cols.map(resolve) } private def outputAttributes: Seq[Attribute] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index b97c22d8ae32..b619b504c922 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{StringType, StructType} class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -266,6 +267,32 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { assert(message.contains("Reference 'f2' is ambiguous")) } + test("fill with col(*)") { + val df = createDF() + // If columns are specified with "*", they are ignored. + checkAnswer(df.na.fill("new name", Seq("*")), df.collect()) + } + + test("fill with nested columns") { + val schema = new StructType() + .add("c1", new StructType() + .add("c1-1", StringType) + .add("c1-2", StringType)) + + val data = Seq( + Row(Row(null, "a2")), + Row(Row("b1", "b2"))) + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(data), schema) + + checkAnswer(df.select("c1.c1-1"), + Row(null) :: Row("b1") :: Nil) + + // Nested columns are ignored for fill(). + checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data) + } + test("replace") { val input = createDF() From 033beb57c3c855c417b596d0bc450417eadb94da Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Sun, 24 Nov 2019 22:19:35 -0800 Subject: [PATCH 8/8] Address PR comments --- .../scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 6 ++---- .../org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | 5 +++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 9cb4221dd823..07b0a54ba077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -478,10 +478,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } private def toAttributes(cols: Seq[String]): Seq[Attribute] = { - cols.flatMap { colName => - df.col(colName).expr.collect { - case a: Attribute => a - } + cols.map(name => df.col(name).expr).collect { + case a: Attribute => a } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index b619b504c922..1afe733b855b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -281,13 +281,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { val data = Seq( Row(Row(null, "a2")), - Row(Row("b1", "b2"))) + Row(Row("b1", "b2")), + Row(null)) val df = spark.createDataFrame( spark.sparkContext.parallelize(data), schema) checkAnswer(df.select("c1.c1-1"), - Row(null) :: Row("b1") :: Nil) + Row(null) :: Row("b1") :: Row(null) :: Nil) // Nested columns are ignored for fill(). checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)