Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): support latin1/utf16 string encoding in python #1997

Merged
merged 16 commits into from
Jan 7, 2025
Merged
2 changes: 1 addition & 1 deletion cpp/fury/util/string_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,4 +623,4 @@ std::u16string utf8ToUtf16(const std::string &utf8, bool is_little_endian) {

#endif

} // namespace fury
} // namespace fury
66 changes: 65 additions & 1 deletion cpp/fury/util/string_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,78 @@

#pragma once

#include <cstdint>
#include <string>
// AVX not included here since some older intel cpu doesn't support avx2
// but the built wheel for avx2 is same as sse2.
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
pandalee99 marked this conversation as resolved.
Show resolved Hide resolved
#include <arm_neon.h>
#define USE_NEON_SIMD
#elif defined(__SSE2__)
#include <emmintrin.h>
#define USE_SSE2_SIMD
#endif

namespace fury {

bool isLatin(const std::string &str);

static inline bool hasSurrogatePairFallback(const uint16_t *data, size_t size) {
for (size_t i = 0; i < size; ++i) {
auto c = data[i];
if (c >= 0xD800 && c <= 0xDFFF) {
return true;
}
}
return false;
}

#if defined(USE_NEON_SIMD)
inline bool utf16HasSurrogatePairs(const uint16_t *data, size_t length) {
size_t i = 0;
uint16x8_t lower_bound = vdupq_n_u16(0xD800);
uint16x8_t higher_bound = vdupq_n_u16(0xDFFF);
for (; i + 7 < length; i += 8) {
uint16x8_t chunk = vld1q_u16(data + i);
uint16x8_t mask1 = vcgeq_u16(chunk, lower_bound);
uint16x8_t mask2 = vcleq_u16(chunk, higher_bound);
if (vmaxvq_u16(mask1 & mask2)) {
return true; // Detected a high surrogate
}
}
return hasSurrogatePairFallback(data + i, length - i);
}
#elif defined(USE_SSE2_SIMD)
inline bool utf16HasSurrogatePairs(const uint16_t *data, size_t length) {
size_t i = 0;
__m128i lower_bound = _mm_set1_epi16(0xd7ff);
__m128i higher_bound = _mm_set1_epi16(0xe000);
for (; i + 7 < length; i += 8) {
__m128i chunk =
_mm_loadu_si128(reinterpret_cast<const __m128i *>(data + i));
__m128i cmp1 = _mm_cmpgt_epi16(chunk, lower_bound);
__m128i cmp2 = _mm_cmpgt_epi16(higher_bound, chunk);
if (_mm_movemask_epi8(_mm_and_si128(cmp1, cmp2)) != 0) {
return true; // Detected a surrogate
}
}
return hasSurrogatePairFallback(data + i, length - i);
}
#else
inline bool utf16HasSurrogatePairs(const uint16_t *data, size_t length) {
return hasSurrogatePairFallback(data, length);
}
#endif

inline bool utf16HasSurrogatePairs(const std::u16string &str) {
// Get the data pointer
const std::uint16_t *data =
reinterpret_cast<const std::uint16_t *>(str.data());
return utf16HasSurrogatePairs(data, str.size());
}

std::string utf16ToUtf8(const std::u16string &utf16, bool is_little_endian);

std::u16string utf8ToUtf16(const std::string &utf8, bool is_little_endian);

} // namespace fury
} // namespace fury
21 changes: 19 additions & 2 deletions cpp/fury/util/string_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,24 @@ std::u16string generateRandomUTF16String(size_t length) {
return str;
}

// Basic implementation
TEST(StringUtilTest, TestUtf16HasSurrogatePairs) {
EXPECT_FALSE(utf16HasSurrogatePairs(std::u16string({0x99, 0x100})));
std::u16string utf16 = {0xD83D, 0xDE00}; // 😀 emoji
EXPECT_TRUE(utf16HasSurrogatePairs(utf16));
EXPECT_TRUE(utf16HasSurrogatePairs(generateRandomUTF16String(3) + u"性能好"));
EXPECT_TRUE(
utf16HasSurrogatePairs(generateRandomUTF16String(10) + u"性能好"));
EXPECT_TRUE(
utf16HasSurrogatePairs(generateRandomUTF16String(30) + u"性能好"));
EXPECT_TRUE(
utf16HasSurrogatePairs(generateRandomUTF16String(60) + u"性能好"));
EXPECT_TRUE(
utf16HasSurrogatePairs(generateRandomUTF16String(120) + u"性能好"));
EXPECT_TRUE(
utf16HasSurrogatePairs(generateRandomUTF16String(200) + u"性能好"));
EXPECT_TRUE(
utf16HasSurrogatePairs(generateRandomUTF16String(300) + u"性能好"));
}

// Swap bytes to convert from big endian to little endian
inline uint16_t swapBytes(uint16_t value) {
Expand Down Expand Up @@ -542,4 +559,4 @@ TEST(UTF8ToUTF16Test, PerformanceTest) {
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ public void xwrite(MemoryBuffer buffer, String[] value) {
for (String elem : value) {
if (elem != null) {
buffer.writeByte(Fury.NOT_NULL_VALUE_FLAG);
stringSerializer.writeUTF8String(buffer, elem);
stringSerializer.writeString(buffer, elem);
} else {
buffer.writeByte(Fury.NULL_FLAG);
}
Expand All @@ -695,7 +695,7 @@ public String[] xread(MemoryBuffer buffer) {
String[] value = new String[numElements];
for (int i = 0; i < numElements; i++) {
if (buffer.readByte() >= Fury.NOT_NULL_VALUE_FLAG) {
value[i] = stringSerializer.readUTF8String(buffer);
value[i] = stringSerializer.readString(buffer);
} else {
value[i] = null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ public AbstractStringBuilderSerializer(Fury fury, Class<T> type) {

@Override
public void xwrite(MemoryBuffer buffer, T value) {
stringSerializer.writeUTF8String(buffer, value.toString());
stringSerializer.writeString(buffer, value.toString());
}

@Override
Expand Down Expand Up @@ -276,7 +276,7 @@ public StringBuilder read(MemoryBuffer buffer) {

@Override
public StringBuilder xread(MemoryBuffer buffer) {
return new StringBuilder(stringSerializer.readUTF8String(buffer));
return new StringBuilder(stringSerializer.readString(buffer));
}
}

Expand All @@ -299,7 +299,7 @@ public StringBuffer read(MemoryBuffer buffer) {

@Override
public StringBuffer xread(MemoryBuffer buffer) {
return new StringBuffer(stringSerializer.readUTF8String(buffer));
return new StringBuffer(stringSerializer.readString(buffer));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public void write(MemoryBuffer buffer, String value) {

@Override
public void xwrite(MemoryBuffer buffer, String value) {
writeUTF8String(buffer, value);
writeJavaString(buffer, value);
}

@Override
Expand All @@ -131,68 +131,52 @@ public String read(MemoryBuffer buffer) {

@Override
public String xread(MemoryBuffer buffer) {
return readUTF8String(buffer);
return readJavaString(buffer);
}

public void writeString(MemoryBuffer buffer, String value) {
if (isJava) {
writeJavaString(buffer, value);
} else {
writeUTF8String(buffer, value);
}
writeJavaString(buffer, value);
}

public Expression writeStringExpr(Expression strSerializer, Expression buffer, Expression str) {
if (isJava) {
if (STRING_VALUE_FIELD_IS_BYTES) {
if (compressString) {
return new Invoke(strSerializer, "writeCompressedBytesString", buffer, str);
} else {
return new StaticInvoke(StringSerializer.class, "writeBytesString", buffer, str);
}
if (STRING_VALUE_FIELD_IS_BYTES) {
if (compressString) {
return new Invoke(strSerializer, "writeCompressedBytesString", buffer, str);
} else {
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new UnsupportedOperationException();
}
if (compressString) {
return new Invoke(strSerializer, "writeCompressedCharsString", buffer, str);
} else {
return new Invoke(strSerializer, "writeCharsString", buffer, str);
}
return new StaticInvoke(StringSerializer.class, "writeBytesString", buffer, str);
}
} else {
return new Invoke(strSerializer, "writeUTF8String", buffer, str);
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new UnsupportedOperationException();
}
if (compressString) {
return new Invoke(strSerializer, "writeCompressedCharsString", buffer, str);
} else {
return new Invoke(strSerializer, "writeCharsString", buffer, str);
}
}
}

public String readString(MemoryBuffer buffer) {
if (isJava) {
return readJavaString(buffer);
} else {
return readUTF8String(buffer);
}
return readJavaString(buffer);
}

public Expression readStringExpr(Expression strSerializer, Expression buffer) {
if (isJava) {
if (STRING_VALUE_FIELD_IS_BYTES) {
if (compressString) {
return new Invoke(strSerializer, "readCompressedBytesString", STRING_TYPE, buffer);
} else {
return new Invoke(strSerializer, "readBytesString", STRING_TYPE, buffer);
}
if (STRING_VALUE_FIELD_IS_BYTES) {
if (compressString) {
return new Invoke(strSerializer, "readCompressedBytesString", STRING_TYPE, buffer);
} else {
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new UnsupportedOperationException();
}
if (compressString) {
return new Invoke(strSerializer, "readCompressedCharsString", STRING_TYPE, buffer);
} else {
return new Invoke(strSerializer, "readCharsString", STRING_TYPE, buffer);
}
return new Invoke(strSerializer, "readBytesString", STRING_TYPE, buffer);
}
} else {
return new Invoke(strSerializer, "readUTF8String", STRING_TYPE, buffer);
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new UnsupportedOperationException();
}
if (compressString) {
return new Invoke(strSerializer, "readCompressedCharsString", STRING_TYPE, buffer);
} else {
return new Invoke(strSerializer, "readCharsString", STRING_TYPE, buffer);
}
}
}

Expand Down Expand Up @@ -275,13 +259,6 @@ public void writeJavaString(MemoryBuffer buffer, String value) {
}
}

@CodegenInvoke
public void writeUTF8String(MemoryBuffer buffer, String value) {
byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
buffer.writeVarUint32(bytes.length);
buffer.writeBytes(bytes);
}

// Invoked by fury JIT
public String readJavaString(MemoryBuffer buffer) {
if (STRING_VALUE_FIELD_IS_BYTES) {
Expand Down Expand Up @@ -367,24 +344,6 @@ public void writeCharsString(MemoryBuffer buffer, String value) {
}
}

@CodegenInvoke
public String readUTF8String(MemoryBuffer buffer) {
int numBytes = buffer.readVarUint32Small14();
buffer.checkReadableBytes(numBytes);
final byte[] targetArray = buffer.getHeapMemory();
if (targetArray != null) {
String str =
new String(
targetArray, buffer._unsafeHeapReaderIndex(), numBytes, StandardCharsets.UTF_8);
buffer.increaseReaderIndex(numBytes);
return str;
} else {
final byte[] tmpArray = getByteArray(numBytes);
buffer.readBytes(tmpArray, 0, numBytes);
return new String(tmpArray, 0, numBytes, StandardCharsets.UTF_8);
}
}

public char[] readCharsLatin1(MemoryBuffer buffer, int numBytes) {
buffer.checkReadableBytes(numBytes);
byte[] srcArray = buffer.getHeapMemory();
Expand Down
2 changes: 1 addition & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pip install -v -e .

### Environment Requirements

- python 3.6+
- python 3.8+

## Testing

Expand Down
15 changes: 3 additions & 12 deletions python/pyfury/_fury.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ def __init__(
stacklevel=2,
)
self.pickler = Pickler(self.buffer)
self.unpickler = Unpickler(self.buffer)
else:
self.pickler = _PicklerStub(self.buffer)
self.unpickler = None
self.pickler = _PicklerStub()
self.unpickler = _UnpicklerStub()
self._buffer_callback = None
self._buffers = None
self._unsupported_callback = None
Expand Down Expand Up @@ -334,10 +335,6 @@ def _deserialize(
):
if type(buffer) == bytes:
buffer = Buffer(buffer)
if self.require_class_registration:
self.unpickler = _UnpicklerStub(buffer)
else:
self.unpickler = Unpickler(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
if self.language == Language.XLANG:
Expand Down Expand Up @@ -527,9 +524,6 @@ def reset(self):


class _PicklerStub:
def __init__(self, buf):
self.buf = buf

def dump(self, o):
raise ValueError(
f"Class {type(o)} is not registered, "
Expand All @@ -542,9 +536,6 @@ def clear_memo(self):


class _UnpicklerStub:
def __init__(self, buf):
self.buf = buf

def load(self):
raise ValueError(
"pickle is not allowed when class registration enabled, Please register"
Expand Down
9 changes: 5 additions & 4 deletions python/pyfury/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,8 @@ cdef class Fury:
)
self.pickler = Pickler(self.buffer)
else:
self.pickler = _PicklerStub(self.buffer)
self.pickler = _PicklerStub()
self.unpickler = _UnpicklerStub()
self.unpickler = None
self._buffer_callback = None
self._buffers = None
Expand Down Expand Up @@ -815,9 +816,7 @@ cdef class Fury:

cpdef inline _deserialize(
self, Buffer buffer, buffers=None, unsupported_objects=None):
if self.require_class_registration:
self.unpickler = _UnpicklerStub(buffer)
else:
if not self.require_class_registration:
self.unpickler = Unpickler(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
Expand Down Expand Up @@ -955,6 +954,8 @@ cdef class Fury:
cpdef inline handle_unsupported_read(self, Buffer buffer):
cdef c_bool in_band = buffer.read_bool()
if in_band:
if self.unpickler is None:
self.unpickler.buffer = Unpickler(buffer)
return self.unpickler.load()
else:
assert self._unsupported_objects is not None
Expand Down
Loading
Loading