Skip to content

Commit 14f1241

Browse files
committed
fix the performance bug when inferring constraints for Generate
1 parent bde47c8 commit 14f1241

File tree

3 files changed

+48
-32
lines changed

3 files changed

+48
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,16 +1178,24 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] {
11781178
e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate
11791179

11801180
case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) =>
1181-
// Exclude child's constraints to guarantee idempotency
1182-
val inferredFilters = ExpressionSet(
1183-
Seq(
1184-
GreaterThan(Size(g.children.head), Literal(0)),
1185-
IsNotNull(g.children.head)
1186-
)
1187-
) -- generate.child.constraints
1188-
1189-
if (inferredFilters.nonEmpty) {
1190-
generate.copy(child = Filter(inferredFilters.reduce(And), generate.child))
1181+
val input = g.children.head
1182+
// Generating extra predicates here has overheads/risks:
1183+
// - We may evaluate expensive input expressions multiple times.
1184+
// - We may infer too many constraints later.
1185+
// - The input expression may fail to be evaluated under ANSI mode. If we reorder the
1186+
// predicates and evaluate the input expression first, we may fail the query unexpectedly.
1187+
// To be safe, here we only generate extra predicates if the input is an attribute.
1188+
if (input.isInstanceOf[Attribute]) {
1189+
// Exclude child's constraints to guarantee idempotency
1190+
val inferredFilters = ExpressionSet(
1191+
Seq(GreaterThan(Size(input), Literal(0)), IsNotNull(input))
1192+
) -- generate.child.constraints
1193+
1194+
if (inferredFilters.nonEmpty) {
1195+
generate.copy(child = Filter(inferredFilters.reduce(And), generate.child))
1196+
} else {
1197+
generate
1198+
}
11911199
} else {
11921200
generate
11931201
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
4747
private def hasNoSideEffect(e: Expression): Boolean = e match {
4848
case _: Attribute => true
4949
case _: Literal => true
50+
case c: Cast if !conf.ansiEnabled => hasNoSideEffect(c.child)
5051
case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect)
5152
case _ => false
5253
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class InferFiltersFromGenerateSuite extends PlanTest {
3636
val testRelation = LocalRelation('a.array(StructType(Seq(
3737
StructField("x", IntegerType),
3838
StructField("y", IntegerType)
39-
))), 'c1.string, 'c2.string)
39+
))), 'c1.string, 'c2.string, 'c3.int)
4040

4141
Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
4242
val generator = f('a)
@@ -74,6 +74,13 @@ class InferFiltersFromGenerateSuite extends PlanTest {
7474
val optimized = Optimize.execute(originalQuery)
7575
comparePlans(optimized, originalQuery)
7676
}
77+
78+
val fromJson = f(JsonToStructs(ArrayType(new StructType().add("s", "string")), Map.empty, 'c1))
79+
test("SPARK-37392: Don't infer filters from " + fromJson) {
80+
val originalQuery = testRelation.generate(fromJson).analyze
81+
val optimized = Optimize.execute(originalQuery)
82+
comparePlans(optimized, originalQuery)
83+
}
7784
}
7885

7986
// setup rules to test inferFilters with ConstantFolding to make sure
@@ -91,28 +98,28 @@ class InferFiltersFromGenerateSuite extends PlanTest {
9198
}
9299

93100
Seq(Explode(_), PosExplode(_)).foreach { f =>
94-
val createArrayExplode = f(CreateArray(Seq('c1)))
95-
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) {
96-
val originalQuery = testRelation.generate(createArrayExplode).analyze
97-
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
98-
comparePlans(optimized, originalQuery)
99-
}
100-
val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
101-
test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) {
102-
val originalQuery = testRelation.generate(createMapExplode).analyze
103-
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
104-
comparePlans(optimized, originalQuery)
105-
}
106-
}
101+
val createArrayExplode = f(CreateArray(Seq('c1)))
102+
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) {
103+
val originalQuery = testRelation.generate(createArrayExplode).analyze
104+
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
105+
comparePlans(optimized, originalQuery)
106+
}
107+
val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
108+
test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) {
109+
val originalQuery = testRelation.generate(createMapExplode).analyze
110+
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
111+
comparePlans(optimized, originalQuery)
112+
}
113+
}
107114

108-
Seq(Inline(_)).foreach { f =>
109-
val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
110-
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) {
111-
val originalQuery = testRelation.generate(createArrayStructExplode).analyze
112-
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
113-
comparePlans(optimized, originalQuery)
114-
}
115-
}
115+
Seq(Inline(_)).foreach { f =>
116+
val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
117+
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) {
118+
val originalQuery = testRelation.generate(createArrayStructExplode).analyze
119+
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
120+
comparePlans(optimized, originalQuery)
121+
}
122+
}
116123

117124
test("SPARK-36715: Don't infer filters from udf") {
118125
Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>

0 commit comments

Comments
 (0)