Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions be/src/vec/functions/function_hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "vec/columns/column.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_string.h"
#include "vec/columns/column_varbinary.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/core/field.h"
Expand Down Expand Up @@ -177,6 +178,17 @@ struct XxHashImpl {
HashUtil::xxHash64WithSeed(value.data(), value.size(), col_to_data[i]);
}
}
} else if (const auto* vb_col = check_and_get_column<ColumnVarbinary>(column)) {
for (size_t i = 0; i < input_rows_count; ++i) {
auto data_ref = vb_col->get_data_at(i);
if constexpr (ReturnType == TYPE_INT) {
col_to_data[i] = HashUtil::xxHash32WithSeed(data_ref.data, data_ref.size,
col_to_data[i]);
} else {
col_to_data[i] = HashUtil::xxHash64WithSeed(data_ref.data, data_ref.size,
col_to_data[i]);
}
}
} else {
DCHECK(false);
return Status::NotSupported("Illegal column {} of argument of function {}",
Expand Down
8 changes: 4 additions & 4 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1458,8 +1458,8 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<
FunctionStringFormatRound<FormatRoundDecimalImpl<TYPE_DECIMAL128I>>>();
factory.register_function<FunctionStringFormatRound<FormatRoundDecimalImpl<TYPE_DECIMAL256>>>();
factory.register_function<FunctionStringDigestOneArg<SM3Sum>>();
factory.register_function<FunctionStringDigestOneArg<MD5Sum>>();
factory.register_function<FunctionStringDigestMulti<SM3Sum>>();
factory.register_function<FunctionStringDigestMulti<MD5Sum>>();
factory.register_function<FunctionStringDigestSHA1>();
factory.register_function<FunctionStringDigestSHA2>();
factory.register_function<FunctionReplace<ReplaceImpl, true>>();
Expand All @@ -1482,9 +1482,9 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_alias(SubstringUtil::name, "substr");
factory.register_alias(FunctionToLower::name, "lcase");
factory.register_alias(FunctionToUpper::name, "ucase");
factory.register_alias(FunctionStringDigestOneArg<MD5Sum>::name, "md5");
factory.register_alias(FunctionStringDigestMulti<MD5Sum>::name, "md5");
factory.register_alias(FunctionStringUTF8Length::name, "character_length");
factory.register_alias(FunctionStringDigestOneArg<SM3Sum>::name, "sm3");
factory.register_alias(FunctionStringDigestMulti<SM3Sum>::name, "sm3");
factory.register_alias(FunctionStringDigestSHA1::name, "sha");
factory.register_alias(FunctionStringLocatePos::name, "position");
}
Expand Down
133 changes: 79 additions & 54 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_varbinary.h"
#include "vec/columns/column_vector.h"
#include "vec/common/hash_table/phmap_fwd_decl.h"
#include "vec/common/int_exp.h"
Expand Down Expand Up @@ -2335,10 +2336,10 @@ struct MD5Sum {
};

template <typename Impl>
class FunctionStringDigestOneArg : public IFunction {
class FunctionStringDigestMulti : public IFunction {
public:
static constexpr auto name = Impl::name;
static FunctionPtr create() { return std::make_shared<FunctionStringDigestOneArg>(); }
static FunctionPtr create() { return std::make_shared<FunctionStringDigestMulti>(); }
String get_name() const override { return name; }
size_t get_number_of_arguments() const override { return 0; }
bool is_variadic() const override { return true; }
Expand All @@ -2351,51 +2352,54 @@ class FunctionStringDigestOneArg : public IFunction {
uint32_t result, size_t input_rows_count) const override {
DCHECK_GE(arguments.size(), 1);

int argument_size = arguments.size();
std::vector<ColumnPtr> argument_columns(argument_size);
auto res = ColumnString::create();
auto& res_data = res->get_chars();
auto& res_offset = res->get_offsets();
res_offset.resize(input_rows_count);

std::vector<const ColumnString::Offsets*> offsets_list(argument_size);
std::vector<const ColumnString::Chars*> chars_list(argument_size);
std::vector<ColumnPtr> argument_columns(arguments.size());
std::vector<uint8_t> is_const(arguments.size(), 0);
for (size_t i = 0; i < arguments.size(); ++i) {
std::tie(argument_columns[i], is_const[i]) =
unpack_if_const(block.get_by_position(arguments[i]).column);
}

for (int i = 0; i < argument_size; ++i) {
argument_columns[i] =
block.get_by_position(arguments[i]).column->convert_to_full_column_if_const();
if (const auto* col_str = assert_cast<const ColumnString*>(argument_columns[i].get())) {
offsets_list[i] = &col_str->get_offsets();
chars_list[i] = &col_str->get_chars();
} else {
return Status::RuntimeError("Illegal column {} of argument of function {}",
block.get_by_position(arguments[0]).column->get_name(),
get_name());
}
if (check_and_get_column<ColumnString>(argument_columns[0].get())) {
vector_execute<ColumnString>(block, input_rows_count, argument_columns, is_const,
res_data, res_offset);
} else if (check_and_get_column<ColumnVarbinary>(argument_columns[0].get())) {
vector_execute<ColumnVarbinary>(block, input_rows_count, argument_columns, is_const,
res_data, res_offset);
} else {
return Status::RuntimeError("Illegal column {} of argument of function {}",
argument_columns[0]->get_name(), get_name());
}

auto res = ColumnString::create();
auto& res_data = res->get_chars();
auto& res_offset = res->get_offsets();
block.replace_by_position(result, std::move(res));
return Status::OK();
}

res_offset.resize(input_rows_count);
private:
template <typename ColumnType>
void vector_execute(Block& block, size_t input_rows_count,
const std::vector<ColumnPtr>& argument_columns,
const std::vector<uint8_t>& is_const, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offset) const {
using ObjectData = typename Impl::ObjectData;
for (size_t i = 0; i < input_rows_count; ++i) {
using ObjectData = typename Impl::ObjectData;
ObjectData digest;
for (size_t j = 0; j < offsets_list.size(); ++j) {
const auto& current_offsets = *offsets_list[j];
const auto& current_chars = *chars_list[j];

int size = current_offsets[i] - current_offsets[i - 1];
if (size < 1) {
for (size_t j = 0; j < argument_columns.size(); ++j) {
const auto* col = assert_cast<const ColumnType*>(argument_columns[j].get());
StringRef data_ref = col->get_data_at(is_const[j] ? 0 : i);
if (data_ref.size < 1) {
continue;
}
digest.update(&current_chars[current_offsets[i - 1]], size);
digest.update(data_ref.data, data_ref.size);
}
digest.digest();

StringOP::push_value_string(std::string_view(digest.hex().c_str(), digest.hex().size()),
i, res_data, res_offset);
}

block.replace_by_position(result, std::move(res));
return Status::OK();
}
};

Expand All @@ -2414,27 +2418,37 @@ class FunctionStringDigestSHA1 : public IFunction {
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
DCHECK_EQ(arguments.size(), 1);

ColumnPtr str_col = block.get_by_position(arguments[0]).column;
auto& data = assert_cast<const ColumnString*>(str_col.get())->get_chars();
auto& offset = assert_cast<const ColumnString*>(str_col.get())->get_offsets();
ColumnPtr data_col = block.get_by_position(arguments[0]).column;

auto res_col = ColumnString::create();
auto& res_data = res_col->get_chars();
auto& res_offset = res_col->get_offsets();
res_offset.resize(input_rows_count);
if (const auto* str_col = check_and_get_column<ColumnString>(data_col.get())) {
vector_execute(str_col, input_rows_count, res_data, res_offset);
} else if (const auto* vb_col = check_and_get_column<ColumnVarbinary>(data_col.get())) {
vector_execute(vb_col, input_rows_count, res_data, res_offset);
} else {
return Status::RuntimeError("Illegal column {} of argument of function {}",
data_col->get_name(), get_name());
}

block.replace_by_position(result, std::move(res_col));
return Status::OK();
}

private:
template <typename ColumnType>
void vector_execute(const ColumnType* col, size_t input_rows_count,
ColumnString::Chars& res_data, ColumnString::Offsets& res_offset) const {
SHA1Digest digest;
for (size_t i = 0; i < input_rows_count; ++i) {
int size = offset[i] - offset[i - 1];
digest.reset(&data[offset[i - 1]], size);
StringRef data_ref = col->get_data_at(i);
digest.reset(data_ref.data, data_ref.size);
std::string_view ans = digest.digest();

StringOP::push_value_string(ans, i, res_data, res_offset);
}

block.replace_by_position(result, std::move(res_col));
return Status::OK();
}
};

Expand All @@ -2454,9 +2468,7 @@ class FunctionStringDigestSHA2 : public IFunction {
uint32_t result, size_t input_rows_count) const override {
DCHECK(!is_column_const(*block.get_by_position(arguments[0]).column));

ColumnPtr str_col = block.get_by_position(arguments[0]).column;
auto& data = assert_cast<const ColumnString*>(str_col.get())->get_chars();
auto& offset = assert_cast<const ColumnString*>(str_col.get())->get_offsets();
ColumnPtr data_col = block.get_by_position(arguments[0]).column;

[[maybe_unused]] const auto& [right_column, right_const] =
unpack_if_const(block.get_by_position(arguments[1]).column);
Expand All @@ -2468,13 +2480,13 @@ class FunctionStringDigestSHA2 : public IFunction {
res_offset.resize(input_rows_count);

if (digest_length == 224) {
execute_base<SHA224Digest>(data, offset, input_rows_count, res_data, res_offset);
execute_base<SHA224Digest>(data_col, input_rows_count, res_data, res_offset);
} else if (digest_length == 256) {
execute_base<SHA256Digest>(data, offset, input_rows_count, res_data, res_offset);
execute_base<SHA256Digest>(data_col, input_rows_count, res_data, res_offset);
} else if (digest_length == 384) {
execute_base<SHA384Digest>(data, offset, input_rows_count, res_data, res_offset);
execute_base<SHA384Digest>(data_col, input_rows_count, res_data, res_offset);
} else if (digest_length == 512) {
execute_base<SHA512Digest>(data, offset, input_rows_count, res_data, res_offset);
execute_base<SHA512Digest>(data_col, input_rows_count, res_data, res_offset);
} else {
return Status::InvalidArgument(
"sha2's digest length only support 224/256/384/512 but meet {}", digest_length);
Expand All @@ -2486,13 +2498,26 @@ class FunctionStringDigestSHA2 : public IFunction {

private:
template <typename T>
void execute_base(const ColumnString::Chars& data, const ColumnString::Offsets& offset,
int input_rows_count, ColumnString::Chars& res_data,
void execute_base(ColumnPtr data_col, int input_rows_count, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offset) const {
T digest;
if (const auto* str_col = check_and_get_column<ColumnString>(data_col.get())) {
vector_execute<T>(str_col, input_rows_count, res_data, res_offset);
} else if (const auto* vb_col = check_and_get_column<ColumnVarbinary>(data_col.get())) {
vector_execute<T>(vb_col, input_rows_count, res_data, res_offset);
} else {
throw Exception(ErrorCode::RUNTIME_ERROR,
"Illegal column {} of argument of function {}", data_col->get_name(),
get_name());
}
}

template <typename DigestType, typename ColumnType>
void vector_execute(const ColumnType* col, size_t input_rows_count,
ColumnString::Chars& res_data, ColumnString::Offsets& res_offset) const {
DigestType digest;
for (size_t i = 0; i < input_rows_count; ++i) {
int size = offset[i] - offset[i - 1];
digest.reset(&data[offset[i - 1]], size);
StringRef data_ref = col->get_data_at(i);
digest.reset(data_ref.data, data_ref.size);
std::string_view ans = digest.digest();

StringOP::push_value_string(ans, i, res_data, res_offset);
Expand Down
58 changes: 58 additions & 0 deletions be/test/vec/function/function_hash_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,34 @@ TEST(HashFunctionTest, xxhash_32_test) {

static_cast<void>(check_function<DataTypeInt32, true>(func_name, input_types, data_set));
};

{
InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY};

DataSet data_set = {{{Null()}, Null()}, {{std::string("hello")}, (int32_t)-83855367}};

static_cast<void>(check_function<DataTypeInt32, true>(func_name, input_types, data_set));
};

{
InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY};

DataSet data_set = {{{std::string("hello"), std::string("world")}, (int32_t)-920844969},
{{std::string("hello"), Null()}, Null()}};

static_cast<void>(check_function<DataTypeInt32, true>(func_name, input_types, data_set));
};

{
InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY,
PrimitiveType::TYPE_VARBINARY};

DataSet data_set = {{{std::string("hello"), std::string("world"), std::string("!")},
(int32_t)352087701},
{{std::string("hello"), std::string("world"), Null()}, Null()}};

static_cast<void>(check_function<DataTypeInt32, true>(func_name, input_types, data_set));
};
}

TEST(HashFunctionTest, xxhash_64_test) {
Expand Down Expand Up @@ -160,6 +188,36 @@ TEST(HashFunctionTest, xxhash_64_test) {

static_cast<void>(check_function<DataTypeInt64, true>(func_name, input_types, data_set));
};

{
InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY};

DataSet data_set = {{{Null()}, Null()},
{{std::string("hello")}, (int64_t)-7685981735718036227}};

static_cast<void>(check_function<DataTypeInt64, true>(func_name, input_types, data_set));
};

{
InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY};

DataSet data_set = {
{{std::string("hello"), std::string("world")}, (int64_t)7001965798170371843},
{{std::string("hello"), Null()}, Null()}};

static_cast<void>(check_function<DataTypeInt64, true>(func_name, input_types, data_set));
};

{
InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, PrimitiveType::TYPE_VARBINARY,
PrimitiveType::TYPE_VARBINARY};

DataSet data_set = {{{std::string("hello"), std::string("world"), std::string("!")},
(int64_t)6796829678999971400},
{{std::string("hello"), std::string("world"), Null()}, Null()}};

static_cast<void>(check_function<DataTypeInt64, true>(func_name, input_types, data_set));
};
}

} // namespace doris::vectorized
Loading
Loading