Skip to content

Commit

Permalink
[opt](function) Optimize the trim function for single-char inputs (#3…
Browse files Browse the repository at this point in the history
…6497)

before
```
mysql [test]>select count(ltrim(str,"1")) from stringDb2;
+------------------------+
| count(ltrim(str, '1')) |
+------------------------+
|               64000000 |
+------------------------+
1 row in set (7.79 sec)
```

now
```
mysql [test]>select count(ltrim(str,"1")) from stringDb2;
+------------------------+
| count(ltrim(str, '1')) |
+------------------------+
|               64000000 |
+------------------------+
1 row in set (0.73 sec)
```
  • Loading branch information
Mryange authored Jun 28, 2024
1 parent 95594d6 commit f836147
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 161 deletions.
196 changes: 57 additions & 139 deletions be/src/util/simd/vstring_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <immintrin.h>
#include <unistd.h>

#include <array>
Expand Down Expand Up @@ -100,169 +101,86 @@ class VStringFunctions {
/// n equals to 16 chars length
static constexpr auto REGISTER_SIZE = sizeof(__m128i);
#endif
public:
static StringRef rtrim(const StringRef& str) {
if (str.size == 0) {
return str;
}
auto begin = 0;
int64_t end = str.size - 1;
#if defined(__SSE2__) || defined(__aarch64__)
char blank = ' ';
const auto pattern = _mm_set1_epi8(blank);
while (end - begin + 1 >= REGISTER_SIZE) {
const auto v_haystack = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(str.data + end + 1 - REGISTER_SIZE));
const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern);
const auto mask = _mm_movemask_epi8(v_against_pattern);
int offset = __builtin_clz(~(mask << REGISTER_SIZE));
/// means not found
if (offset == 0) {
return StringRef(str.data + begin, end - begin + 1);
} else {
end -= offset;
}
}
#endif
while (end >= begin && str.data[end] == ' ') {
--end;
}
if (end < 0) {
return StringRef("");
}
return StringRef(str.data + begin, end - begin + 1);
}

static StringRef ltrim(const StringRef& str) {
if (str.size == 0) {
return str;
}
auto begin = 0;
auto end = str.size - 1;
#if defined(__SSE2__) || defined(__aarch64__)
char blank = ' ';
const auto pattern = _mm_set1_epi8(blank);
while (end - begin + 1 >= REGISTER_SIZE) {
const auto v_haystack =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(str.data + begin));
const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern);
const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 0xffff;
/// zero means not found
if (mask == 0) {
begin += REGISTER_SIZE;
} else {
const auto offset = __builtin_ctz(mask);
begin += offset;
return StringRef(str.data + begin, end - begin + 1);
}
}
#endif
while (begin <= end && str.data[begin] == ' ') {
++begin;
}
return StringRef(str.data + begin, end - begin + 1);
}

static StringRef trim(const StringRef& str) {
if (str.size == 0) {
return str;
template <bool trim_single>
static inline const unsigned char* rtrim(const unsigned char* begin, const unsigned char* end,
const StringRef& remove_str) {
if (remove_str.size == 0) {
return end;
}
return rtrim(ltrim(str));
}
const auto* p = end;

static StringRef rtrim(const StringRef& str, const StringRef& rhs) {
if (str.size == 0 || rhs.size == 0) {
return str;
}
if (rhs.size == 1) {
auto begin = 0;
int64_t end = str.size - 1;
const char blank = rhs.data[0];
#if defined(__SSE2__) || defined(__aarch64__)
const auto pattern = _mm_set1_epi8(blank);
while (end - begin + 1 >= REGISTER_SIZE) {
const auto v_haystack = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(str.data + end + 1 - REGISTER_SIZE));
const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern);
const auto mask = _mm_movemask_epi8(v_against_pattern);
int offset = __builtin_clz(~(mask << REGISTER_SIZE));
/// means not found
if (offset == 0) {
return StringRef(str.data + begin, end - begin + 1);
} else {
end -= offset;
if constexpr (trim_single) {
const auto ch = remove_str.data[0];
#if defined(__AVX2__) || defined(__aarch64__)
constexpr auto AVX2_BYTES = sizeof(__m256i);
const auto size = end - begin;
const auto* const avx2_begin = end - size / AVX2_BYTES * AVX2_BYTES;
const auto spaces = _mm256_set1_epi8(ch);
for (p = end - AVX2_BYTES; p >= avx2_begin; p -= AVX2_BYTES) {
uint32_t masks = _mm256_movemask_epi8(
_mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces));
if ((~masks)) {
break;
}
}
p += AVX2_BYTES;
#endif
while (end >= begin && str.data[end] == blank) {
--end;
}
if (end < 0) {
return StringRef("");
for (; (p - 1) >= begin && *(p - 1) == ch; p--) {
}
return StringRef(str.data + begin, end - begin + 1);
return p;
}
auto begin = 0;
auto end = str.size - 1;
const auto rhs_size = rhs.size;
while (end - begin + 1 >= rhs_size) {
if (memcmp(str.data + end - rhs_size + 1, rhs.data, rhs_size) == 0) {
end -= rhs.size;

const auto remove_size = remove_str.size;
const auto* const remove_data = remove_str.data;
while (p - begin >= remove_size) {
if (memcmp(p - remove_size, remove_data, remove_size) == 0) {
p -= remove_str.size;
} else {
break;
}
}
return StringRef(str.data + begin, end - begin + 1);
return p;
}

static StringRef ltrim(const StringRef& str, const StringRef& rhs) {
if (str.size == 0 || rhs.size == 0) {
return str;
template <bool trim_single>
static inline const unsigned char* ltrim(const unsigned char* begin, const unsigned char* end,
const StringRef& remove_str) {
if (remove_str.size == 0) {
return begin;
}
if (str.size == 1) {
auto begin = 0;
auto end = str.size - 1;
const char blank = rhs.data[0];
#if defined(__SSE2__) || defined(__aarch64__)
const auto pattern = _mm_set1_epi8(blank);
while (end - begin + 1 >= REGISTER_SIZE) {
const auto v_haystack =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(str.data + begin));
const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern);
const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 0xffff;
/// zero means not found
if (mask == 0) {
begin += REGISTER_SIZE;
} else {
const auto offset = __builtin_ctz(mask);
begin += offset;
return StringRef(str.data + begin, end - begin + 1);
const auto* p = begin;

if constexpr (trim_single) {
const auto ch = remove_str.data[0];
#if defined(__AVX2__) || defined(__aarch64__)
constexpr auto AVX2_BYTES = sizeof(__m256i);
const auto size = end - begin;
const auto* const avx2_end = begin + size / AVX2_BYTES * AVX2_BYTES;
const auto spaces = _mm256_set1_epi8(ch);
for (; p < avx2_end; p += AVX2_BYTES) {
uint32_t masks = _mm256_movemask_epi8(
_mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces));
if ((~masks)) {
break;
}
}
#endif
while (begin <= end && str.data[begin] == blank) {
++begin;
for (; p < end && *p == ch; ++p) {
}
return StringRef(str.data + begin, end - begin + 1);
return p;
}
auto begin = 0;
auto end = str.size - 1;
const auto rhs_size = rhs.size;
while (end - begin + 1 >= rhs_size) {
if (memcmp(str.data + begin, rhs.data, rhs_size) == 0) {
begin += rhs.size;

const auto remove_size = remove_str.size;
const auto* const remove_data = remove_str.data;
while (end - p >= remove_size) {
if (memcmp(p, remove_data, remove_size) == 0) {
p += remove_str.size;
} else {
break;
}
}
return StringRef(str.data + begin, end - begin + 1);
}

static StringRef trim(const StringRef& str, const StringRef& rhs) {
if (str.size == 0 || rhs.size == 0) {
return str;
}
return rtrim(ltrim(str, rhs), rhs);
return p;
}

// Gcc will do auto simd in this function
Expand Down
54 changes: 32 additions & 22 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,25 +485,29 @@ struct NameLTrim {
struct NameRTrim {
static constexpr auto name = "rtrim";
};
template <bool is_ltrim, bool is_rtrim>
template <bool is_ltrim, bool is_rtrim, bool trim_single>
struct TrimUtil {
static Status vector(const ColumnString::Chars& str_data,
const ColumnString::Offsets& str_offsets, const StringRef& rhs,
const ColumnString::Offsets& str_offsets, const StringRef& remove_str,
ColumnString::Chars& res_data, ColumnString::Offsets& res_offsets) {
size_t offset_size = str_offsets.size();
res_offsets.resize(str_offsets.size());
const size_t offset_size = str_offsets.size();
res_offsets.resize(offset_size);
res_data.reserve(str_data.size());
for (size_t i = 0; i < offset_size; ++i) {
const char* raw_str = reinterpret_cast<const char*>(&str_data[str_offsets[i - 1]]);
ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1];
StringRef str(raw_str, size);
const auto* str_begin = str_data.data() + str_offsets[i - 1];
const auto* str_end = str_data.data() + str_offsets[i];

if constexpr (is_ltrim) {
str = simd::VStringFunctions::ltrim(str, rhs);
str_begin =
simd::VStringFunctions::ltrim<trim_single>(str_begin, str_end, remove_str);
}
if constexpr (is_rtrim) {
str = simd::VStringFunctions::rtrim(str, rhs);
str_end =
simd::VStringFunctions::rtrim<trim_single>(str_begin, str_end, remove_str);
}
StringOP::push_value_string(std::string_view((char*)str.data, str.size), i, res_data,
res_offsets);

res_data.insert_assume_reserved(str_begin, str_end);
res_offsets[i] = res_data.size();
}
return Status::OK();
}
Expand All @@ -521,9 +525,9 @@ struct Trim1Impl {
if (const auto* col = assert_cast<const ColumnString*>(column.get())) {
auto col_res = ColumnString::create();
char blank[] = " ";
StringRef rhs(blank, 1);
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim>::vector(
col->get_chars(), col->get_offsets(), rhs, col_res->get_chars(),
const StringRef remove_str(blank, 1);
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, true>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
block.replace_by_position(result, std::move(col_res));
} else {
Expand All @@ -550,15 +554,21 @@ struct Trim2Impl {
const auto& rcol =
assert_cast<const ColumnConst*>(block.get_by_position(arguments[1]).column.get())
->get_data_column_ptr();
if (auto col = assert_cast<const ColumnString*>(column.get())) {
if (auto col_right = assert_cast<const ColumnString*>(rcol.get())) {
if (const auto* col = assert_cast<const ColumnString*>(column.get())) {
if (const auto* col_right = assert_cast<const ColumnString*>(rcol.get())) {
auto col_res = ColumnString::create();
const char* raw_rhs = reinterpret_cast<const char*>(&(col_right->get_chars()[0]));
ColumnString::Offset rhs_size = col_right->get_offsets()[0];
StringRef rhs(raw_rhs, rhs_size);
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim>::vector(
col->get_chars(), col->get_offsets(), rhs, col_res->get_chars(),
col_res->get_offsets())));
const auto* remove_str_raw = col_right->get_chars().data();
const ColumnString::Offset remove_str_size = col_right->get_offsets()[0];
const StringRef remove_str(remove_str_raw, remove_str_size);
if (remove_str.size == 1) {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, true>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
} else {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
}
block.replace_by_position(result, std::move(col_res));
} else {
return Status::RuntimeError("Illegal column {} of argument of function {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,7 @@ suite("test_trim_new_parameters") {

rtrim = sql "select rtrim('bcTTTabcabc','abc')"
assertEquals(rtrim[0][0], 'bcTTT')

trim_one = sql "select trim('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa','a')"
assertEquals(trim_one[0][0], 'baaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaab')
}

0 comments on commit f836147

Please sign in to comment.