diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index a98cd85ae..d3df9dfa9 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -246,6 +246,21 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) << ".rfind(" << child_visitor_list[1]->GetResult() << ") != std::string::npos;"; prepare_str_ += prepare_ss.str(); + } else if (func_name.compare("get_json_object") == 0) { + for (int i = 0; i < 2; i++) { + prepare_str_ += child_visitor_list[i]->GetPrepare(); + } + codes_str_ = "get_json_object_" + std::to_string(cur_func_id); + check_str_ = GetValidityName(codes_str_); + real_codes_str_ = codes_str_; + real_validity_str_ = check_str_; + std::stringstream prepare_ss; + prepare_ss << "bool " << check_str_ << " = true;" << std::endl; + prepare_ss << "std::string " << codes_str_ << " = get_json_object(" + << child_visitor_list[0]->GetResult() << ", " + << child_visitor_list[1]->GetResult() << ");\n"; + prepare_str_ += prepare_ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("substr") == 0) { ss << child_visitor_list[0]->GetResult() << ".substr(" << "((" << child_visitor_list[1]->GetResult() << " - 1) < 0 ? 0 : (" diff --git a/native-sql-engine/cpp/src/precompile/gandiva.h b/native-sql-engine/cpp/src/precompile/gandiva.h index a69aaa5b5..6d49c0056 100644 --- a/native-sql-engine/cpp/src/precompile/gandiva.h +++ b/native-sql-engine/cpp/src/precompile/gandiva.h @@ -16,6 +16,9 @@ */ #pragma once +#include +#include +#include #include #include @@ -226,3 +229,37 @@ arrow::Decimal128 round(arrow::Decimal128 in, int32_t original_precision, } return arrow::Decimal128(out); } + +std::string get_json_object(const std::string& json_str, const std::string& json_path) { + std::unique_ptr parser; + (arrow::json::BlockParser::Make(arrow::json::ParseOptions::Defaults(), &parser)); + (parser->Parse(std::make_shared(json_str))); + std::shared_ptr parsed; + (parser->Finish(&parsed)); + auto struct_parsed = std::dynamic_pointer_cast(parsed); + // json_path example: $.col_14, will extract col_14 here + if (json_path.length() < 3) { + return nullptr; + } + auto col_name = json_path.substr(2); + // illegal json string. + if (struct_parsed == nullptr) { + return nullptr; + } + auto dict_parsed = std::dynamic_pointer_cast( + struct_parsed->GetFieldByName(col_name)); + // no data contained for given field. + if (dict_parsed == nullptr) { + return nullptr; + } + + auto dict_array = dict_parsed->dictionary(); + // needs to see whether there is a case that has more than one indices. + auto res_index = dict_parsed->GetValueIndex(0); + // TODO(): check null results + auto utf8_array = std::dynamic_pointer_cast(dict_array); + + auto res = utf8_array->GetString(res_index); + + return res; +} \ No newline at end of file diff --git a/native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc b/native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc index f0ae25e72..dbd996a8a 100644 --- a/native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc +++ b/native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc @@ -111,5 +111,10 @@ TEST(TestArrowCompute, ArithmeticComparisonTest) { ASSERT_EQ(res, true); } +TEST(TestArrowCompute, JsonTest) { + std::string data = get_json_object(R"({"hello": "3.5"})", "$.hello"); + EXPECT_EQ(data, "3.5"); +} + } // namespace codegen } // namespace sparkcolumnarplugin