Skip to content

Commit 29a0c70

Browse files
committed
Address comment.
1 parent 4b0773a commit 29a0c70

File tree

4 files changed

+51
-27
lines changed

4 files changed

+51
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
103103
SimplifyCaseConversionExpressions,
104104
RewriteCorrelatedScalarSubquery,
105105
EliminateSerialization,
106-
RemoveExtraProjectForSerialization) ::
106+
RemoveAliasOnlyProject) ::
107107
Batch("Decimal Optimizations", fixedPoint,
108108
DecimalAggregates) ::
109109
Batch("Typed Filter Optimization", fixedPoint,
@@ -157,21 +157,52 @@ object SamplePushDown extends Rule[LogicalPlan] {
157157
}
158158

159159
/**
160-
* Removes extra Project added in EliminateSerialization rule.
160+
* Removes the Project only conducting Alias of its child node.
161+
* It is created mainly for removing extra Project added in EliminateSerialization rule,
162+
* but can also benefit other operators.
161163
*/
162-
object RemoveExtraProjectForSerialization extends Rule[LogicalPlan] {
164+
object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
165+
// Check if projectList in the Project node has the same attribute names and ordering
166+
// as its child node.
167+
private def checkAliasOnly(
168+
projectList: Seq[NamedExpression],
169+
childOutput: Seq[Attribute]): Boolean = {
170+
if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) {
171+
return false
172+
} else {
173+
projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) =>
174+
a.child match {
175+
case attr: Attribute
176+
if a.name == attr.name && attr.name == o.name && attr.dataType == o.dataType
177+
&& attr.exprId == o.exprId =>
178+
true
179+
case _ => false
180+
}
181+
}
182+
}
183+
}
184+
163185
def apply(plan: LogicalPlan): LogicalPlan = {
164-
val objectProject = plan.find(_.isInstanceOf[ObjectProject]).map { case o: ObjectProject =>
165-
val replaceFrom = o.outputObjAttr
166-
val replaceTo = o.child.output.head
186+
val processedPlan = plan.find { p =>
187+
p match {
188+
case Project(pList, child) if checkAliasOnly(pList, child.output) => true
189+
case _ => false
190+
}
191+
}.map { case p: Project =>
192+
val attrMap = p.projectList.map { a =>
193+
val alias = a.asInstanceOf[Alias]
194+
val replaceFrom = alias.toAttribute
195+
val replaceTo = alias.child.asInstanceOf[Attribute]
196+
(replaceFrom, replaceTo)
197+
}.toMap
167198
plan.transformAllExpressions {
168-
case a: Attribute if a.equals(replaceFrom) => replaceTo
199+
case a: Attribute if attrMap.contains(a) => attrMap(a)
169200
}.transform {
170-
case op: ObjectProject if o == op => op.child
201+
case op: Project if op == p => op.child
171202
}
172203
}
173-
if (objectProject.isDefined) {
174-
objectProject.get
204+
if (processedPlan.isDefined) {
205+
processedPlan.get
175206
} else {
176207
plan
177208
}
@@ -186,9 +217,10 @@ object EliminateSerialization extends Rule[LogicalPlan] {
186217
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
187218
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
188219
if d.outputObjectType == s.inputObjectType =>
189-
// Adds an extra ObjectProject here, to preserve the output expr id of `DeserializeToObject`.
190-
// We will remove it later.
191-
ObjectProject(d.output.head, s.child)
220+
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
221+
// We will remove it later in RemoveAliasOnlyProject rule.
222+
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
223+
Project(objAttr :: Nil, s.child)
192224
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
193225
if a.deserializer.dataType == s.inputObjectType =>
194226
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,6 @@ trait ObjectConsumer extends UnaryNode {
8080
def inputObjectType: DataType = child.output.head.dataType
8181
}
8282

83-
/**
84-
* Takes the object from child and projects it as new attribute.
85-
* This logical plan is just used to preserve expr id temporarily and will be removed before
86-
* the end of optimization phase.
87-
*/
88-
case class ObjectProject(
89-
outputObjAttr: Attribute,
90-
child: LogicalPlan) extends UnaryNode with ObjectProducer
91-
9283
/**
9384
* Takes the input row from child and turns it into object using the given deserializer expression.
9485
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class EliminateSerializationSuite extends PlanTest {
4242
val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
4343
val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze
4444
val optimized = Optimize.execute(plan)
45-
val expected = ObjectProject(input.output.head.withNullability(false), input)
45+
val expected = input.select('obj.as("obj")).analyze
4646
comparePlans(optimized, expected)
4747
}
4848

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2424
import org.apache.spark.sql.catalyst.dsl.plans._
2525
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
2626
import org.apache.spark.sql.catalyst.plans.PlanTest
27-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ObjectProject}
27+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2828
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2929
import org.apache.spark.sql.types.BooleanType
3030

@@ -47,9 +47,10 @@ class TypedFilterOptimizationSuite extends PlanTest {
4747
val query = input.filter(f1).filter(f2).analyze
4848

4949
val optimized = Optimize.execute(query)
50-
val deserialized = input.deserialize[(Int, Int)]
51-
val expected = ObjectProject(deserialized.output.head, deserialized
52-
.where(callFunction(f1, BooleanType, 'obj)))
50+
51+
val expected = input.deserialize[(Int, Int)]
52+
.where(callFunction(f1, BooleanType, 'obj))
53+
.select('obj.as("obj"))
5354
.where(callFunction(f2, BooleanType, 'obj))
5455
.serialize[(Int, Int)].analyze
5556

0 commit comments

Comments
 (0)