diff --git a/be/src/vec/functions/function_hash.cpp b/be/src/vec/functions/function_hash.cpp index 1620a608b261b1..6bc38ddfcaf9a0 100644 --- a/be/src/vec/functions/function_hash.cpp +++ b/be/src/vec/functions/function_hash.cpp @@ -26,6 +26,7 @@ #include "vec/columns/column.h" #include "vec/columns/column_const.h" #include "vec/columns/column_string.h" +#include "vec/columns/column_varbinary.h" #include "vec/columns/column_vector.h" #include "vec/common/assert_cast.h" #include "vec/core/field.h" @@ -177,6 +178,17 @@ struct XxHashImpl { HashUtil::xxHash64WithSeed(value.data(), value.size(), col_to_data[i]); } } + } else if (const auto* vb_col = check_and_get_column(column)) { + for (size_t i = 0; i < input_rows_count; ++i) { + auto data_ref = vb_col->get_data_at(i); + if constexpr (ReturnType == TYPE_INT) { + col_to_data[i] = HashUtil::xxHash32WithSeed(data_ref.data, data_ref.size, + col_to_data[i]); + } else { + col_to_data[i] = HashUtil::xxHash64WithSeed(data_ref.data, data_ref.size, + col_to_data[i]); + } + } } else { DCHECK(false); return Status::NotSupported("Illegal column {} of argument of function {}", diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index f6a94ce94b80e6..cc7499c099f597 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -1458,8 +1458,8 @@ void register_function_string(SimpleFunctionFactory& factory) { factory.register_function< FunctionStringFormatRound>>(); factory.register_function>>(); - factory.register_function>(); - factory.register_function>(); + factory.register_function>(); + factory.register_function>(); factory.register_function(); factory.register_function(); factory.register_function>(); @@ -1482,9 +1482,9 @@ void register_function_string(SimpleFunctionFactory& factory) { factory.register_alias(SubstringUtil::name, "substr"); factory.register_alias(FunctionToLower::name, "lcase"); factory.register_alias(FunctionToUpper::name, "ucase"); - factory.register_alias(FunctionStringDigestOneArg::name, "md5"); + factory.register_alias(FunctionStringDigestMulti::name, "md5"); factory.register_alias(FunctionStringUTF8Length::name, "character_length"); - factory.register_alias(FunctionStringDigestOneArg::name, "sm3"); + factory.register_alias(FunctionStringDigestMulti::name, "sm3"); factory.register_alias(FunctionStringDigestSHA1::name, "sha"); factory.register_alias(FunctionStringLocatePos::name, "position"); } diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 16e788843be438..357670356b78ff 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -55,6 +55,7 @@ #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_const.h" +#include "vec/columns/column_varbinary.h" #include "vec/columns/column_vector.h" #include "vec/common/hash_table/phmap_fwd_decl.h" #include "vec/common/int_exp.h" @@ -2335,10 +2336,10 @@ struct MD5Sum { }; template -class FunctionStringDigestOneArg : public IFunction { +class FunctionStringDigestMulti : public IFunction { public: static constexpr auto name = Impl::name; - static FunctionPtr create() { return std::make_shared(); } + static FunctionPtr create() { return std::make_shared(); } String get_name() const override { return name; } size_t get_number_of_arguments() const override { return 0; } bool is_variadic() const override { return true; } @@ -2351,51 +2352,54 @@ class FunctionStringDigestOneArg : public IFunction { uint32_t result, size_t input_rows_count) const override { DCHECK_GE(arguments.size(), 1); - int argument_size = arguments.size(); - std::vector argument_columns(argument_size); + auto res = ColumnString::create(); + auto& res_data = res->get_chars(); + auto& res_offset = res->get_offsets(); + res_offset.resize(input_rows_count); - std::vector offsets_list(argument_size); - std::vector chars_list(argument_size); + std::vector argument_columns(arguments.size()); + std::vector is_const(arguments.size(), 0); + for (size_t i = 0; i < arguments.size(); ++i) { + std::tie(argument_columns[i], is_const[i]) = + unpack_if_const(block.get_by_position(arguments[i]).column); + } - for (int i = 0; i < argument_size; ++i) { - argument_columns[i] = - block.get_by_position(arguments[i]).column->convert_to_full_column_if_const(); - if (const auto* col_str = assert_cast(argument_columns[i].get())) { - offsets_list[i] = &col_str->get_offsets(); - chars_list[i] = &col_str->get_chars(); - } else { - return Status::RuntimeError("Illegal column {} of argument of function {}", - block.get_by_position(arguments[0]).column->get_name(), - get_name()); - } + if (check_and_get_column(argument_columns[0].get())) { + vector_execute(block, input_rows_count, argument_columns, is_const, + res_data, res_offset); + } else if (check_and_get_column(argument_columns[0].get())) { + vector_execute(block, input_rows_count, argument_columns, is_const, + res_data, res_offset); + } else { + return Status::RuntimeError("Illegal column {} of argument of function {}", + argument_columns[0]->get_name(), get_name()); } - auto res = ColumnString::create(); - auto& res_data = res->get_chars(); - auto& res_offset = res->get_offsets(); + block.replace_by_position(result, std::move(res)); + return Status::OK(); + } - res_offset.resize(input_rows_count); +private: + template + void vector_execute(Block& block, size_t input_rows_count, + const std::vector& argument_columns, + const std::vector& is_const, ColumnString::Chars& res_data, + ColumnString::Offsets& res_offset) const { + using ObjectData = typename Impl::ObjectData; for (size_t i = 0; i < input_rows_count; ++i) { - using ObjectData = typename Impl::ObjectData; ObjectData digest; - for (size_t j = 0; j < offsets_list.size(); ++j) { - const auto& current_offsets = *offsets_list[j]; - const auto& current_chars = *chars_list[j]; - - int size = current_offsets[i] - current_offsets[i - 1]; - if (size < 1) { + for (size_t j = 0; j < argument_columns.size(); ++j) { + const auto* col = assert_cast(argument_columns[j].get()); + StringRef data_ref = col->get_data_at(is_const[j] ? 0 : i); + if (data_ref.size < 1) { continue; } - digest.update(¤t_chars[current_offsets[i - 1]], size); + digest.update(data_ref.data, data_ref.size); } digest.digest(); - StringOP::push_value_string(std::string_view(digest.hex().c_str(), digest.hex().size()), i, res_data, res_offset); } - - block.replace_by_position(result, std::move(res)); - return Status::OK(); } }; @@ -2414,27 +2418,37 @@ class FunctionStringDigestSHA1 : public IFunction { Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, uint32_t result, size_t input_rows_count) const override { DCHECK_EQ(arguments.size(), 1); - - ColumnPtr str_col = block.get_by_position(arguments[0]).column; - auto& data = assert_cast(str_col.get())->get_chars(); - auto& offset = assert_cast(str_col.get())->get_offsets(); + ColumnPtr data_col = block.get_by_position(arguments[0]).column; auto res_col = ColumnString::create(); auto& res_data = res_col->get_chars(); auto& res_offset = res_col->get_offsets(); res_offset.resize(input_rows_count); + if (const auto* str_col = check_and_get_column(data_col.get())) { + vector_execute(str_col, input_rows_count, res_data, res_offset); + } else if (const auto* vb_col = check_and_get_column(data_col.get())) { + vector_execute(vb_col, input_rows_count, res_data, res_offset); + } else { + return Status::RuntimeError("Illegal column {} of argument of function {}", + data_col->get_name(), get_name()); + } + block.replace_by_position(result, std::move(res_col)); + return Status::OK(); + } + +private: + template + void vector_execute(const ColumnType* col, size_t input_rows_count, + ColumnString::Chars& res_data, ColumnString::Offsets& res_offset) const { SHA1Digest digest; for (size_t i = 0; i < input_rows_count; ++i) { - int size = offset[i] - offset[i - 1]; - digest.reset(&data[offset[i - 1]], size); + StringRef data_ref = col->get_data_at(i); + digest.reset(data_ref.data, data_ref.size); std::string_view ans = digest.digest(); StringOP::push_value_string(ans, i, res_data, res_offset); } - - block.replace_by_position(result, std::move(res_col)); - return Status::OK(); } }; @@ -2454,9 +2468,7 @@ class FunctionStringDigestSHA2 : public IFunction { uint32_t result, size_t input_rows_count) const override { DCHECK(!is_column_const(*block.get_by_position(arguments[0]).column)); - ColumnPtr str_col = block.get_by_position(arguments[0]).column; - auto& data = assert_cast(str_col.get())->get_chars(); - auto& offset = assert_cast(str_col.get())->get_offsets(); + ColumnPtr data_col = block.get_by_position(arguments[0]).column; [[maybe_unused]] const auto& [right_column, right_const] = unpack_if_const(block.get_by_position(arguments[1]).column); @@ -2468,13 +2480,13 @@ class FunctionStringDigestSHA2 : public IFunction { res_offset.resize(input_rows_count); if (digest_length == 224) { - execute_base(data, offset, input_rows_count, res_data, res_offset); + execute_base(data_col, input_rows_count, res_data, res_offset); } else if (digest_length == 256) { - execute_base(data, offset, input_rows_count, res_data, res_offset); + execute_base(data_col, input_rows_count, res_data, res_offset); } else if (digest_length == 384) { - execute_base(data, offset, input_rows_count, res_data, res_offset); + execute_base(data_col, input_rows_count, res_data, res_offset); } else if (digest_length == 512) { - execute_base(data, offset, input_rows_count, res_data, res_offset); + execute_base(data_col, input_rows_count, res_data, res_offset); } else { return Status::InvalidArgument( "sha2's digest length only support 224/256/384/512 but meet {}", digest_length); @@ -2486,13 +2498,26 @@ class FunctionStringDigestSHA2 : public IFunction { private: template - void execute_base(const ColumnString::Chars& data, const ColumnString::Offsets& offset, - int input_rows_count, ColumnString::Chars& res_data, + void execute_base(ColumnPtr data_col, int input_rows_count, ColumnString::Chars& res_data, ColumnString::Offsets& res_offset) const { - T digest; + if (const auto* str_col = check_and_get_column(data_col.get())) { + vector_execute(str_col, input_rows_count, res_data, res_offset); + } else if (const auto* vb_col = check_and_get_column(data_col.get())) { + vector_execute(vb_col, input_rows_count, res_data, res_offset); + } else { + throw Exception(ErrorCode::RUNTIME_ERROR, + "Illegal column {} of argument of function {}", data_col->get_name(), + get_name()); + } + } + + template + void vector_execute(const ColumnType* col, size_t input_rows_count, + ColumnString::Chars& res_data, ColumnString::Offsets& res_offset) const { + DigestType digest; for (size_t i = 0; i < input_rows_count; ++i) { - int size = offset[i] - offset[i - 1]; - digest.reset(&data[offset[i - 1]], size); + StringRef data_ref = col->get_data_at(i); + digest.reset(data_ref.data, data_ref.size); std::string_view ans = digest.digest(); StringOP::push_value_string(ans, i, res_data, res_offset); diff --git a/be/test/vec/function/function_hash_test.cpp b/be/test/vec/function/function_hash_test.cpp index a29603951ecaf6..f98430b3c64904 100644 --- a/be/test/vec/function/function_hash_test.cpp +++ b/be/test/vec/function/function_hash_test.cpp @@ -126,6 +126,34 @@ TEST(HashFunctionTest, xxhash_32_test) { static_cast(check_function(func_name, input_types, data_set)); }; + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY}; + + DataSet data_set = {{{Null()}, Null()}, {{std::string("hello")}, (int32_t)-83855367}}; + + static_cast(check_function(func_name, input_types, data_set)); + }; + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY}; + + DataSet data_set = {{{std::string("hello"), std::string("world")}, (int32_t)-920844969}, + {{std::string("hello"), Null()}, Null()}}; + + static_cast(check_function(func_name, input_types, data_set)); + }; + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY, + PrimitiveType::TYPE_VARBINARY}; + + DataSet data_set = {{{std::string("hello"), std::string("world"), std::string("!")}, + (int32_t)352087701}, + {{std::string("hello"), std::string("world"), Null()}, Null()}}; + + static_cast(check_function(func_name, input_types, data_set)); + }; } TEST(HashFunctionTest, xxhash_64_test) { @@ -160,6 +188,36 @@ TEST(HashFunctionTest, xxhash_64_test) { static_cast(check_function(func_name, input_types, data_set)); }; + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY}; + + DataSet data_set = {{{Null()}, Null()}, + {{std::string("hello")}, (int64_t)-7685981735718036227}}; + + static_cast(check_function(func_name, input_types, data_set)); + }; + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY}; + + DataSet data_set = { + {{std::string("hello"), std::string("world")}, (int64_t)7001965798170371843}, + {{std::string("hello"), Null()}, Null()}}; + + static_cast(check_function(func_name, input_types, data_set)); + }; + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY, + PrimitiveType::TYPE_VARBINARY}; + + DataSet data_set = {{{std::string("hello"), std::string("world"), std::string("!")}, + (int64_t)6796829678999971400}, + {{std::string("hello"), std::string("world"), Null()}, Null()}}; + + static_cast(check_function(func_name, input_types, data_set)); + }; } } // namespace doris::vectorized diff --git a/be/test/vec/function/function_string_test.cpp b/be/test/vec/function/function_string_test.cpp index d4c06c5b86096f..e5c2a044fe6ebd 100644 --- a/be/test/vec/function/function_string_test.cpp +++ b/be/test/vec/function/function_string_test.cpp @@ -1966,6 +1966,48 @@ TEST(function_string_test, function_md5sum_test) { check_function_all_arg_comb(func_name, input_types, data_set); } + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY}; + DataSet data_set = { + {{std::string("asd你好")}, {std::string("a38c15675555017e6b8ea042f2eb24f5")}}, + {{std::string("hello world")}, {std::string("5eb63bbbe01eeed093cb22bb8f5acdc3")}}, + {{std::string("HELLO,!^%")}, {std::string("b8e6e34d1cc3dc76b784ddfdfb7df800")}}, + {{std::string("")}, {std::string("d41d8cd98f00b204e9800998ecf8427e")}}, + {{std::string(" ")}, {std::string("7215ee9c7d9dc229d2921a40e899ec5f")}}, + {{Null()}, {Null()}}, + {{std::string("MYtestSTR")}, {std::string("cd24c90b3fc1192eb1879093029e87d4")}}, + {{std::string("ò&ø")}, {std::string("fd157b4cb921fa91acc667380184d59c")}}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY}; + DataSet data_set = {{{std::string("asd"), std::string("你好")}, + {std::string("a38c15675555017e6b8ea042f2eb24f5")}}, + {{std::string("hello "), std::string("world")}, + {std::string("5eb63bbbe01eeed093cb22bb8f5acdc3")}}, + {{std::string("HELLO"), std::string(",!^%")}, + {std::string("b8e6e34d1cc3dc76b784ddfdfb7df800")}}, + {{Null(), std::string("HELLO")}, {Null()}}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY, + PrimitiveType::TYPE_VARBINARY}; + DataSet data_set = {{{std::string("a"), std::string("sd"), std::string("你好")}, + {std::string("a38c15675555017e6b8ea042f2eb24f5")}}, + {{std::string(""), std::string(""), std::string("")}, + {std::string("d41d8cd98f00b204e9800998ecf8427e")}}, + {{std::string("HEL"), std::string("LO,!"), std::string("^%")}, + {std::string("b8e6e34d1cc3dc76b784ddfdfb7df800")}}, + {{Null(), std::string("HELLO"), Null()}, {Null()}}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } } TEST(function_string_test, function_sm3sum_test) { @@ -2022,6 +2064,58 @@ TEST(function_string_test, function_sm3sum_test) { check_function_all_arg_comb(func_name, input_types, data_set); } + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY}; + DataSet data_set = { + {{std::string("asd你好")}, + {std::string("0d6b9dfa8fe5708eb0dccfbaff4f2964abaaa976cc4445a7ecace49c0ceb31d3")}}, + {{std::string("hello world")}, + {std::string("44f0061e69fa6fdfc290c494654a05dc0c053da7e5c52b84ef93a9d67d3fff88")}}, + {{std::string("HELLO,!^%")}, + {std::string("5fc6e38f40b31a659a59e1daba9b68263615f20c02037b419d9deb3509e6b5c6")}}, + {{std::string("")}, + {std::string("1ab21d8355cfa17f8e61194831e81a8f22bec8c728fefb747ed035eb5082aa2b")}}, + {{std::string(" ")}, + {std::string("2ae1d69bb8483e5944310c877573b21d0a420c3bf4a2a91b1a8370d760ba67c5")}}, + {{Null()}, {Null()}}, + {{std::string("MYtestSTR")}, + {std::string("3155ae9f834cae035385fc15b69b6f2c051b91de943ea9a03ab8bfd497aef4c6")}}, + {{std::string("ò&ø")}, + {std::string( + "aa47ac31c85aa819d4cc80c932e7900fa26a3073a67aa7eb011bc2ba4924a066")}}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY}; + DataSet data_set = { + {{std::string("asd"), std::string("你好")}, + {std::string("0d6b9dfa8fe5708eb0dccfbaff4f2964abaaa976cc4445a7ecace49c0ceb31d3")}}, + {{std::string("hello "), std::string("world")}, + {std::string("44f0061e69fa6fdfc290c494654a05dc0c053da7e5c52b84ef93a9d67d3fff88")}}, + {{std::string("HELLO "), std::string(",!^%")}, + {std::string("1f5866e786ebac9ffed0dbd8f2586e3e99d1d05f7efe7c5915478b57b7423570")}}, + {{Null(), std::string("HELLO")}, {Null()}}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY, + PrimitiveType::TYPE_VARBINARY}; + DataSet data_set = { + {{std::string("a"), std::string("sd"), std::string("你好")}, + {std::string("0d6b9dfa8fe5708eb0dccfbaff4f2964abaaa976cc4445a7ecace49c0ceb31d3")}}, + {{std::string(""), std::string(""), std::string("")}, + {std::string("1ab21d8355cfa17f8e61194831e81a8f22bec8c728fefb747ed035eb5082aa2b")}}, + {{std::string("HEL"), std::string("LO,!"), std::string("^%")}, + {std::string("5fc6e38f40b31a659a59e1daba9b68263615f20c02037b419d9deb3509e6b5c6")}}, + {{Null(), std::string("HELLO"), Null()}, {Null()}}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } } TEST(function_string_test, function_aes_encrypt_test) { @@ -3676,4 +3770,34 @@ TEST(function_string_test, soundex_test) { } } +TEST(function_string_test, function_sha1_test) { + std::string func_name = "sha1"; + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARCHAR}; + DataSet data_set = { + {{std::string("hello world")}, + {std::string("2aae6c35c94fcfb415dbe95f408b9ce91ee846ed")}}, + {{std::string("doris")}, {std::string("c29bb8e55610dcfecabb065ce5d01be6e3e810e9")}}, + {{std::string("")}, {std::string("da39a3ee5e6b4b0d3255bfef95601890afd80709")}}, + {{std::string("abc")}, {std::string("a9993e364706816aba3e25717850c26c9cd0d89d")}}, + {{Null()}, {Null()}}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + { + InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY}; + DataSet data_set = { + {{std::string("hello world")}, + {std::string("2aae6c35c94fcfb415dbe95f408b9ce91ee846ed")}}, + {{std::string("doris")}, {std::string("c29bb8e55610dcfecabb065ce5d01be6e3e810e9")}}, + {{std::string("")}, {std::string("da39a3ee5e6b4b0d3255bfef95601890afd80709")}}, + {{std::string("abc")}, {std::string("a9993e364706816aba3e25717850c26c9cd0d89d")}}, + {{Null()}, {Null()}}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } +} + } // namespace doris::vectorized diff --git a/be/test/vec/function/function_test_util.cpp b/be/test/vec/function/function_test_util.cpp index a4844a83c4ea16..ae111c9d10f115 100644 --- a/be/test/vec/function/function_test_util.cpp +++ b/be/test/vec/function/function_test_util.cpp @@ -47,6 +47,7 @@ #include "vec/data_types/data_type_string.h" #include "vec/data_types/data_type_struct.h" #include "vec/data_types/data_type_time.h" +#include "vec/data_types/data_type_varbinary.h" #include "vec/exprs/table_function/table_function.h" #include "vec/functions/cast/cast_base.h" #include "vec/functions/cast/cast_to_time_impl.hpp" @@ -89,6 +90,10 @@ static size_t type_index_to_data_type(const std::vector& input_types, s type = std::make_shared(); desc = type; return 1; + case PrimitiveType::TYPE_VARBINARY: + type = std::make_shared(); + desc = type; + return 1; case PrimitiveType::TYPE_JSONB: type = std::make_shared(); desc = type; @@ -376,6 +381,11 @@ bool insert_cell(MutableColumnPtr& column, DataTypePtr type_ptr, const AnyType& column->insert_data(str.c_str(), str.size()); break; } + case PrimitiveType::TYPE_VARBINARY: { + auto str = any_cast(cell); + column->insert_data(str.c_str(), str.size()); + break; + } case PrimitiveType::TYPE_JSONB: { auto str = any_cast(cell); JsonBinaryValue jsonb_val; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Md5.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Md5.java index b9ce3815b8a701..da5db00c95dfe8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Md5.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Md5.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarBinaryType; import org.apache.doris.nereids.types.VarcharType; import com.google.common.base.Preconditions; @@ -39,7 +40,8 @@ public class Md5 extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE) + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarBinaryType.INSTANCE) ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Md5Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Md5Sum.java index 9a3815e2f0b13c..eb637d18df2a36 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Md5Sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Md5Sum.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarBinaryType; import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.util.ExpressionUtils; @@ -39,7 +40,8 @@ public class Md5Sum extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(StringType.INSTANCE) + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(StringType.INSTANCE), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarBinaryType.INSTANCE) ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sha1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sha1.java index b044ad72fa1526..408c2d030f2950 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sha1.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sha1.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarBinaryType; import org.apache.doris.nereids.types.VarcharType; import com.google.common.base.Preconditions; @@ -38,7 +39,9 @@ public class Sha1 extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE)); + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarBinaryType.INSTANCE) + ); /** * constructor with 1 arguments. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sha2.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sha2.java index 74929c04eff998..9ed0f5b3115b95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sha2.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sha2.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarBinaryType; import org.apache.doris.nereids.types.VarcharType; import com.google.common.base.Preconditions; @@ -42,7 +43,9 @@ public class Sha2 extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT, IntegerType.INSTANCE), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE, IntegerType.INSTANCE)); + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarBinaryType.INSTANCE, IntegerType.INSTANCE) + ); private static final List validDigest = Lists.newArrayList(224, 256, 384, 512); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sm3.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sm3.java index 832a6ea7a03ae7..e7ad41f28ee593 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sm3.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sm3.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarBinaryType; import org.apache.doris.nereids.types.VarcharType; import com.google.common.base.Preconditions; @@ -39,7 +40,8 @@ public class Sm3 extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE) + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarBinaryType.INSTANCE) ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sm3sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sm3sum.java index b62d1f86d2deb4..809e9204ed9303 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sm3sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Sm3sum.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarBinaryType; import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.util.ExpressionUtils; @@ -39,7 +40,8 @@ public class Sm3sum extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(StringType.INSTANCE) + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(StringType.INSTANCE), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarBinaryType.INSTANCE) ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/XxHash32.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/XxHash32.java index 495021a4dc424e..375ef6afb8ea58 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/XxHash32.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/XxHash32.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarBinaryType; import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.util.ExpressionUtils; @@ -40,7 +41,8 @@ public class XxHash32 extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(IntegerType.INSTANCE).varArgs(VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(IntegerType.INSTANCE).varArgs(StringType.INSTANCE) + FunctionSignature.ret(IntegerType.INSTANCE).varArgs(StringType.INSTANCE), + FunctionSignature.ret(IntegerType.INSTANCE).varArgs(VarBinaryType.INSTANCE) ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/XxHash64.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/XxHash64.java index 8171822bb3aeae..1444dc89b2eefd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/XxHash64.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/XxHash64.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarBinaryType; import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.util.ExpressionUtils; @@ -40,7 +41,8 @@ public class XxHash64 extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE).varArgs(VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(BigIntType.INSTANCE).varArgs(StringType.INSTANCE) + FunctionSignature.ret(BigIntType.INSTANCE).varArgs(StringType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).varArgs(VarBinaryType.INSTANCE) ); /** diff --git a/regression-test/suites/query_p0/sql_functions/encryption_digest/test_binary_for_digest.groovy b/regression-test/suites/query_p0/sql_functions/encryption_digest/test_binary_for_digest.groovy new file mode 100644 index 00000000000000..9fa25c134a4dcb --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/encryption_digest/test_binary_for_digest.groovy @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_binary_for_digest", "p0,external,mysql,external_docker,external_docker_mysql") { + String enabled = context.config.otherConfigs.get("enableJdbcTest") + String externalEnvIp = context.config.otherConfigs.get("externalEnvIp") + String s3_endpoint = getS3Endpoint() + String bucket = getS3BucketName() + String driver_url = "https://${bucket}.${s3_endpoint}/regression/jdbc_driver/mysql-connector-java-8.0.25.jar" + + if (enabled != null && enabled.equalsIgnoreCase("true")) { + String catalog_name = "mysql_varbinary_hash_catalog"; + String ex_db_name = "doris_test"; + String mysql_port = context.config.otherConfigs.get("mysql_57_port"); + String test_table = "binary_test"; + + sql """drop catalog if exists ${catalog_name}""" + + sql """create catalog if not exists ${catalog_name} properties( + "type"="jdbc", + "user"="root", + "password"="123456", + "jdbc_url" = "jdbc:mysql://${externalEnvIp}:${mysql_port}/doris_test?useSSL=false", + "driver_url" = "${driver_url}", + "driver_class" = "com.mysql.cj.jdbc.Driver" + );""" + + connect("root", "123456", "jdbc:mysql://${externalEnvIp}:${mysql_port}/doris_test?useSSL=false") { + try_sql """DROP TABLE IF EXISTS ${test_table}""" + + sql """CREATE TABLE ${test_table} ( + id int, + vb varbinary(100), + vc VARCHAR(100) + )""" + + sql """INSERT INTO ${test_table} VALUES + (1, 'hello world', 'hello world'), + (2, 'test data', 'test data'), + (3, 'hash test', 'hash test'), + (4, '', ''), + (5, 'special chars: !@#%', 'special chars: !@#%')""" + } + + sql """switch ${catalog_name}""" + sql """use ${ex_db_name}""" + + def sha1_result = sql """select id, sha1(vb), sha1(vc) from ${test_table} order by id""" + for (int i = 0; i < sha1_result.size(); i++) { + assertTrue(sha1_result[i][1] == sha1_result[i][2], + "SHA1 hash mismatch for row ${sha1_result[i][0]}: VarBinary=${sha1_result[i][1]}, VARCHAR=${sha1_result[i][2]}") + } + + def sha2_256_result = sql """select id, sha2(vb, 256), sha2(vc, 256) from ${test_table} order by id""" + for (int i = 0; i < sha2_256_result.size(); i++) { + assertTrue(sha2_256_result[i][1] == sha2_256_result[i][2], + "SHA2-256 hash mismatch for row ${sha2_256_result[i][0]}: VarBinary=${sha2_256_result[i][1]}, VARCHAR=${sha2_256_result[i][2]}") + } + + def sha2_224_result = sql """select id, sha2(vb, 224), sha2(vc, 224) from ${test_table} order by id""" + for (int i = 0; i < sha2_224_result.size(); i++) { + assertTrue(sha2_224_result[i][1] == sha2_224_result[i][2], + "SHA2-224 hash mismatch for row ${sha2_224_result[i][0]}: VarBinary=${sha2_224_result[i][1]}, VARCHAR=${sha2_224_result[i][2]}") + } + + def sha2_384_result = sql """select id, sha2(vb, 384), sha2(vc, 384) from ${test_table} order by id""" + for (int i = 0; i < sha2_384_result.size(); i++) { + assertTrue(sha2_384_result[i][1] == sha2_384_result[i][2], + "SHA2-384 hash mismatch for row ${sha2_384_result[i][0]}: VarBinary=${sha2_384_result[i][1]}, VARCHAR=${sha2_384_result[i][2]}") + } + + def sha2_512_result = sql """select id, sha2(vb, 512), sha2(vc, 512) from ${test_table} order by id""" + for (int i = 0; i < sha2_512_result.size(); i++) { + assertTrue(sha2_512_result[i][1] == sha2_512_result[i][2], + "SHA2-512 hash mismatch for row ${sha2_512_result[i][0]}: VarBinary=${sha2_512_result[i][1]}, VARCHAR=${sha2_512_result[i][2]}") + } + + def md5_result = sql """select id, md5(vb), md5(vc) from ${test_table} order by id""" + for (int i = 0; i < md5_result.size(); i++) { + assertTrue(md5_result[i][1] == md5_result[i][2], + "MD5 hash mismatch for row ${md5_result[i][0]}: VarBinary=${md5_result[i][1]}, VARCHAR=${md5_result[i][2]}") + } + + def md5sum_result = sql """select id, md5sum(vb, vb, vb), md5sum(vc, vc, vc) from ${test_table} order by id""" + for (int i = 0; i < md5sum_result.size(); i++) { + assertTrue(md5sum_result[i][1] == md5sum_result[i][2], + "MD5SUM hash mismatch for row ${md5sum_result[i][0]}: VarBinary=${md5sum_result[i][1]}, VARCHAR=${md5sum_result[i][2]}") + } + + def sm3_result = sql """select id, sm3(vb), sm3(vc) from ${test_table} order by id""" + for (int i = 0; i < sm3_result.size(); i++) { + assertTrue(sm3_result[i][1] == sm3_result[i][2], + "SM3 hash mismatch for row ${sm3_result[i][0]}: VarBinary=${sm3_result[i][1]}, VARCHAR=${sm3_result[i][2]}") + } + + def sm3sum_result = sql """select id, sm3sum(vb, vb, vb), sm3sum(vc, vc, vc) from ${test_table} order by id""" + for (int i = 0; i < sm3sum_result.size(); i++) { + assertTrue(sm3sum_result[i][1] == sm3sum_result[i][2], + "SM3SUM hash mismatch for row ${sm3sum_result[i][0]}: VarBinary=${sm3sum_result[i][1]}, VARCHAR=${sm3sum_result[i][2]}") + } + + def xxhash32_result = sql """select id, xxhash_32(vb), xxhash_32(vc) from ${test_table} order by id""" + for (int i = 0; i < xxhash32_result.size(); i++) { + assertTrue(xxhash32_result[i][1] == xxhash32_result[i][2], + "xxHash32 mismatch for row ${xxhash32_result[i][0]}: VarBinary=${xxhash32_result[i][1]}, VARCHAR=${xxhash32_result[i][2]}") + } + + def xxhash64_result = sql """select id, xxhash_64(vb), xxhash_64(vc) from ${test_table} order by id""" + for (int i = 0; i < xxhash64_result.size(); i++) { + assertTrue(xxhash64_result[i][1] == xxhash64_result[i][2], + "xxHash64 mismatch for row ${xxhash64_result[i][0]}: VarBinary=${xxhash64_result[i][1]}, VARCHAR=${xxhash64_result[i][2]}") + } + + def variadic_xxhash32_result = sql """select id, xxhash_32(vb, vb), xxhash_32(vc, vc) from ${test_table} order by id""" + for (int i = 0; i < variadic_xxhash32_result.size(); i++) { + assertTrue(variadic_xxhash32_result[i][1] != null && variadic_xxhash32_result[i][2] != null, + "Variadic xxHash32 should work with mixed VarBinary and VARCHAR arguments for row ${variadic_xxhash32_result[i][0]}") + } + + def variadic_xxhash64_result = sql """select id, xxhash_64(vb, vb), xxhash_64(vc, vc) from ${test_table} order by id""" + for (int i = 0; i < variadic_xxhash64_result.size(); i++) { + assertTrue(variadic_xxhash64_result[i][1] != null && variadic_xxhash64_result[i][2] != null, + "Variadic xxHash64 should work with mixed VarBinary and VARCHAR arguments for row ${variadic_xxhash64_result[i][0]}") + } + + connect("root", "123456", "jdbc:mysql://${externalEnvIp}:${mysql_port}/doris_test?useSSL=false") { + try_sql """DROP TABLE IF EXISTS ${test_table}""" + } + + sql """drop catalog if exists ${catalog_name}""" + } +} \ No newline at end of file