Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enable dollar ($) line anchor in regular expressions in find mode #5289

Merged
merged 15 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 124 additions & 11 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
// parse the source regular expression
val regex = new RegexParser(pattern).parse()
// validate that the regex is supported by cuDF
val cudfRegex = rewrite(regex)
val cudfRegex = rewrite(regex, None)
// write out to regex string, performing minor transformations
// such as adding additional escaping
cudfRegex.toRegexString
Expand Down Expand Up @@ -535,19 +535,79 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private val lineTerminatorChars = Seq('\n', '\r', '\u0085', '\u2028', '\u2029')

// from Java 8 documention: a line terminator is a 1 to 2 character sequence that marks
// the end of a line of an input character sequence.
// this method produces a RegexAST which outputs a regular expression to match any possible
// combination of line terminators
private def lineTerminatorMatcher(exclude: Set[Char], excludeCRLF: Boolean):RegexAST = {
val terminatorChars = new ListBuffer[RegexCharacterClassComponent]()
terminatorChars ++= lineTerminatorChars.filter(!exclude.contains(_)).map(RegexChar)

if (terminatorChars.size == 0 && excludeCRLF) {
RegexEmpty()
} else if (terminatorChars.size == 0) {
RegexGroup(capture = false, RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n'))))
} else if (excludeCRLF) {
RegexCharacterClass(negated = false, characters = terminatorChars)
} else {
RegexGroup(capture = false,
RegexChoice(
RegexCharacterClass(negated = false, characters = terminatorChars),
RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n')))))
}
}

private def rewrite(regex: RegexAST): RegexAST = {
private def rewrite(regex: RegexAST, previous: Option[RegexAST]): RegexAST = {
regex match {

case RegexChar(ch) => ch match {
case '.' =>
// workaround for https://github.com/rapidsai/cudf/issues/9619
RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n')))
case '$' =>
case '$' if mode == RegexSplitMode || mode == RegexReplaceMode =>
// see https://github.com/NVIDIA/spark-rapids/issues/4533
throw new RegexUnsupportedException("line anchor $ is not supported")
throw new RegexUnsupportedException("line anchor $ is not supported in split or replace")
case '$' =>
// in the case of the line anchor $, the JVM has special conditions when handling line
// terminators in and around the anchor
// this handles cases where the line terminator characters are *before* the anchor ($)
// NOTE: this applies to when using *standard* mode. In multiline mode, all these
// conditions will change. Currently Spark does not use multiline mode.
previous match {
case Some(RegexChar('$')) =>
// repeating the line anchor in cuDF (for example b$$) causes matches to fail, but in
// Java, it's treated as a single (b$ and b$$ are synonymous), so we create
// an empty RegexAST that outputs to empty string
RegexEmpty()
case Some(RegexChar(ch)) if ch == '\r' =>
// when using the the CR (\r), it prevents the line anchor from handling any other
// line terminator sequences, so we just output the anchor and we are finished
// for example: \r$ -> \r$ (no transpilation)
RegexChar('$')
case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) =>
// when using any other line terminator character, you can match any of the other
// line terminator characters individually as part of the line anchor match.
// for example: \n$ -> \n[\r\u0085\u2028\u2029]?$
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set(ch), true), SimpleQuantifier('?')),
RegexChar('$')))
case _ =>
// otherwise by default we can match any or none the full set of line terminators
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set.empty, false), SimpleQuantifier('?')),
RegexChar('$')))
}
case '^' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("line anchor ^ is not supported in split mode")
case '\r' | '\n' if mode == RegexFindMode =>
previous match {
case Some(RegexChar('$')) =>
RegexEmpty()
case _ =>
regex
}
case _ =>
regex
}
Expand Down Expand Up @@ -634,7 +694,10 @@ class CudfRegexTranspiler(mode: RegexMode) {
case _ =>
}
val components: Seq[RegexCharacterClassComponent] = characters
.map(x => rewrite(x).asInstanceOf[RegexCharacterClassComponent])
.map(x => x match {
case RegexChar(ch) if "^$".contains(ch) => x
case _ => rewrite(x, None).asInstanceOf[RegexCharacterClassComponent]
})

if (negated) {
// There are differences between cuDF and Java handling of newlines
Expand Down Expand Up @@ -699,7 +762,52 @@ class CudfRegexTranspiler(mode: RegexMode) {
throw new RegexUnsupportedException(
"sequences that only contain '^' or '$' are not supported")
}
RegexSequence(parts.map(rewrite))

// Special handling for line anchor ($)
// This code is implemented here because to make it work in cuDF, we have to reorder
// the items in the regex.
// In the JVM, regexes like "\n$" and "$\n" have similar treatment
RegexSequence(parts.foldLeft((new ListBuffer[RegexAST](),
Option.empty[RegexAST]))((m, part) => {
val (r, last) = m
last match {
// when the previous character is a line anchor ($), the JVM has special handling
// when matching against line terminator characters
case Some(RegexChar('$')) =>
val j = r.lastIndexWhere {
case RegexEmpty() => false
case _ => true
}
part match {
case RegexCharacterClass(true, parts)
if parts.forall(!isBeginOrEndLineAnchor(_)) =>
r(j) = RegexSequence(
ListBuffer(lineTerminatorMatcher(Set.empty, true), RegexChar('$')))
case RegexChar(ch) if ch == '\n' =>
// what's really needed here is negative lookahead, but that is not
// supported by cuDF
// in this case: $\n would transpile to (?!\r)\n$
throw new RegexUnsupportedException("regex sequence $\\n is not supported")
case RegexChar(ch) if "\r\u0085\u2028\u2029".contains(ch) =>
r(j) = RegexSequence(
ListBuffer(
rewrite(part, None),
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set(ch), true),
SimpleQuantifier('?')), RegexChar('$')))))
case _ =>
r.append(rewrite(part, last))
}
case _ =>
r.append(rewrite(part, last))
}
r.last match {
case RegexEmpty() =>
(r, last)
case _ =>
(r, Some(part))
}
})._1)

case RegexRepetition(base, quantifier) => (base, quantifier) match {
case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) =>
Expand Down Expand Up @@ -742,17 +850,17 @@ class CudfRegexTranspiler(mode: RegexMode) {
// specifically this variable length repetition: \A{2,}
throw new RegexUnsupportedException(nothingToRepeat)
case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' =>
RegexRepetition(rewrite(base), quantifier)
RegexRepetition(rewrite(base, None), quantifier)
case _ if isSupportedRepetitionBase(base) =>
RegexRepetition(rewrite(base), quantifier)
RegexRepetition(rewrite(base, None), quantifier)
case _ =>
throw new RegexUnsupportedException(nothingToRepeat)

}

case RegexChoice(l, r) =>
val ll = rewrite(l)
val rr = rewrite(r)
val ll = rewrite(l, None)
val rr = rewrite(r, None)

// cuDF does not support repetition on one side of a choice, such as "a*|a"
if (isRepetition(ll) || isRepetition(rr)) {
Expand All @@ -776,7 +884,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
RegexChoice(ll, rr)

case RegexGroup(capture, term) =>
RegexGroup(capture, rewrite(term))
RegexGroup(capture, rewrite(term, None))

case other =>
throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other")
Expand Down Expand Up @@ -804,6 +912,11 @@ sealed trait RegexAST {
def toRegexString: String
}

sealed case class RegexEmpty() extends RegexAST {
override def children(): Seq[RegexAST] = Seq.empty
override def toRegexString: String = ""
}

sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST {
override def children(): Seq[RegexAST] = parts
override def toRegexString: String = parts.map(_.toRegexString).mkString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,13 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite {
frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')")
}

testGpuFallback("String regexp_extract regex 1",
"RegExpExtract",
extractStrings, execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec", "Alias",
"RegExpExtract", "AttributeReference", "Literal"),conf = conf) {
testSparkResultsAreEqual("String regexp_extract regex 1",
extractStrings, conf = conf) {
frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)")
}

testGpuFallback("String regexp_extract regex 2",
"RegExpExtract",
extractStrings, execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec", "Alias",
"RegExpExtract", "AttributeReference", "Literal"),conf = conf) {
testSparkResultsAreEqual("String regexp_extract regex 2",
extractStrings, conf = conf) {
frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,25 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
"\ntest", "test\n", "\ntest\n", "\ntest\r\ntest\n"))
}

test("line anchor $ fall back to CPU") {
for (mode <- Seq(RegexFindMode, RegexReplaceMode)) {
assertUnsupported("a$b", mode, "line anchor $ is not supported")
test("line anchor $ fall back to CPU - split and replace") {
for (mode <- Seq(RegexSplitMode, RegexReplaceMode)) {
assertUnsupported("a$b", mode, "line anchor $ is not supported in split or replace")
}
}

test("line anchor sequence $\\n fall back to CPU") {
assertUnsupported("a$\n", RegexFindMode, "regex sequence $\\n is not supported")
}

test("line anchor $ - find") {
val patterns = Seq("$\r", "a$", "\r$", "\f$", "$\f", "\u0085$", "\u2028$", "\u2029$", "\n$",
"\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b")
val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028",
"\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r",
"\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb")
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

test("whitespace boundaries - replace") {
assertCpuGpuMatchesRegexpReplace(
Seq("\\s", "\\S"),
Expand Down Expand Up @@ -280,6 +293,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
doTranspileTest("abc\\z", "abc$")
}

test("transpile $") {
doTranspileTest("a$", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("$$\r", "\r[\n\u0085\u2028\u2029]?$")
doTranspileTest("]$\r", "]\r[\n\u0085\u2028\u2029]?$")
doTranspileTest("^$[^*A-ZA-Z]", "^[\n\r\u0085\u2028\u2029]$")
}

test("compare CPU and GPU: character range including unescaped + and -") {
val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]")
val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]")
Expand Down