Skip to content

Commit

Permalink
Re-enable empty repetition near end-of-line anchor for rlike, regexp_…
Browse files Browse the repository at this point in the history
…extract and regexp_replace (#8081)

* WIP: fix false positive repetition near line anchor bug

Signed-off-by: Navin Kumar <navink@nvidia.com>

* Enable repetition near end of line anchor edge case

Signed-off-by: Navin Kumar <navink@nvidia.com>

* Add regexp_replace test with empty repetition

Signed-off-by: Navin Kumar <navink@nvidia.com>

* Need to add unicode enabled check for these 2 unit tests

Signed-off-by: Navin Kumar <navink@nvidia.com>

---------

Signed-off-by: Navin Kumar <navink@nvidia.com>
  • Loading branch information
NVnavkumar authored Apr 13, 2023
1 parent 5907b39 commit 90cc0a3
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 23 deletions.
23 changes: 14 additions & 9 deletions integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,21 +427,23 @@ def test_regexp_extract():
'regexp_extract(a, "([0-9]+)", 1)',
'regexp_extract(a, "([0-9])([abcd]+)", 1)',
'regexp_extract(a, "([0-9])([abcd]+)", 2)',
'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 1)',
'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 2)',
'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)\\z", 3)',
'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 1)',
'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 2)',
'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 3)',
'regexp_extract(a, "^([a-d]*)([0-9]*)\\\\/([a-d]*)", 3)',
'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)", 3)'),
'regexp_extract(a, "^([a-d]*)([0-9]*)\\\\/([a-d]*)$", 3)',
'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)", 3)',
'regexp_extract(a, "^([a-d]*)([0-9]*)(\\\\/[a-d]*)$", 3)'),
conf=_regexp_conf)

def test_regexp_extract_no_match():
gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 0)',
'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 1)',
'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 2)',
'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)\\z", 3)'),
'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)$", 0)',
'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)$", 1)',
'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)$", 2)',
'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)$", 3)'),
conf=_regexp_conf)

# if we determine that the index is out of range we fall back to CPU and let
Expand Down Expand Up @@ -680,7 +682,9 @@ def test_rlike():
'a rlike "a{2}"',
'a rlike "a{1,3}"',
'a rlike "a{1,}"',
'a rlike "a[bc]d"'),
'a rlike "a[bc]d"',
'a rlike "a[bc]d"',
'a rlike "^[a-d]*$"'),
conf=_regexp_conf)

def test_rlike_embedded_null():
Expand Down Expand Up @@ -813,6 +817,7 @@ def test_regexp_replace_unicode_support():
'REGEXP_REPLACE(a, "TEST䤫", "PROD")',
'REGEXP_REPLACE(a, "TEST[䤫]", "PROD")',
'REGEXP_REPLACE(a, "TEST.*\\\\d", "PROD")',
'REGEXP_REPLACE(a, "TEST[85]*$", "PROD")',
'REGEXP_REPLACE(a, "TEST.+$", "PROD")',
),
conf=_regexp_conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,18 @@ class CudfRegexTranspiler(mode: RegexMode) {
// check a pair of regex ast nodes for unsupported combinations
// of end string/line anchors and newlines or optional items
def checkEndAnchorContext(r1: RegexAST, r2: RegexAST): Unit = {
if ((containsEndAnchor(r1) &&
(containsNewline(r2) || containsEmpty(r2) || containsBeginAnchor(r2))) ||
(containsEndAnchor(r2) &&
(containsNewline(r1) || containsBeginAnchor(r1)))) {
throw new RegexUnsupportedException(
s"End of line/string anchor is not supported in this context: " +
s"${toReadableString(r1.toRegexString)}" +
s"${toReadableString(r2.toRegexString)}", r1.position)
}
}

def checkEndAnchorContextSplit(r1: RegexAST, r2: RegexAST): Unit = {
if ((containsEndAnchor(r1) &&
(containsNewline(r2) || containsEmpty(r2) || containsBeginAnchor(r2))) ||
(containsEndAnchor(r2) &&
Expand All @@ -957,7 +969,11 @@ class CudfRegexTranspiler(mode: RegexMode) {
regex match {
case RegexSequence(parts) =>
for (i <- 1 until parts.length) {
checkEndAnchorContext(parts(i - 1), parts(i))
if (mode == RegexSplitMode) {
checkEndAnchorContextSplit(parts(i - 1), parts(i))
} else {
checkEndAnchorContext(parts(i - 1), parts(i))
}
}
case RegexChoice(l, r) =>
checkUnsupported(l)
Expand All @@ -966,7 +982,11 @@ class CudfRegexTranspiler(mode: RegexMode) {
case RegexRepetition(ast, _) => checkUnsupported(ast)
case RegexCharacterClass(_, components) =>
for (i <- 1 until components.length) {
checkEndAnchorContext(components(i - 1), components(i))
if (mode == RegexSplitMode) {
checkEndAnchorContextSplit(components(i - 1), components(i))
} else {
checkEndAnchorContext(components(i - 1), components(i))
}
}
case _ =>
// ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,16 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite {
frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')")
}

// https://github.com/NVIDIA/spark-rapids/issues/5659
testGpuFallback("String regexp_extract regex 1",
"ProjectExec", extractStrings, conf = conf,
execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)")
testSparkResultsAreEqual("String regexp_extract regex 1", extractStrings, conf = conf) {
frame =>
assume(isUnicodeEnabled())
frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)")
}

// https://github.com/NVIDIA/spark-rapids/issues/5659
testGpuFallback("String regexp_extract regex 2",
"ProjectExec", extractStrings, conf = conf,
execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)")
testSparkResultsAreEqual("String regexp_extract regex 2", extractStrings, conf = conf) {
frame =>
assume(isUnicodeEnabled())
frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)")
}

// note that regexp_extract with a literal string gets replaced with the literal result of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
)
}

test("zero-length repetition near line anchor - regexp_find") {
val patterns = Seq("\\00*[D$3]$", "\\00*[D$3]\\Z", "^([a-z]*)([0-9]*)([a-z]*)$")
val inputs = Seq("abcd", "abc012abc", "999abb", "\\00D", "D", "D\n", "\\00D\n\r")
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

test("zero-length repetition near line anchor - regexp_replace") {
val patterns = Seq("\\00*[D$3]$", "\\00*[D$3]\\Z", "^([a-z]*)([0-9]*)([a-z]*)$")
val inputs = Seq("abcd", "abc012abc", "999abb", "\\00D", "D", "D\n", "\\00D\n\r")
assertCpuGpuMatchesRegexpReplace(patterns, inputs)
}

test("zero-length repetition near line anchor - regexp_split") {
val patterns = Set("\\00*[D$3]$", "\\00*[D$3]\\Z", "^([a-z]*)([0-9]*)([a-z]*)$")
patterns.foreach(pattern => {
assertUnsupported(pattern, RegexSplitMode,
"End of line/string anchor is not supported in this context")
})
}

test("cuDF unsupported choice cases") {
val patterns = Seq("c*|d*", "c*|dog", "[cat]{3}|dog")
patterns.foreach(pattern => {
Expand Down Expand Up @@ -303,7 +323,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
"\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb")
assertCpuGpuMatchesRegexpFind(patterns, inputs)
val unsupportedPatterns = Seq("[\r\n]?$", "$\r", "\r$",
"\u0085$", "\u2028$", "\u2029$", "\n$", "\r\n$", "\\00*[D$3]$")
// "\u0085$", "\u2028$", "\u2029$", "\n$", "\r\n$", "[D$3]$")
"\u0085$", "\u2028$", "\u2029$", "\n$", "\r\n$")
for (pattern <- unsupportedPatterns) {
assertUnsupported(pattern, RegexFindMode,
"End of line/string anchor is not supported in this context")
Expand All @@ -317,7 +338,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
"\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb")
assertCpuGpuMatchesRegexpFind(patterns, inputs)
val unsupportedPatterns = Seq("[\r\n]?\\Z", "\\Z\r", "\r\\Z",
"\u0085\\Z", "\u2028\\Z", "\u2029\\Z", "\n\\Z", "\r\n\\Z", "\\00*[D$3]\\Z")
"\u0085\\Z", "\u2028\\Z", "\u2029\\Z", "\n\\Z", "\r\n\\Z")
for (pattern <- unsupportedPatterns) {
assertUnsupported(pattern, RegexFindMode,
"End of line/string anchor is not supported in this context")
Expand Down

0 comments on commit 90cc0a3

Please sign in to comment.