From c2310e08a7b5cbf68d36b8be281fe19fbd892b68 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Thu, 14 Apr 2022 12:10:35 -0700 Subject: [PATCH 01/10] WIP: dollar anchor support Signed-off-by: Navin Kumar --- .../com/nvidia/spark/rapids/RegexParser.scala | 51 +++++++++++++++---- .../RegularExpressionTranspilerSuite.scala | 13 +++-- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 2589bf5c897..8c6a80e18c5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -467,7 +467,7 @@ class CudfRegexTranspiler(mode: RegexMode) { // parse the source regular expression val regex = new RegexParser(pattern).parse() // validate that the regex is supported by cuDF - val cudfRegex = rewrite(regex) + val cudfRegex = rewrite(regex, None) // write out to regex string, performing minor transformations // such as adding additional escaping cudfRegex.toRegexString @@ -536,16 +536,35 @@ class CudfRegexTranspiler(mode: RegexMode) { } - private def rewrite(regex: RegexAST): RegexAST = { + private def rewrite(regex: RegexAST, previous: Option[RegexAST]): RegexAST = { regex match { case RegexChar(ch) => ch match { case '.' => // workaround for https://github.com/rapidsai/cudf/issues/9619 RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n'))) - case '$' => + case '$' if mode == RegexSplitMode || mode == RegexReplaceMode => // see https://github.com/NVIDIA/spark-rapids/issues/4533 throw new RegexUnsupportedException("line anchor $ is not supported") + case '$' => + previous match { + case Some(RegexChar('$')) => + RegexEmpty() + case _ => + RegexSequence( + ListBuffer( + RegexRepetition( + RegexChar('\r'), + SimpleQuantifier('?') + ), + RegexRepetition( + RegexChar('\n'), + SimpleQuantifier('?') + ), + RegexChar('$') + ) + ) + } case '^' if mode == RegexSplitMode => throw new RegexUnsupportedException("line anchor ^ is not supported in split mode") case _ => @@ -632,7 +651,7 @@ class CudfRegexTranspiler(mode: RegexMode) { case _ => } val components: Seq[RegexCharacterClassComponent] = characters - .map(x => rewrite(x).asInstanceOf[RegexCharacterClassComponent]) + .map(x => rewrite(x, None).asInstanceOf[RegexCharacterClassComponent]) if (negated) { // There are differences between cuDF and Java handling of newlines @@ -697,7 +716,14 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "sequences that only contain '^' or '$' are not supported") } - RegexSequence(parts.map(rewrite)) + RegexSequence(parts.foldLeft(new ListBuffer[RegexAST]())((m, part) => { + if (m.isEmpty) { + m.append(rewrite(part, None)) + } else { + m.append(rewrite(part, Some(m.last))) + } + m + })) case RegexRepetition(base, quantifier) => (base, quantifier) match { case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) => @@ -740,17 +766,17 @@ class CudfRegexTranspiler(mode: RegexMode) { // specifically this variable length repetition: \A{2,} throw new RegexUnsupportedException(nothingToRepeat) case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' => - RegexRepetition(rewrite(base), quantifier) + RegexRepetition(rewrite(base, None), quantifier) case _ if isSupportedRepetitionBase(base) => - RegexRepetition(rewrite(base), quantifier) + RegexRepetition(rewrite(base, None), quantifier) case _ => throw new RegexUnsupportedException(nothingToRepeat) } case RegexChoice(l, r) => - val ll = rewrite(l) - val rr = rewrite(r) + val ll = rewrite(l, None) + val rr = rewrite(r, None) // cuDF does not support repetition on one side of a choice, such as "a*|a" if (isRepetition(ll) || isRepetition(rr)) { @@ -774,7 +800,7 @@ class CudfRegexTranspiler(mode: RegexMode) { RegexChoice(ll, rr) case RegexGroup(capture, term) => - RegexGroup(capture, rewrite(term)) + RegexGroup(capture, rewrite(term, None)) case other => throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other") @@ -802,6 +828,11 @@ sealed trait RegexAST { def toRegexString: String } +sealed case class RegexEmpty() extends RegexAST { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = "" +} + sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST { override def children(): Seq[RegexAST] = parts override def toRegexString: String = parts.map(_.toRegexString).mkString diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index f1e2c0ca3ff..70c87f34810 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -215,10 +215,15 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "\ntest", "test\n", "\ntest\n", "\ntest\r\ntest\n")) } - test("line anchor $ fall back to CPU") { - for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { - assertUnsupported("a$b", mode, "line anchor $ is not supported") - } + // test("line anchor $ fall back to CPU") { + // for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + // assertUnsupported("a$b", mode, "line anchor $ is not supported") + // } + // } + + test("line anchor $ - find") { + val patterns = Seq("\\00*[D$3]$") + assertCpuGpuMatchesRegexpFind(patterns, Seq("2+|+??wD\n")) } test("match literal $ - find") { From 2857de8a4c0b9d3964757ba3ceda8170a6386c0b Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Thu, 14 Apr 2022 12:58:26 -0700 Subject: [PATCH 02/10] Solved issue with anchor in character class (treat as literal $) Signed-off-by: Navin Kumar --- .../src/main/scala/com/nvidia/spark/rapids/RegexParser.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 8c6a80e18c5..7f2ec66081c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -651,7 +651,10 @@ class CudfRegexTranspiler(mode: RegexMode) { case _ => } val components: Seq[RegexCharacterClassComponent] = characters - .map(x => rewrite(x, None).asInstanceOf[RegexCharacterClassComponent]) + .map(x => x match { + case RegexChar(ch) if "^$".contains(ch) => x + case _ => rewrite(x, None).asInstanceOf[RegexCharacterClassComponent] + }) if (negated) { // There are differences between cuDF and Java handling of newlines From 99abe185473dbcbafaedcfc61e60a0973132fbef Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Tue, 19 Apr 2022 12:24:33 -0700 Subject: [PATCH 03/10] WIP: more dollar anchor support, a few more cases to tidy up Signed-off-by: Navin Kumar --- .../com/nvidia/spark/rapids/RegexParser.scala | 117 +++++++++++++----- .../RegularExpressionTranspilerSuite.scala | 21 ++-- 2 files changed, 101 insertions(+), 37 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index c7003d1bdee..d125ad1fc7a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -467,7 +467,7 @@ class CudfRegexTranspiler(mode: RegexMode) { // parse the source regular expression val regex = new RegexParser(pattern).parse() // validate that the regex is supported by cuDF - val cudfRegex = rewrite(regex, None) + val cudfRegex = rewrite(regex, None, None) // write out to regex string, performing minor transformations // such as adding additional escaping cudfRegex.toRegexString @@ -535,8 +535,7 @@ class CudfRegexTranspiler(mode: RegexMode) { } } - - private def rewrite(regex: RegexAST, previous: Option[RegexAST]): RegexAST = { + private def rewrite(regex: RegexAST, previous: Option[RegexAST], next: Option[RegexAST]): RegexAST = { regex match { case RegexChar(ch) => ch match { @@ -551,22 +550,51 @@ class CudfRegexTranspiler(mode: RegexMode) { case Some(RegexChar('$')) => RegexEmpty() case _ => - RegexSequence( - ListBuffer( - RegexRepetition( - RegexChar('\r'), - SimpleQuantifier('?') - ), - RegexRepetition( - RegexChar('\n'), - SimpleQuantifier('?') - ), - RegexChar('$') - ) - ) + // next match { + // case Some(RegexChar(ch)) if "\n\r\u0085\u2028\u2029".contains(ch) => + // RegexSequence( + // ListBuffer( + // RegexChar(ch), + // RegexChar('$') + // ) + // ) + // case _ => + RegexSequence( + ListBuffer( + RegexRepetition( + RegexGroup(capture = false, + RegexChoice( + RegexCharacterClass(negated=false, ListBuffer( + RegexChar('\n'), + RegexChar('\r'), + RegexChar('\u0085'), + RegexChar('\u2028'), + RegexChar('\u2029') + )), + RegexSequence( + ListBuffer( + RegexChar('\r'), + RegexChar('\n') + ) + ) + ) + ), + SimpleQuantifier('?') + ), + RegexChar('$') + ) + ) + // } } case '^' if mode == RegexSplitMode => throw new RegexUnsupportedException("line anchor ^ is not supported in split mode") + case '\r' | '\n' if mode == RegexFindMode => + previous match { + case Some(RegexChar('$')) => + RegexEmpty() + case _ => + regex + } case _ => regex } @@ -655,7 +683,7 @@ class CudfRegexTranspiler(mode: RegexMode) { val components: Seq[RegexCharacterClassComponent] = characters .map(x => x match { case RegexChar(ch) if "^$".contains(ch) => x - case _ => rewrite(x, None).asInstanceOf[RegexCharacterClassComponent] + case _ => rewrite(x, None, None).asInstanceOf[RegexCharacterClassComponent] }) if (negated) { @@ -721,14 +749,43 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "sequences that only contain '^' or '$' are not supported") } - RegexSequence(parts.foldLeft(new ListBuffer[RegexAST]())((m, part) => { - if (m.isEmpty) { - m.append(rewrite(part, None)) - } else { - m.append(rewrite(part, Some(m.last))) - } - m - })) + RegexSequence(parts.foldLeft((new ListBuffer[RegexAST](), + Option.empty[RegexAST]))((m, part) => { + val (r, last) = m + last match { + case Some(RegexChar('$')) => + val j = r.lastIndexWhere { + case RegexEmpty() => false + case _ => true + } + part match { + case RegexCharacterClass(true, parts) if parts.forall(!isBeginOrEndLineAnchor(_)) => + r(j) = RegexSequence( + ListBuffer( + rewrite(part, None, None), + last.get + ) + ) + case RegexChar(ch) if "\n\r\u0085\u2028\u2029".contains(ch) => + r(j) = RegexSequence( + ListBuffer( + rewrite(part, None, None), + last.get + ) + ) + case _ => + r.append(rewrite(part, last, None)) + } + case _ => + r.append(rewrite(part, last, None)) + } + r.last match { + case RegexEmpty() => + (r, last) + case _ => + (r, Some(part)) + } + })._1) case RegexRepetition(base, quantifier) => (base, quantifier) match { case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) => @@ -771,17 +828,17 @@ class CudfRegexTranspiler(mode: RegexMode) { // specifically this variable length repetition: \A{2,} throw new RegexUnsupportedException(nothingToRepeat) case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' => - RegexRepetition(rewrite(base, None), quantifier) + RegexRepetition(rewrite(base, None, None), quantifier) case _ if isSupportedRepetitionBase(base) => - RegexRepetition(rewrite(base, None), quantifier) + RegexRepetition(rewrite(base, None, None), quantifier) case _ => throw new RegexUnsupportedException(nothingToRepeat) } case RegexChoice(l, r) => - val ll = rewrite(l, None) - val rr = rewrite(r, None) + val ll = rewrite(l, None, None) + val rr = rewrite(r, None, None) // cuDF does not support repetition on one side of a choice, such as "a*|a" if (isRepetition(ll) || isRepetition(rr)) { @@ -805,7 +862,7 @@ class CudfRegexTranspiler(mode: RegexMode) { RegexChoice(ll, rr) case RegexGroup(capture, term) => - RegexGroup(capture, rewrite(term, None)) + RegexGroup(capture, rewrite(term, None, None)) case other => throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 774282a8884..869ae3e5ae8 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -215,15 +215,16 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "\ntest", "test\n", "\ntest\n", "\ntest\r\ntest\n")) } - // test("line anchor $ fall back to CPU") { - // for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { - // assertUnsupported("a$b", mode, "line anchor $ is not supported") - // } - // } + test("line anchor $ fall back to CPU") { + for (mode <- Seq(RegexSplitMode, RegexReplaceMode)) { + assertUnsupported("a$b", mode, "line anchor $ is not supported") + } + } test("line anchor $ - find") { - val patterns = Seq("\\00*[D$3]$") - assertCpuGpuMatchesRegexpFind(patterns, Seq("2+|+??wD\n")) + val patterns = Seq("$\r", "a$", "\r$", "\n$", "\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b") + val inputs = Seq("a", "a\n", "a\r", "a\r\n", "\r", "\n", "\r\n", "\n\r", "2+|+??wD\n", "a\r\nb") + assertCpuGpuMatchesRegexpFind(patterns, inputs) } test("whitespace boundaries - replace") { @@ -285,6 +286,12 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { doTranspileTest("abc\\z", "abc$") } + test("transpile $") { + doTranspileTest("a$", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") + doTranspileTest("$$\n", "\n$") + doTranspileTest("^$[^*A-ZA-Z]", "^(?:[\r\n]|[^*A-ZA-Z])$") + } + test("compare CPU and GPU: character range including unescaped + and -") { val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]") val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]") From 9c2879cb917d60de422b9d2b369638b945ea6959 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Tue, 19 Apr 2022 13:58:14 -0700 Subject: [PATCH 04/10] Remove next parameter and fix style issues Signed-off-by: Navin Kumar --- .../com/nvidia/spark/rapids/RegexParser.scala | 83 ++++++++----------- 1 file changed, 36 insertions(+), 47 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index d125ad1fc7a..3f1bb598980 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -467,7 +467,7 @@ class CudfRegexTranspiler(mode: RegexMode) { // parse the source regular expression val regex = new RegexParser(pattern).parse() // validate that the regex is supported by cuDF - val cudfRegex = rewrite(regex, None, None) + val cudfRegex = rewrite(regex, None) // write out to regex string, performing minor transformations // such as adding additional escaping cudfRegex.toRegexString @@ -535,7 +535,7 @@ class CudfRegexTranspiler(mode: RegexMode) { } } - private def rewrite(regex: RegexAST, previous: Option[RegexAST], next: Option[RegexAST]): RegexAST = { + private def rewrite(regex: RegexAST, previous: Option[RegexAST]): RegexAST = { regex match { case RegexChar(ch) => ch match { @@ -550,41 +550,29 @@ class CudfRegexTranspiler(mode: RegexMode) { case Some(RegexChar('$')) => RegexEmpty() case _ => - // next match { - // case Some(RegexChar(ch)) if "\n\r\u0085\u2028\u2029".contains(ch) => - // RegexSequence( - // ListBuffer( - // RegexChar(ch), - // RegexChar('$') - // ) - // ) - // case _ => - RegexSequence( - ListBuffer( - RegexRepetition( - RegexGroup(capture = false, - RegexChoice( - RegexCharacterClass(negated=false, ListBuffer( - RegexChar('\n'), - RegexChar('\r'), - RegexChar('\u0085'), - RegexChar('\u2028'), - RegexChar('\u2029') - )), - RegexSequence( - ListBuffer( - RegexChar('\r'), - RegexChar('\n') - ) - ) + RegexSequence( + ListBuffer( + RegexRepetition( + RegexGroup(capture = false, + RegexChoice( + RegexCharacterClass(negated=false, ListBuffer( + RegexChar('\n'), + RegexChar('\r'), + RegexChar('\u0085'), + RegexChar('\u2028'), + RegexChar('\u2029') + )), + RegexSequence( + ListBuffer( + RegexChar('\r'), + RegexChar('\n') ) - ), - SimpleQuantifier('?') - ), - RegexChar('$') - ) - ) - // } + ) + ) + ), + SimpleQuantifier('?') + ), + RegexChar('$'))) } case '^' if mode == RegexSplitMode => throw new RegexUnsupportedException("line anchor ^ is not supported in split mode") @@ -683,7 +671,7 @@ class CudfRegexTranspiler(mode: RegexMode) { val components: Seq[RegexCharacterClassComponent] = characters .map(x => x match { case RegexChar(ch) if "^$".contains(ch) => x - case _ => rewrite(x, None, None).asInstanceOf[RegexCharacterClassComponent] + case _ => rewrite(x, None).asInstanceOf[RegexCharacterClassComponent] }) if (negated) { @@ -759,25 +747,26 @@ class CudfRegexTranspiler(mode: RegexMode) { case _ => true } part match { - case RegexCharacterClass(true, parts) if parts.forall(!isBeginOrEndLineAnchor(_)) => + case RegexCharacterClass(true, parts) + if parts.forall(!isBeginOrEndLineAnchor(_)) => r(j) = RegexSequence( ListBuffer( - rewrite(part, None, None), + rewrite(part, None), last.get ) ) case RegexChar(ch) if "\n\r\u0085\u2028\u2029".contains(ch) => r(j) = RegexSequence( ListBuffer( - rewrite(part, None, None), + rewrite(part, None), last.get ) ) case _ => - r.append(rewrite(part, last, None)) + r.append(rewrite(part, last)) } case _ => - r.append(rewrite(part, last, None)) + r.append(rewrite(part, last)) } r.last match { case RegexEmpty() => @@ -828,17 +817,17 @@ class CudfRegexTranspiler(mode: RegexMode) { // specifically this variable length repetition: \A{2,} throw new RegexUnsupportedException(nothingToRepeat) case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' => - RegexRepetition(rewrite(base, None, None), quantifier) + RegexRepetition(rewrite(base, None), quantifier) case _ if isSupportedRepetitionBase(base) => - RegexRepetition(rewrite(base, None, None), quantifier) + RegexRepetition(rewrite(base, None), quantifier) case _ => throw new RegexUnsupportedException(nothingToRepeat) } case RegexChoice(l, r) => - val ll = rewrite(l, None, None) - val rr = rewrite(r, None, None) + val ll = rewrite(l, None) + val rr = rewrite(r, None) // cuDF does not support repetition on one side of a choice, such as "a*|a" if (isRepetition(ll) || isRepetition(rr)) { @@ -862,7 +851,7 @@ class CudfRegexTranspiler(mode: RegexMode) { RegexChoice(ll, rr) case RegexGroup(capture, term) => - RegexGroup(capture, rewrite(term, None, None)) + RegexGroup(capture, rewrite(term, None)) case other => throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other") From 40e0746ca755b2f1db46a30934933417a0fbb23d Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Tue, 19 Apr 2022 15:05:59 -0700 Subject: [PATCH 05/10] Handle most edge cases with line anchor $ Signed-off-by: Navin Kumar --- .../com/nvidia/spark/rapids/RegexParser.scala | 64 ++++++++++--------- .../RegularExpressionTranspilerSuite.scala | 6 +- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 3f1bb598980..3ca87d52f90 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -535,6 +535,26 @@ class CudfRegexTranspiler(mode: RegexMode) { } } + private val lineTerminatorChars = Seq('\n', '\r', '\u0085', '\u2028', '\u2029') + + private def lineTerminatorMatcher(exclude: Set[Char], excludeCRLF: Boolean):RegexAST = { + val terminatorChars = new ListBuffer[RegexCharacterClassComponent]() + terminatorChars ++= lineTerminatorChars.filter(!exclude.contains(_)).map(RegexChar) + + if (terminatorChars.size == 0 && excludeCRLF) { + RegexEmpty() + } else if (terminatorChars.size == 0) { + RegexGroup(capture = false, RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n')))) + } else if (excludeCRLF) { + RegexCharacterClass(negated = false, characters = terminatorChars) + } else { + RegexGroup(capture = false, + RegexChoice( + RegexCharacterClass(negated = false, characters = terminatorChars), + RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n'))))) + } + } + private def rewrite(regex: RegexAST, previous: Option[RegexAST]): RegexAST = { regex match { @@ -549,30 +569,16 @@ class CudfRegexTranspiler(mode: RegexMode) { previous match { case Some(RegexChar('$')) => RegexEmpty() + case Some(RegexChar(ch)) if ch == '\n' => + RegexSequence(ListBuffer( + RegexRepetition(lineTerminatorMatcher(Set('\n'), true), SimpleQuantifier('?')), + RegexChar('$'))) + case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) => + RegexChar('$') case _ => - RegexSequence( - ListBuffer( - RegexRepetition( - RegexGroup(capture = false, - RegexChoice( - RegexCharacterClass(negated=false, ListBuffer( - RegexChar('\n'), - RegexChar('\r'), - RegexChar('\u0085'), - RegexChar('\u2028'), - RegexChar('\u2029') - )), - RegexSequence( - ListBuffer( - RegexChar('\r'), - RegexChar('\n') - ) - ) - ) - ), - SimpleQuantifier('?') - ), - RegexChar('$'))) + RegexSequence(ListBuffer( + RegexRepetition(lineTerminatorMatcher(Set.empty, false), SimpleQuantifier('?')), + RegexChar('$'))) } case '^' if mode == RegexSplitMode => throw new RegexUnsupportedException("line anchor ^ is not supported in split mode") @@ -750,18 +756,14 @@ class CudfRegexTranspiler(mode: RegexMode) { case RegexCharacterClass(true, parts) if parts.forall(!isBeginOrEndLineAnchor(_)) => r(j) = RegexSequence( - ListBuffer( - rewrite(part, None), - last.get - ) - ) + ListBuffer(lineTerminatorMatcher(Set.empty, true), RegexChar('$'))) case RegexChar(ch) if "\n\r\u0085\u2028\u2029".contains(ch) => r(j) = RegexSequence( ListBuffer( rewrite(part, None), - last.get - ) - ) + RegexSequence(ListBuffer( + RegexRepetition(lineTerminatorMatcher(Set(ch), true), + SimpleQuantifier('?')), RegexChar('$'))))) case _ => r.append(rewrite(part, last)) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 869ae3e5ae8..4833ebb9fda 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -215,7 +215,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "\ntest", "test\n", "\ntest\n", "\ntest\r\ntest\n")) } - test("line anchor $ fall back to CPU") { + test("line anchor $ fall back to CPU - split and replace") { for (mode <- Seq(RegexSplitMode, RegexReplaceMode)) { assertUnsupported("a$b", mode, "line anchor $ is not supported") } @@ -288,8 +288,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("transpile $") { doTranspileTest("a$", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") - doTranspileTest("$$\n", "\n$") - doTranspileTest("^$[^*A-ZA-Z]", "^(?:[\r\n]|[^*A-ZA-Z])$") + doTranspileTest("$$\n", "\n[\r\u0085\u2028\u2029]?$") + doTranspileTest("^$[^*A-ZA-Z]", "^[\n\r\u0085\u2028\u2029]$") } test("compare CPU and GPU: character range including unescaped + and -") { From 868a3ffaefbc510a7c6a85bcb81e621227120bcf Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Tue, 19 Apr 2022 15:29:23 -0700 Subject: [PATCH 06/10] Finish handling last few edge cases regarding carriage return vs other line termination characters Signed-off-by: Navin Kumar --- .../scala/com/nvidia/spark/rapids/RegexParser.scala | 10 +++++----- .../rapids/RegularExpressionTranspilerSuite.scala | 9 ++++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 3ca87d52f90..579bd681d33 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -564,17 +564,17 @@ class CudfRegexTranspiler(mode: RegexMode) { RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n'))) case '$' if mode == RegexSplitMode || mode == RegexReplaceMode => // see https://github.com/NVIDIA/spark-rapids/issues/4533 - throw new RegexUnsupportedException("line anchor $ is not supported") + throw new RegexUnsupportedException("line anchor $ is not supported in split or replace") case '$' => previous match { case Some(RegexChar('$')) => RegexEmpty() - case Some(RegexChar(ch)) if ch == '\n' => + case Some(RegexChar(ch)) if ch == '\r' => + RegexChar('$') + case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) => RegexSequence(ListBuffer( - RegexRepetition(lineTerminatorMatcher(Set('\n'), true), SimpleQuantifier('?')), + RegexRepetition(lineTerminatorMatcher(Set(ch), true), SimpleQuantifier('?')), RegexChar('$'))) - case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) => - RegexChar('$') case _ => RegexSequence(ListBuffer( RegexRepetition(lineTerminatorMatcher(Set.empty, false), SimpleQuantifier('?')), diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 4833ebb9fda..784ed464477 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -217,13 +217,16 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("line anchor $ fall back to CPU - split and replace") { for (mode <- Seq(RegexSplitMode, RegexReplaceMode)) { - assertUnsupported("a$b", mode, "line anchor $ is not supported") + assertUnsupported("a$b", mode, "line anchor $ is not supported in split or replace") } } test("line anchor $ - find") { - val patterns = Seq("$\r", "a$", "\r$", "\n$", "\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b") - val inputs = Seq("a", "a\n", "a\r", "a\r\n", "\r", "\n", "\r\n", "\n\r", "2+|+??wD\n", "a\r\nb") + val patterns = Seq("$\r", "a$", "\r$", "\u0085$", "\u2028$", "\u2029$", "\n$", "\r\n$", + "[\r\n]?$", "\\00*[D$3]$", "a$b") + val inputs = Seq("a", "a\n", "a\r", "a\r\n", "\r", "\u0085", "\u2028", "\u2029", "\n", + "\r\n", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", "\n\u0085", "\n\u2028", "\n\u2029", + "2+|+??wD\n", "a\r\nb") assertCpuGpuMatchesRegexpFind(patterns, inputs) } From d789db5c3218c4a4df0f08cd665253eb548f1e30 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Wed, 20 Apr 2022 17:06:05 -0700 Subject: [PATCH 07/10] Disable one particular edge case as we don't have the underlying cudf support to transpile. Also, add more comments Signed-off-by: Navin Kumar --- .../com/nvidia/spark/rapids/RegexParser.scala | 31 ++++++++++++++++++- .../RegularExpressionTranspilerSuite.scala | 13 +++++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 5933ca99b71..de4b4dde939 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -537,6 +537,10 @@ class CudfRegexTranspiler(mode: RegexMode) { private val lineTerminatorChars = Seq('\n', '\r', '\u0085', '\u2028', '\u2029') + // from Java 8 documention: a line terminator is a 1 to 2 character sequence that marks + // the end of a line of an input character sequence. + // this method produces a RegexAST which outputs a regular expression to match any possible + // combination of line terminators private def lineTerminatorMatcher(exclude: Set[Char], excludeCRLF: Boolean):RegexAST = { val terminatorChars = new ListBuffer[RegexCharacterClassComponent]() terminatorChars ++= lineTerminatorChars.filter(!exclude.contains(_)).map(RegexChar) @@ -566,16 +570,29 @@ class CudfRegexTranspiler(mode: RegexMode) { // see https://github.com/NVIDIA/spark-rapids/issues/4533 throw new RegexUnsupportedException("line anchor $ is not supported in split or replace") case '$' => + // in the case of the line anchor $, the JVM has special conditions when handling line + // terminators in and around the anchor + // this handles cases where the line terminator characters are *before* the anchor ($) + // NOTE: this applies to when using *standard* mode. In multiline mode, all these + // conditions will change. Currently Spark does not use multiline mode. previous match { case Some(RegexChar('$')) => + // repeating the line anchor in cuDF (for example b$$) causes matches to fail, but in + // Java, it's treated as a single (b$ and b$$ are synonymous), so we create + // an empty RegexAST that outputs to empty string RegexEmpty() case Some(RegexChar(ch)) if ch == '\r' => + // when using the the CR (\r), it prevents the line anchor from handling any other + // line terminator sequences, so we just output the anchor and we are finished RegexChar('$') case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) => + // when using any other line terminator character, you can match any of the other + // line terminator characters individually as part of the line anchor match. RegexSequence(ListBuffer( RegexRepetition(lineTerminatorMatcher(Set(ch), true), SimpleQuantifier('?')), RegexChar('$'))) case _ => + // otherwise by default we can match any or none the full set of line terminators RegexSequence(ListBuffer( RegexRepetition(lineTerminatorMatcher(Set.empty, false), SimpleQuantifier('?')), RegexChar('$'))) @@ -743,10 +760,17 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "sequences that only contain '^' or '$' are not supported") } + + // Special handling for line anchor ($) + // This code is implemented here because to make it work in cuDF, we have to reorder + // the items in the regex. + // In the JVM, regexes like "\n$" and "$\n" have similar treatment RegexSequence(parts.foldLeft((new ListBuffer[RegexAST](), Option.empty[RegexAST]))((m, part) => { val (r, last) = m last match { + // when the previous character is a line anchor ($), the JVM has special handling + // when matching against line terminator characters case Some(RegexChar('$')) => val j = r.lastIndexWhere { case RegexEmpty() => false @@ -757,7 +781,12 @@ class CudfRegexTranspiler(mode: RegexMode) { if parts.forall(!isBeginOrEndLineAnchor(_)) => r(j) = RegexSequence( ListBuffer(lineTerminatorMatcher(Set.empty, true), RegexChar('$'))) - case RegexChar(ch) if "\n\r\u0085\u2028\u2029".contains(ch) => + case RegexChar(ch) if ch == '\n' => + // what's really needed here is negative lookahead, but that is not + // supported by cuDF + // in this case: $\n would transpile to (?!\r)\n$ + throw new RegexUnsupportedException("regex sequence $\\n is not supported") + case RegexChar(ch) if "\r\u0085\u2028\u2029".contains(ch) => r(j) = RegexSequence( ListBuffer( rewrite(part, None), diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 355b10de88c..e4066ab0386 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -221,12 +221,16 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } + test("line anchor sequence $\\n fall back to CPU") { + assertUnsupported("a$\n", RegexFindMode, "regex sequence $\\n is not supported") + } + test("line anchor $ - find") { val patterns = Seq("$\r", "a$", "\r$", "\u0085$", "\u2028$", "\u2029$", "\n$", "\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b") - val inputs = Seq("a", "a\n", "a\r", "a\r\n", "\r", "\u0085", "\u2028", "\u2029", "\n", - "\r\n", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", "\n\u0085", "\n\u2028", "\n\u2029", - "2+|+??wD\n", "a\r\nb") + val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "\r", "\u0085", "\u2028", "\u2029", + "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", "\n\u0085", + "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb") assertCpuGpuMatchesRegexpFind(patterns, inputs) } @@ -291,7 +295,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("transpile $") { doTranspileTest("a$", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") - doTranspileTest("$$\n", "\n[\r\u0085\u2028\u2029]?$") + doTranspileTest("$$\r", "\r[\n\u0085\u2028\u2029]?$") + doTranspileTest("]$\r", "]\r[\n\u0085\u2028\u2029]?$") doTranspileTest("^$[^*A-ZA-Z]", "^[\n\r\u0085\u2028\u2029]$") } From 4f1cd551acaf0f5e48acd0ecf55587032ddb7e31 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Wed, 20 Apr 2022 17:59:26 -0700 Subject: [PATCH 08/10] forgot to add these examples to comments Signed-off-by: Navin Kumar --- .../src/main/scala/com/nvidia/spark/rapids/RegexParser.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index de4b4dde939..bace90d6d16 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -584,10 +584,12 @@ class CudfRegexTranspiler(mode: RegexMode) { case Some(RegexChar(ch)) if ch == '\r' => // when using the the CR (\r), it prevents the line anchor from handling any other // line terminator sequences, so we just output the anchor and we are finished + // for example: \r$ -> \r$ (no transpilation) RegexChar('$') case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) => // when using any other line terminator character, you can match any of the other // line terminator characters individually as part of the line anchor match. + // for example: \n$ -> \n[\r\u0085\u2028\u2029]?$ RegexSequence(ListBuffer( RegexRepetition(lineTerminatorMatcher(Set(ch), true), SimpleQuantifier('?')), RegexChar('$'))) From 02004118ba4491592ec3efe5807f5c91c26fea80 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Fri, 22 Apr 2022 10:55:25 -0700 Subject: [PATCH 09/10] add tests including form feed \f which is not technically a line terminator to test whitespace around a line anchor Signed-off-by: Navin Kumar --- .../rapids/RegularExpressionTranspilerSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index e4066ab0386..1b6c0e0d12f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -226,11 +226,11 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("line anchor $ - find") { - val patterns = Seq("$\r", "a$", "\r$", "\u0085$", "\u2028$", "\u2029$", "\n$", "\r\n$", - "[\r\n]?$", "\\00*[D$3]$", "a$b") - val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "\r", "\u0085", "\u2028", "\u2029", - "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", "\n\u0085", - "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb") + val patterns = Seq("$\r", "a$", "\r$", "\f$", "$\f", "\u0085$", "\u2028$", "\u2029$", "\n$", + "\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b") + val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028", + "\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", + "\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb") assertCpuGpuMatchesRegexpFind(patterns, inputs) } From d187506f877e88904b1cbb56d6d459d5beb6f680 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Mon, 25 Apr 2022 11:21:55 -0700 Subject: [PATCH 10/10] fix additional regular expression tests that now run on GPU Signed-off-by: Navin Kumar --- .../nvidia/spark/rapids/RegularExpressionSuite.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala index c336da0c0dd..1bec4a73d45 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala @@ -93,17 +93,13 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite { frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')") } - testGpuFallback("String regexp_extract regex 1", - "RegExpExtract", - extractStrings, execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec", "Alias", - "RegExpExtract", "AttributeReference", "Literal"),conf = conf) { + testSparkResultsAreEqual("String regexp_extract regex 1", + extractStrings, conf = conf) { frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)") } - testGpuFallback("String regexp_extract regex 2", - "RegExpExtract", - extractStrings, execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec", "Alias", - "RegExpExtract", "AttributeReference", "Literal"),conf = conf) { + testSparkResultsAreEqual("String regexp_extract regex 2", + extractStrings, conf = conf) { frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)") }