diff --git a/ext/mbstring/mbstring.c b/ext/mbstring/mbstring.c index 964d459543efd..ed5c3721eb748 100644 --- a/ext/mbstring/mbstring.c +++ b/ext/mbstring/mbstring.c @@ -1034,6 +1034,10 @@ static PHP_GSHUTDOWN_FUNCTION(mbstring) } /* }}} */ +#ifdef ZEND_INTRIN_AVX2_FUNC_PTR +static void init_check_utf8(void); +#endif + /* {{{ PHP_MINIT_FUNCTION(mbstring) */ PHP_MINIT_FUNCTION(mbstring) { @@ -1074,6 +1078,10 @@ ZEND_TSRMLS_CACHE_UPDATE(); php_mb_rfc1867_getword_conf, php_mb_rfc1867_basename); +#ifdef ZEND_INTRIN_AVX2_FUNC_PTR + init_check_utf8(); +#endif + return SUCCESS; } /* }}} */ @@ -4605,9 +4613,14 @@ MBSTRING_API bool php_mb_check_encoding(const char *input, size_t length, const return true; } -static bool mb_fast_check_utf8(zend_string *str) +/* If we are building an AVX2-only binary, don't compile the next function */ +#ifndef ZEND_INTRIN_AVX2_NATIVE + +/* SSE2-based function for validating UTF-8 strings + * A faster implementation which uses AVX2 instructions follows */ +static bool mb_fast_check_utf8_default(zend_string *str) { -#ifdef __SSE2__ +# ifdef __SSE2__ unsigned char *p = (unsigned char*)ZSTR_VAL(str); /* `e` points 1 byte past the last full 16-byte block of string content * Note that we include the terminating null byte which is included in each zend_string @@ -4723,13 +4736,14 @@ static bool mb_fast_check_utf8(zend_string *str) p += sizeof(__m128i); } -finish_up_remaining_bytes: ; +finish_up_remaining_bytes: /* Finish up 1-15 remaining bytes */ if (p == e) { uint8_t remaining_bytes = ZSTR_LEN(str) & (sizeof(__m128i) - 1); /* Not including terminating null */ - /* Crazy hack here... we want to use the above vectorized code to check a block of less than 16 - * bytes, but there is no good way to read a variable number of bytes into an XMM register + /* Crazy hack here for cases where 9 or more bytes are remaining... + * We want to use the above vectorized code to check a block of less than 16 bytes, + * but there is no good way to read a variable number of bytes into an XMM register * However, we know that these bytes are part of a zend_string, and a zend_string has some * 'header' fields which occupy the memory just before its content * And, those header fields occupy more than 16 bytes... @@ -4744,20 +4758,17 @@ finish_up_remaining_bytes: ; * shift distance, so the compiler will choke on _mm_srli_si128(operand, shift_dist) */ switch (remaining_bytes) { - case 0: - operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 15)), 15); - goto check_operand; + case 0: ; + __m128i bad_mask = _mm_set_epi8(-64, -32, -16, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m128i bad = _mm_cmpeq_epi8(_mm_and_si128(last_block, bad_mask), bad_mask); + return _mm_movemask_epi8(bad) == 0; case 1: - operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 14)), 14); - goto check_operand; case 2: - operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 13)), 13); + operand = _mm_set_epi16(0, 0, 0, 0, 0, 0, 0, *((uint16_t*)p)); goto check_operand; case 3: - operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 12)), 12); - goto check_operand; case 4: - operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 11)), 11); + operand = _mm_set_epi32(0, 0, 0, *((uint32_t*)p)); goto check_operand; case 5: operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 10)), 10); @@ -4766,10 +4777,8 @@ finish_up_remaining_bytes: ; operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 9)), 9); goto check_operand; case 7: - operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 8)), 8); - goto check_operand; case 8: - operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 7)), 7); + operand = _mm_set_epi64x(0, *((uint64_t*)p)); goto check_operand; case 9: operand = _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 6)), 6); @@ -4800,11 +4809,366 @@ finish_up_remaining_bytes: ; } return true; -#else +# else + /* No SSE2 support; we might add generic UTF-8 specific validation code here later */ return php_mb_check_encoding(ZSTR_VAL(str), ZSTR_LEN(str), &mbfl_encoding_utf8); +# endif +} + +#endif /* #ifndef ZEND_INTRIN_AVX2_NATIVE */ + +#ifdef ZEND_INTRIN_AVX2_NATIVE + +/* We are building AVX2-only binary */ +# include +# define mb_fast_check_utf8 mb_fast_check_utf8_avx2 + +#elif defined(ZEND_INTRIN_AVX2_RESOLVER) + +/* We are building binary which works with or without AVX2; whether or not to use + * AVX2-accelerated functions will be determined at runtime */ +# include +# include "Zend/zend_cpuinfo.h" + +# ifdef ZEND_INTRIN_AVX2_FUNC_PROTO +/* Dynamic linker will decide whether or not to use AVX2-based functions and + * resolve symbols accordingly */ + +ZEND_INTRIN_AVX2_FUNC_DECL(bool mb_fast_check_utf8_avx2(zend_string *str)); + +bool mb_fast_check_utf8(zend_string *str) __attribute__((ifunc("resolve_check_utf8"))); + +typedef bool (*check_utf8_func_t)(zend_string*); + +ZEND_NO_SANITIZE_ADDRESS +ZEND_ATTRIBUTE_UNUSED +static check_utf8_func_t resolve_check_utf8(void) +{ + if (zend_cpu_supports_avx2()) { + return mb_fast_check_utf8_avx2; + } + return mb_fast_check_utf8_default; +} + +# else /* ZEND_INTRIN_AVX2_FUNC_PTR */ +/* We are compiling for a target where the dynamic linker will not be able to + * resolve symbols according to whether the host supports AVX2 or not; so instead, + * we can make calls go through a function pointer and set the function pointer + * on module load */ + +#ifdef HAVE_FUNC_ATTRIBUTE_TARGET +static bool mb_fast_check_utf8_avx2(zend_string *str) __attribute__((target("avx2"))); +#else +static bool mb_fast_check_utf8_avx2(zend_string *str); #endif + +static bool (*check_utf8_ptr)(zend_string *str) = NULL; + +static bool mb_fast_check_utf8(zend_string *str) +{ + return check_utf8_ptr(str); } +static void init_check_utf8(void) +{ + if (zend_cpu_supports_avx2()) { + check_utf8_ptr = mb_fast_check_utf8_avx2; + } else { + check_utf8_ptr = mb_fast_check_utf8_default; + } +} +# endif + +#else + +/* No AVX2 support */ +#define mb_fast_check_utf8 mb_fast_check_utf8_default + +#endif + +#if defined(ZEND_INTRIN_AVX2_NATIVE) || defined(ZEND_INTRIN_AVX2_RESOLVER) + +/* Take (256-bit) `hi` and `lo` as a 512-bit value, shift down by some + * number of bytes, then take the low 256 bits + * This is used to take some number of trailing bytes from the previous 32-byte + * block followed by some number of leading bytes from the current 32-byte block + * + * _mm256_alignr_epi8 (VPALIGNR) is used to shift out bytes from a 256-bit + * YMM register while shifting in bytes from another YMM register... but + * it works separately on respective 128-bit halves of the YMM registers, + * which is not what we want. + * To make it work as desired, we first do _mm256_permute2x128_si256 + * (VPERM2I128) to combine the low 128 bits from the previous block and + * the high 128 bits of the current block in one YMM register. + * Then VPALIGNR will do what is needed. */ +#define _mm256_shift_epi8(hi, lo, shift) _mm256_alignr_epi8(lo, _mm256_permute2x128_si256(hi, lo, 33), 16 - shift) + +/* AVX2-based UTF-8 validation function; validates text in 32-byte chunks + * + * Some parts of this function are the same as `mb_fast_check_utf8`; code comments + * are not repeated, so consult `mb_fast_check_utf8` for information on uncommented + * sections. */ +#ifdef ZEND_INTRIN_AVX2_FUNC_PROTO +ZEND_API bool mb_fast_check_utf8_avx2(zend_string *str) +#else +static bool mb_fast_check_utf8_avx2(zend_string *str) +#endif +{ + unsigned char *p = (unsigned char*)ZSTR_VAL(str); + unsigned char *e = p + ((ZSTR_LEN(str) + 1) & ~(sizeof(__m256i) - 1)); + + /* The algorithm used here for UTF-8 validation is partially adapted from the + * paper "Validating UTF-8 In Less Than One Instruction Per Byte", by John Keiser + * and Daniel Lemire. + * Ref: https://arxiv.org/pdf/2010.03090.pdf + * + * Most types of invalid UTF-8 text can be detected by examining pairs of + * successive bytes. Specifically: + * + * • Overlong 2-byte code units start with 0xC0 or 0xC1. + * No valid UTF-8 string ever uses these byte values. + * • Overlong 3-byte code units start with 0xE0, followed by a byte < 0xA0. + * • Overlong 4-byte code units start with 0xF0, followed by a byte < 0x90. + * • 5-byte or 6-byte code units, which should never be used, start with + * 0xF8-FE. + * • A codepoint value higher than U+10FFFF, which is the highest value for + * any Unicode codepoint, would either start with 0xF4, followed by a + * byte >= 0x90, or else would start with 0xF5-F7, followed by any value. + * • A codepoint value from U+D800-DFFF, which are reserved and should never + * be used, would start with 0xED, followed by a byte >= 0xA0. + * • The byte value 0xFF is also illegal and is never used in valid UTF-8. + * + * To detect all these problems, for each pair of successive bytes, we do + * table lookups using the high nibble of the first byte, the low nibble of + * the first byte, and the high nibble of the second byte. Each table lookup + * retrieves a bitmask, in which each 1 bit indicates a possible invalid + * combination; AND those three bitmasks together, and any 1 bit in the result + * will indicate an actual invalid byte combination was found. + */ + +#define BAD_BYTE 0x1 +#define OVERLONG_2BYTE 0x2 +#define _1BYTE (BAD_BYTE | OVERLONG_2BYTE) +#define OVERLONG_3BYTE 0x4 +#define SURROGATE 0x8 +#define OVERLONG_4BYTE 0x10 +#define INVALID_CP 0x20 + + /* Each of these are 16-entry tables, repeated twice; this is required by the + * VPSHUFB instruction which we use to perform 32 table lookups in parallel + * The first entry is for 0xF, the second is for 0xE, and so on down to 0x0 + * + * So, for example, notice that the 4th entry in the 1st table is OVERLONG_2BYTE; + * that means that high nibble 0xC is consistent with the byte pair being part of + * an overlong 2-byte code unit */ + const __m256i bad_hi_nibble2 = _mm256_set_epi8( + BAD_BYTE | OVERLONG_4BYTE | INVALID_CP, OVERLONG_3BYTE | SURROGATE, 0, OVERLONG_2BYTE, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + BAD_BYTE | OVERLONG_4BYTE | INVALID_CP, OVERLONG_3BYTE | SURROGATE, 0, OVERLONG_2BYTE, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0); + const __m256i bad_lo_nibble2 = _mm256_set_epi8( + BAD_BYTE, BAD_BYTE, BAD_BYTE | SURROGATE, BAD_BYTE, + BAD_BYTE, BAD_BYTE, BAD_BYTE, BAD_BYTE, + BAD_BYTE, BAD_BYTE, BAD_BYTE, INVALID_CP, + 0, 0, OVERLONG_2BYTE, OVERLONG_2BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + BAD_BYTE, BAD_BYTE, BAD_BYTE | SURROGATE, BAD_BYTE, + BAD_BYTE, BAD_BYTE, BAD_BYTE, BAD_BYTE, + BAD_BYTE, BAD_BYTE, BAD_BYTE, INVALID_CP, + 0, 0, OVERLONG_2BYTE, OVERLONG_2BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE); + const __m256i bad_hi_nibble = _mm256_set_epi8( + _1BYTE | SURROGATE | INVALID_CP, _1BYTE | SURROGATE | INVALID_CP, + _1BYTE | SURROGATE | INVALID_CP, _1BYTE | SURROGATE | INVALID_CP, + _1BYTE | SURROGATE | INVALID_CP, _1BYTE | SURROGATE | INVALID_CP, + _1BYTE | OVERLONG_3BYTE | INVALID_CP, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | SURROGATE | INVALID_CP, _1BYTE | SURROGATE | INVALID_CP, + _1BYTE | SURROGATE | INVALID_CP, _1BYTE | SURROGATE | INVALID_CP, + _1BYTE | SURROGATE | INVALID_CP, _1BYTE | SURROGATE | INVALID_CP, + _1BYTE | OVERLONG_3BYTE | INVALID_CP, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, + _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE, _1BYTE | OVERLONG_3BYTE | OVERLONG_4BYTE); + + const __m256i find_continuation = _mm256_set1_epi8(-64); + const __m256i _b = _mm256_set1_epi8(0xB); + const __m256i _d = _mm256_set1_epi8(0xD); + const __m256i _f = _mm256_set1_epi8(0xF); + + __m256i last_hi_nibbles = _mm256_setzero_si256(), last_lo_nibbles = _mm256_setzero_si256(); + __m256i operand; + + while (p < e) { + operand = _mm256_loadu_si256((__m256i*)p); + +check_operand: + if (!_mm256_movemask_epi8(operand)) { + /* Entire 32-byte block is ASCII characters; the only thing we need to validate is that + * the previous block didn't end with an incomplete multi-byte character + * (This will also confirm that the previous block didn't end with a bad byte like 0xFF) */ + __m256i bad_mask = _mm256_set_epi8(0xB, 0xD, 0xE, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127); + __m256i bad = _mm256_cmpgt_epi8(last_hi_nibbles, bad_mask); + if (_mm256_movemask_epi8(bad)) { + return false; + } + + /* Consume as many full blocks of single-byte characters as we can */ + while (true) { + p += sizeof(__m256i); + if (p >= e) { + goto finish_up_remaining_bytes; + } + operand = _mm256_loadu_si256((__m256i*)p); + if (_mm256_movemask_epi8(operand)) { + break; + } + } + } + + __m256i hi_nibbles = _mm256_and_si256(_mm256_srli_epi16(operand, 4), _f); + __m256i lo_nibbles = _mm256_and_si256(operand, _f); + + __m256i lo_nibbles2 = _mm256_shift_epi8(last_lo_nibbles, lo_nibbles, 1); + __m256i hi_nibbles2 = _mm256_shift_epi8(last_hi_nibbles, hi_nibbles, 1); + + /* Do parallel table lookups in all 3 tables */ + __m256i bad = _mm256_cmpgt_epi8( + _mm256_and_si256( + _mm256_and_si256( + _mm256_shuffle_epi8(bad_lo_nibble2, lo_nibbles2), + _mm256_shuffle_epi8(bad_hi_nibble2, hi_nibbles2)), + _mm256_shuffle_epi8(bad_hi_nibble, hi_nibbles)), + _mm256_setzero_si256()); + + __m256i cont_mask = _mm256_cmpgt_epi8(hi_nibbles2, _b); + __m256i hi_nibbles3 = _mm256_shift_epi8(last_hi_nibbles, hi_nibbles, 2); + cont_mask = _mm256_or_si256(cont_mask, _mm256_cmpgt_epi8(hi_nibbles3, _d)); + __m256i hi_nibbles4 = _mm256_shift_epi8(last_hi_nibbles, hi_nibbles, 3); + cont_mask = _mm256_or_si256(cont_mask, _mm256_cmpeq_epi8(hi_nibbles4, _f)); + + __m256i continuation = _mm256_cmpgt_epi8(find_continuation, operand); + bad = _mm256_or_si256(bad, _mm256_xor_si256(continuation, cont_mask)); + + if (_mm256_movemask_epi8(bad)) { + return false; + } + + last_hi_nibbles = hi_nibbles; + last_lo_nibbles = lo_nibbles; + p += sizeof(__m256i); + } + +finish_up_remaining_bytes: + if (p == e) { + uint8_t remaining_bytes = ZSTR_LEN(str) & (sizeof(__m256i) - 1); /* Not including terminating null */ + + switch (remaining_bytes) { + case 0: ; + /* No actual data bytes are remaining */ + __m256i bad_mask = _mm256_set_epi8(0xB, 0xD, 0xE, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127); + __m256i bad = _mm256_cmpgt_epi8(last_hi_nibbles, bad_mask); + return _mm256_movemask_epi8(bad) == 0; + case 1: + case 2: + operand = _mm256_set_epi16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, *((int16_t*)p)); + goto check_operand; + case 3: + case 4: + operand = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, *((int32_t*)p)); + goto check_operand; + case 5: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 10)), 10)); + goto check_operand; + case 6: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 9)), 9)); + goto check_operand; + case 7: + case 8: + operand = _mm256_set_epi64x(0, 0, 0, *((int64_t*)p)); + goto check_operand; + case 9: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 6)), 6)); + goto check_operand; + case 10: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 5)), 5)); + goto check_operand; + case 11: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 4)), 4)); + goto check_operand; + case 12: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 3)), 3)); + goto check_operand; + case 13: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 2)), 2)); + goto check_operand; + case 14: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_srli_si128(_mm_loadu_si128((__m128i*)(p - 1)), 1)); + goto check_operand; + case 15: + case 16: + operand = _mm256_set_m128i(_mm_setzero_si128(), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 17: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 2)), 14), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 18: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 3)), 13), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 19: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 4)), 12), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 20: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 5)), 11), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 21: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 6)), 10), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 22: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 7)), 9), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 23: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 8)), 8), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 24: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 9)), 7), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 25: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 10)), 6), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 26: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 11)), 5), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 27: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 12)), 4), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 28: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 13)), 3), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 29: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 14)), 2), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 30: + operand = _mm256_set_m128i(_mm_srli_si128(_mm_loadu_si128((__m128i*)(p + 15)), 1), _mm_loadu_si128((__m128i*)p)); + goto check_operand; + case 31: + return true; + } + + ZEND_UNREACHABLE(); + } + + return true; +} + +#endif /* defined(ZEND_INTRIN_AVX2_NATIVE) || defined(ZEND_INTRIN_AVX2_RESOLVER) */ + static bool mb_check_str_encoding(zend_string *str, const mbfl_encoding *encoding) { if (encoding == &mbfl_encoding_utf8) {