From f8361470bbd87148a46cdafc3b4834dae50df71b Mon Sep 17 00:00:00 2001 From: Mryange <59914473+Mryange@users.noreply.github.com> Date: Fri, 28 Jun 2024 15:57:54 +0800 Subject: [PATCH] [opt](function) Optimize the trim function for single-char inputs (#36497) before ``` mysql [test]>select count(ltrim(str,"1")) from stringDb2; +------------------------+ | count(ltrim(str, '1')) | +------------------------+ | 64000000 | +------------------------+ 1 row in set (7.79 sec) ``` now ``` mysql [test]>select count(ltrim(str,"1")) from stringDb2; +------------------------+ | count(ltrim(str, '1')) | +------------------------+ | 64000000 | +------------------------+ 1 row in set (0.73 sec) ``` --- be/src/util/simd/vstring_function.h | 196 +++++------------- be/src/vec/functions/function_string.cpp | 54 +++-- .../test_trim_new_parameters.groovy | 3 + 3 files changed, 92 insertions(+), 161 deletions(-) diff --git a/be/src/util/simd/vstring_function.h b/be/src/util/simd/vstring_function.h index dac964b1b94224..4fff59a01df2d8 100644 --- a/be/src/util/simd/vstring_function.h +++ b/be/src/util/simd/vstring_function.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -100,169 +101,86 @@ class VStringFunctions { /// n equals to 16 chars length static constexpr auto REGISTER_SIZE = sizeof(__m128i); #endif -public: - static StringRef rtrim(const StringRef& str) { - if (str.size == 0) { - return str; - } - auto begin = 0; - int64_t end = str.size - 1; -#if defined(__SSE2__) || defined(__aarch64__) - char blank = ' '; - const auto pattern = _mm_set1_epi8(blank); - while (end - begin + 1 >= REGISTER_SIZE) { - const auto v_haystack = _mm_loadu_si128( - reinterpret_cast(str.data + end + 1 - REGISTER_SIZE)); - const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); - const auto mask = _mm_movemask_epi8(v_against_pattern); - int offset = __builtin_clz(~(mask << REGISTER_SIZE)); - /// means not found - if (offset == 0) { - return StringRef(str.data + begin, end - begin + 1); - } else { - end -= offset; - } - } -#endif - while (end >= begin && str.data[end] == ' ') { - --end; - } - if (end < 0) { - return StringRef(""); - } - return StringRef(str.data + begin, end - begin + 1); - } - - static StringRef ltrim(const StringRef& str) { - if (str.size == 0) { - return str; - } - auto begin = 0; - auto end = str.size - 1; -#if defined(__SSE2__) || defined(__aarch64__) - char blank = ' '; - const auto pattern = _mm_set1_epi8(blank); - while (end - begin + 1 >= REGISTER_SIZE) { - const auto v_haystack = - _mm_loadu_si128(reinterpret_cast(str.data + begin)); - const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); - const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 0xffff; - /// zero means not found - if (mask == 0) { - begin += REGISTER_SIZE; - } else { - const auto offset = __builtin_ctz(mask); - begin += offset; - return StringRef(str.data + begin, end - begin + 1); - } - } -#endif - while (begin <= end && str.data[begin] == ' ') { - ++begin; - } - return StringRef(str.data + begin, end - begin + 1); - } - static StringRef trim(const StringRef& str) { - if (str.size == 0) { - return str; + template + static inline const unsigned char* rtrim(const unsigned char* begin, const unsigned char* end, + const StringRef& remove_str) { + if (remove_str.size == 0) { + return end; } - return rtrim(ltrim(str)); - } + const auto* p = end; - static StringRef rtrim(const StringRef& str, const StringRef& rhs) { - if (str.size == 0 || rhs.size == 0) { - return str; - } - if (rhs.size == 1) { - auto begin = 0; - int64_t end = str.size - 1; - const char blank = rhs.data[0]; -#if defined(__SSE2__) || defined(__aarch64__) - const auto pattern = _mm_set1_epi8(blank); - while (end - begin + 1 >= REGISTER_SIZE) { - const auto v_haystack = _mm_loadu_si128( - reinterpret_cast(str.data + end + 1 - REGISTER_SIZE)); - const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); - const auto mask = _mm_movemask_epi8(v_against_pattern); - int offset = __builtin_clz(~(mask << REGISTER_SIZE)); - /// means not found - if (offset == 0) { - return StringRef(str.data + begin, end - begin + 1); - } else { - end -= offset; + if constexpr (trim_single) { + const auto ch = remove_str.data[0]; +#if defined(__AVX2__) || defined(__aarch64__) + constexpr auto AVX2_BYTES = sizeof(__m256i); + const auto size = end - begin; + const auto* const avx2_begin = end - size / AVX2_BYTES * AVX2_BYTES; + const auto spaces = _mm256_set1_epi8(ch); + for (p = end - AVX2_BYTES; p >= avx2_begin; p -= AVX2_BYTES) { + uint32_t masks = _mm256_movemask_epi8( + _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces)); + if ((~masks)) { + break; } } + p += AVX2_BYTES; #endif - while (end >= begin && str.data[end] == blank) { - --end; - } - if (end < 0) { - return StringRef(""); + for (; (p - 1) >= begin && *(p - 1) == ch; p--) { } - return StringRef(str.data + begin, end - begin + 1); + return p; } - auto begin = 0; - auto end = str.size - 1; - const auto rhs_size = rhs.size; - while (end - begin + 1 >= rhs_size) { - if (memcmp(str.data + end - rhs_size + 1, rhs.data, rhs_size) == 0) { - end -= rhs.size; + + const auto remove_size = remove_str.size; + const auto* const remove_data = remove_str.data; + while (p - begin >= remove_size) { + if (memcmp(p - remove_size, remove_data, remove_size) == 0) { + p -= remove_str.size; } else { break; } } - return StringRef(str.data + begin, end - begin + 1); + return p; } - static StringRef ltrim(const StringRef& str, const StringRef& rhs) { - if (str.size == 0 || rhs.size == 0) { - return str; + template + static inline const unsigned char* ltrim(const unsigned char* begin, const unsigned char* end, + const StringRef& remove_str) { + if (remove_str.size == 0) { + return begin; } - if (str.size == 1) { - auto begin = 0; - auto end = str.size - 1; - const char blank = rhs.data[0]; -#if defined(__SSE2__) || defined(__aarch64__) - const auto pattern = _mm_set1_epi8(blank); - while (end - begin + 1 >= REGISTER_SIZE) { - const auto v_haystack = - _mm_loadu_si128(reinterpret_cast(str.data + begin)); - const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); - const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 0xffff; - /// zero means not found - if (mask == 0) { - begin += REGISTER_SIZE; - } else { - const auto offset = __builtin_ctz(mask); - begin += offset; - return StringRef(str.data + begin, end - begin + 1); + const auto* p = begin; + + if constexpr (trim_single) { + const auto ch = remove_str.data[0]; +#if defined(__AVX2__) || defined(__aarch64__) + constexpr auto AVX2_BYTES = sizeof(__m256i); + const auto size = end - begin; + const auto* const avx2_end = begin + size / AVX2_BYTES * AVX2_BYTES; + const auto spaces = _mm256_set1_epi8(ch); + for (; p < avx2_end; p += AVX2_BYTES) { + uint32_t masks = _mm256_movemask_epi8( + _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces)); + if ((~masks)) { + break; } } #endif - while (begin <= end && str.data[begin] == blank) { - ++begin; + for (; p < end && *p == ch; ++p) { } - return StringRef(str.data + begin, end - begin + 1); + return p; } - auto begin = 0; - auto end = str.size - 1; - const auto rhs_size = rhs.size; - while (end - begin + 1 >= rhs_size) { - if (memcmp(str.data + begin, rhs.data, rhs_size) == 0) { - begin += rhs.size; + + const auto remove_size = remove_str.size; + const auto* const remove_data = remove_str.data; + while (end - p >= remove_size) { + if (memcmp(p, remove_data, remove_size) == 0) { + p += remove_str.size; } else { break; } } - return StringRef(str.data + begin, end - begin + 1); - } - - static StringRef trim(const StringRef& str, const StringRef& rhs) { - if (str.size == 0 || rhs.size == 0) { - return str; - } - return rtrim(ltrim(str, rhs), rhs); + return p; } // Gcc will do auto simd in this function diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index d4dae54612cac1..9216ad1b9c80e0 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -485,25 +485,29 @@ struct NameLTrim { struct NameRTrim { static constexpr auto name = "rtrim"; }; -template +template struct TrimUtil { static Status vector(const ColumnString::Chars& str_data, - const ColumnString::Offsets& str_offsets, const StringRef& rhs, + const ColumnString::Offsets& str_offsets, const StringRef& remove_str, ColumnString::Chars& res_data, ColumnString::Offsets& res_offsets) { - size_t offset_size = str_offsets.size(); - res_offsets.resize(str_offsets.size()); + const size_t offset_size = str_offsets.size(); + res_offsets.resize(offset_size); + res_data.reserve(str_data.size()); for (size_t i = 0; i < offset_size; ++i) { - const char* raw_str = reinterpret_cast(&str_data[str_offsets[i - 1]]); - ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1]; - StringRef str(raw_str, size); + const auto* str_begin = str_data.data() + str_offsets[i - 1]; + const auto* str_end = str_data.data() + str_offsets[i]; + if constexpr (is_ltrim) { - str = simd::VStringFunctions::ltrim(str, rhs); + str_begin = + simd::VStringFunctions::ltrim(str_begin, str_end, remove_str); } if constexpr (is_rtrim) { - str = simd::VStringFunctions::rtrim(str, rhs); + str_end = + simd::VStringFunctions::rtrim(str_begin, str_end, remove_str); } - StringOP::push_value_string(std::string_view((char*)str.data, str.size), i, res_data, - res_offsets); + + res_data.insert_assume_reserved(str_begin, str_end); + res_offsets[i] = res_data.size(); } return Status::OK(); } @@ -521,9 +525,9 @@ struct Trim1Impl { if (const auto* col = assert_cast(column.get())) { auto col_res = ColumnString::create(); char blank[] = " "; - StringRef rhs(blank, 1); - RETURN_IF_ERROR((TrimUtil::vector( - col->get_chars(), col->get_offsets(), rhs, col_res->get_chars(), + const StringRef remove_str(blank, 1); + RETURN_IF_ERROR((TrimUtil::vector( + col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(), col_res->get_offsets()))); block.replace_by_position(result, std::move(col_res)); } else { @@ -550,15 +554,21 @@ struct Trim2Impl { const auto& rcol = assert_cast(block.get_by_position(arguments[1]).column.get()) ->get_data_column_ptr(); - if (auto col = assert_cast(column.get())) { - if (auto col_right = assert_cast(rcol.get())) { + if (const auto* col = assert_cast(column.get())) { + if (const auto* col_right = assert_cast(rcol.get())) { auto col_res = ColumnString::create(); - const char* raw_rhs = reinterpret_cast(&(col_right->get_chars()[0])); - ColumnString::Offset rhs_size = col_right->get_offsets()[0]; - StringRef rhs(raw_rhs, rhs_size); - RETURN_IF_ERROR((TrimUtil::vector( - col->get_chars(), col->get_offsets(), rhs, col_res->get_chars(), - col_res->get_offsets()))); + const auto* remove_str_raw = col_right->get_chars().data(); + const ColumnString::Offset remove_str_size = col_right->get_offsets()[0]; + const StringRef remove_str(remove_str_raw, remove_str_size); + if (remove_str.size == 1) { + RETURN_IF_ERROR((TrimUtil::vector( + col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(), + col_res->get_offsets()))); + } else { + RETURN_IF_ERROR((TrimUtil::vector( + col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(), + col_res->get_offsets()))); + } block.replace_by_position(result, std::move(col_res)); } else { return Status::RuntimeError("Illegal column {} of argument of function {}", diff --git a/regression-test/suites/correctness/test_trim_new_parameters.groovy b/regression-test/suites/correctness/test_trim_new_parameters.groovy index 3209eb7aae743d..17ac4a0c65eae5 100644 --- a/regression-test/suites/correctness/test_trim_new_parameters.groovy +++ b/regression-test/suites/correctness/test_trim_new_parameters.groovy @@ -67,4 +67,7 @@ suite("test_trim_new_parameters") { rtrim = sql "select rtrim('bcTTTabcabc','abc')" assertEquals(rtrim[0][0], 'bcTTT') + + trim_one = sql "select trim('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa','a')" + assertEquals(trim_one[0][0], 'baaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaab') }