Skip to content

Commit 00922cf

Browse files
authored
fix: Fallback to Spark for lpad/rpad for unsupported arguments & fix negative length handling (#2630)
1 parent 62a68ac commit 00922cf

File tree

6 files changed

+174
-87
lines changed

6 files changed

+174
-87
lines changed

native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,21 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
204204
);
205205

206206
for (string, length) in string_array.iter().zip(int_pad_array) {
207+
let length = length.unwrap();
207208
match string {
208-
Some(string) => builder.append_value(add_padding_string(
209-
string.parse().unwrap(),
210-
length.unwrap() as usize,
211-
truncate,
212-
pad_string,
213-
is_left_pad,
214-
)?),
209+
Some(string) => {
210+
if length >= 0 {
211+
builder.append_value(add_padding_string(
212+
string.parse().unwrap(),
213+
length as usize,
214+
truncate,
215+
pad_string,
216+
is_left_pad,
217+
)?)
218+
} else {
219+
builder.append_value("");
220+
}
221+
}
215222
_ => builder.append_null(),
216223
}
217224
}

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ object CometRLike extends CometExpressionSerde[RLike] {
162162

163163
object CometStringRPad extends CometExpressionSerde[StringRPad] {
164164

165+
override def getSupportLevel(expr: StringRPad): SupportLevel = {
166+
if (expr.str.isInstanceOf[Literal]) {
167+
return Unsupported(Some("Scalar values are not supported for the str argument"))
168+
}
169+
if (!expr.pad.isInstanceOf[Literal]) {
170+
return Unsupported(Some("Only scalar values are supported for the pad argument"))
171+
}
172+
Compatible()
173+
}
174+
165175
override def convert(
166176
expr: StringRPad,
167177
inputs: Seq[Attribute],
@@ -177,21 +187,16 @@ object CometStringRPad extends CometExpressionSerde[StringRPad] {
177187

178188
object CometStringLPad extends CometExpressionSerde[StringLPad] {
179189

180-
/**
181-
* Convert a Spark expression into a protocol buffer representation that can be passed into
182-
* native code.
183-
*
184-
* @param expr
185-
* The Spark expression.
186-
* @param inputs
187-
* The input attributes.
188-
* @param binding
189-
* Whether the attributes are bound (this is only relevant in aggregate expressions).
190-
* @return
191-
* Protocol buffer representation, or None if the expression could not be converted. In this
192-
* case it is expected that the input expression will have been tagged with reasons why it
193-
* could not be converted.
194-
*/
190+
override def getSupportLevel(expr: StringLPad): SupportLevel = {
191+
if (expr.str.isInstanceOf[Literal]) {
192+
return Unsupported(Some("Scalar values are not supported for the str argument"))
193+
}
194+
if (!expr.pad.isInstanceOf[Literal]) {
195+
return Unsupported(Some("Only scalar values are supported for the pad argument"))
196+
}
197+
Compatible()
198+
}
199+
195200
override def convert(
196201
expr: StringLPad,
197202
inputs: Seq[Attribute],

spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,13 @@ object FuzzDataGenerator {
194194
case 1 => r.nextInt().toByte.toString
195195
case 2 => r.nextLong().toString
196196
case 3 => r.nextDouble().toString
197-
case 4 => RandomStringUtils.randomAlphabetic(8)
197+
case 4 => RandomStringUtils.randomAlphabetic(options.maxStringLength)
198198
case 5 =>
199199
// use a constant value to trigger dictionary encoding
200200
"dict_encode_me!"
201-
case _ => r.nextString(8)
201+
case 6 if options.customStrings.nonEmpty =>
202+
randomChoice(options.customStrings, r)
203+
case _ => r.nextString(options.maxStringLength)
202204
}
203205
})
204206
case DataTypes.BinaryType =>
@@ -221,6 +223,11 @@ object FuzzDataGenerator {
221223
case _ => throw new IllegalStateException(s"Cannot generate data for $dataType yet")
222224
}
223225
}
226+
227+
private def randomChoice[T](list: Seq[T], r: Random): T = {
228+
list(r.nextInt(list.length))
229+
}
230+
224231
}
225232

226233
object SchemaGenOptions {
@@ -250,4 +257,6 @@ case class SchemaGenOptions(
250257
case class DataGenOptions(
251258
allowNull: Boolean = true,
252259
generateNegativeZero: Boolean = true,
253-
baseDate: Long = FuzzDataGenerator.defaultBaseDate)
260+
baseDate: Long = FuzzDataGenerator.defaultBaseDate,
261+
customStrings: Seq[String] = Seq.empty,
262+
maxStringLength: Int = 8)

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -414,41 +414,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
414414
}
415415
}
416416
}
417-
test("Verify rpad expr support for second arg instead of just literal") {
418-
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
419-
withParquetTable(data, "t1") {
420-
val res = sql("select rpad(_1,_2) , rpad(_1,2) from t1 order by _1")
421-
checkSparkAnswerAndOperator(res)
422-
}
423-
}
424-
425-
test("RPAD with character support other than default space") {
426-
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
427-
withParquetTable(data, "t1") {
428-
val res = sql(
429-
""" select rpad(_1,_2,'?'), rpad(_1,_2,'??') , rpad(_1,2, '??'), hex(rpad(unhex('aabb'), 5)),
430-
rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
431-
checkSparkAnswerAndOperator(res)
432-
}
433-
}
434-
435-
test("test lpad expression support") {
436-
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
437-
withParquetTable(data, "t1") {
438-
val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1")
439-
checkSparkAnswerAndOperator(res)
440-
}
441-
}
442-
443-
test("LPAD with character support other than default space") {
444-
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
445-
withParquetTable(data, "t1") {
446-
val res = sql(
447-
""" select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'), hex(lpad(unhex('aabb'), 5)),
448-
rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
449-
checkSparkAnswerAndOperator(res)
450-
}
451-
}
452417

453418
test("dictionary arithmetic") {
454419
// TODO: test ANSI mode
@@ -2292,33 +2257,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
22922257
}
22932258
}
22942259

2295-
test("rpad") {
2296-
val table = "rpad"
2297-
val gen = new DataGenerator(new Random(42))
2298-
withTable(table) {
2299-
// generate some data
2300-
val dataChars = "abc123"
2301-
sql(s"create table $table(id int, name1 char(8), name2 varchar(8)) using parquet")
2302-
val testData = gen.generateStrings(100, dataChars, 6) ++ Seq(
2303-
"", // unicode 'e\\u{301}'
2304-
"é" // unicode '\\u{e9}'
2305-
)
2306-
testData.zipWithIndex.foreach { x =>
2307-
sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}')")
2308-
}
2309-
// test 2-arg version
2310-
checkSparkAnswerAndOperator(
2311-
s"SELECT id, rpad(name1, 10), rpad(name2, 10) FROM $table ORDER BY id")
2312-
// test 3-arg version
2313-
for (length <- Seq(2, 10)) {
2314-
checkSparkAnswerAndOperator(
2315-
s"SELECT id, name1, rpad(name1, $length, ' ') FROM $table ORDER BY id")
2316-
checkSparkAnswerAndOperator(
2317-
s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY id")
2318-
}
2319-
}
2320-
}
2321-
23222260
test("isnan") {
23232261
Seq("true", "false").foreach { dictionary =>
23242262
withSQLConf("parquet.enable.dictionary" -> dictionary) {

spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,133 @@
1919

2020
package org.apache.comet
2121

22+
import scala.util.Random
23+
2224
import org.apache.parquet.hadoop.ParquetOutputFormat
2325
import org.apache.spark.sql.{CometTestBase, DataFrame}
2426
import org.apache.spark.sql.internal.SQLConf
27+
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
28+
29+
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
2530

2631
class CometStringExpressionSuite extends CometTestBase {
2732

33+
test("lpad string") {
34+
testStringPadding("lpad")
35+
}
36+
37+
test("rpad string") {
38+
testStringPadding("rpad")
39+
}
40+
41+
test("lpad binary") {
42+
testBinaryPadding("lpad")
43+
}
44+
45+
test("rpad binary") {
46+
testBinaryPadding("rpad")
47+
}
48+
49+
private def testStringPadding(expr: String): Unit = {
50+
val r = new Random(42)
51+
val schema = StructType(
52+
Seq(
53+
StructField("str", DataTypes.StringType, nullable = true),
54+
StructField("len", DataTypes.IntegerType, nullable = true),
55+
StructField("pad", DataTypes.StringType, nullable = true)))
56+
// scalastyle:off
57+
val edgeCases = Seq(
58+
"", // unicode 'e\\u{301}'
59+
"é", // unicode '\\u{e9}'
60+
"తెలుగు")
61+
// scalastyle:on
62+
val df = FuzzDataGenerator.generateDataFrame(
63+
r,
64+
spark,
65+
schema,
66+
1000,
67+
DataGenOptions(maxStringLength = 6, customStrings = edgeCases))
68+
df.createOrReplaceTempView("t1")
69+
70+
// test all combinations of scalar and array arguments
71+
for (str <- Seq("'hello'", "str")) {
72+
for (len <- Seq("6", "-6", "0", "len % 10")) {
73+
for (pad <- Seq(Some("'x'"), Some("'zzz'"), Some("pad"), None)) {
74+
val sql = pad match {
75+
case Some(p) =>
76+
// 3 args
77+
s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str, len, pad"
78+
case _ =>
79+
// 2 args (default pad of ' ')
80+
s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str, len, pad"
81+
}
82+
val isLiteralStr = str == "'hello'"
83+
val isLiteralLen = !len.contains("len")
84+
val isLiteralPad = !pad.contains("pad")
85+
if (isLiteralStr && isLiteralLen && isLiteralPad) {
86+
// all arguments are literal, so Spark constant folding will kick in
87+
// and pad function will not be evaluated by Comet
88+
checkSparkAnswer(sql)
89+
} else if (isLiteralStr) {
90+
checkSparkAnswerAndFallbackReason(
91+
sql,
92+
"Scalar values are not supported for the str argument")
93+
} else if (!isLiteralPad) {
94+
checkSparkAnswerAndFallbackReason(
95+
sql,
96+
"Only scalar values are supported for the pad argument")
97+
} else {
98+
checkSparkAnswerAndOperator(sql)
99+
}
100+
}
101+
}
102+
}
103+
}
104+
105+
private def testBinaryPadding(expr: String): Unit = {
106+
val r = new Random(42)
107+
val schema = StructType(
108+
Seq(
109+
StructField("str", DataTypes.BinaryType, nullable = true),
110+
StructField("len", DataTypes.IntegerType, nullable = true),
111+
StructField("pad", DataTypes.BinaryType, nullable = true)))
112+
val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions())
113+
df.createOrReplaceTempView("t1")
114+
115+
// test all combinations of scalar and array arguments
116+
for (str <- Seq("unhex('DDEEFF')", "str")) {
117+
// Spark does not support negative length for lpad/rpad with binary input and Comet does
118+
// not support abs yet, so use `10 + len % 10` to avoid negative length
119+
for (len <- Seq("6", "0", "10 + len % 10")) {
120+
for (pad <- Seq(Some("unhex('CAFE')"), Some("pad"), None)) {
121+
122+
val sql = pad match {
123+
case Some(p) =>
124+
// 3 args
125+
s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str, len, pad"
126+
case _ =>
127+
// 2 args (default pad of ' ')
128+
s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str, len, pad"
129+
}
130+
131+
val isLiteralStr = str != "str"
132+
val isLiteralLen = !len.contains("len")
133+
val isLiteralPad = !pad.contains("pad")
134+
135+
if (isLiteralStr && isLiteralLen && isLiteralPad) {
136+
// all arguments are literal, so Spark constant folding will kick in
137+
// and pad function will not be evaluated by Comet
138+
checkSparkAnswer(sql)
139+
} else {
140+
// Comet will fall back to Spark because the plan contains a staticinvoke instruction
141+
// which is not supported
142+
checkSparkAnswerAndFallbackReason(sql, "staticinvoke is not supported")
143+
}
144+
}
145+
}
146+
}
147+
}
148+
28149
test("Various String scalar functions") {
29150
val table = "names"
30151
withTable(table) {

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ abstract class CometTestBase
166166
(sparkPlan, dfComet.queryExecution.executedPlan)
167167
}
168168

169+
/** Check for the correct results as well as the expected fallback reason */
170+
def checkSparkAnswerAndFallbackReason(sql: String, fallbackReason: String): Unit = {
171+
val (_, cometPlan) = checkSparkAnswer(sql)
172+
val explain = new ExtendedExplainInfo().generateVerboseExtendedInfo(cometPlan)
173+
assert(explain.contains(fallbackReason))
174+
}
175+
169176
protected def checkSparkAnswerAndOperator(query: String, excludedClasses: Class[_]*): Unit = {
170177
checkSparkAnswerAndOperator(sql(query), excludedClasses: _*)
171178
}

0 commit comments

Comments
 (0)