diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index e70563555d3c9..fcb440cf34ec4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -41,7 +41,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def drop(): DataFrame = drop("any", df.columns) + def drop(): DataFrame = drop0("any", outputAttributes) /** * Returns a new `DataFrame` that drops rows containing null or NaN values. @@ -51,7 +51,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def drop(how: String): DataFrame = drop(how, df.columns) + def drop(how: String): DataFrame = drop0(how, outputAttributes) /** * Returns a new `DataFrame` that drops rows containing any null or NaN values @@ -90,11 +90,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def drop(how: String, cols: Seq[String]): DataFrame = { - how.toLowerCase(Locale.ROOT) match { - case "any" => drop(cols.size, cols) - case "all" => drop(1, cols) - case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") - } + drop0(how, toAttributes(cols)) } /** @@ -120,10 +116,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { - // Filtering condition: - // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) - df.filter(Column(predicate)) + drop0(minNonNulls, toAttributes(cols)) } /** @@ -488,6 +481,23 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.queryExecution.analyzed.output } + private def drop0(how: String, cols: Seq[Attribute]): DataFrame = { + how.toLowerCase(Locale.ROOT) match { + case "any" => drop0(cols.size, cols) + case "all" => drop0(1, cols) + case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") + } + } + + private def drop0(minNonNulls: Int, cols: Seq[Attribute]): DataFrame = { + // Filtering condition: + // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. + val predicate = AtLeastNNonNulls( + minNonNulls, + outputAttributes.filter{ col => cols.exists(_.semanticEquals(col)) }) + df.filter(Column(predicate)) + } + /** * Returns a new `DataFrame` that replaces null or NaN values in the specified * columns. If a specified column is not a numeric, string or boolean column, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index c1abd1edf98df..1587c99257509 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -240,13 +240,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } } - test("fill with col(*)") { + test("fill/drop with col(*)") { val df = createDF() // If columns are specified with "*", they are ignored. checkAnswer(df.na.fill("new name", Seq("*")), df.collect()) + checkAnswer(df.na.drop("any", Seq("*")), df.collect()) } - test("fill with nested columns") { + test("fill/drop with nested columns") { val schema = new StructType() .add("c1", new StructType() .add("c1-1", StringType) @@ -263,8 +264,9 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select("c1.c1-1"), Row(null) :: Row("b1") :: Row(null) :: Nil) - // Nested columns are ignored for fill(). + // Nested columns are ignored for fill() and drop(). checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data) + checkAnswer(df.na.drop("any", Seq("c1.c1-1")), data) } test("replace") { @@ -394,4 +396,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { df.na.fill("hello"), Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil) } + + test("SPARK-30065: duplicate names are allowed for drop() if column names are not specified.") { + val left = Seq(("1", null), ("3", "4"), ("5", "6")).toDF("col1", "col2") + val right = Seq(("1", "2"), ("3", null), ("5", "6")).toDF("col1", "col2") + val df = left.join(right, Seq("col1")) + + // If column names are specified, the following fails due to ambiguity. + val exception = intercept[AnalysisException] { + df.na.drop("any", Seq("col2")) + } + assert(exception.getMessage.contains("Reference 'col2' is ambiguous")) + + // If column names are not specified, drop() is applied to all the eligible rows. + checkAnswer( + df.na.drop("any"), + Row("5", "6", "6") :: Nil) + } }