Skip to content

Commit

Permalink
Re-enable dollar ($) line anchor in regular expressions in find mode (#…
Browse files Browse the repository at this point in the history
…5289)

* WIP: dollar anchor support

Signed-off-by: Navin Kumar <navink@nvidia.com>

* Solved issue with anchor in character class (treat as literal $)

Signed-off-by: Navin Kumar <navink@nvidia.com>

* WIP: more dollar anchor support, a few more cases to tidy up

Signed-off-by: Navin Kumar <navink@nvidia.com>

* Remove next parameter and fix style issues

Signed-off-by: Navin Kumar <navink@nvidia.com>

* Handle most edge cases with line anchor $

Signed-off-by: Navin Kumar <navink@nvidia.com>

* Finish handling last few edge cases regarding carriage return vs other line termination characters

Signed-off-by: Navin Kumar <navink@nvidia.com>

* Disable one particular edge case as we don't have the underlying cudf support to transpile. Also, add more comments

Signed-off-by: Navin Kumar <navink@nvidia.com>

* forgot to add these examples to comments

Signed-off-by: Navin Kumar <navink@nvidia.com>

* add tests including form feed \f which is not technically a line terminator to test whitespace around a line anchor

Signed-off-by: Navin Kumar <navink@nvidia.com>

* fix additional regular expression tests that now run on GPU

Signed-off-by: Navin Kumar <navink@nvidia.com>
  • Loading branch information
NVnavkumar authored Apr 26, 2022
1 parent 785b4ac commit cc3af4b
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 22 deletions.
135 changes: 124 additions & 11 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
// parse the source regular expression
val regex = new RegexParser(pattern).parse()
// validate that the regex is supported by cuDF
val cudfRegex = rewrite(regex)
val cudfRegex = rewrite(regex, None)
// write out to regex string, performing minor transformations
// such as adding additional escaping
cudfRegex.toRegexString
Expand Down Expand Up @@ -535,19 +535,79 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private val lineTerminatorChars = Seq('\n', '\r', '\u0085', '\u2028', '\u2029')

// from Java 8 documention: a line terminator is a 1 to 2 character sequence that marks
// the end of a line of an input character sequence.
// this method produces a RegexAST which outputs a regular expression to match any possible
// combination of line terminators
private def lineTerminatorMatcher(exclude: Set[Char], excludeCRLF: Boolean):RegexAST = {
val terminatorChars = new ListBuffer[RegexCharacterClassComponent]()
terminatorChars ++= lineTerminatorChars.filter(!exclude.contains(_)).map(RegexChar)

if (terminatorChars.size == 0 && excludeCRLF) {
RegexEmpty()
} else if (terminatorChars.size == 0) {
RegexGroup(capture = false, RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n'))))
} else if (excludeCRLF) {
RegexCharacterClass(negated = false, characters = terminatorChars)
} else {
RegexGroup(capture = false,
RegexChoice(
RegexCharacterClass(negated = false, characters = terminatorChars),
RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n')))))
}
}

private def rewrite(regex: RegexAST): RegexAST = {
private def rewrite(regex: RegexAST, previous: Option[RegexAST]): RegexAST = {
regex match {

case RegexChar(ch) => ch match {
case '.' =>
// workaround for https://github.com/rapidsai/cudf/issues/9619
RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n')))
case '$' =>
case '$' if mode == RegexSplitMode || mode == RegexReplaceMode =>
// see https://github.com/NVIDIA/spark-rapids/issues/4533
throw new RegexUnsupportedException("line anchor $ is not supported")
throw new RegexUnsupportedException("line anchor $ is not supported in split or replace")
case '$' =>
// in the case of the line anchor $, the JVM has special conditions when handling line
// terminators in and around the anchor
// this handles cases where the line terminator characters are *before* the anchor ($)
// NOTE: this applies to when using *standard* mode. In multiline mode, all these
// conditions will change. Currently Spark does not use multiline mode.
previous match {
case Some(RegexChar('$')) =>
// repeating the line anchor in cuDF (for example b$$) causes matches to fail, but in
// Java, it's treated as a single (b$ and b$$ are synonymous), so we create
// an empty RegexAST that outputs to empty string
RegexEmpty()
case Some(RegexChar(ch)) if ch == '\r' =>
// when using the the CR (\r), it prevents the line anchor from handling any other
// line terminator sequences, so we just output the anchor and we are finished
// for example: \r$ -> \r$ (no transpilation)
RegexChar('$')
case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) =>
// when using any other line terminator character, you can match any of the other
// line terminator characters individually as part of the line anchor match.
// for example: \n$ -> \n[\r\u0085\u2028\u2029]?$
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set(ch), true), SimpleQuantifier('?')),
RegexChar('$')))
case _ =>
// otherwise by default we can match any or none the full set of line terminators
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set.empty, false), SimpleQuantifier('?')),
RegexChar('$')))
}
case '^' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("line anchor ^ is not supported in split mode")
case '\r' | '\n' if mode == RegexFindMode =>
previous match {
case Some(RegexChar('$')) =>
RegexEmpty()
case _ =>
regex
}
case _ =>
regex
}
Expand Down Expand Up @@ -634,7 +694,10 @@ class CudfRegexTranspiler(mode: RegexMode) {
case _ =>
}
val components: Seq[RegexCharacterClassComponent] = characters
.map(x => rewrite(x).asInstanceOf[RegexCharacterClassComponent])
.map(x => x match {
case RegexChar(ch) if "^$".contains(ch) => x
case _ => rewrite(x, None).asInstanceOf[RegexCharacterClassComponent]
})

if (negated) {
// There are differences between cuDF and Java handling of newlines
Expand Down Expand Up @@ -699,7 +762,52 @@ class CudfRegexTranspiler(mode: RegexMode) {
throw new RegexUnsupportedException(
"sequences that only contain '^' or '$' are not supported")
}
RegexSequence(parts.map(rewrite))

// Special handling for line anchor ($)
// This code is implemented here because to make it work in cuDF, we have to reorder
// the items in the regex.
// In the JVM, regexes like "\n$" and "$\n" have similar treatment
RegexSequence(parts.foldLeft((new ListBuffer[RegexAST](),
Option.empty[RegexAST]))((m, part) => {
val (r, last) = m
last match {
// when the previous character is a line anchor ($), the JVM has special handling
// when matching against line terminator characters
case Some(RegexChar('$')) =>
val j = r.lastIndexWhere {
case RegexEmpty() => false
case _ => true
}
part match {
case RegexCharacterClass(true, parts)
if parts.forall(!isBeginOrEndLineAnchor(_)) =>
r(j) = RegexSequence(
ListBuffer(lineTerminatorMatcher(Set.empty, true), RegexChar('$')))
case RegexChar(ch) if ch == '\n' =>
// what's really needed here is negative lookahead, but that is not
// supported by cuDF
// in this case: $\n would transpile to (?!\r)\n$
throw new RegexUnsupportedException("regex sequence $\\n is not supported")
case RegexChar(ch) if "\r\u0085\u2028\u2029".contains(ch) =>
r(j) = RegexSequence(
ListBuffer(
rewrite(part, None),
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set(ch), true),
SimpleQuantifier('?')), RegexChar('$')))))
case _ =>
r.append(rewrite(part, last))
}
case _ =>
r.append(rewrite(part, last))
}
r.last match {
case RegexEmpty() =>
(r, last)
case _ =>
(r, Some(part))
}
})._1)

case RegexRepetition(base, quantifier) => (base, quantifier) match {
case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) =>
Expand Down Expand Up @@ -742,17 +850,17 @@ class CudfRegexTranspiler(mode: RegexMode) {
// specifically this variable length repetition: \A{2,}
throw new RegexUnsupportedException(nothingToRepeat)
case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' =>
RegexRepetition(rewrite(base), quantifier)
RegexRepetition(rewrite(base, None), quantifier)
case _ if isSupportedRepetitionBase(base) =>
RegexRepetition(rewrite(base), quantifier)
RegexRepetition(rewrite(base, None), quantifier)
case _ =>
throw new RegexUnsupportedException(nothingToRepeat)

}

case RegexChoice(l, r) =>
val ll = rewrite(l)
val rr = rewrite(r)
val ll = rewrite(l, None)
val rr = rewrite(r, None)

// cuDF does not support repetition on one side of a choice, such as "a*|a"
if (isRepetition(ll) || isRepetition(rr)) {
Expand All @@ -776,7 +884,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
RegexChoice(ll, rr)

case RegexGroup(capture, term) =>
RegexGroup(capture, rewrite(term))
RegexGroup(capture, rewrite(term, None))

case other =>
throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other")
Expand Down Expand Up @@ -804,6 +912,11 @@ sealed trait RegexAST {
def toRegexString: String
}

sealed case class RegexEmpty() extends RegexAST {
override def children(): Seq[RegexAST] = Seq.empty
override def toRegexString: String = ""
}

sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST {
override def children(): Seq[RegexAST] = parts
override def toRegexString: String = parts.map(_.toRegexString).mkString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,13 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite {
frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')")
}

testGpuFallback("String regexp_extract regex 1",
"RegExpExtract",
extractStrings, execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec", "Alias",
"RegExpExtract", "AttributeReference", "Literal"),conf = conf) {
testSparkResultsAreEqual("String regexp_extract regex 1",
extractStrings, conf = conf) {
frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)")
}

testGpuFallback("String regexp_extract regex 2",
"RegExpExtract",
extractStrings, execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec", "Alias",
"RegExpExtract", "AttributeReference", "Literal"),conf = conf) {
testSparkResultsAreEqual("String regexp_extract regex 2",
extractStrings, conf = conf) {
frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,25 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
"\ntest", "test\n", "\ntest\n", "\ntest\r\ntest\n"))
}

test("line anchor $ fall back to CPU") {
for (mode <- Seq(RegexFindMode, RegexReplaceMode)) {
assertUnsupported("a$b", mode, "line anchor $ is not supported")
test("line anchor $ fall back to CPU - split and replace") {
for (mode <- Seq(RegexSplitMode, RegexReplaceMode)) {
assertUnsupported("a$b", mode, "line anchor $ is not supported in split or replace")
}
}

test("line anchor sequence $\\n fall back to CPU") {
assertUnsupported("a$\n", RegexFindMode, "regex sequence $\\n is not supported")
}

test("line anchor $ - find") {
val patterns = Seq("$\r", "a$", "\r$", "\f$", "$\f", "\u0085$", "\u2028$", "\u2029$", "\n$",
"\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b")
val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028",
"\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r",
"\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb")
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

test("whitespace boundaries - replace") {
assertCpuGpuMatchesRegexpReplace(
Seq("\\s", "\\S"),
Expand Down Expand Up @@ -280,6 +293,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
doTranspileTest("abc\\z", "abc$")
}

test("transpile $") {
doTranspileTest("a$", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("$$\r", "\r[\n\u0085\u2028\u2029]?$")
doTranspileTest("]$\r", "]\r[\n\u0085\u2028\u2029]?$")
doTranspileTest("^$[^*A-ZA-Z]", "^[\n\r\u0085\u2028\u2029]$")
}

test("compare CPU and GPU: character range including unescaped + and -") {
val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]")
val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]")
Expand Down

0 comments on commit cc3af4b

Please sign in to comment.