From 4750b8ff7e8e43090714e2e5693f64ea0f62e803 Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Tue, 18 Jun 2024 23:33:00 +0800 Subject: [PATCH 1/5] upd --- be/src/util/simd/vstring_function.h | 192 ++++++----------------- be/src/vec/functions/function_string.cpp | 47 +++--- 2 files changed, 79 insertions(+), 160 deletions(-) diff --git a/be/src/util/simd/vstring_function.h b/be/src/util/simd/vstring_function.h index dac964b1b94224..651ea1cbfd782d 100644 --- a/be/src/util/simd/vstring_function.h +++ b/be/src/util/simd/vstring_function.h @@ -100,169 +100,77 @@ class VStringFunctions { /// n equals to 16 chars length static constexpr auto REGISTER_SIZE = sizeof(__m128i); #endif -public: - static StringRef rtrim(const StringRef& str) { - if (str.size == 0) { - return str; - } - auto begin = 0; - int64_t end = str.size - 1; -#if defined(__SSE2__) || defined(__aarch64__) - char blank = ' '; - const auto pattern = _mm_set1_epi8(blank); - while (end - begin + 1 >= REGISTER_SIZE) { - const auto v_haystack = _mm_loadu_si128( - reinterpret_cast(str.data + end + 1 - REGISTER_SIZE)); - const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); - const auto mask = _mm_movemask_epi8(v_against_pattern); - int offset = __builtin_clz(~(mask << REGISTER_SIZE)); - /// means not found - if (offset == 0) { - return StringRef(str.data + begin, end - begin + 1); - } else { - end -= offset; - } - } -#endif - while (end >= begin && str.data[end] == ' ') { - --end; - } - if (end < 0) { - return StringRef(""); - } - return StringRef(str.data + begin, end - begin + 1); - } - static StringRef ltrim(const StringRef& str) { - if (str.size == 0) { - return str; + template + static inline const char* rtrim(const char* begin, const char* end, + const StringRef& remove_str) { + if (remove_str.size == 0) { + return end; } - auto begin = 0; - auto end = str.size - 1; + const char* p = end; #if defined(__SSE2__) || defined(__aarch64__) - char blank = ' '; - const auto pattern = _mm_set1_epi8(blank); - while (end - begin + 1 >= REGISTER_SIZE) { - const auto v_haystack = - _mm_loadu_si128(reinterpret_cast(str.data + begin)); - const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); - const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 0xffff; - /// zero means not found - if (mask == 0) { - begin += REGISTER_SIZE; - } else { - const auto offset = __builtin_ctz(mask); - begin += offset; - return StringRef(str.data + begin, end - begin + 1); - } - } -#endif - while (begin <= end && str.data[begin] == ' ') { - ++begin; - } - return StringRef(str.data + begin, end - begin + 1); - } - - static StringRef trim(const StringRef& str) { - if (str.size == 0) { - return str; - } - return rtrim(ltrim(str)); - } - - static StringRef rtrim(const StringRef& str, const StringRef& rhs) { - if (str.size == 0 || rhs.size == 0) { - return str; - } - if (rhs.size == 1) { - auto begin = 0; - int64_t end = str.size - 1; - const char blank = rhs.data[0]; -#if defined(__SSE2__) || defined(__aarch64__) - const auto pattern = _mm_set1_epi8(blank); - while (end - begin + 1 >= REGISTER_SIZE) { - const auto v_haystack = _mm_loadu_si128( - reinterpret_cast(str.data + end + 1 - REGISTER_SIZE)); - const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); - const auto mask = _mm_movemask_epi8(v_against_pattern); - int offset = __builtin_clz(~(mask << REGISTER_SIZE)); - /// means not found - if (offset == 0) { - return StringRef(str.data + begin, end - begin + 1); - } else { - end -= offset; + if constexpr (trim_single) { + const auto size = end - begin; + const auto SSE2_BYTES = sizeof(__m128i); + const auto* const sse2_begin = end - (size & ~(SSE2_BYTES - 1)); + const auto spaces = _mm_set1_epi8(remove_str.data[0]); + for (p = end - SSE2_BYTES; p >= sse2_begin; p -= SSE2_BYTES) { + uint32_t masks = + _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_loadu_si128((__m128i*)p), spaces)); + int pos = __builtin_clz(~(masks << SSE2_BYTES)); + if (pos < SSE2_BYTES) { + return p + SSE2_BYTES - pos; } } -#endif - while (end >= begin && str.data[end] == blank) { - --end; - } - if (end < 0) { - return StringRef(""); - } - return StringRef(str.data + begin, end - begin + 1); + p += SSE2_BYTES; } - auto begin = 0; - auto end = str.size - 1; - const auto rhs_size = rhs.size; - while (end - begin + 1 >= rhs_size) { - if (memcmp(str.data + end - rhs_size + 1, rhs.data, rhs_size) == 0) { - end -= rhs.size; +#endif + const auto remove_size = remove_str.size; + const auto* const remove_data = remove_str.data; + while (p - begin >= remove_size) { + if (memcmp(p - remove_size, remove_data, remove_size) == 0) { + p -= remove_str.size; } else { break; } } - return StringRef(str.data + begin, end - begin + 1); + return p; } - static StringRef ltrim(const StringRef& str, const StringRef& rhs) { - if (str.size == 0 || rhs.size == 0) { - return str; + template + static inline const char* ltrim(const char* begin, const char* end, + const StringRef& remove_str) { + if (remove_str.size == 0) { + return begin; } - if (str.size == 1) { - auto begin = 0; - auto end = str.size - 1; - const char blank = rhs.data[0]; + const char* p = begin; #if defined(__SSE2__) || defined(__aarch64__) - const auto pattern = _mm_set1_epi8(blank); - while (end - begin + 1 >= REGISTER_SIZE) { - const auto v_haystack = - _mm_loadu_si128(reinterpret_cast(str.data + begin)); - const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); - const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 0xffff; - /// zero means not found - if (mask == 0) { - begin += REGISTER_SIZE; - } else { - const auto offset = __builtin_ctz(mask); - begin += offset; - return StringRef(str.data + begin, end - begin + 1); + if constexpr (trim_single) { + const auto size = end - begin; + const auto SSE2_BYTES = sizeof(__m128i); + const auto* const sse2_end = begin + (size & ~(SSE2_BYTES - 1)); + const auto spaces = _mm_set1_epi8(remove_str.data[0]); + for (; p < sse2_end; p += SSE2_BYTES) { + uint32_t masks = + _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_loadu_si128((__m128i*)p), spaces)); + int pos = __builtin_ctz((1U << SSE2_BYTES) | ~masks); + if (pos < SSE2_BYTES) { + return p + pos; } } -#endif - while (begin <= end && str.data[begin] == blank) { - ++begin; - } - return StringRef(str.data + begin, end - begin + 1); } - auto begin = 0; - auto end = str.size - 1; - const auto rhs_size = rhs.size; - while (end - begin + 1 >= rhs_size) { - if (memcmp(str.data + begin, rhs.data, rhs_size) == 0) { - begin += rhs.size; +#endif + + const auto remove_size = remove_str.size; + const auto* const remove_data = remove_str.data; + while (end - p >= remove_size) { + if (memcmp(p, remove_data, remove_size) == 0) { + p += remove_str.size; } else { break; } } - return StringRef(str.data + begin, end - begin + 1); - } - - static StringRef trim(const StringRef& str, const StringRef& rhs) { - if (str.size == 0 || rhs.size == 0) { - return str; - } - return rtrim(ltrim(str, rhs), rhs); + return p; } // Gcc will do auto simd in this function diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index 407c9ffa1ce7c6..2e0deb044017b2 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -485,24 +485,28 @@ struct NameLTrim { struct NameRTrim { static constexpr auto name = "rtrim"; }; -template +template struct TrimUtil { static Status vector(const ColumnString::Chars& str_data, - const ColumnString::Offsets& str_offsets, const StringRef& rhs, + const ColumnString::Offsets& str_offsets, const StringRef& remove_str, ColumnString::Chars& res_data, ColumnString::Offsets& res_offsets) { size_t offset_size = str_offsets.size(); res_offsets.resize(str_offsets.size()); for (size_t i = 0; i < offset_size; ++i) { const char* raw_str = reinterpret_cast(&str_data[str_offsets[i - 1]]); - ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1]; - StringRef str(raw_str, size); + const ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1]; + const char* str_begin = raw_str; + const char* str_end = raw_str + size; + if constexpr (is_ltrim) { - str = simd::VStringFunctions::ltrim(str, rhs); + str_begin = + simd::VStringFunctions::ltrim(str_begin, str_end, remove_str); } if constexpr (is_rtrim) { - str = simd::VStringFunctions::rtrim(str, rhs); + str_end = + simd::VStringFunctions::rtrim(str_begin, str_end, remove_str); } - StringOP::push_value_string(std::string_view((char*)str.data, str.size), i, res_data, + StringOP::push_value_string(std::string_view(str_begin, str_end), i, res_data, res_offsets); } return Status::OK(); @@ -521,9 +525,9 @@ struct Trim1Impl { if (const auto* col = assert_cast(column.get())) { auto col_res = ColumnString::create(); char blank[] = " "; - StringRef rhs(blank, 1); - RETURN_IF_ERROR((TrimUtil::vector( - col->get_chars(), col->get_offsets(), rhs, col_res->get_chars(), + const StringRef remove_str(blank, 1); + RETURN_IF_ERROR((TrimUtil::vector( + col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(), col_res->get_offsets()))); block.replace_by_position(result, std::move(col_res)); } else { @@ -550,15 +554,22 @@ struct Trim2Impl { const auto& rcol = assert_cast(block.get_by_position(arguments[1]).column.get()) ->get_data_column_ptr(); - if (auto col = assert_cast(column.get())) { - if (auto col_right = assert_cast(rcol.get())) { + if (const auto* col = assert_cast(column.get())) { + if (const auto* col_right = assert_cast(rcol.get())) { auto col_res = ColumnString::create(); - const char* raw_rhs = reinterpret_cast(&(col_right->get_chars()[0])); - ColumnString::Offset rhs_size = col_right->get_offsets()[0]; - StringRef rhs(raw_rhs, rhs_size); - RETURN_IF_ERROR((TrimUtil::vector( - col->get_chars(), col->get_offsets(), rhs, col_res->get_chars(), - col_res->get_offsets()))); + const char* remove_str_raw_rhs = + reinterpret_cast(col_right->get_chars().data()); + const ColumnString::Offset remove_str_rhs_size = col_right->get_offsets()[0]; + const StringRef remove_str(remove_str_raw_rhs, remove_str_rhs_size); + if (remove_str.size == 1) { + RETURN_IF_ERROR((TrimUtil::vector( + col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(), + col_res->get_offsets()))); + } else { + RETURN_IF_ERROR((TrimUtil::vector( + col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(), + col_res->get_offsets()))); + } block.replace_by_position(result, std::move(col_res)); } else { return Status::RuntimeError("Illegal column {} of argument of function {}", From f7bd50debad462c7e3ac70c61e95b26b38e40041 Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Wed, 19 Jun 2024 09:09:05 +0800 Subject: [PATCH 2/5] upd --- be/src/util/simd/vstring_function.h | 12 ++++++------ be/src/vec/functions/function_string.cpp | 23 +++++++++++------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/be/src/util/simd/vstring_function.h b/be/src/util/simd/vstring_function.h index 651ea1cbfd782d..a1db67325e74c1 100644 --- a/be/src/util/simd/vstring_function.h +++ b/be/src/util/simd/vstring_function.h @@ -102,12 +102,12 @@ class VStringFunctions { #endif template - static inline const char* rtrim(const char* begin, const char* end, - const StringRef& remove_str) { + static inline const unsigned char* rtrim(const unsigned char* begin, const unsigned char* end, + const StringRef& remove_str) { if (remove_str.size == 0) { return end; } - const char* p = end; + const auto* p = end; #if defined(__SSE2__) || defined(__aarch64__) if constexpr (trim_single) { const auto size = end - begin; @@ -138,12 +138,12 @@ class VStringFunctions { } template - static inline const char* ltrim(const char* begin, const char* end, - const StringRef& remove_str) { + static inline const unsigned char* ltrim(const unsigned char* begin, const unsigned char* end, + const StringRef& remove_str) { if (remove_str.size == 0) { return begin; } - const char* p = begin; + const auto* p = begin; #if defined(__SSE2__) || defined(__aarch64__) if constexpr (trim_single) { const auto size = end - begin; diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index 2e0deb044017b2..5d49ec960e899e 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -490,13 +490,12 @@ struct TrimUtil { static Status vector(const ColumnString::Chars& str_data, const ColumnString::Offsets& str_offsets, const StringRef& remove_str, ColumnString::Chars& res_data, ColumnString::Offsets& res_offsets) { - size_t offset_size = str_offsets.size(); - res_offsets.resize(str_offsets.size()); + const size_t offset_size = str_offsets.size(); + res_offsets.resize(offset_size); + res_data.reserve(str_data.size()); for (size_t i = 0; i < offset_size; ++i) { - const char* raw_str = reinterpret_cast(&str_data[str_offsets[i - 1]]); - const ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1]; - const char* str_begin = raw_str; - const char* str_end = raw_str + size; + const auto* str_begin = str_data.data() + str_offsets[i - 1]; + const auto* str_end = str_data.data() + str_offsets[i]; if constexpr (is_ltrim) { str_begin = @@ -506,8 +505,9 @@ struct TrimUtil { str_end = simd::VStringFunctions::rtrim(str_begin, str_end, remove_str); } - StringOP::push_value_string(std::string_view(str_begin, str_end), i, res_data, - res_offsets); + + res_data.insert_assume_reserved(str_begin, str_end); + res_offsets[i] = res_data.size(); } return Status::OK(); } @@ -557,10 +557,9 @@ struct Trim2Impl { if (const auto* col = assert_cast(column.get())) { if (const auto* col_right = assert_cast(rcol.get())) { auto col_res = ColumnString::create(); - const char* remove_str_raw_rhs = - reinterpret_cast(col_right->get_chars().data()); - const ColumnString::Offset remove_str_rhs_size = col_right->get_offsets()[0]; - const StringRef remove_str(remove_str_raw_rhs, remove_str_rhs_size); + const auto* remove_str_raw = col_right->get_chars().data(); + const ColumnString::Offset remove_str_size = col_right->get_offsets()[0]; + const StringRef remove_str(remove_str_raw, remove_str_size); if (remove_str.size == 1) { RETURN_IF_ERROR((TrimUtil::vector( col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(), From d50b6d394ce2146a7a27b7140f6888c9783aac6b Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Wed, 19 Jun 2024 22:03:18 +0800 Subject: [PATCH 3/5] avx256 --- be/src/util/simd/vstring_function.h | 48 +++++++++++-------- .../test_trim_new_parameters.groovy | 3 ++ 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/be/src/util/simd/vstring_function.h b/be/src/util/simd/vstring_function.h index a1db67325e74c1..5da2a67e34e510 100644 --- a/be/src/util/simd/vstring_function.h +++ b/be/src/util/simd/vstring_function.h @@ -108,21 +108,24 @@ class VStringFunctions { return end; } const auto* p = end; -#if defined(__SSE2__) || defined(__aarch64__) +#if defined(__AVX2__) if constexpr (trim_single) { + constexpr auto AVX2_BYTES = sizeof(__m256i); + const auto ch = remove_str.data[0]; const auto size = end - begin; - const auto SSE2_BYTES = sizeof(__m128i); - const auto* const sse2_begin = end - (size & ~(SSE2_BYTES - 1)); - const auto spaces = _mm_set1_epi8(remove_str.data[0]); - for (p = end - SSE2_BYTES; p >= sse2_begin; p -= SSE2_BYTES) { - uint32_t masks = - _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_loadu_si128((__m128i*)p), spaces)); - int pos = __builtin_clz(~(masks << SSE2_BYTES)); - if (pos < SSE2_BYTES) { - return p + SSE2_BYTES - pos; + const auto* const avx2_begin = end - size / AVX2_BYTES * AVX2_BYTES; + const auto spaces = _mm256_set1_epi8(ch); + for (p = end - AVX2_BYTES; p >= avx2_begin; p -= AVX2_BYTES) { + uint32_t masks = _mm256_movemask_epi8( + _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces)); + if ((~masks)) { + break; } } - p += SSE2_BYTES; + p += AVX2_BYTES; + for (; (p - 1) >= begin && *(p - 1) == ch; p--) { + } + return p; } #endif const auto remove_size = remove_str.size; @@ -144,20 +147,23 @@ class VStringFunctions { return begin; } const auto* p = begin; -#if defined(__SSE2__) || defined(__aarch64__) +#if defined(__AVX2__) if constexpr (trim_single) { + constexpr auto AVX2_BYTES = sizeof(__m256i); + const auto ch = remove_str.data[0]; const auto size = end - begin; - const auto SSE2_BYTES = sizeof(__m128i); - const auto* const sse2_end = begin + (size & ~(SSE2_BYTES - 1)); - const auto spaces = _mm_set1_epi8(remove_str.data[0]); - for (; p < sse2_end; p += SSE2_BYTES) { - uint32_t masks = - _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_loadu_si128((__m128i*)p), spaces)); - int pos = __builtin_ctz((1U << SSE2_BYTES) | ~masks); - if (pos < SSE2_BYTES) { - return p + pos; + const auto* const avx2_end = begin + size / AVX2_BYTES * AVX2_BYTES; + const auto spaces = _mm256_set1_epi8(ch); + for (; p < avx2_end; p += AVX2_BYTES) { + uint32_t masks = _mm256_movemask_epi8( + _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces)); + if ((~masks)) { + break; } } + for (; p < end && *p == ch; ++p) { + } + return p; } #endif diff --git a/regression-test/suites/correctness/test_trim_new_parameters.groovy b/regression-test/suites/correctness/test_trim_new_parameters.groovy index 3209eb7aae743d..17ac4a0c65eae5 100644 --- a/regression-test/suites/correctness/test_trim_new_parameters.groovy +++ b/regression-test/suites/correctness/test_trim_new_parameters.groovy @@ -67,4 +67,7 @@ suite("test_trim_new_parameters") { rtrim = sql "select rtrim('bcTTTabcabc','abc')" assertEquals(rtrim[0][0], 'bcTTT') + + trim_one = sql "select trim('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa','a')" + assertEquals(trim_one[0][0], 'baaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaab') } From 7f675eaef0ae084ff6a1ef62e4fbe15b5e1d13b3 Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Thu, 20 Jun 2024 10:57:37 +0800 Subject: [PATCH 4/5] for arm --- be/src/util/simd/vstring_function.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/be/src/util/simd/vstring_function.h b/be/src/util/simd/vstring_function.h index 5da2a67e34e510..9c4e94b9a612dc 100644 --- a/be/src/util/simd/vstring_function.h +++ b/be/src/util/simd/vstring_function.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -108,8 +109,9 @@ class VStringFunctions { return end; } const auto* p = end; -#if defined(__AVX2__) + if constexpr (trim_single) { +#if defined(__AVX2__) || defined(__aarch64__) constexpr auto AVX2_BYTES = sizeof(__m256i); const auto ch = remove_str.data[0]; const auto size = end - begin; @@ -123,11 +125,12 @@ class VStringFunctions { } } p += AVX2_BYTES; +#endif for (; (p - 1) >= begin && *(p - 1) == ch; p--) { } return p; } -#endif + const auto remove_size = remove_str.size; const auto* const remove_data = remove_str.data; while (p - begin >= remove_size) { @@ -147,8 +150,9 @@ class VStringFunctions { return begin; } const auto* p = begin; -#if defined(__AVX2__) + if constexpr (trim_single) { +#if defined(__AVX2__) || defined(__aarch64__) constexpr auto AVX2_BYTES = sizeof(__m256i); const auto ch = remove_str.data[0]; const auto size = end - begin; @@ -161,11 +165,11 @@ class VStringFunctions { break; } } +#endif for (; p < end && *p == ch; ++p) { } return p; } -#endif const auto remove_size = remove_str.size; const auto* const remove_data = remove_str.data; From 3c8689b657f569882788764bed0d14688ef91fb4 Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Thu, 20 Jun 2024 11:26:59 +0800 Subject: [PATCH 5/5] fix --- be/src/util/simd/vstring_function.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/be/src/util/simd/vstring_function.h b/be/src/util/simd/vstring_function.h index 9c4e94b9a612dc..4fff59a01df2d8 100644 --- a/be/src/util/simd/vstring_function.h +++ b/be/src/util/simd/vstring_function.h @@ -111,9 +111,9 @@ class VStringFunctions { const auto* p = end; if constexpr (trim_single) { + const auto ch = remove_str.data[0]; #if defined(__AVX2__) || defined(__aarch64__) constexpr auto AVX2_BYTES = sizeof(__m256i); - const auto ch = remove_str.data[0]; const auto size = end - begin; const auto* const avx2_begin = end - size / AVX2_BYTES * AVX2_BYTES; const auto spaces = _mm256_set1_epi8(ch); @@ -152,9 +152,9 @@ class VStringFunctions { const auto* p = begin; if constexpr (trim_single) { + const auto ch = remove_str.data[0]; #if defined(__AVX2__) || defined(__aarch64__) constexpr auto AVX2_BYTES = sizeof(__m256i); - const auto ch = remove_str.data[0]; const auto size = end - begin; const auto* const avx2_end = begin + size / AVX2_BYTES * AVX2_BYTES; const auto spaces = _mm256_set1_epi8(ch);