Skip to content

Commit 3036847

Browse files
committed
address comments.
1 parent 1abdbb9 commit 3036847

File tree

4 files changed

+52
-11
lines changed

4 files changed

+52
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class Analyzer(
109109
TimeWindowing ::
110110
TypeCoercion.typeCoercionRules ++
111111
extendedResolutionRules : _*),
112+
Batch("LIMIT", Once,
113+
ResolveLimits),
112114
Batch("Nondeterministic", Once,
113115
PullOutNondeterministic),
114116
Batch("UDF", Once,
@@ -2044,6 +2046,21 @@ object EliminateUnions extends Rule[LogicalPlan] {
20442046
}
20452047
}
20462048

2049+
/**
2050+
* Converts foldable numeric expressions to integers of [[GlobalLimit]] and [[LocalLimit]] operators
2051+
*/
2052+
object ResolveLimits extends Rule[LogicalPlan] {
2053+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
2054+
case g @ GlobalLimit(limitExpr, _) if limitExpr.foldable && isNumeric(limitExpr.eval()) =>
2055+
g.copy(limitExpr = Literal(limitExpr.eval().asInstanceOf[Number].intValue(), IntegerType))
2056+
case l @ LocalLimit(limitExpr, _) if limitExpr.foldable && isNumeric(limitExpr.eval()) =>
2057+
l.copy(limitExpr = Literal(limitExpr.eval().asInstanceOf[Number].intValue(), IntegerType))
2058+
}
2059+
2060+
private def isNumeric(value: Any): Boolean =
2061+
scala.util.Try(value.asInstanceOf[Number].intValue()).isSuccess
2062+
}
2063+
20472064
/**
20482065
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
20492066
* expression in Project(project list) or Aggregate(aggregate expressions) or

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,13 @@ trait CheckAnalysis extends PredicateHelper {
5050
if (!limitExpr.foldable) {
5151
failAnalysis(
5252
"The argument to the LIMIT clause must evaluate to a constant value. " +
53-
s"Limit:${limitExpr.sql}")
53+
s"Limit:${limitExpr.sql}")
5454
}
55-
limitExpr.eval() match {
56-
case o: Int if o >= 0 => // OK
57-
case o: Int => failAnalysis(
58-
s"number_rows in limit clause must be equal to or greater than 0. number_rows:$o")
59-
case o => failAnalysis(
60-
s"number_rows in limit clause cannot be cast to integer:$o")
55+
limitExpr match {
56+
case IntegerLiteral(limit) if limit >= 0 => // OK
57+
case IntegerLiteral(limit) => failAnalysis(
58+
s"number_rows in limit clause must be equal to or greater than 0. number_rows:$limit")
59+
case o => failAnalysis(s"""number_rows in limit clause cannot be cast to integer:"$o".""")
6160
}
6261
}
6362

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
659659
}
660660
}
661661
override lazy val statistics: Statistics = {
662-
val limit = limitExpr.eval().asInstanceOf[Int]
662+
val limit = limitExpr.eval().asInstanceOf[Number].intValue()
663663
val sizeInBytes = if (limit == 0) {
664664
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
665665
// (product of children).
@@ -680,7 +680,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
680680
}
681681
}
682682
override lazy val statistics: Statistics = {
683-
val limit = limitExpr.eval().asInstanceOf[Int]
683+
val limit = limitExpr.eval().asInstanceOf[Number].intValue()
684684
val sizeInBytes = if (limit == 0) {
685685
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
686686
// (product of children).

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
670670
checkAnswer(
671671
sql("SELECT * FROM mapData LIMIT 1"),
672672
mapData.collect().take(1).map(Row.fromTuple).toSeq)
673+
674+
checkAnswer(
675+
sql("SELECT * FROM mapData LIMIT CAST(1 AS Double)"),
676+
mapData.collect().take(1).map(Row.fromTuple).toSeq)
677+
678+
checkAnswer(
679+
sql("SELECT * FROM mapData LIMIT CAST(1 AS BYTE)"),
680+
mapData.collect().take(1).map(Row.fromTuple).toSeq)
681+
682+
checkAnswer(
683+
sql("SELECT * FROM mapData LIMIT CAST(1 AS LONG)"),
684+
mapData.collect().take(1).map(Row.fromTuple).toSeq)
685+
686+
checkAnswer(
687+
sql("SELECT * FROM mapData LIMIT CAST(1 AS SHORT)"),
688+
mapData.collect().take(1).map(Row.fromTuple).toSeq)
689+
690+
checkAnswer(
691+
sql("SELECT * FROM mapData LIMIT CAST(1 AS FLOAT)"),
692+
mapData.collect().take(1).map(Row.fromTuple).toSeq)
673693
}
674694

675695
test("non-foldable expressions in LIMIT") {
@@ -681,10 +701,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
681701
}
682702

683703
test("Limit: unable to evaluate and cast expressions in limit clauses to Int") {
684-
val e = intercept[AnalysisException] {
704+
var e = intercept[AnalysisException] {
685705
sql("SELECT * FROM testData LIMIT true")
686706
}.getMessage
687-
assert(e.contains("number_rows in limit clause cannot be cast to integer:true"))
707+
assert(e.contains("number_rows in limit clause cannot be cast to integer:\"true\""))
708+
709+
e = intercept[AnalysisException] {
710+
sql("SELECT * FROM testData LIMIT 'a'")
711+
}.getMessage
712+
assert(e.contains("number_rows in limit clause cannot be cast to integer:\"a\""))
688713
}
689714

690715
test("negative in LIMIT or TABLESAMPLE") {

0 commit comments

Comments
 (0)