Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ trait PredicateHelper {
*
* For example consider a join between two relations R(a, b) and S(c, d).
*
* `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns
* `false`.
* - `canEvaluate(EqualTo(a,b), R)` returns `true`
* - `canEvaluate(EqualTo(a,c), R)` returns `false`
* - `canEvaluate(Literal(1), R)` returns `true` as literals CAN be evaluated on any plan
*/
protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
expr.references.subsetOf(plan.outputSet)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
val conditionalJoin = rest.find { planJoinPair =>
val plan = planJoinPair._1
val refs = left.outputSet ++ plan.outputSet
conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan))
conditions
.filterNot(l => l.references.nonEmpty && canEvaluate(l, left))
.filterNot(r => r.references.nonEmpty && canEvaluate(r, plan))
.exists(_.references.subsetOf(refs))
}
// pick the next one if no condition left
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
// as join keys.
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
val joinKeys = predicates.flatMap {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None
case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
// Replace null with default value for joining key, then those rows with null in it could
Expand All @@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
case other => None
}
val otherPredicates = predicates.filterNot {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
case EqualTo(l, r) =>
canEvaluate(l, left) && canEvaluate(r, right) ||
canEvaluate(l, right) && canEvaluate(r, left)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
private def testBucketing(
bucketSpecLeft: Option[BucketSpec],
bucketSpecRight: Option[BucketSpec],
joinColumns: Seq[String],
joinType: String = "inner",
joinCondition: (DataFrame, DataFrame) => Column,
shuffleLeft: Boolean,
shuffleRight: Boolean,
sortLeft: Boolean = true,
Expand Down Expand Up @@ -268,12 +269,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val t1 = spark.table("bucketed_table1")
val t2 = spark.table("bucketed_table2")
val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))
val joined = t1.join(t2, joinCondition(t1, t2), joinType)

// First check the result is corrected.
checkAnswer(
joined.sort("bucketed_table1.k", "bucketed_table2.k"),
df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k"))
df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k"))

assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec])
val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec]
Expand All @@ -297,86 +298,144 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}

private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = {
private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = {
joinCols.map(col => left(col) === right(col)).reduce(_ && _)
}

test("avoid shuffle when join 2 bucketed tables") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false
)
}

// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
ignore("avoid shuffle when join keys are a super-set of bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false
)
}

test("only shuffle one side when join bucketed table and non-bucketed table") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = None,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = true
)
}

test("only shuffle one side when 2 bucketed tables have different bucket number") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = true
)
}

test("only shuffle one side when 2 bucketed tables have different bucket keys") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i")),
shuffleLeft = false,
shuffleRight = true
)
}

test("shuffle when join keys are not equal to bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("j")),
shuffleLeft = true,
shuffleRight = true
)
}

test("shuffle when join 2 bucketed tables with bucketing disabled") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = true,
shuffleRight = true
)
}
}

test("avoid shuffle and sort when bucket and sort columns are join keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
testBucketing(
bucketSpec, bucketSpec, Seq("i", "j"),
shuffleLeft = false, shuffleRight = false,
sortLeft = false, sortRight = false
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = false
)
}

test("avoid shuffle and sort when sort columns are a super set of join keys") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k")))
testBucketing(
bucketSpec1, bucketSpec2, Seq("i"),
shuffleLeft = false, shuffleRight = false,
sortLeft = false, sortRight = false
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i")),
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = false
)
}

test("only sort one side when sort columns are different") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k")))
testBucketing(
bucketSpec1, bucketSpec2, Seq("i", "j"),
shuffleLeft = false, shuffleRight = false,
sortLeft = false, sortRight = true
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = true
)
}

test("only sort one side when sort columns are same but their ordering is different") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i")))
testBucketing(
bucketSpec1, bucketSpec2, Seq("i", "j"),
shuffleLeft = false, shuffleRight = false,
sortLeft = false, sortRight = true
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = true
)
}

Expand Down Expand Up @@ -408,6 +467,25 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}

test("SPARK-17698 Join predicates should not contain filter clauses") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i")))
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinType = "fullouter",
joinCondition = (left: DataFrame, right: DataFrame) => {
val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it just val joinPredicates = left(col) === right(col)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea @tejasapatil mind fixing this? We can merge it then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually @tejasapatil given there is a need for backport, I'd let you fix this in your other prs since this is fairly cosmetic.

val filterLeft = left("i") === Literal("1")
val filterRight = right("i") === Literal("1")
joinPredicates && filterLeft && filterRight
},
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = false
)
}

test("error if there exists any malformed bucket files") {
withTable("bucketed_table") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
Expand Down