Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (

// output label and label metadata as prediction
aggregatedDataset
.withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
.withColumn($(predictionCol), labelUDF(col(accColName)), Some(labelMetadata))
.drop(accColName)
}

Expand Down Expand Up @@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
// TODO: use when ... otherwise after SPARK-7321 is merged
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val trainingDataset = multiclassLabeled
.withColumn(labelColName, labelUDF(col($(labelCol))), Some(newLabelMeta))
val classifier = getClassifier
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
}
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
dataset.withColumn($(outputCol), newCol, Some(newField.metadata))
}

private def prepOutputField(schema: StructType): StructField = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
val newField = prepOutputField(dataset.schema)
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
val newCol = transformUDF(dataset($(inputCol)))
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
dataset.withColumn($(outputCol), newCol, Some(newField.metadata))
}

override def transformSchema(schema: StructType): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
case features: SparseVector => features.slice(inds)
}
}
dataset.withColumn($(outputCol),
slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), Some(outputAttr.toMetadata()))
}

/** Get the feature indices in order: indices, names */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class Analyzer(
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic)
PullOutNondeterministic),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)

/**
Expand Down Expand Up @@ -146,8 +148,6 @@ class Analyzer(
child match {
case _: UnresolvedAttribute => u
case ne: NamedExpression => ne
case g: GetStructField => Alias(g, g.field.name)()
case g: GetArrayStructFields => Alias(g, g.field.name)()
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
case e if !e.resolved => u
case other => Alias(other, s"_c$i")()
Expand Down Expand Up @@ -384,9 +384,7 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) {
q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
}
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
Expand All @@ -412,11 +410,6 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}

private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
case UnresolvedAlias(child) => child
case other => other
}

private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
Expand All @@ -426,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
plan.resolve(nameParts, resolver).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
Expand Down Expand Up @@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
case Subquery(_, child) => child
}
}

/**
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
* expression in Project(project list) or Aggregate(aggregate expressions) or
* Window(window expressions).
*/
object CleanupAliases extends Rule[LogicalPlan] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove everything except for this and the related tests? I'd like to pull this into the release branch without new features.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have opened #8159 to improve the withColumn, but left the code here to see if we can pass the tests.

This PR did 2 things:

  1. use Alias instead of UnresolvedAlias when resolve nested column in LogicalPlan.resolve
  2. clean unnecessary aliases at the end of analysis

If we only do 1, some tests will fail as we need to trim aliases in the middle of getField chain. If we only do 2, it can't fix any bugs. So I put them together here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've open #8215 which is basically your patch without the mllib changes.

private def trimAliases(e: Expression): Expression = {
var stop = false
e.transformDown {
// CreateStruct is a special case, we need to retain its top level Aliases as they decide the
// name of StructField. We also need to stop transform down this expression, or the Aliases
// under CreateStruct will be mistakenly trimmed.
case c: CreateStruct if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case c: CreateStructUnsafe if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case Alias(child, _) if !stop => child
}
}

def trimNonTopLevelAliases(e: Expression): Expression = e match {
case a: Alias =>
Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata)
case other => trimAliases(other)
}

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Project(projectList, child) =>
val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Project(cleanedProjectList, child)

case Aggregate(grouping, aggs, child) =>
val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Aggregate(grouping.map(trimAliases), cleanedAggs, child)

case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
val cleanedWindowExprs =
windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)

case other =>
var stop = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this release, but this makes me think that we are abusing aliases. I would rather that resolved expressions past the analyzer move the names out of the subexpressions and into the CreateStruct expression itself.

other transformExpressionsDown {
case c: CreateStruct if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case c: CreateStructUnsafe if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case Alias(child, _) if !stop => child
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

override lazy val resolved: Boolean = childrenResolved

override lazy val dataType: StructType = {
val fields = children.zipWithIndex.map { case (child, idx) =>
child match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
Expand Down Expand Up @@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]

Project(substitutedProjection, child)
// collapse 2 projects may introduce unnecessary Aliases, trim them here.
val cleanedProjection = substitutedProjection.map(p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
)
Project(cleanedProjection, child)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and wrap it with UnresolvedAlias which will be removed later.
// and aliased it with the last part of the name.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
// UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
// expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
Some(UnresolvedAlias(fieldExprs))
Some(Alias(fieldExprs, nestedFields.last)())

// No matches.
case Seq() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ case class Window(
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] =
(projectList ++ windowExpressions).map(_.toAttribute)
projectList ++ windowExpressions.map(_.toAttribute)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest {
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}

test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") {
val a = testRelation.output.head
var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col"))
var expected = testRelation.select((a + 1 + 2).as("col"))
checkAnalysis(plan, expected)

plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col"))
expected = testRelation.groupBy(a)((min(a) + 1).as("col"))
checkAnalysis(plan, expected)

// CreateStruct is a special case that we should not trim Alias for it.
plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test CreateStructUnsafe too?

checkAnalysis(plan, plan)
plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
checkAnalysis(plan, plan)
}
}
7 changes: 4 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1135,17 +1135,18 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
def withColumn(colName: String, col: Column): DataFrame = {
def withColumn(colName: String, col: Column, metadata: Option[Metadata] = None): DataFrame = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @rxin , in MLlib sometimes we need to set metadata for the new column, thus we will alias the new column with metadata before call withColumn and in withColumn we alias this clolumn again. Here I added a new parameter to allow user set metadata.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do this in a different PR? Also we should do it without using Option and default arguments so that it works well in Java.

val resolver = sqlContext.analyzer.resolver
val replaced = schema.exists(f => resolver(f.name, colName))
val aliasedColumn = metadata.map(md => col.as(colName, md)).getOrElse(col.as(colName))
if (replaced) {
val colNames = schema.map { field =>
val name = field.name
if (resolver(name, colName)) col.as(colName) else Column(name)
if (resolver(name, colName)) aliasedColumn else Column(name)
}
select(colNames : _*)
} else {
select(Column("*"), col.as(colName))
select(Column("*"), aliasedColumn)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,4 +871,10 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
assert(expected === actual)
}

test("SPARK-9323: DataFrame.orderBy should support nested column name") {
val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a": {"b": 1}}""" :: Nil))
checkAnswer(df.orderBy("a.b"), Row(Row(1)))
}
}