Skip to content

Commit fb0894b

Browse files
tejasapatilrxin
authored andcommitted
[SPARK-17698][SQL] Join predicates should not contain filter clauses
## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-17698 `ExtractEquiJoinKeys` is incorrectly using filter predicates as the join condition for joins. `canEvaluate` [0] tries to see if the an `Expression` can be evaluated using output of a given `Plan`. In case of filter predicates (eg. `a.id='1'`), the `Expression` passed for the right hand side (ie. '1' ) is a `Literal` which does not have any attribute references. Thus `expr.references` is an empty set which theoretically is a subset of any set. This leads to `canEvaluate` returning `true` and `a.id='1'` is treated as a join predicate. While this does not lead to incorrect results but in case of bucketed + sorted tables, we might miss out on avoiding un-necessary shuffle + sort. See example below: [0] : https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala#L91 eg. ``` val df = (1 until 10).toDF("id").coalesce(1) hc.sql("DROP TABLE IF EXISTS table1").collect df.write.bucketBy(8, "id").sortBy("id").saveAsTable("table1") hc.sql("DROP TABLE IF EXISTS table2").collect df.write.bucketBy(8, "id").sortBy("id").saveAsTable("table2") sqlContext.sql(""" SELECT a.id, b.id FROM table1 a FULL OUTER JOIN table2 b ON a.id = b.id AND a.id='1' AND b.id='1' """).explain(true) ``` BEFORE: This is doing shuffle + sort over table scan outputs which is not needed as both tables are bucketed and sorted on the same columns and have same number of buckets. This should be a single stage job. ``` SortMergeJoin [id#38, cast(id#38 as double), 1.0], [id#39, 1.0, cast(id#39 as double)], FullOuter :- *Sort [id#38 ASC NULLS FIRST, cast(id#38 as double) ASC NULLS FIRST, 1.0 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#38, cast(id#38 as double), 1.0, 200) : +- *FileScan parquet default.table1[id#38] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table1, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int> +- *Sort [id#39 ASC NULLS FIRST, 1.0 ASC NULLS FIRST, cast(id#39 as double) ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#39, 1.0, cast(id#39 as double), 200) +- *FileScan parquet default.table2[id#39] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int> ``` AFTER : ``` SortMergeJoin [id#32], [id#33], FullOuter, ((cast(id#32 as double) = 1.0) && (cast(id#33 as double) = 1.0)) :- *FileScan parquet default.table1[id#32] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table1, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int> +- *FileScan parquet default.table2[id#33] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int> ``` ## How was this patch tested? - Added a new test case for this scenario : `SPARK-17698 Join predicates should not contain filter clauses` - Ran all the tests in `BucketedReadSuite` Author: Tejas Patil <tejasp@fb.com> Closes #15272 from tejasapatil/SPARK-17698_join_predicate_filter_clause.
1 parent e895bc2 commit fb0894b

File tree

4 files changed

+109
-26
lines changed

4 files changed

+109
-26
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ trait PredicateHelper {
8484
*
8585
* For example consider a join between two relations R(a, b) and S(c, d).
8686
*
87-
* `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns
88-
* `false`.
87+
* - `canEvaluate(EqualTo(a,b), R)` returns `true`
88+
* - `canEvaluate(EqualTo(a,c), R)` returns `false`
89+
* - `canEvaluate(Literal(1), R)` returns `true` as literals CAN be evaluated on any plan
8990
*/
9091
protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
9192
expr.references.subsetOf(plan.outputSet)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
6565
val conditionalJoin = rest.find { planJoinPair =>
6666
val plan = planJoinPair._1
6767
val refs = left.outputSet ++ plan.outputSet
68-
conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan))
68+
conditions
69+
.filterNot(l => l.references.nonEmpty && canEvaluate(l, left))
70+
.filterNot(r => r.references.nonEmpty && canEvaluate(r, plan))
6971
.exists(_.references.subsetOf(refs))
7072
}
7173
// pick the next one if no condition left

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
112112
// as join keys.
113113
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
114114
val joinKeys = predicates.flatMap {
115+
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None
115116
case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
116117
case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
117118
// Replace null with default value for joining key, then those rows with null in it could
@@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
125126
case other => None
126127
}
127128
val otherPredicates = predicates.filterNot {
129+
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
128130
case EqualTo(l, r) =>
129131
canEvaluate(l, left) && canEvaluate(r, right) ||
130132
canEvaluate(l, right) && canEvaluate(r, left)

sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 101 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
235235
private def testBucketing(
236236
bucketSpecLeft: Option[BucketSpec],
237237
bucketSpecRight: Option[BucketSpec],
238-
joinColumns: Seq[String],
238+
joinType: String = "inner",
239+
joinCondition: (DataFrame, DataFrame) => Column,
239240
shuffleLeft: Boolean,
240241
shuffleRight: Boolean,
241242
sortLeft: Boolean = true,
@@ -268,12 +269,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
268269
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
269270
val t1 = spark.table("bucketed_table1")
270271
val t2 = spark.table("bucketed_table2")
271-
val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))
272+
val joined = t1.join(t2, joinCondition(t1, t2), joinType)
272273

273274
// First check the result is corrected.
274275
checkAnswer(
275276
joined.sort("bucketed_table1.k", "bucketed_table2.k"),
276-
df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k"))
277+
df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k"))
277278

278279
assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec])
279280
val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec]
@@ -297,86 +298,144 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
297298
}
298299
}
299300

300-
private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = {
301+
private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = {
301302
joinCols.map(col => left(col) === right(col)).reduce(_ && _)
302303
}
303304

304305
test("avoid shuffle when join 2 bucketed tables") {
305306
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
306-
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
307+
testBucketing(
308+
bucketSpecLeft = bucketSpec,
309+
bucketSpecRight = bucketSpec,
310+
joinCondition = joinCondition(Seq("i", "j")),
311+
shuffleLeft = false,
312+
shuffleRight = false
313+
)
307314
}
308315

309316
// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
310317
ignore("avoid shuffle when join keys are a super-set of bucket keys") {
311318
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
312-
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
319+
testBucketing(
320+
bucketSpecLeft = bucketSpec,
321+
bucketSpecRight = bucketSpec,
322+
joinCondition = joinCondition(Seq("i", "j")),
323+
shuffleLeft = false,
324+
shuffleRight = false
325+
)
313326
}
314327

315328
test("only shuffle one side when join bucketed table and non-bucketed table") {
316329
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
317-
testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
330+
testBucketing(
331+
bucketSpecLeft = bucketSpec,
332+
bucketSpecRight = None,
333+
joinCondition = joinCondition(Seq("i", "j")),
334+
shuffleLeft = false,
335+
shuffleRight = true
336+
)
318337
}
319338

320339
test("only shuffle one side when 2 bucketed tables have different bucket number") {
321340
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
322341
val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
323-
testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
342+
testBucketing(
343+
bucketSpecLeft = bucketSpec1,
344+
bucketSpecRight = bucketSpec2,
345+
joinCondition = joinCondition(Seq("i", "j")),
346+
shuffleLeft = false,
347+
shuffleRight = true
348+
)
324349
}
325350

326351
test("only shuffle one side when 2 bucketed tables have different bucket keys") {
327352
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
328353
val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
329-
testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
354+
testBucketing(
355+
bucketSpecLeft = bucketSpec1,
356+
bucketSpecRight = bucketSpec2,
357+
joinCondition = joinCondition(Seq("i")),
358+
shuffleLeft = false,
359+
shuffleRight = true
360+
)
330361
}
331362

332363
test("shuffle when join keys are not equal to bucket keys") {
333364
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
334-
testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
365+
testBucketing(
366+
bucketSpecLeft = bucketSpec,
367+
bucketSpecRight = bucketSpec,
368+
joinCondition = joinCondition(Seq("j")),
369+
shuffleLeft = true,
370+
shuffleRight = true
371+
)
335372
}
336373

337374
test("shuffle when join 2 bucketed tables with bucketing disabled") {
338375
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
339376
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
340-
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
377+
testBucketing(
378+
bucketSpecLeft = bucketSpec,
379+
bucketSpecRight = bucketSpec,
380+
joinCondition = joinCondition(Seq("i", "j")),
381+
shuffleLeft = true,
382+
shuffleRight = true
383+
)
341384
}
342385
}
343386

344387
test("avoid shuffle and sort when bucket and sort columns are join keys") {
345388
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
346389
testBucketing(
347-
bucketSpec, bucketSpec, Seq("i", "j"),
348-
shuffleLeft = false, shuffleRight = false,
349-
sortLeft = false, sortRight = false
390+
bucketSpecLeft = bucketSpec,
391+
bucketSpecRight = bucketSpec,
392+
joinCondition = joinCondition(Seq("i", "j")),
393+
shuffleLeft = false,
394+
shuffleRight = false,
395+
sortLeft = false,
396+
sortRight = false
350397
)
351398
}
352399

353400
test("avoid shuffle and sort when sort columns are a super set of join keys") {
354401
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j")))
355402
val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k")))
356403
testBucketing(
357-
bucketSpec1, bucketSpec2, Seq("i"),
358-
shuffleLeft = false, shuffleRight = false,
359-
sortLeft = false, sortRight = false
404+
bucketSpecLeft = bucketSpec1,
405+
bucketSpecRight = bucketSpec2,
406+
joinCondition = joinCondition(Seq("i")),
407+
shuffleLeft = false,
408+
shuffleRight = false,
409+
sortLeft = false,
410+
sortRight = false
360411
)
361412
}
362413

363414
test("only sort one side when sort columns are different") {
364415
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
365416
val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k")))
366417
testBucketing(
367-
bucketSpec1, bucketSpec2, Seq("i", "j"),
368-
shuffleLeft = false, shuffleRight = false,
369-
sortLeft = false, sortRight = true
418+
bucketSpecLeft = bucketSpec1,
419+
bucketSpecRight = bucketSpec2,
420+
joinCondition = joinCondition(Seq("i", "j")),
421+
shuffleLeft = false,
422+
shuffleRight = false,
423+
sortLeft = false,
424+
sortRight = true
370425
)
371426
}
372427

373428
test("only sort one side when sort columns are same but their ordering is different") {
374429
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
375430
val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i")))
376431
testBucketing(
377-
bucketSpec1, bucketSpec2, Seq("i", "j"),
378-
shuffleLeft = false, shuffleRight = false,
379-
sortLeft = false, sortRight = true
432+
bucketSpecLeft = bucketSpec1,
433+
bucketSpecRight = bucketSpec2,
434+
joinCondition = joinCondition(Seq("i", "j")),
435+
shuffleLeft = false,
436+
shuffleRight = false,
437+
sortLeft = false,
438+
sortRight = true
380439
)
381440
}
382441

@@ -408,6 +467,25 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
408467
}
409468
}
410469

470+
test("SPARK-17698 Join predicates should not contain filter clauses") {
471+
val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i")))
472+
testBucketing(
473+
bucketSpecLeft = bucketSpec,
474+
bucketSpecRight = bucketSpec,
475+
joinType = "fullouter",
476+
joinCondition = (left: DataFrame, right: DataFrame) => {
477+
val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _)
478+
val filterLeft = left("i") === Literal("1")
479+
val filterRight = right("i") === Literal("1")
480+
joinPredicates && filterLeft && filterRight
481+
},
482+
shuffleLeft = false,
483+
shuffleRight = false,
484+
sortLeft = false,
485+
sortRight = false
486+
)
487+
}
488+
411489
test("error if there exists any malformed bucket files") {
412490
withTable("bucketed_table") {
413491
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")

0 commit comments

Comments
 (0)