diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1ecb3d1958f4..0e820c40a32b 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -532,7 +532,9 @@ predicate : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression | NOT? kind=IN '(' expression (',' expression)* ')' | NOT? kind=IN '(' query ')' - | NOT? kind=(RLIKE | LIKE) pattern=valueExpression + | NOT? kind=LIKE pattern=valueExpression + | NOT? kind=RLIKE regex=regexString + | NOT? kind=RLIKE pattern=valueExpression | IS NOT? kind=NULL ; @@ -576,6 +578,10 @@ constant | STRING+ #stringLiteral ; +regexString + : STRING+ #regexPattern + ; + comparisonOperator : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e1db1ef5b869..c5bdedd5f186 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -954,7 +954,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case SqlBaseParser.LIKE => invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => - invertIfNotDefined(RLike(e, expression(ctx.pattern))) + if (ctx.pattern != null) { + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + } else { + invertIfNotDefined(RLike(e, expression(ctx.regex))) + } case SqlBaseParser.NULL if ctx.NOT != null => IsNotNull(e) case SqlBaseParser.NULL => @@ -1398,6 +1402,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Literal(createString(ctx)) } + override def visitRegexPattern(ctx: RegexPatternContext): Literal = withOrigin(ctx) { + Literal(ctx.STRING().asScala.map(regexString).mkString) + } + /** * Create a String from a string literal context. This supports multiple consecutive string * literals, these are concatenated, for example this expression "'hello' 'world'" will be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 6fbc33fad735..218d52e4d483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -68,6 +68,11 @@ object ParserUtils { /** Convert a string node into a string. */ def string(node: TerminalNode): String = unescapeSQLString(node.getText) + /** Convert a string node to a regex string. */ + def regexString(node: TerminalNode): String = { + node.getText.slice(1, node.getText.size - 1) + } + /** Get the origin (line and position) of the token. */ def position(token: Token): Origin = { val opt = Option(token) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e7f3b64a7113..1994bec6fc0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -158,6 +158,11 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$") + assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\") + assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n") + intercept("a rlike 'pattern\\'", "mismatched input") } test("is null expressions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5b5cd28ad0c9..f5beab567029 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1168,6 +1168,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) } + + test("do not unescaped regex pattern string") { + val data = Seq("\u0020\u0021\u0023", "abc") + val df = data.toDF() + val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'") + val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$")) + val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'") + checkAnswer(rlike1, rlike2) + assert(rlike3.count() == 0) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])