Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ case class EnsureRequirements(
private def createKeyGroupedShuffleSpec(
partitioning: Partitioning,
distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = {
def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = {
def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = {
val attributes = partitioning.expressions.flatMap(_.collectLeaves())
val clustering = distribution.clustering

Expand All @@ -570,11 +570,10 @@ case class EnsureRequirements(
}

partitioning match {
case p: KeyGroupedPartitioning => check(p)
case p: KeyGroupedPartitioning => tryCreate(p)
case PartitioningCollection(partitionings) =>
val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution))
assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined))
specs.head
specs.filter(_.isDefined).map(_.get).headOption
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,28 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
.add("price", FloatType)
.add("time", TimestampType)

test("SPARK-49179: Fix v2 multi bucketed inner joins throw AssertionError") {
val cols = new StructType()
.add("id", LongType)
.add("name", StringType)
val buckets = Array(bucket(8, "id"))

withTable("t1", "t2", "t3") {
Seq("t1", "t2", "t3").foreach { t =>
createTable(t, cols, buckets)
sql(s"INSERT INTO testcat.ns.$t VALUES (1, 'aa'), (2, 'bb'), (3, 'cc')")
}
val df = sql(
"""
|SELECT t1.id, t2.id, t3.name FROM testcat.ns.t1
|JOIN testcat.ns.t2 ON t1.id = t2.id
|JOIN testcat.ns.t3 ON t1.id = t3.id
|""".stripMargin)
checkAnswer(df, Seq(Row(1, 1, "aa"), Row(2, 2, "bb"), Row(3, 3, "cc")))
assert(collectShuffles(df.queryExecution.executedPlan).isEmpty)
}
}

test("partitioned join: join with two partition keys and matching & sorted partitions") {
val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
createTable(items, items_schema, items_partitions)
Expand Down