Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def drop(): DataFrame = drop("any", df.columns)
def drop(): DataFrame = drop0("any", outputAttributes)

/**
* Returns a new `DataFrame` that drops rows containing null or NaN values.
Expand All @@ -50,7 +50,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def drop(how: String): DataFrame = drop(how, df.columns)
def drop(how: String): DataFrame = drop0(how, outputAttributes)

/**
* Returns a new `DataFrame` that drops rows containing any null or NaN values
Expand Down Expand Up @@ -89,11 +89,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def drop(how: String, cols: Seq[String]): DataFrame = {
how.toLowerCase(Locale.ROOT) match {
case "any" => drop(cols.size, cols)
case "all" => drop(1, cols)
case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
}
drop0(how, toAttributes(cols))
}

/**
Expand All @@ -119,10 +115,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
df.filter(Column(predicate))
drop0(minNonNulls, toAttributes(cols))
}

/**
Expand Down Expand Up @@ -487,6 +480,23 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
df.queryExecution.analyzed.output
}

private def drop0(how: String, cols: Seq[Attribute]): DataFrame = {
how.toLowerCase(Locale.ROOT) match {
case "any" => drop0(cols.size, cols)
case "all" => drop0(1, cols)
case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
}
}

private def drop0(minNonNulls: Int, cols: Seq[Attribute]): DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
val predicate = AtLeastNNonNulls(
minNonNulls,
outputAttributes.filter{ col => cols.exists(_.semanticEquals(col)) })
df.filter(Column(predicate))
}

/**
* Returns a new `DataFrame` that replaces null or NaN values in the specified
* columns. If a specified column is not a numeric, string or boolean column,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
assert(message.contains("Reference 'f2' is ambiguous"))
}

test("fill with col(*)") {
test("fill/drop with col(*)") {
val df = createDF()
// If columns are specified with "*", they are ignored.
checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
checkAnswer(df.na.drop("any", Seq("*")), df.collect())
}

test("fill with nested columns") {
test("fill/drop with nested columns") {
val schema = new StructType()
.add("c1", new StructType()
.add("c1-1", StringType)
Expand All @@ -290,8 +291,9 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select("c1.c1-1"),
Row(null) :: Row("b1") :: Row(null) :: Nil)

// Nested columns are ignored for fill().
// Nested columns are ignored for fill() and drop().
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
checkAnswer(df.na.drop("any", Seq("c1.c1-1")), data)
}

test("replace") {
Expand Down Expand Up @@ -385,4 +387,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
df.na.fill("hello"),
Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil)
}

test("SPARK-30065: duplicate names are allowed for drop() if column names are not specified.") {
val left = Seq(("1", null), ("3", "4"), ("5", "6")).toDF("col1", "col2")
val right = Seq(("1", "2"), ("3", null), ("5", "6")).toDF("col1", "col2")
val df = left.join(right, Seq("col1"))

// If column names are specified, the following fails due to ambiguity.
val exception = intercept[AnalysisException] {
df.na.drop("any", Seq("col2"))
}
assert(exception.getMessage.contains("Reference 'col2' is ambiguous"))

// If column names are not specified, drop() is applied to all the eligible rows.
checkAnswer(
df.na.drop("any"),
Row("5", "6", "6") :: Nil)
}
}