Skip to content

Commit

Permalink
ARROW-11960: [C++][Gandiva] Support escape in LIKE
Browse files Browse the repository at this point in the history
Add gdv_fn_like_utf8_utf8_int8 function in Gandiva to support escape char in LIKE. An escape char is stored in an int8 type which is compatible with char type in C++.

Closes apache#9700 from Crystrix/arrow-11960

Authored-by: crystrix <chenxi.li@live.com>
Signed-off-by: Sutou Kouhei <kou@clear-code.com>
(cherry picked from commit ca66567)
  • Loading branch information
Crystrix authored and jvictorhuguenin committed Sep 16, 2021
1 parent 86c72de commit ad154d2
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 4 deletions.
4 changes: 4 additions & 0 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
kResultNullIfNull, "gdv_fn_like_utf8_utf8",
NativeFunction::kNeedsFunctionHolder),

NativeFunction("like", {}, DataTypeVector{utf8(), utf8(), utf8()}, boolean(),
kResultNullIfNull, "gdv_fn_like_utf8_utf8_utf8",
NativeFunction::kNeedsFunctionHolder),

NativeFunction("ltrim", {}, DataTypeVector{utf8(), utf8()}, utf8(),
kResultNullIfNull, "ltrim_utf8_utf8", NativeFunction::kNeedsContext),

Expand Down
20 changes: 20 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len,
return (*holder)(std::string(data, data_len));
}

bool gdv_fn_like_utf8_utf8_utf8(int64_t ptr, const char* data, int data_len,
const char* pattern, int pattern_len,
const char* escape_char, int escape_char_len) {
gandiva::LikeHolder* holder = reinterpret_cast<gandiva::LikeHolder*>(ptr);
return (*holder)(std::string(data, data_len));
}

double gdv_fn_random(int64_t ptr) {
gandiva::RandomGeneratorHolder* holder =
reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr);
Expand Down Expand Up @@ -732,6 +739,19 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
types->i1_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_like_utf8_utf8));

// gdv_fn_like_utf8_utf8_utf8
args = {types->i64_type(), // int64_t ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int data_len
types->i8_ptr_type(), // const char* pattern
types->i32_type(), // int pattern_len
types->i8_ptr_type(), // const char* escape_char
types->i32_type()}; // int escape_char_len

engine->AddGlobalMappingForFunc("gdv_fn_like_utf8_utf8_utf8",
types->i1_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_like_utf8_utf8_utf8));

// gdv_fn_to_date_utf8_utf8
args = {types->i64_type(), // int64_t execution_context
types->i64_type(), // int64_t holder_ptr
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ using gdv_day_time_interval = int64_t;
bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len,
const char* pattern, int pattern_len);

bool gdv_fn_like_utf8_utf8_utf8(int64_t ptr, const char* data, int data_len,
const char* pattern, int pattern_len,
const char* escape_char, int escape_char_len);

int64_t gdv_fn_to_date_utf8_utf8_int32(int64_t context, int64_t ptr, const char* data,
int data_len, bool in1_validity,
const char* pattern, int pattern_len,
Expand Down
43 changes: 39 additions & 4 deletions cpp/src/gandiva/like_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ static bool IsArrowStringLiteral(arrow::Type::type type) {
}

Status LikeHolder::Make(const FunctionNode& node, std::shared_ptr<LikeHolder>* holder) {
ARROW_RETURN_IF(node.children().size() != 2,
Status::Invalid("'like' function requires two parameters"));
ARROW_RETURN_IF(node.children().size() != 2 && node.children().size() != 3,
Status::Invalid("'like' function requires two or three parameters"));

auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get());
ARROW_RETURN_IF(
Expand All @@ -80,8 +80,22 @@ Status LikeHolder::Make(const FunctionNode& node, std::shared_ptr<LikeHolder>* h
!IsArrowStringLiteral(literal_type),
Status::Invalid(
"'like' function requires a string literal as the second parameter"));

return Make(arrow::util::get<std::string>(literal->holder()), holder);
if (node.children().size() == 2) {
return Make(arrow::util::get<std::string>(literal->holder()), holder);
} else {
auto escape_char = dynamic_cast<LiteralNode*>(node.children().at(2).get());
ARROW_RETURN_IF(
escape_char == nullptr,
Status::Invalid("'like' function requires a literal as the third parameter"));

auto escape_char_type = escape_char->return_type()->id();
ARROW_RETURN_IF(
!IsArrowStringLiteral(escape_char_type),
Status::Invalid(
"'like' function requires a string literal as the third parameter"));
return Make(arrow::util::get<std::string>(literal->holder()),
arrow::util::get<std::string>(escape_char->holder()), holder);
}
}

Status LikeHolder::Make(const std::string& sql_pattern,
Expand All @@ -97,4 +111,25 @@ Status LikeHolder::Make(const std::string& sql_pattern,
return Status::OK();
}

Status LikeHolder::Make(const std::string& sql_pattern, const std::string& escape_char,
std::shared_ptr<LikeHolder>* holder) {
ARROW_RETURN_IF(escape_char.length() > 1,
Status::Invalid("The length of escape char ", escape_char,
" in 'like' function is greater than 1"));
std::string pcre_pattern;
if (escape_char.length() == 1) {
ARROW_RETURN_NOT_OK(
RegexUtil::SqlLikePatternToPcre(sql_pattern, escape_char.at(0), pcre_pattern));
} else {
ARROW_RETURN_NOT_OK(RegexUtil::SqlLikePatternToPcre(sql_pattern, pcre_pattern));
}

auto lholder = std::shared_ptr<LikeHolder>(new LikeHolder(pcre_pattern));
ARROW_RETURN_IF(!lholder->regex_.ok(),
Status::Invalid("Building RE2 pattern '", pcre_pattern, "' failed"));

*holder = lholder;
return Status::OK();
}

} // namespace gandiva
3 changes: 3 additions & 0 deletions cpp/src/gandiva/like_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class GANDIVA_EXPORT LikeHolder : public FunctionHolder {

static Status Make(const std::string& sql_pattern, std::shared_ptr<LikeHolder>* holder);

static Status Make(const std::string& sql_pattern, const std::string& escape_char,
std::shared_ptr<LikeHolder>* holder);

// Try and optimise a function node with a "like" pattern.
static const FunctionNode TryOptimize(const FunctionNode& node);

Expand Down
84 changes: 84 additions & 0 deletions cpp/src/gandiva/like_holder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ class TestLikeHolder : public ::testing::Test {
std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false);
return FunctionNode("like", {field, pattern_node}, arrow::boolean());
}

FunctionNode BuildLike(std::string pattern, char escape_char) {
auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
auto pattern_node =
std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false);
auto escape_char_node = std::make_shared<LiteralNode>(
arrow::int8(), LiteralHolder((int8_t)escape_char), false);
return FunctionNode("like", {field, pattern_node, escape_char_node},
arrow::boolean());
}
};

TEST_F(TestLikeHolder, TestMatchAny) {
Expand Down Expand Up @@ -125,6 +135,80 @@ TEST_F(TestLikeHolder, TestOptimise) {

fnode = LikeHolder::TryOptimize(BuildLike("x_yz%"));
EXPECT_EQ(fnode.descriptor()->name(), "like");

// no optimisation for escaped pattern.
fnode = LikeHolder::TryOptimize(BuildLike("\\%xyz", '\\'));
EXPECT_EQ(fnode.descriptor()->name(), "like");
EXPECT_EQ(fnode.ToString(),
"bool like((string) in, (const string) \\%xyz, (const int8) \\)");
}

TEST_F(TestLikeHolder, TestMatchOneEscape) {
std::shared_ptr<LikeHolder> like_holder;

auto status = LikeHolder::Make("ab\\_", "\\", &like_holder);
EXPECT_EQ(status.ok(), true) << status.message();

auto& like = *like_holder;

EXPECT_TRUE(like("ab_"));

EXPECT_FALSE(like("abc"));
EXPECT_FALSE(like("abd"));
EXPECT_FALSE(like("a"));
EXPECT_FALSE(like("abcd"));
EXPECT_FALSE(like("dabc"));
}

TEST_F(TestLikeHolder, TestMatchManyEscape) {
std::shared_ptr<LikeHolder> like_holder;

auto status = LikeHolder::Make("ab\\%", "\\", &like_holder);
EXPECT_EQ(status.ok(), true) << status.message();

auto& like = *like_holder;

EXPECT_TRUE(like("ab%"));

EXPECT_FALSE(like("abc"));
EXPECT_FALSE(like("abd"));
EXPECT_FALSE(like("a"));
EXPECT_FALSE(like("abcd"));
EXPECT_FALSE(like("dabc"));
}

TEST_F(TestLikeHolder, TestMatchEscape) {
std::shared_ptr<LikeHolder> like_holder;

auto status = LikeHolder::Make("ab\\\\", "\\", &like_holder);
EXPECT_EQ(status.ok(), true) << status.message();

auto& like = *like_holder;

EXPECT_TRUE(like("ab\\"));

EXPECT_FALSE(like("abc"));
}

TEST_F(TestLikeHolder, TestEmptyEscapeChar) {
std::shared_ptr<LikeHolder> like_holder;

auto status = LikeHolder::Make("ab\\_", "", &like_holder);
EXPECT_EQ(status.ok(), true) << status.message();

auto& like = *like_holder;

EXPECT_TRUE(like("ab\\c"));
EXPECT_TRUE(like("ab\\_"));

EXPECT_FALSE(like("ab\\_d"));
EXPECT_FALSE(like("ab__"));
}

TEST_F(TestLikeHolder, TestMultipleEscapeChar) {
std::shared_ptr<LikeHolder> like_holder;

auto status = LikeHolder::Make("ab\\_", "\\\\", &like_holder);
EXPECT_EQ(status.ok(), false) << status.message();
}
} // namespace gandiva
43 changes: 43 additions & 0 deletions cpp/src/gandiva/tests/utf8_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,49 @@ TEST_F(TestUtf8, TestLike) {
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}

TEST_F(TestUtf8, TestLikeWithEscape) {
// schema for input fields
auto field_a = field("a", utf8());
auto schema = arrow::schema({field_a});

// output fields
auto res = field("res", boolean());

// build expressions.
// like(literal(s), a, '\')

auto node_a = TreeExprBuilder::MakeField(field_a);
auto literal_s = TreeExprBuilder::MakeStringLiteral("%pa\\%rk%");
auto escape_char = TreeExprBuilder::MakeStringLiteral("\\");
auto is_like =
TreeExprBuilder::MakeFunction("like", {node_a, literal_s, escape_char}, boolean());
auto expr = TreeExprBuilder::MakeExpression(is_like, res);

// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok()) << status.message();

// Create a row-batch with some sample data
int num_records = 4;
auto array_a = MakeArrowArrayUtf8(
{"park", "spa%rkle", "bright spa%rk and fire", "spark"}, {true, true, true, true});

// expected output
auto exp = MakeArrowArrayBool({false, true, true, false}, {true, true, true, true});

// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});

// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok()) << status.message();

// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}

TEST_F(TestUtf8, TestBeginsEnds) {
// schema for input fields
auto field_a = field("a", utf8());
Expand Down

0 comments on commit ad154d2

Please sign in to comment.