diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 2f30d084f0014..f9cb907f568a1 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -325,6 +325,10 @@ std::vector GetStringFunctionRegistry() { NativeFunction("url_decoder", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull, "url_decoder", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("conv", {}, DataTypeVector{utf8(), int32(), int32()}, utf8(), + kResultNullInternal, "conv", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)}; return string_fn_registry_; diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index bf1532392488a..5ae0b178dc424 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -1568,4 +1568,80 @@ const char* url_decoder(gdv_int64 context, const char* input, gdv_int32 input_le return out_str; } -} // extern "C" +FORCE_INLINE +const char* conv(gdv_int64 context, const char* input, gdv_int32 input_len, bool in1_valid, + gdv_int32 from_base, bool in2_valid, gdv_int32 to_base, bool in3_valid, + bool* out_valid, gdv_int32* out_len) { + if (!in1_valid || !in2_valid || !in3_valid || input_len == 0) { + *out_len = 0; + *out_valid = false; + return ""; + } + + // Consistent with spark, only support base belonging to [2, 36]. + const int MIN_BASE = 2; + const int MAX_BASE = 36; + if (from_base < MIN_BASE || from_base > MAX_BASE || + fabs(to_base) < MIN_BASE || fabs(to_base) > MAX_BASE) { + *out_len = 0; + *out_valid = false; + return ""; + } + + from_base = from_base < 0 ? -from_base : from_base; + bool is_negative_input; + unsigned long unsigned_decimal_value; + if (input[0] == '-') { + is_negative_input = true; + unsigned_decimal_value = strtoul(input + 1, nullptr, from_base); + } else { + is_negative_input = false; + unsigned_decimal_value = strtoul(input, nullptr, from_base); + } + + bool has_negative_mark = false; + if (is_negative_input && to_base < 0) { + has_negative_mark = true; + } else if (is_negative_input && to_base > 0) { + // Use the max value for 64-bit to convert it to positive. + unsigned_decimal_value = strtoul("FFFFFFFFFFFFFFFF", nullptr, 16) - unsigned_decimal_value + 1; + } + to_base = to_base < 0 ? -to_base : to_base; + + char reverse_ret[64]; + int i = 0; + while (unsigned_decimal_value > 0) { + int remainder = unsigned_decimal_value % to_base; + char c; + if (remainder < 10) { + c = (char)(remainder + (int)'0'); + } else { + c = (char)(remainder - 10 + (int)'A'); + } + reverse_ret[i] = c; + i++; + unsigned_decimal_value = unsigned_decimal_value / to_base; + } + if (has_negative_mark) { + reverse_ret[i] = '-'; + i++; + } + *out_len = i; + char ret[*out_len]; + for (int i = 0; i < *out_len; i++) { + ret[i] = reverse_ret[*out_len - i - 1]; + } + + char* out_str = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + *out_valid = false; + return ""; + } + memcpy(out_str, ret, *out_len); + *out_valid = true; + return out_str; +} + +} // extern "C" \ No newline at end of file diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index 0afb48d9aa523..a672c65183541 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -1142,4 +1142,70 @@ TEST(TestStringOps, TestURLDecoder) { EXPECT_EQ(std::string(out_str, out_len), exp_str); } +TEST(TestStringOps, TestConv) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + bool out_valid; + + // 10-base to 2-base + out_str = conv(ctx_ptr, "4", 1, true, 10, true, 2, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 3); + EXPECT_EQ(out_valid, true); + EXPECT_EQ(std::string(out_str, out_len), "100"); + + // 2-bae to 10-base + out_str = conv(ctx_ptr, "110", 3, true, 2, true, 10, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 1); + EXPECT_EQ(std::string(out_str, out_len), "6"); + + // 10-base to 16-base + out_str = conv(ctx_ptr, "15", 2, true, 10, true, 16, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 1); + EXPECT_EQ(std::string(out_str, out_len), "F"); + + // 36-base to 16-base + out_str = conv(ctx_ptr, "big", 3, true, 36, true, 16, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 4); + EXPECT_EQ(std::string(out_str, out_len), "3A48"); + + // 36-base to 16-base. + std::string input = "9223372036854775807"; + out_str = conv(ctx_ptr, input.c_str(), input.length(), true, 36, true, 16, true, &out_valid, &out_len); + std::string expected_str = "FFFFFFFFFFFFFFFF"; + EXPECT_EQ(out_len, expected_str.length()); + EXPECT_EQ(std::string(out_str, out_len), expected_str); + + // Space is contained in input string. + out_str = conv(ctx_ptr, " 15 ", 2, true, 10, true, 16, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 1); + EXPECT_EQ(std::string(out_str, out_len), "F"); + + // Negative input and negative to_base. + out_str = conv(ctx_ptr, "-15", 3, true, 10, true, -16, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 2); + EXPECT_EQ(std::string(out_str, out_len), "-F"); + + // Negative input and positive to_base + out_str = conv(ctx_ptr, "-15", 3, true, 10, true, 16, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 16); + EXPECT_EQ(std::string(out_str, out_len), "FFFFFFFFFFFFFFF1"); + + // Negative input and negative base. + out_str = conv(ctx_ptr, "-10", 3, true, 16, true, -10, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 3); + EXPECT_EQ(std::string(out_str, out_len), "-16"); + + // If there is an invalid digit in the number, the longest + // valid prefix should be converted. + out_str = conv(ctx_ptr, "11abc", 5, true, 10, true, 16, true, &out_valid, &out_len); + EXPECT_EQ(out_len, 1); + EXPECT_EQ(std::string(out_str, out_len), "B"); + + // Should return null for Empty input. + out_str = conv(ctx_ptr, "", 0, true, 10, true, 16, true, &out_valid, &out_len); + EXPECT_EQ(out_valid, false); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index 82c3daa6059a1..e69269cca97ac 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -517,4 +517,8 @@ double castFLOAT8_utf8(int64_t context, const char* data, int32_t len); const char* url_decoder(gdv_int64 context, const char* input, gdv_int32 input_len, gdv_int32* out_len); +const char* conv(gdv_int64 context, const char* input, gdv_int32 input_len, bool in1_valid, + gdv_int32 from_base, bool in2_valid, gdv_int32 to_base, bool in3_valid, + bool* out_valid, gdv_int32* out_len); + } // extern "C"