Skip to content

Commit

Permalink
Fix issue in get_json_object (apache#50)
Browse files Browse the repository at this point in the history
* Correct a typo

* Fix issues in getting result and cover some corner cases

* Return null if given field has no data in json string

* Move out the setting for out_valid

* Add missing args and re-order them

* Change arg order

* Fix incorrect number of args issue
  • Loading branch information
PHILO-HE authored Dec 1, 2021
1 parent dc54cda commit e99f4f8
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 28 deletions.
2 changes: 1 addition & 1 deletion cpp/src/gandiva/function_holder_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class FunctionHolderRegistry {
static map_type maker_map = {
{"like", LAMBDA_MAKER(LikeHolder)},
{"ilike", LAMBDA_MAKER(LikeHolder)},
{"get_json_obejct", LAMBDA_MAKER(JsonHolder)},
{"get_json_object", LAMBDA_MAKER(JsonHolder)},
{"to_date", LAMBDA_MAKER(ToDateHolder)},
{"random", LAMBDA_MAKER(RandomGeneratorHolder)},
{"rand", LAMBDA_MAKER(RandomGeneratorHolder)},
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction::kNeedsFunctionHolder),

NativeFunction("get_json_object", {}, DataTypeVector{utf8(), utf8()}, utf8(),
kResultNullIfNull, "gdv_fn_get_json_object_utf8_utf8",
NativeFunction::kNeedsFunctionHolder),
kResultNullInternal, "gdv_fn_get_json_object_utf8_utf8",
NativeFunction::kNeedsContext | NativeFunction::kNeedsFunctionHolder |
NativeFunction::kCanReturnErrors),

NativeFunction("ltrim", {}, DataTypeVector{utf8(), utf8()}, utf8(),
kResultNullIfNull, "ltrim_utf8_utf8", NativeFunction::kNeedsContext),
Expand Down
36 changes: 26 additions & 10 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,23 @@

extern "C" {

const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, const char* data, int data_len,
const char* pattern, int pattern_len, int32_t* out_len) {
gandiva::JsonHolder* holder = reinterpret_cast<gandiva::JsonHolder*>(ptr);
return (*holder)(std::string(data, data_len), std::string(pattern, pattern_len), out_len);
const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, int64_t holder_ptr, const char* data, int data_len, bool in1_valid,
const char* pattern, int pattern_len, bool in2_valid, bool* out_valid, int32_t* out_len) {
if (!in1_valid || !in2_valid) {
*out_valid = false;
*out_len = 0;
return reinterpret_cast<const uint8_t*>("");
}
gandiva::ExecutionContext* context = reinterpret_cast<gandiva::ExecutionContext*>(ptr);
gandiva::JsonHolder* holder = reinterpret_cast<gandiva::JsonHolder*>(holder_ptr);
auto res = (*holder)(context, std::string(data, data_len), std::string(pattern, pattern_len), out_len);
if (res == nullptr) {
*out_valid = false;
*out_len = 0;
return reinterpret_cast<const uint8_t*>("");
}
*out_valid = true;
return res;
}

bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len,
Expand Down Expand Up @@ -492,12 +505,15 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {

// gdv_fn_get_json_object_utf8_utf8
args = {types->i64_type(), // int64_t ptr
types->i64_type(), // int64_t holder_ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int data_len
types->i1_type(), // bool in1_validity
types->i8_ptr_type(), // const char* pattern
types->i32_type(), // int pattern_len
types->i32_ptr_type()}; // int out_len

types->i1_type(), // bool in2_validity
types->ptr_type(types->i8_type()), // bool* out_valid
types->i32_ptr_type()}; // int out_len
engine->AddGlobalMappingForFunc("gdv_fn_get_json_object_utf8_utf8",
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_get_json_object_utf8_utf8));
Expand Down Expand Up @@ -655,7 +671,7 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int32_t lenr
types->i1_type(), // bool in2_validity
types->i1_type(), // bool in1_validity
types->ptr_type(types->i8_type())}; // bool* out_valid

engine->AddGlobalMappingForFunc("gdv_fn_castINT_or_null_utf8", types->i32_type(), args,
Expand All @@ -671,7 +687,7 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int32_t lenr
types->i1_type(), // bool in2_validity
types->i1_type(), // bool in1_validity
types->ptr_type(types->i8_type())}; // bool* out_valid

engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_or_null_utf8", types->i64_type(), args,
Expand All @@ -687,7 +703,7 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int32_t lenr
types->i1_type(), // bool in2_validity
types->i1_type(), // bool in1_validity
types->ptr_type(types->i8_type())}; // bool* out_valid

engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_or_null_utf8", types->float_type(), args,
Expand All @@ -703,7 +719,7 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type(), // int32_t lenr
types->i1_type(), // bool in2_validity
types->i1_type(), // bool in1_validity
types->ptr_type(types->i8_type())}; // bool* out_valid

engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_or_null_utf8", types->double_type(), args,
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/gandiva/gdv_function_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ bool gdv_fn_like_utf8_utf8_utf8(int64_t ptr, const char* data, int data_len,
bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len,
const char* pattern, int pattern_len);

const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, const char* data, int data_len,
const char* pattern, int pattern_len, int32_t* out_len);
const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, int64_t holder_ptr, const char* data, int data_len, bool in1_valid,
const char* pattern, int pattern_len, bool in2_valid, bool* out_valid, int32_t* out_len);

int64_t gdv_fn_to_date_utf8_utf8_int32(int64_t context, int64_t ptr, const char* data,
int data_len, bool in1_validity,
Expand Down
26 changes: 18 additions & 8 deletions cpp/src/gandiva/json_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,36 @@ Status JsonHolder::Make(std::shared_ptr<JsonHolder>* holder) {
return Status::OK();
}

const uint8_t* JsonHolder::operator()(const std::string& json_str, const std::string& json_path, int32_t* out_len) {

const uint8_t* JsonHolder::operator()(gandiva::ExecutionContext* ctx, const std::string& json_str, const std::string& json_path, int32_t* out_len) {
std::unique_ptr<arrow::json::BlockParser> parser;
(arrow::json::BlockParser::Make(parse_options_, &parser));

(parser->Parse(std::make_shared<arrow::Buffer>(json_str)));
std::shared_ptr<arrow::Array> parsed;
(parser->Finish(&parsed));
auto struct_parsed = std::dynamic_pointer_cast<arrow::StructArray>(parsed);

//json_path example: $.col_14, will extract col_14 here
// needs to gurad failure here
if (json_path.length() < 3) {
return nullptr;
}
auto col_name = json_path.substr(2);

// illegal json string.
if (struct_parsed == nullptr) {
return nullptr;
}
auto dict_parsed = std::dynamic_pointer_cast<arrow::DictionaryArray>(
struct_parsed->GetFieldByName(col_name));
// no data contained for given field.
if (dict_parsed == nullptr) {
return nullptr;
}
auto dict_array = dict_parsed->dictionary();
auto uft8_array = std::dynamic_pointer_cast<arrow::BinaryArray>(dict_array);

return uft8_array->GetValue(0, out_len);
auto utf8_array = std::dynamic_pointer_cast<arrow::BinaryArray>(dict_array);
auto res = utf8_array->GetValue(0, out_len);

uint8_t* result_buffer = reinterpret_cast<uint8_t*>(ctx->arena()->Allocate(*out_len));
memcpy(result_buffer, std::string((char*)res, *out_len).data(), *out_len);
return result_buffer;
}

} // namespace gandiva
5 changes: 3 additions & 2 deletions cpp/src/gandiva/json_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/json/api.h"
#include "arrow/json/parser.h"
#include "arrow/status.h"
#include "gandiva/execution_context.h"
#include "gandiva/function_holder.h"
#include "gandiva/node.h"
#include "gandiva/visibility.h"
Expand All @@ -39,8 +40,8 @@ class GANDIVA_EXPORT JsonHolder : public FunctionHolder {
static Status Make(std::shared_ptr<JsonHolder>* holder);

//TODO(): should try to return const uint8_t *
const uint8_t* operator()(const std::string& json_str, const std::string& json_path, int32_t* out_len);

const uint8_t* operator()(ExecutionContext* ctx, const std::string& json_str, const std::string& json_path, int32_t* out_len);
arrow::json::ParseOptions parse_options_ = arrow::json::ParseOptions::Defaults();
arrow::json::ReadOptions read_options_ = arrow::json::ReadOptions::Defaults();
};
Expand Down
25 changes: 22 additions & 3 deletions cpp/src/gandiva/json_holder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,39 @@

namespace gandiva {

class TestJsonHolder : public ::testing::Test {};
class TestJsonHolder : public ::testing::Test {
protected:
ExecutionContext execution_context_;
};

TEST_F(TestJsonHolder, TestJson) {
std::shared_ptr<JsonHolder> json_holder;

auto status = JsonHolder::Make(&json_holder);
EXPECT_EQ(status.ok(), true) << status.message();

auto& get_json_object = *json_holder;
auto get_json_object = *json_holder;

int32_t out_len;
const uint8_t* data = get_json_object(R"({"hello": 3.5 })", "$.hello", &out_len);

const uint8_t* data = get_json_object(&execution_context_, R"({"hello": 3.5})", "$.hello", &out_len);
EXPECT_EQ(std::string((char*)data, out_len), "3.5");

// no data contained for given field.
data = get_json_object(&execution_context_, R"({"hello": 3.5})", "$.hi", &out_len);
EXPECT_EQ(data, nullptr);

// illegal json string.
data = get_json_object(&execution_context_, R"({"hello"-3.5})", "$.hello", &out_len);
EXPECT_EQ(data, nullptr);

// illegal field is given.
data = get_json_object(&execution_context_, R"({"hello": 3.5})", "$xx", &out_len);
EXPECT_EQ(data, nullptr);

// illegal field is given and a short string field.
data = get_json_object(&execution_context_, R"({"hello": 3.5})", "$", &out_len);
EXPECT_EQ(data, nullptr);
}

} // namespace gandiva

0 comments on commit e99f4f8

Please sign in to comment.