diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e51b35c0accc..5e1c3f46fd11 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2012,8 +2012,20 @@ }, "INTERVAL_ARITHMETIC_OVERFLOW" : { "message" : [ - "." + "Integer overflow while operating with intervals." ], + "subClass" : { + "WITHOUT_SUGGESTION" : { + "message" : [ + "Try devising appropriate values for the interval parameters." + ] + }, + "WITH_SUGGESTION" : { + "message" : [ + "Use to tolerate overflow and return NULL instead." + ] + } + }, "sqlState" : "22015" }, "INTERVAL_DIVIDED_BY_ZERO" : { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 1ce7dfd39acc..a7b67f55d8cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -481,7 +481,7 @@ case class MakeDTInterval( hours: Expression, mins: Expression, secs: Expression) - extends QuaternaryExpression with ImplicitCastInputTypes { + extends QuaternaryExpression with ImplicitCastInputTypes with SupportQueryContext { override def nullIntolerant: Boolean = true def this( @@ -514,13 +514,15 @@ case class MakeDTInterval( day.asInstanceOf[Int], hour.asInstanceOf[Int], min.asInstanceOf[Int], - sec.asInstanceOf[Decimal]) + sec.asInstanceOf[Decimal], + origin.context) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (day, hour, min, sec) => { + val errorContext = getContextOrNullCode(ctx) val iu = IntervalUtils.getClass.getName.stripSuffix("$") - s"$iu.makeDayTimeInterval($day, $hour, $min, $sec)" + s"$iu.makeDayTimeInterval($day, $hour, $min, $sec, $errorContext)" }) } @@ -532,6 +534,8 @@ case class MakeDTInterval( mins: Expression, secs: Expression): MakeDTInterval = copy(days, hours, mins, secs) + + override def initQueryContext(): Option[QueryContext] = Some(origin.context) } @ExpressionDescription( @@ -556,7 +560,7 @@ case class MakeDTInterval( group = "datetime_funcs") // scalastyle:on line.size.limit case class MakeYMInterval(years: Expression, months: Expression) - extends BinaryExpression with ImplicitCastInputTypes with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with Serializable with SupportQueryContext { override def nullIntolerant: Boolean = true def this(years: Expression) = this(years, Literal(0)) @@ -568,17 +572,28 @@ case class MakeYMInterval(years: Expression, months: Expression) override def dataType: DataType = YearMonthIntervalType() override def nullSafeEval(year: Any, month: Any): Any = { - Math.toIntExact(Math.addExact(month.asInstanceOf[Number].longValue(), - Math.multiplyExact(year.asInstanceOf[Number].longValue(), MONTHS_PER_YEAR))) + try { + Math.toIntExact( + Math.addExact(month.asInstanceOf[Int], + Math.multiplyExact(year.asInstanceOf[Int], MONTHS_PER_YEAR))) + } catch { + case _: ArithmeticException => + throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(origin.context) + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (years, months) => { + nullSafeCodeGen(ctx, ev, (years, months) => { val math = classOf[Math].getName.stripSuffix("$") + val errorContext = getContextOrNullCode(ctx) + // scalastyle:off line.size.limit s""" - |$math.toIntExact(java.lang.Math.addExact($months, - | $math.multiplyExact($years, $MONTHS_PER_YEAR))) - |""".stripMargin + |try { + | ${ev.value} = $math.toIntExact($math.addExact($months, $math.multiplyExact($years, $MONTHS_PER_YEAR))); + |} catch (java.lang.ArithmeticException e) { + | throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError($errorContext); + |}""".stripMargin + // scalastyle:on line.size.limit }) } @@ -587,6 +602,10 @@ case class MakeYMInterval(years: Expression, months: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = copy(years = newLeft, months = newRight) + + override def initQueryContext(): Option[QueryContext] = { + Some(origin.context) + } } // Multiply an year-month interval by a numeric @@ -699,8 +718,8 @@ trait IntervalDivide { context: QueryContext): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { - throw QueryExecutionErrors.intervalArithmeticOverflowError( - "Interval value overflows after being divided by -1", "try_divide", context) + throw QueryExecutionErrors.withSuggestionIntervalArithmeticOverflowError( + "try_divide", context) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalMathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalMathUtils.scala index c935c6057376..756f2598f13f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalMathUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalMathUtils.scala @@ -35,12 +35,15 @@ object IntervalMathUtils { def negateExact(a: Long): Long = withOverflow(Math.negateExact(a)) - private def withOverflow[A](f: => A, hint: String = ""): A = { + private def withOverflow[A](f: => A, suggestedFunc: String = ""): A = { try { f } catch { - case e: ArithmeticException => - throw QueryExecutionErrors.intervalArithmeticOverflowError(e.getMessage, hint, null) + case _: ArithmeticException if suggestedFunc.isEmpty => + throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(context = null) + case _: ArithmeticException => + throw QueryExecutionErrors.withSuggestionIntervalArithmeticOverflowError( + suggestedFunc, context = null) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 90c802b7e28d..39a07990dea3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import scala.util.control.NonFatal -import org.apache.spark.{SparkIllegalArgumentException, SparkThrowable} +import org.apache.spark.{QueryContext, SparkIllegalArgumentException, SparkThrowable} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -782,13 +782,19 @@ object IntervalUtils extends SparkIntervalUtils { days: Int, hours: Int, mins: Int, - secs: Decimal): Long = { + secs: Decimal, + context: QueryContext): Long = { assert(secs.scale == 6, "Seconds fractional must have 6 digits for microseconds") var micros = secs.toUnscaledLong - micros = Math.addExact(micros, Math.multiplyExact(days, MICROS_PER_DAY)) - micros = Math.addExact(micros, Math.multiplyExact(hours, MICROS_PER_HOUR)) - micros = Math.addExact(micros, Math.multiplyExact(mins, MICROS_PER_MINUTE)) - micros + try { + micros = Math.addExact(micros, Math.multiplyExact(days, MICROS_PER_DAY)) + micros = Math.addExact(micros, Math.multiplyExact(hours, MICROS_PER_HOUR)) + micros = Math.addExact(micros, Math.multiplyExact(mins, MICROS_PER_MINUTE)) + micros + } catch { + case _: ArithmeticException => + throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(context) + } } def intToYearMonthInterval(v: Int, startField: Byte, endField: Byte): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 09836995925e..fb39d3c5d7c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -636,18 +636,21 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE summary = "") } - def intervalArithmeticOverflowError( - message: String, - hint: String = "", + def withSuggestionIntervalArithmeticOverflowError( + suggestedFunc: String, context: QueryContext): ArithmeticException = { - val alternative = if (hint.nonEmpty) { - s" Use '$hint' to tolerate overflow and return NULL instead." - } else "" new SparkArithmeticException( - errorClass = "INTERVAL_ARITHMETIC_OVERFLOW", - messageParameters = Map( - "message" -> message, - "alternative" -> alternative), + errorClass = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + messageParameters = Map("functionName" -> toSQLId(suggestedFunc)), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def withoutSuggestionIntervalArithmeticOverflowError( + context: QueryContext): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + messageParameters = Map(), context = getQueryContext(context), summary = getSummary(context)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 78bc77b9dc2a..8fb72ad53062 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -316,7 +316,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val secFrac = DateTimeTestUtils.secFrac(seconds, millis, micros) val durationExpr = MakeDTInterval(Literal(days), Literal(hours), Literal(minutes), Literal(Decimal(secFrac, Decimal.MAX_LONG_DIGITS, 6))) - checkExceptionInExpression[ArithmeticException](durationExpr, EmptyRow, "") + checkExceptionInExpression[ArithmeticException]( + durationExpr, "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION") } check(millis = -123) @@ -528,7 +529,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Seq(MakeYMInterval(Literal(178956970), Literal(8)), MakeYMInterval(Literal(-178956970), Literal(-9))) .foreach { ym => - checkExceptionInExpression[ArithmeticException](ym, "integer overflow") + checkExceptionInExpression[ArithmeticException]( + ym, "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION") } def checkImplicitEvaluation(expr: Expression, value: Any): Unit = { diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 766bfba7696f..4e012df792de 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -936,8 +936,18 @@ select make_dt_interval(2147483647) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -long overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 35, + "fragment" : "make_dt_interval(2147483647)" + } ] +} -- !query @@ -977,8 +987,18 @@ select make_ym_interval(178956970, 8) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -integer overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 37, + "fragment" : "make_ym_interval(178956970, 8)" + } ] +} -- !query @@ -994,8 +1014,18 @@ select make_ym_interval(-178956970, -9) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -integer overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 39, + "fragment" : "make_ym_interval(-178956970, -9)" + } ] +} -- !query @@ -2493,12 +2523,8 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", - "sqlState" : "22015", - "messageParameters" : { - "alternative" : "", - "message" : "integer overflow" - } + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015" } @@ -2509,11 +2535,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_subtract' to tolerate overflow and return NULL instead.", - "message" : "integer overflow" + "functionName" : "`try_subtract`" } } @@ -2525,11 +2550,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", - "message" : "integer overflow" + "functionName" : "`try_add`" } } @@ -2838,11 +2862,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2861,11 +2884,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2918,11 +2940,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2941,11 +2962,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 7eed2d42da04..a8a0423bdb3e 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -823,8 +823,18 @@ select make_dt_interval(2147483647) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -long overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 35, + "fragment" : "make_dt_interval(2147483647)" + } ] +} -- !query @@ -864,8 +874,18 @@ select make_ym_interval(178956970, 8) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -integer overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 37, + "fragment" : "make_ym_interval(178956970, 8)" + } ] +} -- !query @@ -881,8 +901,18 @@ select make_ym_interval(-178956970, -9) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -integer overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 39, + "fragment" : "make_ym_interval(-178956970, -9)" + } ] +} -- !query @@ -2316,12 +2346,8 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", - "sqlState" : "22015", - "messageParameters" : { - "alternative" : "", - "message" : "integer overflow" - } + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015" } @@ -2332,11 +2358,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_subtract' to tolerate overflow and return NULL instead.", - "message" : "integer overflow" + "functionName" : "`try_subtract`" } } @@ -2348,11 +2373,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", - "message" : "integer overflow" + "functionName" : "`try_add`" } } @@ -2661,11 +2685,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2684,11 +2707,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2741,11 +2763,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2764,11 +2785,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7ebcb280def6..6348e5f31539 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1485,15 +1486,22 @@ class DataFrameAggregateSuite extends QueryTest val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), (Period.ofMonths(10), Duration.ofDays(10))) .toDF("year-month", "day") - val error = intercept[SparkArithmeticException] { - checkAnswer(df2.select(sum($"year-month")), Nil) - } - assert(error.getMessage contains "[INTERVAL_ARITHMETIC_OVERFLOW] integer overflow") - val error2 = intercept[SparkArithmeticException] { - checkAnswer(df2.select(sum($"day")), Nil) - } - assert(error2.getMessage contains "[INTERVAL_ARITHMETIC_OVERFLOW] long overflow") + checkError( + exception = intercept[SparkArithmeticException] { + checkAnswer(df2.select(sum($"year-month")), Nil) + }, + condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + parameters = Map("functionName" -> toSQLId("try_add")) + ) + + checkError( + exception = intercept[SparkArithmeticException] { + checkAnswer(df2.select(sum($"day")), Nil) + }, + condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + parameters = Map("functionName" -> toSQLId("try_add")) + ) } test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") { @@ -1620,15 +1628,22 @@ class DataFrameAggregateSuite extends QueryTest val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), (Period.ofMonths(10), Duration.ofDays(10))) .toDF("year-month", "day") - val error = intercept[SparkArithmeticException] { - checkAnswer(df2.select(avg($"year-month")), Nil) - } - assert(error.getMessage contains "[INTERVAL_ARITHMETIC_OVERFLOW] integer overflow") - val error2 = intercept[SparkArithmeticException] { - checkAnswer(df2.select(avg($"day")), Nil) - } - assert(error2.getMessage contains "[INTERVAL_ARITHMETIC_OVERFLOW] long overflow") + checkError( + exception = intercept[SparkArithmeticException] { + checkAnswer(df2.select(avg($"year-month")), Nil) + }, + condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + parameters = Map("functionName" -> toSQLId("try_add")) + ) + + checkError( + exception = intercept[SparkArithmeticException] { + checkAnswer(df2.select(avg($"day")), Nil) + }, + condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + parameters = Map("functionName" -> toSQLId("try_add")) + ) val df3 = intervalData.filter($"class" > 4) val avgDF3 = df3.select(avg($"year-month"), avg($"day"))