Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.BinaryLike
Expand All @@ -37,7 +38,6 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


abstract class StringRegexExpression extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant with Predicate {

Expand Down Expand Up @@ -594,14 +594,28 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
return defaultCheck
}
if (!pos.foldable) {
return TypeCheckFailure(s"Position expression must be foldable, but got $pos")
return DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "position",
"inputType" -> toSQLType(pos.dataType),
"inputExpr" -> toSQLExpr(pos)
)
)
}

val posEval = pos.eval()
if (posEval == null || posEval.asInstanceOf[Int] > 0) {
TypeCheckSuccess
} else {
TypeCheckFailure(s"Position expression must be positive, but got: $posEval")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "position",
"valueRange" -> s"(0, ${Int.MaxValue}]",
"currentValue" -> toSQLValue(posEval, pos.dataType)
)
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
Expand Down Expand Up @@ -273,18 +275,35 @@ case class Elt(

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size < 2) {
TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments")
DataTypeMismatch(
errorSubClass = "WRONG_NUM_PARAMS",
messageParameters = Map(
"functionName" -> "elt",
"expectedNum" -> "> 1",
"actualNum" -> children.length.toString
)
)
} else {
val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType))
if (indexType != IntegerType) {
return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " +
s"have ${IntegerType.catalogString}, but it's ${indexType.catalogString}")
return DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "1",
"requiredType" -> toSQLType(IntegerType),
"inputSql" -> toSQLExpr(indexExpr),
"inputType" -> toSQLType(indexType)))
}
if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
return TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have ${StringType.catalogString} or " +
s"${BinaryType.catalogString}, but it's " +
inputTypes.map(_.catalogString).mkString("[", ", ", "]"))
return DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "2...",
"requiredType" -> (toSQLType(StringType) + " or " + toSQLType(BinaryType)),
"inputSql" -> inputExprs.map(toSQLExpr(_)).mkString(","),
"inputType" -> inputTypes.map(toSQLType(_)).mkString(",")
)
)
}
TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.optimizer.ConstantFolding
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.{IntegerType, StringType}

/**
* Unit tests for regular expression (regexp) related SQL expressions.
Expand Down Expand Up @@ -531,4 +533,23 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
create_row("abc", ", (", 0),
s"$prefix `regexp_instr` is invalid: , (")
}

test("RegExpReplace: fails analysis if pos is not a constant") {
val s = $"s".string.at(0)
val p = $"p".string.at(1)
val r = $"r".string.at(2)
val posExpr = AttributeReference("b", IntegerType)()
val expr = RegExpReplace(s, p, r, posExpr)

assert(expr.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "position",
"inputType" -> toSQLType(posExpr.dataType),
"inputExpr" -> toSQLExpr(posExpr)
)
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import java.math.{BigDecimal => JavaBigDecimal}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1583,4 +1585,51 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Contains(Literal("Spark SQL"), Literal("SQL")), true)
checkEvaluation(Contains(Literal("Spark SQL"), Literal("k S")), true)
}

test("Elt: checkInputDataTypes") {
// requires at least two arguments
val indexExpr1 = Literal(8)
val expr1 = Elt(Seq(indexExpr1))
assert(expr1.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "WRONG_NUM_PARAMS",
messageParameters = Map(
"functionName" -> "elt",
"expectedNum" -> "> 1",
"actualNum" -> "1"
)
)
)

// first input to function etl should have IntegerType
val indexExpr2 = Literal('a')
val expr2 = Elt(Seq(indexExpr2, Literal('b')))
assert(expr2.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "1",
"requiredType" -> toSQLType(IntegerType),
"inputSql" -> toSQLExpr(indexExpr2),
"inputType" -> toSQLType(indexExpr2.dataType)
)
)
)

// input to function etl should have StringType or BinaryType
val indexExpr3 = Literal(1)
val inputExpr3 = Seq(Literal('a'), Literal('b'), Literal(12345))
val expr3 = Elt(indexExpr3 +: inputExpr3)
assert(expr3.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "2...",
"requiredType" -> (toSQLType(StringType) + " or " + toSQLType(BinaryType)),
"inputSql" -> inputExpr3.map(toSQLExpr(_)).mkString(","),
"inputType" -> inputExpr3.map(expr => toSQLType(expr.dataType)).mkString(",")
)
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,22 @@ SELECT regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something', -2)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something', -2)' due to data type mismatch: Position expression must be positive, but got: -2; line 1 pos 7
{
"errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
"messageParameters" : {
"currentValue" : "-2",
"exprName" : "position",
"sqlExpr" : "\"regexp_replace(healthy, wealthy, and wise, \\w+thy, something, -2)\"",
"valueRange" : "(0, 2147483647]"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 79,
"fragment" : "regexp_replace('healthy, wealthy, and wise', '\\\\w+thy', 'something', -2)"
} ]
}


-- !query
Expand All @@ -364,7 +379,22 @@ SELECT regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something', 0)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something', 0)' due to data type mismatch: Position expression must be positive, but got: 0; line 1 pos 7
{
"errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
"messageParameters" : {
"currentValue" : "0",
"exprName" : "position",
"sqlExpr" : "\"regexp_replace(healthy, wealthy, and wise, \\w+thy, something, 0)\"",
"valueRange" : "(0, 2147483647]"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 78,
"fragment" : "regexp_replace('healthy, wealthy, and wise', '\\\\w+thy', 'something', 0)"
} ]
}


-- !query
Expand Down