Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Add tests for Scalar and Inverval values for UnaryMinus #538

Merged
merged 10 commits into from
Jun 11, 2024
105 changes: 46 additions & 59 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet

import java.time.{Duration, Period}

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
Expand Down Expand Up @@ -1562,7 +1564,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(dtype + " overflow"))
assert(cometException.getMessage.contains(dtype + " overflow"))
case (None, None) => assert(true) // got same outputs
case (None, None) => checkSparkAnswerAndOperator(sql(query))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we call checkSparkMaybeThrows,, and if that doesn't fail, then we call checkSparkAnswerAndOperator. I think this could all be streamlined, but this is beyond the scope of this PR

case (None, Some(ex)) =>
fail("Comet threw an exception but Spark did not " + ex.getMessage)
case (Some(_), None) =>
Expand All @@ -1583,66 +1585,51 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {

withTempDir { dir =>
// Array values test
val arrayPath = new Path(dir.toURI.toString, "array_test.parquet").toString
Seq(Int.MaxValue, Int.MinValue).toDF("a").write.mode("overwrite").parquet(arrayPath)
val arrayQuery = "select a, -a from t"
runArrayTest(arrayQuery, "integer", arrayPath)

// long values test
val longArrayPath = new Path(dir.toURI.toString, "long_array_test.parquet").toString
Seq(Long.MaxValue, Long.MinValue)
.toDF("a")
.write
.mode("overwrite")
.parquet(longArrayPath)
val longArrayQuery = "select a, -a from t"
runArrayTest(longArrayQuery, "long", longArrayPath)

// short values test
val shortArrayPath = new Path(dir.toURI.toString, "short_array_test.parquet").toString
Seq(Short.MaxValue, Short.MinValue)
.toDF("a")
.write
.mode("overwrite")
.parquet(shortArrayPath)
val shortArrayQuery = "select a, -a from t"
runArrayTest(shortArrayQuery, "", shortArrayPath)

// byte values test
val byteArrayPath = new Path(dir.toURI.toString, "byte_array_test.parquet").toString
Seq(Byte.MaxValue, Byte.MinValue)
.toDF("a")
.write
.mode("overwrite")
.parquet(byteArrayPath)
val byteArrayQuery = "select a, -a from t"
runArrayTest(byteArrayQuery, "", byteArrayPath)

// interval values test
withTable("t_interval") {
spark.sql("CREATE TABLE t_interval(a STRING) USING PARQUET")
spark.sql("INSERT INTO t_interval VALUES ('INTERVAL 10000000000 YEAR')")
withAnsiMode(enabled = true) {
spark
.sql("SELECT CAST(a AS INTERVAL) AS a FROM t_interval")
.createOrReplaceTempView("t_interval_casted")
checkOverflow("SELECT a, -a FROM t_interval_casted", "interval")
}
}

withTable("t") {
sql("create table t(a int) using parquet")
sql("insert into t values (-2147483648)")
withAnsiMode(enabled = true) {
checkOverflow("select a, -a from t", "integer")
}
val dataTypes = Seq(
("array_test.parquet", Seq(Int.MaxValue, Int.MinValue).toDF("a"), "integer"),
("long_array_test.parquet", Seq(Long.MaxValue, Long.MinValue).toDF("a"), "long"),
("short_array_test.parquet", Seq(Short.MaxValue, Short.MinValue).toDF("a"), ""),
("byte_array_test.parquet", Seq(Byte.MaxValue, Byte.MinValue).toDF("a"), ""))

dataTypes.foreach { case (fileName, df, dtype) =>
val path = new Path(dir.toURI.toString, fileName).toString
df.write.mode("overwrite").parquet(path)
val query = "select a, -a from t"
runArrayTest(query, dtype, path)
}

withTable("t_float") {
sql("create table t_float(a float) using parquet")
sql("insert into t_float values (3.4128235E38)")
withAnsiMode(enabled = true) {
checkOverflow("select a, -a from t_float", "float")
// scalar tests
withParquetTable((0 until 5).map(i => (i % 5, i % 3)), "tbl") {
withSQLConf(
"spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding",
SQLConf.ANSI_ENABLED.key -> "true",
CometConf.COMET_ANSI_MODE_ENABLED.key -> "true",
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {
for (n <- Seq("2147483647", "-2147483648")) {
vaibhawvipul marked this conversation as resolved.
Show resolved Hide resolved
checkOverflow(s"select -(cast(${n} as int)) FROM tbl", "integer")
vaibhawvipul marked this conversation as resolved.
Show resolved Hide resolved
}
for (n <- Seq("32767", "-32768")) {
checkOverflow(s"select -(cast(${n} as short)) FROM tbl", "")
}
for (n <- Seq("127", "-128")) {
checkOverflow(s"select -(cast(${n} as byte)) FROM tbl", "")
}
for (n <- Seq("9223372036854775807", "-9223372036854775808")) {
checkOverflow(s"select -(cast(${n} as long)) FROM tbl", "long")
}
for (n <- Seq("3.4028235E38", "-3.4028235E38")) {
checkOverflow(s"select -(cast(${n} as float)) FROM tbl", "float")
}
// interval test without cast
val longDf = Seq(Long.MaxValue, Long.MaxValue, 2)
val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
.map(Period.ofMonths)
val dayTimeDf = Seq(106751991L, 106751991L, 2L)
.map(Duration.ofDays)
Seq(longDf, yearMonthDf, dayTimeDf).foreach { df =>
checkOverflow(s"select -(_1) FROM tbl", "")
vaibhawvipul marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
Expand Down
Loading