Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 33 additions & 44 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,24 @@ class DataFrame private[sql](
queryExecution.analyzed
}

/**
* Resolves a column path i.e column name may contain "." or "`". .
*/
protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
}
}

private[sql] def resolveToIndex(colName: String): Option[Int] = {
val resolver = sqlContext.analyzer.resolver
// First remove any user supplied quotes.
val unquotedColName = colName.stripPrefix("`").stripSuffix("`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do this? I think for these methods that require column name, user should just pass in an exact column name string, and we don't need to do any extra parsing here, i.e. no resolver, no strip for "`"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for example, what if a column is named a`a? User should be able to just pass in a`a and we shouldn't strip the "`"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Hi Wenchen,

Can you please go through the following comment.

https://issues.apache.org/jira/browse/SPARK-12988?focusedCommentId=15118433&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15118433

I was trying to address the 3rd bullet in the list. About your second question , per bullet one this should be disallowed ? Please let me know.

val index = queryExecution.analyzed.output.indexWhere(f => resolver(f.name, unquotedColName))
if (index >= 0) Some(index) else None
}

protected[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolveQuoted(n.name, sqlContext.analyzer.resolver).get
Expand Down Expand Up @@ -1175,19 +1186,10 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def withColumn(colName: String, col: Column): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldReplace = output.exists(f => resolver(f.name, colName))
if (shouldReplace) {
val columns = output.map { field =>
if (resolver(field.name, colName)) {
col.as(colName)
} else {
Column(field)
}
}
select(columns : _*)
} else {
resolveToIndex(colName).map { index =>
select(output.map(attr => Column(attr)).updated(index, col.as(colName)) : _*)
}.getOrElse {
select(Column("*"), col.as(colName))
}
}
Expand All @@ -1196,19 +1198,10 @@ class DataFrame private[sql](
* Returns a new [[DataFrame]] by adding a column with metadata.
*/
private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldReplace = output.exists(f => resolver(f.name, colName))
if (shouldReplace) {
val columns = output.map { field =>
if (resolver(field.name, colName)) {
col.as(colName, metadata)
} else {
Column(field)
}
}
select(columns : _*)
} else {
resolveToIndex(colName).map {index =>
select(output.map(attr => Column(attr)).updated(index, col.as(colName, metadata)) : _*)
}.getOrElse {
select(Column("*"), col.as(colName, metadata))
}
}
Expand All @@ -1220,19 +1213,11 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def withColumnRenamed(existingName: String, newName: String): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldRename = output.exists(f => resolver(f.name, existingName))
if (shouldRename) {
val columns = output.map { col =>
if (resolver(col.name, existingName)) {
Column(col).as(newName)
} else {
Column(col)
}
}
select(columns : _*)
} else {
resolveToIndex(existingName).map {index =>
val renamed = Column(output(index)).as(newName)
select(output.map(attr => Column(attr)).updated(index, renamed) : _*)
}.getOrElse {
this
}
}
Expand All @@ -1255,13 +1240,13 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def drop(colNames: String*): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val remainingCols =
schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name))
if (remainingCols.size == this.schema.size) {
val indexesToDrop = colNames.flatMap(resolveToIndex)
if (indexesToDrop.isEmpty) {
this
} else {
this.select(remainingCols: _*)
val output = queryExecution.analyzed.output
val remainingCols = output.indices.diff(indexesToDrop).map(index => Column(output(index)))
select(remainingCols: _*)
}
}

Expand All @@ -1274,16 +1259,20 @@ class DataFrame private[sql](
* @since 1.4.1
*/
def drop(col: Column): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we have this method....
we can only drop top level columns, allowing users to pass in a Column doesn't make sense.

cc @rxin @marmbrus

val expression = col match {
val expression: Expression = col match {
case Column(u: UnresolvedAttribute) =>
queryExecution.analyzed.resolveQuoted(u.name, sqlContext.analyzer.resolver).getOrElse(u)
resolveToIndex(u.name).map(this.logicalPlan.output).getOrElse(u)
case Column(expr: Expression) => expr
}
val attrs = this.logicalPlan.output
val colsAfterDrop = attrs.filter { attr =>
attr != expression
}.map(attr => Column(attr))
select(colsAfterDrop : _*)
if (colsAfterDrop.size == this.schema.size) {
this
} else {
select(colsAfterDrop: _*)
}
}

/**
Expand Down
25 changes: 25 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1270,4 +1270,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Seq(1 -> "a").toDF("i", "j").filter($"i".cast(StringType) === "1"),
Row(1, "a"))
}

test("SPARK-12988: drop columns with ` in column name") {
val src = Seq((1, 2, 3)).toDF("a_b", "a.b", "a.c")
val df = src.drop("a_b")
checkAnswer(df, Row(2, 3))
assert(df.schema.map(_.name) === Seq("a.b", "a.c"))
val df1 = src.drop("a.b")
checkAnswer(df1, Row(1, 3))
assert(df1.schema.map(_.name) === Seq("a_b", "a.c"))
val df2 = src.drop("`a.c`")
checkAnswer(df2, Row(1, 2))
assert(df2.schema.map(_.name) === Seq("a_b", "a.b"))
val col1 = new Column("a_b")
val df4 = src.drop(col1)
checkAnswer(df4, Row(2, 3))
assert(df4.schema.map(_.name) === Seq("a.b", "a.c"))
val col2 = new Column("a.b")
val df5 = src.drop(col2)
checkAnswer(df5, Row(1, 3))
assert(df5.schema.map(_.name) === Seq("a_b", "a.c"))
val col3 = new Column("`a.c`")
val df6 = src.drop(col3)
checkAnswer(df6, Row(1, 2))
assert(df6.schema.map(_.name) === Seq("a_b", "a.b"))
}
}