diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 953db806258a..bbf0ac1dd85e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -89,7 +89,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def drop(how: String, cols: Seq[String]): DataFrame = { - drop0(how, toAttributes(cols)) + drop0(how, cols.map(df.resolve(_))) } /** @@ -115,7 +115,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { - drop0(minNonNulls, toAttributes(cols)) + drop0(minNonNulls, cols.map(df.resolve(_))) } /** @@ -480,7 +480,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.queryExecution.analyzed.output } - private def drop0(how: String, cols: Seq[Attribute]): DataFrame = { + private def drop0(how: String, cols: Seq[NamedExpression]): DataFrame = { how.toLowerCase(Locale.ROOT) match { case "any" => drop0(cols.size, cols) case "all" => drop0(1, cols) @@ -488,12 +488,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } } - private def drop0(minNonNulls: Int, cols: Seq[Attribute]): DataFrame = { + private def drop0(minNonNulls: Int, cols: Seq[NamedExpression]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNulls( - minNonNulls, - outputAttributes.filter{ col => cols.exists(_.semanticEquals(col)) }) + val predicate = AtLeastNNonNulls(minNonNulls, cols) df.filter(Column(predicate)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index fb1ca69b6f73..091877f7cac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -45,6 +45,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { ).toDF("int", "long", "short", "byte", "float", "double") } + def createDFWithNestedColumns: DataFrame = { + val schema = new StructType() + .add("c1", new StructType() + .add("c1-1", StringType) + .add("c1-2", StringType)) + val data = Seq(Row(Row(null, "a2")), Row(Row("b1", "b2")), Row(null)) + spark.createDataFrame( + spark.sparkContext.parallelize(data), schema) + } + test("drop") { val input = createDF() val rows = input.collect() @@ -275,33 +285,35 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { assert(message.contains("Reference 'f2' is ambiguous")) } - test("fill/drop with col(*)") { + test("fill with col(*)") { val df = createDF() // If columns are specified with "*", they are ignored. checkAnswer(df.na.fill("new name", Seq("*")), df.collect()) - checkAnswer(df.na.drop("any", Seq("*")), df.collect()) } - test("fill/drop with nested columns") { - val schema = new StructType() - .add("c1", new StructType() - .add("c1-1", StringType) - .add("c1-2", StringType)) + test("drop with col(*)") { + val df = createDF() + val exception = intercept[AnalysisException] { + df.na.drop("any", Seq("*")) + } + assert(exception.getMessage.contains("Cannot resolve column name \"*\"")) + } - val data = Seq( - Row(Row(null, "a2")), - Row(Row("b1", "b2")), - Row(null)) + test("fill with nested columns") { + val df = createDFWithNestedColumns - val df = spark.createDataFrame( - spark.sparkContext.parallelize(data), schema) + // Nested columns are ignored for fill(). + checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), df) + } - checkAnswer(df.select("c1.c1-1"), - Row(null) :: Row("b1") :: Row(null) :: Nil) + test("drop with nested columns") { + val df = createDFWithNestedColumns - // Nested columns are ignored for fill() and drop(). - checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data) - checkAnswer(df.na.drop("any", Seq("c1.c1-1")), data) + // Rows with the specified nested columns whose null values are dropped. + assert(df.count == 3) + checkAnswer( + df.na.drop("any", Seq("c1.c1-1")), + Seq(Row(Row("b1", "b2")))) } test("replace") {