diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index a7b67f55d8cd..653ee9f836ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -24,7 +24,6 @@ import com.google.common.math.{DoubleMath, IntMath, LongMath} import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -572,28 +571,14 @@ case class MakeYMInterval(years: Expression, months: Expression) override def dataType: DataType = YearMonthIntervalType() override def nullSafeEval(year: Any, month: Any): Any = { - try { - Math.toIntExact( - Math.addExact(month.asInstanceOf[Int], - Math.multiplyExact(year.asInstanceOf[Int], MONTHS_PER_YEAR))) - } catch { - case _: ArithmeticException => - throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(origin.context) - } + makeYearMonthInterval(year.asInstanceOf[Int], month.asInstanceOf[Int], origin.context) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (years, months) => { - val math = classOf[Math].getName.stripSuffix("$") + defineCodeGen(ctx, ev, (years, months) => { val errorContext = getContextOrNullCode(ctx) - // scalastyle:off line.size.limit - s""" - |try { - | ${ev.value} = $math.toIntExact($math.addExact($months, $math.multiplyExact($years, $MONTHS_PER_YEAR))); - |} catch (java.lang.ArithmeticException e) { - | throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError($errorContext); - |}""".stripMargin - // scalastyle:on line.size.limit + val iu = IntervalUtils.getClass.getName.stripSuffix("$") + s"$iu.makeYearMonthInterval($years, $months, $errorContext)" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 39a07990dea3..8793c0407a9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -797,6 +797,15 @@ object IntervalUtils extends SparkIntervalUtils { } } + def makeYearMonthInterval(year: Int, month: Int, context: QueryContext): Int = { + try { + Math.toIntExact(Math.addExact(month, Math.multiplyExact(year, MONTHS_PER_YEAR))) + } catch { + case _: ArithmeticException => + throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(context) + } + } + def intToYearMonthInterval(v: Int, startField: Byte, endField: Byte): Int = { endField match { case YEAR =>