diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 79c0ca56a8e79..99a62343f138d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -21,8 +21,11 @@ import java.text.DecimalFormat import java.util.Locale import java.util.regex.{MatchResult, Pattern} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -160,32 +163,51 @@ trait StringRegexExpression extends ImplicitCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression with CodegenFallback { - // replace the _ with .{1} exactly match 1 time of any character - // replace the % with .*, match 0 or more times with any character - override def escape(v: String): String = - if (!v.isEmpty) { - "(?s)" + (' ' +: v.init).zip(v).flatMap { - case (prev, '\\') => "" - case ('\\', c) => - c match { - case '_' => "_" - case '%' => "%" - case _ => Pattern.quote("\\" + c) - } - case (prev, c) => - c match { - case '_' => "." - case '%' => ".*" - case _ => Pattern.quote(Character.toString(c)) - } - }.mkString - } else { - v - } + override def escape(v: String): String = StringUtils.escapeLikeRegex(v) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def toString: String = s"$left LIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); + """ + }) + } + } } @@ -195,6 +217,45 @@ case class RLike(left: Expression, right: Expression) override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile(rightStr); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); + """ + }) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala new file mode 100644 index 0000000000000..9ddfb3a0d3759 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.regex.Pattern + +object StringUtils { + + // replace the _ with .{1} exactly match 1 time of any character + // replace the % with .*, match 0 or more times with any character + def escapeLikeRegex(v: String): String = { + if (!v.isEmpty) { + "(?s)" + (' ' +: v.init).zip(v).flatMap { + case (prev, '\\') => "" + case ('\\', c) => + c match { + case '_' => "_" + case '%' => "%" + case _ => Pattern.quote("\\" + c) + } + case (prev, c) => + c match { + case '_' => "." + case '%' => ".*" + case _ => Pattern.quote(Character.toString(c)) + } + }.mkString + } else { + v + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 07b952531ec2e..3ecd0d374c46b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -191,6 +191,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create("a", StringType)), true) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create("a", StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -232,6 +241,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) checkEvaluation("abdef" rlike Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) + checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) + checkEvaluation("abdef" rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create("abdef", StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala new file mode 100644 index 0000000000000..d6f273f9e568a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.StringUtils._ + +class StringUtilsSuite extends SparkFunSuite { + + test("escapeLikeRegex") { + assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") + assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") + assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") + assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") + assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") + assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") + assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") + } +}