@@ -103,7 +103,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
103103 SimplifyCaseConversionExpressions ,
104104 RewriteCorrelatedScalarSubquery ,
105105 EliminateSerialization ,
106- RemoveExtraProjectForSerialization ) ::
106+ RemoveAliasOnlyProject ) ::
107107 Batch (" Decimal Optimizations" , fixedPoint,
108108 DecimalAggregates ) ::
109109 Batch (" Typed Filter Optimization" , fixedPoint,
@@ -157,21 +157,52 @@ object SamplePushDown extends Rule[LogicalPlan] {
157157}
158158
159159/**
160- * Removes extra Project added in EliminateSerialization rule.
160+ * Removes the Project only conducting Alias of its child node.
161+ * It is created mainly for removing extra Project added in EliminateSerialization rule,
162+ * but can also benefit other operators.
161163 */
162- object RemoveExtraProjectForSerialization extends Rule [LogicalPlan ] {
164+ object RemoveAliasOnlyProject extends Rule [LogicalPlan ] {
165+ // Check if projectList in the Project node has the same attribute names and ordering
166+ // as its child node.
167+ private def checkAliasOnly (
168+ projectList : Seq [NamedExpression ],
169+ childOutput : Seq [Attribute ]): Boolean = {
170+ if (! projectList.forall(_.isInstanceOf [Alias ]) || projectList.length != childOutput.length) {
171+ return false
172+ } else {
173+ projectList.map(_.asInstanceOf [Alias ]).zip(childOutput).forall { case (a, o) =>
174+ a.child match {
175+ case attr : Attribute
176+ if a.name == attr.name && attr.name == o.name && attr.dataType == o.dataType
177+ && attr.exprId == o.exprId =>
178+ true
179+ case _ => false
180+ }
181+ }
182+ }
183+ }
184+
163185 def apply (plan : LogicalPlan ): LogicalPlan = {
164- val objectProject = plan.find(_.isInstanceOf [ObjectProject ]).map { case o : ObjectProject =>
165- val replaceFrom = o.outputObjAttr
166- val replaceTo = o.child.output.head
186+ val processedPlan = plan.find { p =>
187+ p match {
188+ case Project (pList, child) if checkAliasOnly(pList, child.output) => true
189+ case _ => false
190+ }
191+ }.map { case p : Project =>
192+ val attrMap = p.projectList.map { a =>
193+ val alias = a.asInstanceOf [Alias ]
194+ val replaceFrom = alias.toAttribute
195+ val replaceTo = alias.child.asInstanceOf [Attribute ]
196+ (replaceFrom, replaceTo)
197+ }.toMap
167198 plan.transformAllExpressions {
168- case a : Attribute if a.equals(replaceFrom ) => replaceTo
199+ case a : Attribute if attrMap.contains(a ) => attrMap(a)
169200 }.transform {
170- case op : ObjectProject if o == op => op.child
201+ case op : Project if op == p => op.child
171202 }
172203 }
173- if (objectProject .isDefined) {
174- objectProject .get
204+ if (processedPlan .isDefined) {
205+ processedPlan .get
175206 } else {
176207 plan
177208 }
@@ -186,9 +217,10 @@ object EliminateSerialization extends Rule[LogicalPlan] {
186217 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
187218 case d @ DeserializeToObject (_, _, s : SerializeFromObject )
188219 if d.outputObjectType == s.inputObjectType =>
189- // Adds an extra ObjectProject here, to preserve the output expr id of `DeserializeToObject`.
190- // We will remove it later.
191- ObjectProject (d.output.head, s.child)
220+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
221+ // We will remove it later in RemoveAliasOnlyProject rule.
222+ val objAttr = Alias (s.child.output.head, " obj" )(exprId = d.output.head.exprId)
223+ Project (objAttr :: Nil , s.child)
192224 case a @ AppendColumns (_, _, _, s : SerializeFromObject )
193225 if a.deserializer.dataType == s.inputObjectType =>
194226 AppendColumnsWithObject (a.func, s.serializer, a.serializer, s.child)
0 commit comments