diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 518f9dcf94a7..1678006debc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -143,6 +143,9 @@ class DataFrame private[sql]( queryExecution.analyzed } + /** + * Resolves a column path i.e column name may contain "." or "`". . + */ protected[sql] def resolve(colName: String): NamedExpression = { queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( @@ -150,6 +153,14 @@ class DataFrame private[sql]( } } + private[sql] def resolveToIndex(colName: String): Option[Int] = { + val resolver = sqlContext.analyzer.resolver + // First remove any user supplied quotes. + val unquotedColName = colName.stripPrefix("`").stripSuffix("`") + val index = queryExecution.analyzed.output.indexWhere(f => resolver(f.name, unquotedColName)) + if (index >= 0) Some(index) else None + } + protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => queryExecution.analyzed.resolveQuoted(n.name, sqlContext.analyzer.resolver).get @@ -1175,19 +1186,10 @@ class DataFrame private[sql]( * @since 1.3.0 */ def withColumn(colName: String, col: Column): DataFrame = { - val resolver = sqlContext.analyzer.resolver val output = queryExecution.analyzed.output - val shouldReplace = output.exists(f => resolver(f.name, colName)) - if (shouldReplace) { - val columns = output.map { field => - if (resolver(field.name, colName)) { - col.as(colName) - } else { - Column(field) - } - } - select(columns : _*) - } else { + resolveToIndex(colName).map { index => + select(output.map(attr => Column(attr)).updated(index, col.as(colName)) : _*) + }.getOrElse { select(Column("*"), col.as(colName)) } } @@ -1196,19 +1198,10 @@ class DataFrame private[sql]( * Returns a new [[DataFrame]] by adding a column with metadata. */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { - val resolver = sqlContext.analyzer.resolver val output = queryExecution.analyzed.output - val shouldReplace = output.exists(f => resolver(f.name, colName)) - if (shouldReplace) { - val columns = output.map { field => - if (resolver(field.name, colName)) { - col.as(colName, metadata) - } else { - Column(field) - } - } - select(columns : _*) - } else { + resolveToIndex(colName).map {index => + select(output.map(attr => Column(attr)).updated(index, col.as(colName, metadata)) : _*) + }.getOrElse { select(Column("*"), col.as(colName, metadata)) } } @@ -1220,19 +1213,11 @@ class DataFrame private[sql]( * @since 1.3.0 */ def withColumnRenamed(existingName: String, newName: String): DataFrame = { - val resolver = sqlContext.analyzer.resolver val output = queryExecution.analyzed.output - val shouldRename = output.exists(f => resolver(f.name, existingName)) - if (shouldRename) { - val columns = output.map { col => - if (resolver(col.name, existingName)) { - Column(col).as(newName) - } else { - Column(col) - } - } - select(columns : _*) - } else { + resolveToIndex(existingName).map {index => + val renamed = Column(output(index)).as(newName) + select(output.map(attr => Column(attr)).updated(index, renamed) : _*) + }.getOrElse { this } } @@ -1255,13 +1240,13 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def drop(colNames: String*): DataFrame = { - val resolver = sqlContext.analyzer.resolver - val remainingCols = - schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) - if (remainingCols.size == this.schema.size) { + val indexesToDrop = colNames.flatMap(resolveToIndex) + if (indexesToDrop.isEmpty) { this } else { - this.select(remainingCols: _*) + val output = queryExecution.analyzed.output + val remainingCols = output.indices.diff(indexesToDrop).map(index => Column(output(index))) + select(remainingCols: _*) } } @@ -1274,16 +1259,20 @@ class DataFrame private[sql]( * @since 1.4.1 */ def drop(col: Column): DataFrame = { - val expression = col match { + val expression: Expression = col match { case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted(u.name, sqlContext.analyzer.resolver).getOrElse(u) + resolveToIndex(u.name).map(this.logicalPlan.output).getOrElse(u) case Column(expr: Expression) => expr } val attrs = this.logicalPlan.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) - select(colsAfterDrop : _*) + if (colsAfterDrop.size == this.schema.size) { + this + } else { + select(colsAfterDrop: _*) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 09bbe57a43ce..a0c17ed049e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1270,4 +1270,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(1 -> "a").toDF("i", "j").filter($"i".cast(StringType) === "1"), Row(1, "a")) } + + test("SPARK-12988: drop columns with ` in column name") { + val src = Seq((1, 2, 3)).toDF("a_b", "a.b", "a.c") + val df = src.drop("a_b") + checkAnswer(df, Row(2, 3)) + assert(df.schema.map(_.name) === Seq("a.b", "a.c")) + val df1 = src.drop("a.b") + checkAnswer(df1, Row(1, 3)) + assert(df1.schema.map(_.name) === Seq("a_b", "a.c")) + val df2 = src.drop("`a.c`") + checkAnswer(df2, Row(1, 2)) + assert(df2.schema.map(_.name) === Seq("a_b", "a.b")) + val col1 = new Column("a_b") + val df4 = src.drop(col1) + checkAnswer(df4, Row(2, 3)) + assert(df4.schema.map(_.name) === Seq("a.b", "a.c")) + val col2 = new Column("a.b") + val df5 = src.drop(col2) + checkAnswer(df5, Row(1, 3)) + assert(df5.schema.map(_.name) === Seq("a_b", "a.c")) + val col3 = new Column("`a.c`") + val df6 = src.drop(col3) + checkAnswer(df6, Row(1, 2)) + assert(df6.schema.map(_.name) === Seq("a_b", "a.b")) + } }