Skip to content

Commit 571b802

Browse files
mihailoale-dbdongjoon-hyun
authored andcommitted
[SPARK-54075][SQL] Make ResolvedCollation evaluable
### What changes were proposed in this pull request? In this PR I propose to make `ResolvedCollation` evaluable. By making `ResolvedCollation` evaluable, it can now pass the `canEvaluateWithinJoin` check, allowing the optimizer to use efficient hash joins for queries with inline COLLATE in join conditions (before this change we would fallback to `BroadcastNestedLoopJoin`). ### Why are the changes needed? To improve performance of specific queries (see added tests). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52779 from mihailoale-db/makecollationresolvable. Authored-by: mihailoale-db <mihailo.aleksic@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 576f9a5 commit 571b802

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,23 @@ case class UnresolvedCollation(collationName: Seq[String])
125125
/**
126126
* An expression that represents a resolved collation name.
127127
*/
128-
case class ResolvedCollation(collationName: String) extends LeafExpression with Unevaluable {
128+
case class ResolvedCollation(collationName: String) extends LeafExpression {
129129
override def nullable: Boolean = false
130130

131131
override def dataType: DataType = StringType(CollationFactory.collationNameToId(collationName))
132132

133133
override def toString: String = collationName
134134

135135
override def sql: String = collationName
136+
137+
override def eval(input: InternalRow): Any = Literal.create(collationName, dataType).eval(input)
138+
139+
/** Just a simple passthrough for code generation. */
140+
override def genCode(ctx: CodegenContext): ExprCode =
141+
Literal.create(collationName, dataType).genCode(ctx)
142+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
143+
throw SparkException.internalError("ResolvedCollation.doGenCode should not be called.")
144+
}
136145
}
137146

138147
// scalastyle:off line.contains.tab

sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,34 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
18651865
}
18661866
}
18671867

1868+
test("inline COLLATE expressions in join conditions should not use nested loop join") {
1869+
withTable("table1", "table2", "table3") {
1870+
sql("CREATE TABLE table1 (id STRING, col1 STRING) USING PARQUET")
1871+
sql("INSERT INTO table1 VALUES ('1', 'a'), ('2', 'b')")
1872+
1873+
sql("CREATE TABLE table2 (id STRING, col1 STRING) USING PARQUET")
1874+
sql("INSERT INTO table2 VALUES ('1', 'a'), ('2', 'b')")
1875+
1876+
sql("CREATE TABLE table3 (col1 STRING COLLATE UTF8_LCASE_RTRIM) USING PARQUET")
1877+
sql("INSERT INTO table3 VALUES ('a'), ('b')")
1878+
1879+
val df = sql(
1880+
"""SELECT t1.col1 COLLATE UTF8_LCASE_RTRIM AS result
1881+
|FROM table1 t1
1882+
|INNER JOIN table2 t2 ON t2.id = t1.id
1883+
|INNER JOIN table3 t3 ON t3.col1 = t1.col1 COLLATE UTF8_LCASE_RTRIM
1884+
|""".stripMargin
1885+
)
1886+
1887+
checkAnswer(df, Seq(Row("a"), Row("b")))
1888+
1889+
val queryPlan = df.queryExecution.executedPlan
1890+
assert(collectFirst(queryPlan) {
1891+
case _: BroadcastNestedLoopJoinExec => ()
1892+
}.isEmpty)
1893+
}
1894+
}
1895+
18681896
test("hll sketch aggregate should respect collation") {
18691897
case class HllSketchAggTestCase[R](c: String, result: R)
18701898
val testCases = Seq(

0 commit comments

Comments
 (0)