Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regex parsing logic handling of nested quantifiers #16798

Merged
merged 9 commits into from
Oct 10, 2024
40 changes: 27 additions & 13 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,13 +716,13 @@ class regex_parser {
if (item.type != COUNTED && item.type != COUNTED_LAZY) {
out.push_back(item);
if (item.type == LBRA || item.type == LBRA_NC) {
lbra_stack.push(index);
lbra_stack.push(out.size() - 1);
repeat_start_index = -1;
} else if (item.type == RBRA) {
repeat_start_index = lbra_stack.top();
lbra_stack.pop();
} else if ((item.type & ITEM_MASK) != OPERATOR_MASK) {
repeat_start_index = index;
repeat_start_index = out.size() - 1;
}
} else {
// item is of type COUNTED or COUNTED_LAZY
Expand All @@ -731,26 +731,39 @@ class regex_parser {
CUDF_EXPECTS(repeat_start_index >= 0, "regex: invalid counted quantifier location");

// range of affected item(s) to repeat
auto const begin = in.begin() + repeat_start_index;
auto const end = in.begin() + index;
auto const begin = out.begin() + repeat_start_index;
auto const end = out.end();

// count range values
auto const n = item.d.count.n; // minimum count
auto const m = item.d.count.m; // maximum count

assert(n >= 0 && "invalid repeat count value n");
// zero-repeat edge-case: need to erase the previous items
if (n == 0) { out.erase(out.end() - (index - repeat_start_index), out.end()); }

// minimum repeats (n)
for (int j = 1; j < n; j++) {
out.insert(out.end(), begin, end);
if (n == 0 && m == 0) { out.erase(begin, end); }

std::vector<regex_parser::Item> repeat_copy(begin, end);
// special handling for quantified capture groups
if ((n > 1) && (*begin).type == LBRA) {
(*begin).type = LBRA_NC; // change first one to non-capture
// add intermediate groups as non-capture
vyasr marked this conversation as resolved.
Show resolved Hide resolved
std::vector<regex_parser::Item> ncg_copy(begin, end);
for (int j = 1; j < (n - 1); j++) {
out.insert(out.end(), ncg_copy.begin(), ncg_copy.end());
}
// add the last entry as a regular capture-group
out.insert(out.end(), repeat_copy.begin(), repeat_copy.end());
} else {
// minimum repeats (n)
for (int j = 1; j < n; j++) {
out.insert(out.end(), repeat_copy.begin(), repeat_copy.end());
}
}

// optional maximum repeats (m)
if (m >= 0) {
for (int j = n; j < m; j++) {
out.emplace_back(LBRA_NC, 0);
out.insert(out.end(), begin, end);
out.insert(out.end(), repeat_copy.begin(), repeat_copy.end());
}
for (int j = n; j < m; j++) {
out.emplace_back(RBRA, 0);
Expand All @@ -760,8 +773,9 @@ class regex_parser {
// infinite repeats
if (n > 0) { // append '+' after last repetition
out.emplace_back(item.type == COUNTED ? PLUS : PLUS_LAZY, 0);
} else { // copy it once then append '*'
out.insert(out.end(), begin, end);
} else {
// copy it once then append '*'
out.insert(out.end(), repeat_copy.begin(), repeat_copy.end());
out.emplace_back(item.type == COUNTED ? STAR : STAR_LAZY, 0);
}
}
Expand Down
14 changes: 14 additions & 0 deletions cpp/tests/strings/contains_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,20 @@ TEST_F(StringsContainsTests, FixedQuantifier)
}
}

TEST_F(StringsContainsTests, NestedQuantifier)
{
auto input = cudf::test::strings_column_wrapper({"TEST12 1111 2222 3333 4444 5555",
"0000 AAAA 9999 BBBB 8888",
"7777 6666 4444 3333",
"12345 3333 4444 1111 ABCD"});
auto sv = cudf::strings_column_view(input);
auto pattern = std::string(R"((\d{4}\s){4})");
cudf::test::fixed_width_column_wrapper<bool> expected({true, false, false, true});
auto prog = cudf::strings::regex_program::create(pattern);
auto results = cudf::strings::contains_re(sv, *prog);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

TEST_F(StringsContainsTests, QuantifierErrors)
{
EXPECT_THROW(cudf::strings::regex_program::create("^+"), cudf::logic_error);
Expand Down
16 changes: 15 additions & 1 deletion cpp/tests/strings/extract_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_utilities.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/debug_utilities.hpp>
#include <cudf_test/table_utilities.hpp>

#include <cudf/detail/iterator.cuh>
Expand Down Expand Up @@ -240,6 +239,21 @@ TEST_F(StringsExtractTests, SpecialNewLines)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view().column(0), expected);
}

TEST_F(StringsExtractTests, NestedQuantifier)
{
auto input = cudf::test::strings_column_wrapper({"TEST12 1111 2222 3333 4444 5555",
"0000 AAAA 9999 BBBB 8888",
"7777 6666 4444 3333",
"12345 3333 4444 1111 ABCD"});
auto sv = cudf::strings_column_view(input);
auto pattern = std::string(R"((\d{4}\s){4})");
auto prog = cudf::strings::regex_program::create(pattern);
auto results = cudf::strings::extract(sv, *prog);
// fixed quantifier on capture group only honors the last group
auto expected = cudf::test::strings_column_wrapper({"4444 ", "", "", "1111 "}, {1, 0, 0, 1});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view().column(0), expected);
}

TEST_F(StringsExtractTests, EmptyExtractTest)
{
std::vector<char const*> h_strings{nullptr, "AAA", "AAA_A", "AAA_AAA_", "A__", ""};
Expand Down
Loading