Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
SimplifyCaseConversionExpressions,
RewriteCorrelatedScalarSubquery,
EliminateSerialization,
RemoveAliasOnlyProject) ::
RemoveRedundantAliases,
RemoveRedundantProject) ::
Batch("Check Cartesian Products", Once,
CheckCartesianProducts(conf)) ::
Batch("Decimal Optimizations", fixedPoint,
Expand Down Expand Up @@ -153,56 +154,98 @@ class SimpleTestOptimizer extends Optimizer(
new SimpleCatalystConf(caseSensitiveAnalysis = true))

/**
* Removes the Project only conducting Alias of its child node.
* It is created mainly for removing extra Project added in EliminateSerialization rule,
* but can also benefit other operators.
* Remove redundant aliases from a query plan. A redundant alias is an alias that does not change
* the name or metadata of a column, and does not deduplicate it.
*/
object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
object RemoveRedundantAliases extends Rule[LogicalPlan] {

/**
* Returns true if the project list is semantically same as child output, after strip alias on
* attribute.
* Create an attribute mapping from the old to the new attributes. This function will only
* return the attribute pairs that have changed.
*/
private def isAliasOnly(
projectList: Seq[NamedExpression],
childOutput: Seq[Attribute]): Boolean = {
if (projectList.length != childOutput.length) {
false
} else {
stripAliasOnAttribute(projectList).zip(childOutput).forall {
case (a: Attribute, o) if a semanticEquals o => true
case _ => false
}
private def createAttributeMapping(current: LogicalPlan, next: LogicalPlan)
: Seq[(Attribute, Attribute)] = {
current.output.zip(next.output).filterNot {
case (a1, a2) => a1.semanticEquals(a2)
}
}

private def stripAliasOnAttribute(projectList: Seq[NamedExpression]) = {
projectList.map {
// Alias with metadata can not be stripped, or the metadata will be lost.
// If the alias name is different from attribute name, we can't strip it either, or we may
// accidentally change the output schema name of the root plan.
case a @ Alias(attr: Attribute, name) if a.metadata == Metadata.empty && name == attr.name =>
attr
case other => other
}
/**
* Remove the top-level alias from an expression when it is redundant.
*/
private def removeRedundantAlias(e: Expression, blacklist: AttributeSet): Expression = e match {
// Alias with metadata can not be stripped, or the metadata will be lost.
// If the alias name is different from attribute name, we can't strip it either, or we
// may accidentally change the output schema name of the root plan.
case a @ Alias(attr: Attribute, name)
if a.metadata == Metadata.empty && name == attr.name && !blacklist.contains(attr) =>
attr
case a => a
}

def apply(plan: LogicalPlan): LogicalPlan = {
val aliasOnlyProject = plan.collectFirst {
case p @ Project(pList, child) if isAliasOnly(pList, child.output) => p
}
/**
* Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to
* prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self)
* join.
*/
private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = {
plan match {
// A join has to be treated differently, because the left and the right side of the join are
// not allowed to use the same attributes. We use a blacklist to prevent us from creating a
// situation in which this happens; the rule will only remove an alias if its child
// attribute is not on the black list.
case Join(left, right, joinType, condition) =>
val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet)
val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet)
val mapping = AttributeMap(
createAttributeMapping(left, newLeft) ++
createAttributeMapping(right, newRight))
val newCondition = condition.map(_.transform {
case a: Attribute => mapping.getOrElse(a, a)
})
Join(newLeft, newRight, joinType, newCondition)

case _ =>
// Remove redundant aliases in the subtree(s).
val currentNextAttrPairs = mutable.Buffer.empty[(Attribute, Attribute)]
val newNode = plan.mapChildren { child =>
val newChild = removeRedundantAliases(child, blacklist)
currentNextAttrPairs ++= createAttributeMapping(child, newChild)
newChild
}

aliasOnlyProject.map { case proj =>
val attributesToReplace = proj.output.zip(proj.child.output).filterNot {
case (a1, a2) => a1 semanticEquals a2
}
val attrMap = AttributeMap(attributesToReplace)
plan transform {
case plan: Project if plan eq proj => plan.child
case plan => plan transformExpressions {
case a: Attribute if attrMap.contains(a) => attrMap(a)
// Create the attribute mapping. Note that the currentNextAttrPairs can contain duplicate
// keys in case of Union (this is caused by the PushProjectionThroughUnion rule); in this
// case we use the the first mapping (which should be provided by the first child).
val mapping = AttributeMap(currentNextAttrPairs)

// Create a an expression cleaning function for nodes that can actually produce redundant
// aliases, use identity otherwise.
val clean: Expression => Expression = plan match {
case _: Project => removeRedundantAlias(_, blacklist)
case _: Aggregate => removeRedundantAlias(_, blacklist)
case _: Window => removeRedundantAlias(_, blacklist)
case _ => identity[Expression]
}
}
}.getOrElse(plan)

// Transform the expressions.
newNode.mapExpressions { expr =>
clean(expr.transform {
case a: Attribute => mapping.getOrElse(a, a)
})
}
}
}

def apply(plan: LogicalPlan): LogicalPlan = removeRedundantAliases(plan, AttributeSet.empty)
}

/**
* Remove projections from the query plan that do not make any modifications.
*/
object RemoveRedundantProject extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p @ Project(_, child) if p.output == child.output => child
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,31 +242,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* @param rule the rule to be applied to every expression in this operator.
*/
def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
var changed = false

@inline def transformExpressionDown(e: Expression): Expression = {
val newE = e.transformDown(rule)
if (newE.fastEquals(e)) {
e
} else {
changed = true
newE
}
}

def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
case null => null
}

val newArgs = mapProductIterator(recursiveTransform)

if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
mapExpressions(_.transformDown(rule))
}

/**
Expand All @@ -276,10 +252,18 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* @return
*/
def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = {
mapExpressions(_.transformUp(rule))
}

/**
* Apply a map function to each expression present in this query operator, and return a new
* query operator based on the mapped expressions.
*/
def mapExpressions(f: Expression => Expression): this.type = {
var changed = false

@inline def transformExpressionUp(e: Expression): Expression = {
val newE = e.transformUp(rule)
@inline def transformExpression(e: Expression): Expression = {
val newE = f(e)
if (newE.fastEquals(e)) {
e
} else {
Expand All @@ -289,8 +273,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}

def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e))
case e: Expression => transformExpression(e)
case Some(e: Expression) => Some(transformExpression(e))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map(recursiveTransform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
if (!analyzed) {
val afterRuleOnChildren = transformChildren(rule, (t, r) => t.resolveOperators(r))
val afterRuleOnChildren = mapChildren(_.resolveOperators(rule))
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[LogicalPlan])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,26 +191,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arr
}

/**
* Returns a copy of this node where `f` has been applied to all the nodes children.
*/
def mapChildren(f: BaseType => BaseType): BaseType = {
var changed = false
val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (newChild fastEquals arg) {
arg
} else {
changed = true
newChild
}
case nonChild: AnyRef => nonChild
case null => null
}
if (changed) makeCopy(newArgs) else this
}

/**
* Returns a copy of this node with the children replaced.
* TODO: Validate somewhere (in debug mode?) that children are ordered correctly.
Expand Down Expand Up @@ -290,9 +270,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {

// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
transformChildren(rule, (t, r) => t.transformDown(r))
mapChildren(_.transformDown(rule))
} else {
afterRule.transformChildren(rule, (t, r) => t.transformDown(r))
afterRule.mapChildren(_.transformDown(rule))
}
}

Expand All @@ -304,7 +284,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* @param rule the function use to transform this nodes children
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
val afterRuleOnChildren = mapChildren(_.transformUp(rule))
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
Expand All @@ -317,26 +297,22 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}

/**
* Returns a copy of this node where `rule` has been recursively applied to all the children of
* this node. When `rule` does not apply to a given node it is left unchanged.
* @param rule the function used to transform this nodes children
* Returns a copy of this node where `f` has been applied to all the nodes children.
*/
protected def transformChildren(
rule: PartialFunction[BaseType, BaseType],
nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = {
def mapChildren(f: BaseType => BaseType): BaseType = {
if (children.nonEmpty) {
var changed = false
val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
Some(newChild)
Expand All @@ -345,7 +321,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
Expand All @@ -357,16 +333,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
val newChild = f(arg.asInstanceOf[BaseType])
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
val newChild1 = f(arg1.asInstanceOf[BaseType])
val newChild2 = f(arg2.asInstanceOf[BaseType])
if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
Expand Down
Loading