diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5a8505dc6992f..b0fa4f889cda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -951,10 +951,13 @@ case class KeyGroupedShuffleSpec( } override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map { - case (c, e: TransformExpression) => TransformExpression( - e.function, Seq(c), e.numBucketsOpt) - case (c, _) => c + assert(clustering.size == distribution.clustering.size, + "Required distributions of join legs should be the same size.") + + val newExpressions = partitioning.expressions.zip(keyPositions).map { + case (te: TransformExpression, positionSet) => + te.copy(children = te.children.map(_ => clustering(positionSet.head))) + case (_, positionSet) => clustering(positionSet.head) } KeyGroupedPartitioning(newExpressions, partitioning.numPartitions, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 7798397d96b38..a1b1b8444719a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2771,4 +2771,56 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-54439: KeyGroupedPartitioning and join key size mismatch") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + // `time` and `item_id` in the required `ClusteredDistribution` for `purchases`, but `item` is + // storage partitioned only by `id` + val df = createJoinTestDF(Seq("arrive_time" -> "time", "id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0))) + } + } + + test("SPARK-54439: KeyGroupedPartitioning with transform and join key size mismatch") { + // Do not use `bucket()` in "one side partition" tests as its implementation in + // `InMemoryBaseTable` conflicts with `BucketFunction` + val items_partitions = Array(years("arrive_time")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'bb', 10.0, cast('2021-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2021-02-01' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + // `item_id` and `time` in the required `ClusteredDistribution` for `purchases`, but `item` is + // storage partitioned only by `year(arrive_time)` + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0))) + } + } }