Skip to content

Commit

Permalink
test: Add tests for Scalar and Inverval values for UnaryMinus (#538)
Browse files Browse the repository at this point in the history
* adding scalar tests

* refactor and test for interval

* ci checks fixed

* running operator checks when no error

* removing redundant sqlconf

* fix ci errorsg

* moving interval test to array section

* ci fixes
  • Loading branch information
vaibhawvipul authored Jun 11, 2024
1 parent e07f24c commit a4e268c
Showing 1 changed file with 47 additions and 55 deletions.
102 changes: 47 additions & 55 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet

import java.time.{Duration, Period}

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

Expand Down Expand Up @@ -1635,7 +1637,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(dtype + " overflow"))
assert(cometException.getMessage.contains(dtype + " overflow"))
case (None, None) => assert(true) // got same outputs
case (None, None) => checkSparkAnswerAndOperator(sql(query))
case (None, Some(ex)) =>
fail("Comet threw an exception but Spark did not " + ex.getMessage)
case (Some(_), None) =>
Expand All @@ -1656,66 +1658,56 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {

withTempDir { dir =>
// Array values test
val arrayPath = new Path(dir.toURI.toString, "array_test.parquet").toString
Seq(Int.MaxValue, Int.MinValue).toDF("a").write.mode("overwrite").parquet(arrayPath)
val arrayQuery = "select a, -a from t"
runArrayTest(arrayQuery, "integer", arrayPath)

// long values test
val longArrayPath = new Path(dir.toURI.toString, "long_array_test.parquet").toString
Seq(Long.MaxValue, Long.MinValue)
.toDF("a")
.write
.mode("overwrite")
.parquet(longArrayPath)
val longArrayQuery = "select a, -a from t"
runArrayTest(longArrayQuery, "long", longArrayPath)

// short values test
val shortArrayPath = new Path(dir.toURI.toString, "short_array_test.parquet").toString
Seq(Short.MaxValue, Short.MinValue)
.toDF("a")
.write
.mode("overwrite")
.parquet(shortArrayPath)
val shortArrayQuery = "select a, -a from t"
runArrayTest(shortArrayQuery, "", shortArrayPath)

// byte values test
val byteArrayPath = new Path(dir.toURI.toString, "byte_array_test.parquet").toString
Seq(Byte.MaxValue, Byte.MinValue)
.toDF("a")
.write
.mode("overwrite")
.parquet(byteArrayPath)
val byteArrayQuery = "select a, -a from t"
runArrayTest(byteArrayQuery, "", byteArrayPath)

// interval values test
withTable("t_interval") {
spark.sql("CREATE TABLE t_interval(a STRING) USING PARQUET")
spark.sql("INSERT INTO t_interval VALUES ('INTERVAL 10000000000 YEAR')")
withAnsiMode(enabled = true) {
spark
.sql("SELECT CAST(a AS INTERVAL) AS a FROM t_interval")
.createOrReplaceTempView("t_interval_casted")
checkOverflow("SELECT a, -a FROM t_interval_casted", "interval")
}
val dataTypes = Seq(
("array_test.parquet", Seq(Int.MaxValue, Int.MinValue).toDF("a"), "integer"),
("long_array_test.parquet", Seq(Long.MaxValue, Long.MinValue).toDF("a"), "long"),
("short_array_test.parquet", Seq(Short.MaxValue, Short.MinValue).toDF("a"), ""),
("byte_array_test.parquet", Seq(Byte.MaxValue, Byte.MinValue).toDF("a"), ""))

dataTypes.foreach { case (fileName, df, dtype) =>
val path = new Path(dir.toURI.toString, fileName).toString
df.write.mode("overwrite").parquet(path)
val query = "select a, -a from t"
runArrayTest(query, dtype, path)
}

withTable("t") {
sql("create table t(a int) using parquet")
sql("insert into t values (-2147483648)")
withParquetTable((0 until 5).map(i => (i % 5, i % 3)), "tbl") {
withAnsiMode(enabled = true) {
checkOverflow("select a, -a from t", "integer")
// interval test without cast
val longDf = Seq(Long.MaxValue, Long.MaxValue, 2)
val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
.map(Period.ofMonths)
val dayTimeDf = Seq(106751991L, 106751991L, 2L)
.map(Duration.ofDays)
Seq(longDf, yearMonthDf, dayTimeDf).foreach { _ =>
checkOverflow("select -(_1) FROM tbl", "")
}
}
}

withTable("t_float") {
sql("create table t_float(a float) using parquet")
sql("insert into t_float values (3.4128235E38)")
withAnsiMode(enabled = true) {
checkOverflow("select a, -a from t_float", "float")
// scalar tests
withParquetTable((0 until 5).map(i => (i % 5, i % 3)), "tbl") {
withSQLConf(
"spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding",
SQLConf.ANSI_ENABLED.key -> "true",
CometConf.COMET_ANSI_MODE_ENABLED.key -> "true",
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {
for (n <- Seq("2147483647", "-2147483648")) {
checkOverflow(s"select -(cast(${n} as int)) FROM tbl", "integer")
}
for (n <- Seq("32767", "-32768")) {
checkOverflow(s"select -(cast(${n} as short)) FROM tbl", "")
}
for (n <- Seq("127", "-128")) {
checkOverflow(s"select -(cast(${n} as byte)) FROM tbl", "")
}
for (n <- Seq("9223372036854775807", "-9223372036854775808")) {
checkOverflow(s"select -(cast(${n} as long)) FROM tbl", "long")
}
for (n <- Seq("3.4028235E38", "-3.4028235E38")) {
checkOverflow(s"select -(cast(${n} as float)) FROM tbl", "float")
}
}
}
}
Expand Down

0 comments on commit a4e268c

Please sign in to comment.