-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-29890][SQL] DataFrameNaFunctions.fill should handle duplicate columns #26593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ef9e38d
623099e
5642d9e
b0f5f5c
204bb10
3efcf13
702897c
033beb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -130,20 +130,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { | |
| * | ||
| * @since 2.2.0 | ||
| */ | ||
| def fill(value: Long): DataFrame = fill(value, df.columns) | ||
| def fill(value: Long): DataFrame = fillValue(value, outputAttributes) | ||
|
|
||
| /** | ||
| * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. | ||
| * @since 1.3.1 | ||
| */ | ||
| def fill(value: Double): DataFrame = fill(value, df.columns) | ||
| def fill(value: Double): DataFrame = fillValue(value, outputAttributes) | ||
|
|
||
| /** | ||
| * Returns a new `DataFrame` that replaces null values in string columns with `value`. | ||
| * | ||
| * @since 1.3.1 | ||
| */ | ||
| def fill(value: String): DataFrame = fill(value, df.columns) | ||
| def fill(value: String): DataFrame = fillValue(value, outputAttributes) | ||
|
|
||
| /** | ||
| * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. | ||
|
|
@@ -167,15 +167,15 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { | |
| * | ||
| * @since 2.2.0 | ||
| */ | ||
| def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols) | ||
| def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) | ||
|
|
||
| /** | ||
| * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified | ||
| * numeric columns. If a specified column is not a numeric column, it is ignored. | ||
| * | ||
| * @since 1.3.1 | ||
| */ | ||
| def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols) | ||
| def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) | ||
|
|
||
|
|
||
| /** | ||
|
|
@@ -192,22 +192,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { | |
| * | ||
| * @since 1.3.1 | ||
| */ | ||
| def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols) | ||
| def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) | ||
|
|
||
| /** | ||
| * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. | ||
| * | ||
| * @since 2.3.0 | ||
| */ | ||
| def fill(value: Boolean): DataFrame = fill(value, df.columns) | ||
| def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes) | ||
|
|
||
| /** | ||
| * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified | ||
| * boolean columns. If a specified column is not a boolean column, it is ignored. | ||
| * | ||
| * @since 2.3.0 | ||
| */ | ||
| def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols) | ||
| def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) | ||
|
|
||
| /** | ||
| * Returns a new `DataFrame` that replaces null values in specified boolean columns. | ||
|
|
@@ -433,15 +433,24 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { | |
|
|
||
| /** | ||
| * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. | ||
| * It selects a column based on its name. | ||
| */ | ||
| private def fillCol[T](col: StructField, replacement: T): Column = { | ||
| val quotedColName = "`" + col.name + "`" | ||
| val colValue = col.dataType match { | ||
| fillCol(col.dataType, col.name, df.col(quotedColName), replacement) | ||
| } | ||
|
|
||
| /** | ||
| * Returns a [[Column]] expression that replaces null value in `expr` with `replacement`. | ||
| * It uses the given `expr` as a column. | ||
| */ | ||
| private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = { | ||
| val colValue = dataType match { | ||
| case DoubleType | FloatType => | ||
| nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types | ||
| case _ => df.col(quotedColName) | ||
| nanvl(expr, lit(null)) // nanvl only supports these types | ||
| case _ => expr | ||
| } | ||
| coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) | ||
| coalesce(colValue, lit(replacement).cast(dataType)).as(name) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -468,12 +477,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { | |
| s"Unsupported value type ${v.getClass.getName} ($v).") | ||
| } | ||
|
|
||
| private def toAttributes(cols: Seq[String]): Seq[Attribute] = { | ||
| cols.map(name => df.col(name).expr).collect { | ||
| case a: Attribute => a | ||
| } | ||
| } | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan I noticed that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should. To avoid breaking change, I think we should change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. I will do this in a separate PR which will address the issue with handling duplicate columns in |
||
|
|
||
| private def outputAttributes: Seq[Attribute] = { | ||
| df.queryExecution.analyzed.output | ||
| } | ||
|
|
||
| /** | ||
| * Returns a new `DataFrame` that replaces null or NaN values in specified | ||
| * numeric, string columns. If a specified column is not a numeric, string | ||
| * or boolean column it is ignored. | ||
| * Returns a new `DataFrame` that replaces null or NaN values in the specified | ||
| * columns. If a specified column is not a numeric, string or boolean column, | ||
| * it is ignored. | ||
| */ | ||
| private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { | ||
| private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = { | ||
| // the fill[T] which T is Long/Double, | ||
| // should apply on all the NumericType Column, for example: | ||
| // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b") | ||
|
|
@@ -487,18 +506,21 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { | |
| s"Unsupported value type ${value.getClass.getName} ($value).") | ||
| } | ||
|
|
||
| val columnEquals = df.sparkSession.sessionState.analyzer.resolver | ||
| val filledColumns = df.schema.fields.filter { f => | ||
| val typeMatches = (targetType, f.dataType) match { | ||
| val projections = outputAttributes.map { col => | ||
| val typeMatches = (targetType, col.dataType) match { | ||
| case (NumericType, dt) => dt.isInstanceOf[NumericType] | ||
| case (StringType, dt) => dt == StringType | ||
| case (BooleanType, dt) => dt == BooleanType | ||
| case _ => | ||
| throw new IllegalArgumentException(s"$targetType is not matched at fillValue") | ||
| } | ||
| // Only fill if the column is part of the cols list. | ||
| typeMatches && cols.exists(col => columnEquals(f.name, col)) | ||
| if (typeMatches && cols.exists(_.semanticEquals(col))) { | ||
| fillCol(col.dataType, col.name, Column(col), value) | ||
| } else { | ||
| Column(col) | ||
| } | ||
| } | ||
| df.withColumns(filledColumns.map(_.name), filledColumns.map(fillCol[T](_, value))) | ||
| df.select(projections : _*) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ | |
|
|
||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.test.SharedSparkSession | ||
| import org.apache.spark.sql.types.{StringType, StructType} | ||
|
|
||
| class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { | ||
| import testImplicits._ | ||
|
|
@@ -266,6 +267,33 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { | |
| assert(message.contains("Reference 'f2' is ambiguous")) | ||
| } | ||
|
|
||
| test("fill with col(*)") { | ||
| val df = createDF() | ||
| // If columns are specified with "*", they are ignored. | ||
| checkAnswer(df.na.fill("new name", Seq("*")), df.collect()) | ||
| } | ||
|
|
||
| test("fill with nested columns") { | ||
| val schema = new StructType() | ||
| .add("c1", new StructType() | ||
| .add("c1-1", StringType) | ||
| .add("c1-2", StringType)) | ||
|
|
||
| val data = Seq( | ||
| Row(Row(null, "a2")), | ||
| Row(Row("b1", "b2")), | ||
| Row(null)) | ||
|
|
||
| val df = spark.createDataFrame( | ||
| spark.sparkContext.parallelize(data), schema) | ||
|
|
||
| checkAnswer(df.select("c1.c1-1"), | ||
| Row(null) :: Row("b1") :: Row(null) :: Nil) | ||
|
|
||
| // Nested columns are ignored for fill(). | ||
| checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can expose the bug if the data is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This example works fine, and you also need to update the previous line to: (This is just to illustrate the nested type, but I can remove it if you think it's confusing.) The reason the nested types are ignored is the following check: case (NumericType, dt) => dt.isInstanceOf[NumericType]
case (StringType, dt) => dt == StringType
case (BooleanType, dt) => dt == BooleanTypeThe datatype for the nested column that is resolved to |
||
| } | ||
|
|
||
| test("replace") { | ||
| val input = createDF() | ||
|
|
||
|
|
@@ -340,4 +368,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { | |
| )).na.drop("name" :: Nil).select("name"), | ||
| Row("Alice") :: Row("David") :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-29890: duplicate names are allowed for fill() if column names are not specified.") { | ||
| val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2") | ||
| val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2") | ||
| val df = left.join(right, Seq("col1")) | ||
|
|
||
| // If column names are specified, the following fails due to ambiguity. | ||
| val exception = intercept[AnalysisException] { | ||
| df.na.fill("hello", Seq("col2")) | ||
| } | ||
| assert(exception.getMessage.contains("Reference 'col2' is ambiguous")) | ||
|
|
||
| // If column names are not specified, fill() is applied to all the eligible columns. | ||
| checkAnswer( | ||
| df.na.fill("hello"), | ||
| Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.