diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 593c7fc0206..bace90d6d16 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -467,7 +467,7 @@ class CudfRegexTranspiler(mode: RegexMode) { // parse the source regular expression val regex = new RegexParser(pattern).parse() // validate that the regex is supported by cuDF - val cudfRegex = rewrite(regex) + val cudfRegex = rewrite(regex, None) // write out to regex string, performing minor transformations // such as adding additional escaping cudfRegex.toRegexString @@ -535,19 +535,79 @@ class CudfRegexTranspiler(mode: RegexMode) { } } + private val lineTerminatorChars = Seq('\n', '\r', '\u0085', '\u2028', '\u2029') + + // from Java 8 documention: a line terminator is a 1 to 2 character sequence that marks + // the end of a line of an input character sequence. + // this method produces a RegexAST which outputs a regular expression to match any possible + // combination of line terminators + private def lineTerminatorMatcher(exclude: Set[Char], excludeCRLF: Boolean):RegexAST = { + val terminatorChars = new ListBuffer[RegexCharacterClassComponent]() + terminatorChars ++= lineTerminatorChars.filter(!exclude.contains(_)).map(RegexChar) + + if (terminatorChars.size == 0 && excludeCRLF) { + RegexEmpty() + } else if (terminatorChars.size == 0) { + RegexGroup(capture = false, RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n')))) + } else if (excludeCRLF) { + RegexCharacterClass(negated = false, characters = terminatorChars) + } else { + RegexGroup(capture = false, + RegexChoice( + RegexCharacterClass(negated = false, characters = terminatorChars), + RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n'))))) + } + } - private def rewrite(regex: RegexAST): RegexAST = { + private def rewrite(regex: RegexAST, previous: Option[RegexAST]): RegexAST = { regex match { case RegexChar(ch) => ch match { case '.' => // workaround for https://github.com/rapidsai/cudf/issues/9619 RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n'))) - case '$' => + case '$' if mode == RegexSplitMode || mode == RegexReplaceMode => // see https://github.com/NVIDIA/spark-rapids/issues/4533 - throw new RegexUnsupportedException("line anchor $ is not supported") + throw new RegexUnsupportedException("line anchor $ is not supported in split or replace") + case '$' => + // in the case of the line anchor $, the JVM has special conditions when handling line + // terminators in and around the anchor + // this handles cases where the line terminator characters are *before* the anchor ($) + // NOTE: this applies to when using *standard* mode. In multiline mode, all these + // conditions will change. Currently Spark does not use multiline mode. + previous match { + case Some(RegexChar('$')) => + // repeating the line anchor in cuDF (for example b$$) causes matches to fail, but in + // Java, it's treated as a single (b$ and b$$ are synonymous), so we create + // an empty RegexAST that outputs to empty string + RegexEmpty() + case Some(RegexChar(ch)) if ch == '\r' => + // when using the the CR (\r), it prevents the line anchor from handling any other + // line terminator sequences, so we just output the anchor and we are finished + // for example: \r$ -> \r$ (no transpilation) + RegexChar('$') + case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) => + // when using any other line terminator character, you can match any of the other + // line terminator characters individually as part of the line anchor match. + // for example: \n$ -> \n[\r\u0085\u2028\u2029]?$ + RegexSequence(ListBuffer( + RegexRepetition(lineTerminatorMatcher(Set(ch), true), SimpleQuantifier('?')), + RegexChar('$'))) + case _ => + // otherwise by default we can match any or none the full set of line terminators + RegexSequence(ListBuffer( + RegexRepetition(lineTerminatorMatcher(Set.empty, false), SimpleQuantifier('?')), + RegexChar('$'))) + } case '^' if mode == RegexSplitMode => throw new RegexUnsupportedException("line anchor ^ is not supported in split mode") + case '\r' | '\n' if mode == RegexFindMode => + previous match { + case Some(RegexChar('$')) => + RegexEmpty() + case _ => + regex + } case _ => regex } @@ -634,7 +694,10 @@ class CudfRegexTranspiler(mode: RegexMode) { case _ => } val components: Seq[RegexCharacterClassComponent] = characters - .map(x => rewrite(x).asInstanceOf[RegexCharacterClassComponent]) + .map(x => x match { + case RegexChar(ch) if "^$".contains(ch) => x + case _ => rewrite(x, None).asInstanceOf[RegexCharacterClassComponent] + }) if (negated) { // There are differences between cuDF and Java handling of newlines @@ -699,7 +762,52 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "sequences that only contain '^' or '$' are not supported") } - RegexSequence(parts.map(rewrite)) + + // Special handling for line anchor ($) + // This code is implemented here because to make it work in cuDF, we have to reorder + // the items in the regex. + // In the JVM, regexes like "\n$" and "$\n" have similar treatment + RegexSequence(parts.foldLeft((new ListBuffer[RegexAST](), + Option.empty[RegexAST]))((m, part) => { + val (r, last) = m + last match { + // when the previous character is a line anchor ($), the JVM has special handling + // when matching against line terminator characters + case Some(RegexChar('$')) => + val j = r.lastIndexWhere { + case RegexEmpty() => false + case _ => true + } + part match { + case RegexCharacterClass(true, parts) + if parts.forall(!isBeginOrEndLineAnchor(_)) => + r(j) = RegexSequence( + ListBuffer(lineTerminatorMatcher(Set.empty, true), RegexChar('$'))) + case RegexChar(ch) if ch == '\n' => + // what's really needed here is negative lookahead, but that is not + // supported by cuDF + // in this case: $\n would transpile to (?!\r)\n$ + throw new RegexUnsupportedException("regex sequence $\\n is not supported") + case RegexChar(ch) if "\r\u0085\u2028\u2029".contains(ch) => + r(j) = RegexSequence( + ListBuffer( + rewrite(part, None), + RegexSequence(ListBuffer( + RegexRepetition(lineTerminatorMatcher(Set(ch), true), + SimpleQuantifier('?')), RegexChar('$'))))) + case _ => + r.append(rewrite(part, last)) + } + case _ => + r.append(rewrite(part, last)) + } + r.last match { + case RegexEmpty() => + (r, last) + case _ => + (r, Some(part)) + } + })._1) case RegexRepetition(base, quantifier) => (base, quantifier) match { case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) => @@ -742,17 +850,17 @@ class CudfRegexTranspiler(mode: RegexMode) { // specifically this variable length repetition: \A{2,} throw new RegexUnsupportedException(nothingToRepeat) case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' => - RegexRepetition(rewrite(base), quantifier) + RegexRepetition(rewrite(base, None), quantifier) case _ if isSupportedRepetitionBase(base) => - RegexRepetition(rewrite(base), quantifier) + RegexRepetition(rewrite(base, None), quantifier) case _ => throw new RegexUnsupportedException(nothingToRepeat) } case RegexChoice(l, r) => - val ll = rewrite(l) - val rr = rewrite(r) + val ll = rewrite(l, None) + val rr = rewrite(r, None) // cuDF does not support repetition on one side of a choice, such as "a*|a" if (isRepetition(ll) || isRepetition(rr)) { @@ -776,7 +884,7 @@ class CudfRegexTranspiler(mode: RegexMode) { RegexChoice(ll, rr) case RegexGroup(capture, term) => - RegexGroup(capture, rewrite(term)) + RegexGroup(capture, rewrite(term, None)) case other => throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other") @@ -804,6 +912,11 @@ sealed trait RegexAST { def toRegexString: String } +sealed case class RegexEmpty() extends RegexAST { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = "" +} + sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST { override def children(): Seq[RegexAST] = parts override def toRegexString: String = parts.map(_.toRegexString).mkString diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala index c336da0c0dd..1bec4a73d45 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala @@ -93,17 +93,13 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite { frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')") } - testGpuFallback("String regexp_extract regex 1", - "RegExpExtract", - extractStrings, execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec", "Alias", - "RegExpExtract", "AttributeReference", "Literal"),conf = conf) { + testSparkResultsAreEqual("String regexp_extract regex 1", + extractStrings, conf = conf) { frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)") } - testGpuFallback("String regexp_extract regex 2", - "RegExpExtract", - extractStrings, execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec", "Alias", - "RegExpExtract", "AttributeReference", "Literal"),conf = conf) { + testSparkResultsAreEqual("String regexp_extract regex 2", + extractStrings, conf = conf) { frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)") } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 622b839c7c4..1b6c0e0d12f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -215,12 +215,25 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "\ntest", "test\n", "\ntest\n", "\ntest\r\ntest\n")) } - test("line anchor $ fall back to CPU") { - for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { - assertUnsupported("a$b", mode, "line anchor $ is not supported") + test("line anchor $ fall back to CPU - split and replace") { + for (mode <- Seq(RegexSplitMode, RegexReplaceMode)) { + assertUnsupported("a$b", mode, "line anchor $ is not supported in split or replace") } } + test("line anchor sequence $\\n fall back to CPU") { + assertUnsupported("a$\n", RegexFindMode, "regex sequence $\\n is not supported") + } + + test("line anchor $ - find") { + val patterns = Seq("$\r", "a$", "\r$", "\f$", "$\f", "\u0085$", "\u2028$", "\u2029$", "\n$", + "\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b") + val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028", + "\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", + "\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb") + assertCpuGpuMatchesRegexpFind(patterns, inputs) + } + test("whitespace boundaries - replace") { assertCpuGpuMatchesRegexpReplace( Seq("\\s", "\\S"), @@ -280,6 +293,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { doTranspileTest("abc\\z", "abc$") } + test("transpile $") { + doTranspileTest("a$", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") + doTranspileTest("$$\r", "\r[\n\u0085\u2028\u2029]?$") + doTranspileTest("]$\r", "]\r[\n\u0085\u2028\u2029]?$") + doTranspileTest("^$[^*A-ZA-Z]", "^[\n\r\u0085\u2028\u2029]$") + } + test("compare CPU and GPU: character range including unescaped + and -") { val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]") val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]")