Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1700,8 +1700,11 @@ class Dataset[T] private[sql](
}

/**
* Returns a new [[Dataset]] with a column dropped.
* This is a no-op if schema doesn't contain column name.
* Returns a new [[Dataset]] with a column dropped. This is a no-op if schema doesn't contain
* column name.
*
* This method can only be used to drop top level columns. the colName string is treated
* literally without further interpretation.
*
* @group untypedrel
* @since 2.0.0
Expand All @@ -1714,15 +1717,20 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] with columns dropped.
* This is a no-op if schema doesn't contain column name(s).
*
* This method can only be used to drop top level columns. the colName string is treated literally
* without further interpretation.
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def drop(colNames: String*): DataFrame = {
val resolver = sparkSession.sessionState.analyzer.resolver
val remainingCols =
schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name))
if (remainingCols.size == this.schema.size) {
val allColumns = queryExecution.analyzed.output
val remainingCols = allColumns.filter { attribute =>
colNames.forall(n => !resolver(attribute.name, n))
}.map(attribute => Column(attribute))
if (remainingCols.size == allColumns.size) {
toDF()
} else {
this.select(remainingCols: _*)
Expand Down
21 changes: 21 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(df("id") == person("id"))
}

test("drop top level columns that contains dot") {
val df1 = Seq((1, 2)).toDF("a.b", "a.c")
checkAnswer(df1.drop("a.b"), Row(2))

// Creates data set: {"a.b": 1, "a": {"b": 3}}
val df2 = Seq((1)).toDF("a.b").withColumn("a", struct(lit(3) as "b"))
// Not like select(), drop() parses the column name "a.b" literally without interpreting "."
checkAnswer(df2.drop("a.b").select("a.b"), Row(3))

// "`" is treated as a normal char here with no interpreting, "`a`b" is a valid column name.
assert(df2.drop("`a.b`").columns.size == 2)
}

test("drop(name: String) search and drop all top level columns that matchs the name") {
val df1 = Seq((1, 2)).toDF("a", "b")
val df2 = Seq((3, 4)).toDF("a", "b")
checkAnswer(df1.join(df2), Row(1, 2, 3, 4))
// Finds and drops all columns that match the name (case insensitive).
checkAnswer(df1.join(df2).drop("A"), Row(2, 4))
}

test("withColumnRenamed") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
.withColumnRenamed("value", "valueRenamed")
Expand Down