diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 37e3dd5ea89e..01b6e81b3cfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -213,7 +213,7 @@ case class CheckOverflowInSum( } /** - * An add expression for decimal values which is only used internally by Sum/Avg. + * An add expression for decimal values which is only used internally by Sum/Avg/Window. * * Nota that, this expression does not check overflow which is different with `Add`. When * aggregating values, Spark writes the aggregation buffer values to `UnsafeRow` via diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index 0f19f14576b9..44181c79bced 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -128,6 +128,7 @@ trait WindowExecBase extends UnaryExecNode { TimestampAddYMInterval(expr, boundOffset, Some(timeZone)) case (TimestampType | TimestampNTZType, _: DayTimeIntervalType) => TimeAdd(expr, boundOffset, Some(timeZone)) + case (d: DecimalType, _: DecimalType) => DecimalAddNoOverflowCheck(expr, boundOffset, d) case (a, b) if a == b => Add(expr, boundOffset) } val bound = MutableProjection.create(boundExpr :: Nil, child.output) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index ea8cfc7b81a4..48a3d7405597 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, NonFoldableLiteral, RangeFrame, SortOrder, SpecifiedWindowFrame, UnspecifiedFrame} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, NonFoldableLiteral, RangeFrame, SortOrder, SpecifiedWindowFrame, UnaryMinus, UnspecifiedFrame} import org.apache.spark.sql.catalyst.plans.logical.{Window => WindowNode} import org.apache.spark.sql.expressions.{Window, WindowSpec} import org.apache.spark.sql.functions._ @@ -474,4 +474,22 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(3, 1.5) :: Row(3, 1.5) :: Row(6, 2.0) :: Row(6, 2.0) :: Row(6, 2.0) :: Nil) } + + test("SPARK-41793: Incorrect result for window frames defined by a range clause on large " + + "decimals") { + val window = new WindowSpec(Seq($"a".expr), Seq(SortOrder($"b".expr, Ascending)), + SpecifiedWindowFrame(RangeFrame, + UnaryMinus(Literal(BigDecimal(10.2345))), Literal(BigDecimal(6.7890)))) + + val df = Seq( + 1 -> "11342371013783243717493546650944543.47", + 1 -> "999999999999999999999999999999999999.99" + ).toDF("a", "b") + .select($"a", $"b".cast("decimal(38, 2)")) + .select(count("*").over(window)) + + checkAnswer( + df, + Row(1) :: Row(1) :: Nil) + } }