diff --git a/be/src/util/simd/vstring_function.h b/be/src/util/simd/vstring_function.h index a1db67325e74c15..5da2a67e34e5100 100644 --- a/be/src/util/simd/vstring_function.h +++ b/be/src/util/simd/vstring_function.h @@ -108,21 +108,24 @@ class VStringFunctions { return end; } const auto* p = end; -#if defined(__SSE2__) || defined(__aarch64__) +#if defined(__AVX2__) if constexpr (trim_single) { + constexpr auto AVX2_BYTES = sizeof(__m256i); + const auto ch = remove_str.data[0]; const auto size = end - begin; - const auto SSE2_BYTES = sizeof(__m128i); - const auto* const sse2_begin = end - (size & ~(SSE2_BYTES - 1)); - const auto spaces = _mm_set1_epi8(remove_str.data[0]); - for (p = end - SSE2_BYTES; p >= sse2_begin; p -= SSE2_BYTES) { - uint32_t masks = - _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_loadu_si128((__m128i*)p), spaces)); - int pos = __builtin_clz(~(masks << SSE2_BYTES)); - if (pos < SSE2_BYTES) { - return p + SSE2_BYTES - pos; + const auto* const avx2_begin = end - size / AVX2_BYTES * AVX2_BYTES; + const auto spaces = _mm256_set1_epi8(ch); + for (p = end - AVX2_BYTES; p >= avx2_begin; p -= AVX2_BYTES) { + uint32_t masks = _mm256_movemask_epi8( + _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces)); + if ((~masks)) { + break; } } - p += SSE2_BYTES; + p += AVX2_BYTES; + for (; (p - 1) >= begin && *(p - 1) == ch; p--) { + } + return p; } #endif const auto remove_size = remove_str.size; @@ -144,20 +147,23 @@ class VStringFunctions { return begin; } const auto* p = begin; -#if defined(__SSE2__) || defined(__aarch64__) +#if defined(__AVX2__) if constexpr (trim_single) { + constexpr auto AVX2_BYTES = sizeof(__m256i); + const auto ch = remove_str.data[0]; const auto size = end - begin; - const auto SSE2_BYTES = sizeof(__m128i); - const auto* const sse2_end = begin + (size & ~(SSE2_BYTES - 1)); - const auto spaces = _mm_set1_epi8(remove_str.data[0]); - for (; p < sse2_end; p += SSE2_BYTES) { - uint32_t masks = - _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_loadu_si128((__m128i*)p), spaces)); - int pos = __builtin_ctz((1U << SSE2_BYTES) | ~masks); - if (pos < SSE2_BYTES) { - return p + pos; + const auto* const avx2_end = begin + size / AVX2_BYTES * AVX2_BYTES; + const auto spaces = _mm256_set1_epi8(ch); + for (; p < avx2_end; p += AVX2_BYTES) { + uint32_t masks = _mm256_movemask_epi8( + _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces)); + if ((~masks)) { + break; } } + for (; p < end && *p == ch; ++p) { + } + return p; } #endif diff --git a/regression-test/suites/correctness/test_trim_new_parameters.groovy b/regression-test/suites/correctness/test_trim_new_parameters.groovy index 3209eb7aae743de..17ac4a0c65eae5b 100644 --- a/regression-test/suites/correctness/test_trim_new_parameters.groovy +++ b/regression-test/suites/correctness/test_trim_new_parameters.groovy @@ -67,4 +67,7 @@ suite("test_trim_new_parameters") { rtrim = sql "select rtrim('bcTTTabcabc','abc')" assertEquals(rtrim[0][0], 'bcTTT') + + trim_one = sql "select trim('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa','a')" + assertEquals(trim_one[0][0], 'baaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaab') }