diff --git a/cpp/src/gandiva/function_holder_registry.h b/cpp/src/gandiva/function_holder_registry.h index 5c2d978b3921a..ed111c86dcfdf 100644 --- a/cpp/src/gandiva/function_holder_registry.h +++ b/cpp/src/gandiva/function_holder_registry.h @@ -65,7 +65,7 @@ class FunctionHolderRegistry { static map_type maker_map = { {"like", LAMBDA_MAKER(LikeHolder)}, {"ilike", LAMBDA_MAKER(LikeHolder)}, - {"get_json_obejct", LAMBDA_MAKER(JsonHolder)}, + {"get_json_object", LAMBDA_MAKER(JsonHolder)}, {"to_date", LAMBDA_MAKER(ToDateHolder)}, {"random", LAMBDA_MAKER(RandomGeneratorHolder)}, {"rand", LAMBDA_MAKER(RandomGeneratorHolder)}, diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index fe8a75333b5b0..457b87df9949e 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -169,8 +169,9 @@ std::vector GetStringFunctionRegistry() { NativeFunction::kNeedsFunctionHolder), NativeFunction("get_json_object", {}, DataTypeVector{utf8(), utf8()}, utf8(), - kResultNullIfNull, "gdv_fn_get_json_object_utf8_utf8", - NativeFunction::kNeedsFunctionHolder), + kResultNullInternal, "gdv_fn_get_json_object_utf8_utf8", + NativeFunction::kNeedsContext | NativeFunction::kNeedsFunctionHolder | + NativeFunction::kCanReturnErrors), NativeFunction("ltrim", {}, DataTypeVector{utf8(), utf8()}, utf8(), kResultNullIfNull, "ltrim_utf8_utf8", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index db1802ae6b8dc..952af2f1e43e9 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -38,10 +38,23 @@ extern "C" { -const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, const char* data, int data_len, - const char* pattern, int pattern_len, int32_t* out_len) { - gandiva::JsonHolder* holder = reinterpret_cast(ptr); - return (*holder)(std::string(data, data_len), std::string(pattern, pattern_len), out_len); +const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, int64_t holder_ptr, const char* data, int data_len, bool in1_valid, + const char* pattern, int pattern_len, bool in2_valid, bool* out_valid, int32_t* out_len) { + if (!in1_valid || !in2_valid) { + *out_valid = false; + *out_len = 0; + return reinterpret_cast(""); + } + gandiva::ExecutionContext* context = reinterpret_cast(ptr); + gandiva::JsonHolder* holder = reinterpret_cast(holder_ptr); + auto res = (*holder)(context, std::string(data, data_len), std::string(pattern, pattern_len), out_len); + if (res == nullptr) { + *out_valid = false; + *out_len = 0; + return reinterpret_cast(""); + } + *out_valid = true; + return res; } bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len, @@ -492,12 +505,15 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { // gdv_fn_get_json_object_utf8_utf8 args = {types->i64_type(), // int64_t ptr + types->i64_type(), // int64_t holder_ptr types->i8_ptr_type(), // const char* data types->i32_type(), // int data_len + types->i1_type(), // bool in1_validity types->i8_ptr_type(), // const char* pattern types->i32_type(), // int pattern_len - types->i32_ptr_type()}; // int out_len - + types->i1_type(), // bool in2_validity + types->ptr_type(types->i8_type()), // bool* out_valid + types->i32_ptr_type()}; // int out_len engine->AddGlobalMappingForFunc("gdv_fn_get_json_object_utf8_utf8", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_get_json_object_utf8_utf8)); @@ -655,7 +671,7 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { args = {types->i64_type(), // int64_t context_ptr types->i8_ptr_type(), // const char* data types->i32_type(), // int32_t lenr - types->i1_type(), // bool in2_validity + types->i1_type(), // bool in1_validity types->ptr_type(types->i8_type())}; // bool* out_valid engine->AddGlobalMappingForFunc("gdv_fn_castINT_or_null_utf8", types->i32_type(), args, @@ -671,7 +687,7 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { args = {types->i64_type(), // int64_t context_ptr types->i8_ptr_type(), // const char* data types->i32_type(), // int32_t lenr - types->i1_type(), // bool in2_validity + types->i1_type(), // bool in1_validity types->ptr_type(types->i8_type())}; // bool* out_valid engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_or_null_utf8", types->i64_type(), args, @@ -687,7 +703,7 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { args = {types->i64_type(), // int64_t context_ptr types->i8_ptr_type(), // const char* data types->i32_type(), // int32_t lenr - types->i1_type(), // bool in2_validity + types->i1_type(), // bool in1_validity types->ptr_type(types->i8_type())}; // bool* out_valid engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_or_null_utf8", types->float_type(), args, @@ -703,7 +719,7 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { args = {types->i64_type(), // int64_t context_ptr types->i8_ptr_type(), // const char* data types->i32_type(), // int32_t lenr - types->i1_type(), // bool in2_validity + types->i1_type(), // bool in1_validity types->ptr_type(types->i8_type())}; // bool* out_valid engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_or_null_utf8", types->double_type(), args, diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index ea32aa5cefc01..2a92f25493640 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -54,8 +54,8 @@ bool gdv_fn_like_utf8_utf8_utf8(int64_t ptr, const char* data, int data_len, bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len, const char* pattern, int pattern_len); -const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, const char* data, int data_len, - const char* pattern, int pattern_len, int32_t* out_len); +const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, int64_t holder_ptr, const char* data, int data_len, bool in1_valid, + const char* pattern, int pattern_len, bool in2_valid, bool* out_valid, int32_t* out_len); int64_t gdv_fn_to_date_utf8_utf8_int32(int64_t context, int64_t ptr, const char* data, int data_len, bool in1_validity, diff --git a/cpp/src/gandiva/json_holder.cc b/cpp/src/gandiva/json_holder.cc index 5d640aab838ea..7862016696566 100644 --- a/cpp/src/gandiva/json_holder.cc +++ b/cpp/src/gandiva/json_holder.cc @@ -34,26 +34,36 @@ Status JsonHolder::Make(std::shared_ptr* holder) { return Status::OK(); } -const uint8_t* JsonHolder::operator()(const std::string& json_str, const std::string& json_path, int32_t* out_len) { - +const uint8_t* JsonHolder::operator()(gandiva::ExecutionContext* ctx, const std::string& json_str, const std::string& json_path, int32_t* out_len) { std::unique_ptr parser; (arrow::json::BlockParser::Make(parse_options_, &parser)); - (parser->Parse(std::make_shared(json_str))); std::shared_ptr parsed; (parser->Finish(&parsed)); auto struct_parsed = std::dynamic_pointer_cast(parsed); - //json_path example: $.col_14, will extract col_14 here // needs to gurad failure here + if (json_path.length() < 3) { + return nullptr; + } auto col_name = json_path.substr(2); - + // illegal json string. + if (struct_parsed == nullptr) { + return nullptr; + } auto dict_parsed = std::dynamic_pointer_cast( struct_parsed->GetFieldByName(col_name)); + // no data contained for given field. + if (dict_parsed == nullptr) { + return nullptr; + } auto dict_array = dict_parsed->dictionary(); - auto uft8_array = std::dynamic_pointer_cast(dict_array); - - return uft8_array->GetValue(0, out_len); + auto utf8_array = std::dynamic_pointer_cast(dict_array); + auto res = utf8_array->GetValue(0, out_len); + + uint8_t* result_buffer = reinterpret_cast(ctx->arena()->Allocate(*out_len)); + memcpy(result_buffer, std::string((char*)res, *out_len).data(), *out_len); + return result_buffer; } } // namespace gandiva diff --git a/cpp/src/gandiva/json_holder.h b/cpp/src/gandiva/json_holder.h index 4cf289465db89..0ca4765505083 100644 --- a/cpp/src/gandiva/json_holder.h +++ b/cpp/src/gandiva/json_holder.h @@ -23,6 +23,7 @@ #include "arrow/json/api.h" #include "arrow/json/parser.h" #include "arrow/status.h" +#include "gandiva/execution_context.h" #include "gandiva/function_holder.h" #include "gandiva/node.h" #include "gandiva/visibility.h" @@ -39,8 +40,8 @@ class GANDIVA_EXPORT JsonHolder : public FunctionHolder { static Status Make(std::shared_ptr* holder); //TODO(): should try to return const uint8_t * - const uint8_t* operator()(const std::string& json_str, const std::string& json_path, int32_t* out_len); - + const uint8_t* operator()(ExecutionContext* ctx, const std::string& json_str, const std::string& json_path, int32_t* out_len); + arrow::json::ParseOptions parse_options_ = arrow::json::ParseOptions::Defaults(); arrow::json::ReadOptions read_options_ = arrow::json::ReadOptions::Defaults(); }; diff --git a/cpp/src/gandiva/json_holder_test.cc b/cpp/src/gandiva/json_holder_test.cc index 63f4370c33720..9767962f27693 100644 --- a/cpp/src/gandiva/json_holder_test.cc +++ b/cpp/src/gandiva/json_holder_test.cc @@ -26,7 +26,10 @@ namespace gandiva { -class TestJsonHolder : public ::testing::Test {}; +class TestJsonHolder : public ::testing::Test { + protected: + ExecutionContext execution_context_; +}; TEST_F(TestJsonHolder, TestJson) { std::shared_ptr json_holder; @@ -34,12 +37,28 @@ TEST_F(TestJsonHolder, TestJson) { auto status = JsonHolder::Make(&json_holder); EXPECT_EQ(status.ok(), true) << status.message(); - auto& get_json_object = *json_holder; + auto get_json_object = *json_holder; int32_t out_len; - const uint8_t* data = get_json_object(R"({"hello": 3.5 })", "$.hello", &out_len); + + const uint8_t* data = get_json_object(&execution_context_, R"({"hello": 3.5})", "$.hello", &out_len); EXPECT_EQ(std::string((char*)data, out_len), "3.5"); + + // no data contained for given field. + data = get_json_object(&execution_context_, R"({"hello": 3.5})", "$.hi", &out_len); + EXPECT_EQ(data, nullptr); + + // illegal json string. + data = get_json_object(&execution_context_, R"({"hello"-3.5})", "$.hello", &out_len); + EXPECT_EQ(data, nullptr); + + // illegal field is given. + data = get_json_object(&execution_context_, R"({"hello": 3.5})", "$xx", &out_len); + EXPECT_EQ(data, nullptr); + // illegal field is given and a short string field. + data = get_json_object(&execution_context_, R"({"hello": 3.5})", "$", &out_len); + EXPECT_EQ(data, nullptr); } } // namespace gandiva