From f1f05fb110ee5eaa17d2e33e9d0151426c662942 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 9 Aug 2024 15:53:28 +0800 Subject: [PATCH 1/2] Fix v2 multi bucketed inner joins throw AssertionError --- .../exchange/EnsureRequirements.scala | 7 +++---- .../KeyGroupedPartitioningSuite.scala | 21 +++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 90287c202846..e669165f4f2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -616,7 +616,7 @@ case class EnsureRequirements( private def createKeyGroupedShuffleSpec( partitioning: Partitioning, distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { - def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { + def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) val clustering = distribution.clustering @@ -636,11 +636,10 @@ case class EnsureRequirements( } partitioning match { - case p: KeyGroupedPartitioning => check(p) + case p: KeyGroupedPartitioning => tryCreate(p) case PartitioningCollection(partitionings) => val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution)) - assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined)) - specs.head + specs.filter(_.isDefined).map(_.get).headOption case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 99d99fede848..1cf0c2203f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -369,6 +369,27 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0))) } + test("SPARK-49179: Fix v2 multi bucketed inner joins throw AssertionError") { + val cols = Array( + Column.create("id", LongType), + Column.create("name", StringType)) + val buckets = Array(bucket(8, "id")) + + withTable("t1", "t2", "t3") { + Seq("t1", "t2", "t3").foreach { t => + createTable(t, cols, buckets) + sql(s"INSERT INTO testcat.ns.$t VALUES (1, 'aa'), (1, 'aa'), (2, 'bb'), (3, 'cc')") + } + val df = sql( + """ + |SELECT * FROM testcat.ns.t1 + |JOIN testcat.ns.t2 ON t1.id = t2.id + |JOIN testcat.ns.t3 ON t1.id = t3.id + |""".stripMargin) + assert(collectShuffles(df.queryExecution.executedPlan).isEmpty) + } + } + test("partitioned join: join with two partition keys and matching & sorted partitions") { val items_partitions = Array(bucket(8, "id"), days("arrive_time")) createTable(items, itemsColumns, items_partitions) From c72bf21196a2ab48b8e966ce6250c78e480b6c20 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 13 Aug 2024 09:16:31 +0800 Subject: [PATCH 2/2] address comments --- .../spark/sql/connector/KeyGroupedPartitioningSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 1cf0c2203f0d..6a146bc887db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -378,14 +378,15 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withTable("t1", "t2", "t3") { Seq("t1", "t2", "t3").foreach { t => createTable(t, cols, buckets) - sql(s"INSERT INTO testcat.ns.$t VALUES (1, 'aa'), (1, 'aa'), (2, 'bb'), (3, 'cc')") + sql(s"INSERT INTO testcat.ns.$t VALUES (1, 'aa'), (2, 'bb'), (3, 'cc')") } val df = sql( """ - |SELECT * FROM testcat.ns.t1 + |SELECT t1.id, t2.id, t3.name FROM testcat.ns.t1 |JOIN testcat.ns.t2 ON t1.id = t2.id |JOIN testcat.ns.t3 ON t1.id = t3.id |""".stripMargin) + checkAnswer(df, Seq(Row(1, 1, "aa"), Row(2, 2, "bb"), Row(3, 3, "cc"))) assert(collectShuffles(df.queryExecution.executedPlan).isEmpty) } }