diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 6723a7a5d473..c570a61e4b49 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -576,7 +576,7 @@ struct AnyEqual { details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); const BytesObjBase* rhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0; + return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); } return false; } diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h index fb7be1a955ba..cfdadff6ea48 100644 --- a/ffi/include/tvm/ffi/base_details.h +++ b/ffi/include/tvm/ffi/base_details.h @@ -181,17 +181,30 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) { const char* it = data; const char* end = it + size; uint64_t result = 0; - for (; it + 8 <= end; it += 8) { - if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { - u.a[0] = it[0]; - u.a[1] = it[1]; - u.a[2] = it[2]; - u.a[3] = it[3]; - u.a[4] = it[4]; - u.a[5] = it[5]; - u.a[6] = it[6]; - u.a[7] = it[7]; + if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { + // if alignment requirement is met, directly use load + if (reinterpret_cast(it) % 8 == 0) { + for (; it + 8 <= end; it += 8) { + u.b = *reinterpret_cast(it); + result = (result * kMultiplier + u.b) % kMod; + } } else { + // unaligned version + for (; it + 8 <= end; it += 8) { + u.a[0] = it[0]; + u.a[1] = it[1]; + u.a[2] = it[2]; + u.a[3] = it[3]; + u.a[4] = it[4]; + u.a[5] = it[5]; + u.a[6] = it[6]; + u.a[7] = it[7]; + result = (result * kMultiplier + u.b) % kMod; + } + } + } else { + // need endian swap + for (; it + 8 <= end; it += 8) { u.a[0] = it[7]; u.a[1] = it[6]; u.a[2] = it[5]; @@ -200,9 +213,10 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) { u.a[5] = it[2]; u.a[6] = it[1]; u.a[7] = it[0]; + result = (result * kMultiplier + u.b) % kMod; } - result = (result * kMultiplier + u.b) % kMod; } + if (it < end) { u.b = 0; uint8_t* a = u.a; diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index ed654e8557e0..e77b27b26831 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -175,7 +175,34 @@ class Bytes : public ObjectRef { * \return int zero if both char sequences compare equal. negative if this * appear before other, positive otherwise. */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { + if (lhs == rhs && lhs_count == rhs_count) return 0; + + for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { + if (lhs[i] < rhs[i]) return -1; + if (lhs[i] > rhs[i]) return 1; + } + if (lhs_count < rhs_count) { + return -1; + } else if (lhs_count > rhs_count) { + return 1; + } else { + return 0; + } + } + /*! + * \brief Compare two char sequence for equality + * + * \param lhs Pointers to the char array to compare + * \param rhs Pointers to the char array to compare + * \param lhs_count Length of the char array to compare + * \param rhs_count Length of the char array to compare + * + * \return true if the two char sequences are equal, false otherwise. + */ + static bool memequal(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { + return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs, lhs_count) == 0); + } private: friend class String; @@ -311,7 +338,18 @@ class String : public ObjectRef { * before other, positive otherwise. */ int compare(const char* other) const { - return Bytes::memncmp(data(), other, size(), std::strlen(other)); + const char* this_data = data(); + size_t this_size = size(); + for (size_t i = 0; i < this_size; ++i) { + // other is shorter than this + if (other[i] == '\0') return 1; + if (this_data[i] < other[i]) return -1; + if (this_data[i] > other[i]) return 1; + } + // other equals this + if (other[this_size] == '\0') return 0; + // other longer than this + return -1; } /*! @@ -616,11 +654,17 @@ inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare( inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } // Overload == operator -inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } +inline bool operator==(const String& lhs, const std::string& rhs) { + return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); +} -inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } +inline bool operator==(const std::string& lhs, const String& rhs) { + return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); +} -inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } +inline bool operator==(const String& lhs, const String& rhs) { + return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); +} inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } @@ -641,22 +685,6 @@ inline std::ostream& operator<<(std::ostream& out, const String& input) { out.write(input.data(), input.size()); return out; } - -inline int Bytes::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { - if (lhs == rhs && lhs_count == rhs_count) return 0; - - for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { - if (lhs[i] < rhs[i]) return -1; - if (lhs[i] > rhs[i]) return 1; - } - if (lhs_count < rhs_count) { - return -1; - } else if (lhs_count > rhs_count) { - return 1; - } else { - return 0; - } -} } // namespace ffi // Expose to the tvm namespace for usability diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc index a74102a95349..d53ac105abe4 100644 --- a/ffi/tests/cpp/test_string.cc +++ b/ffi/tests/cpp/test_string.cc @@ -95,6 +95,24 @@ TEST(String, Comparisons) { EXPECT_EQ(m != s, mismatch != source); } +TEST(String, Compare) { + // string compare const char* + String s{"hello"}; + EXPECT_EQ(s.compare("hello"), 0); + EXPECT_EQ(s.compare(String("hello")), 0); + + EXPECT_EQ(s.compare("hallo"), 1); + EXPECT_EQ(s.compare(String("hallo")), 1); + EXPECT_EQ(s.compare("hfllo"), -1); + EXPECT_EQ(s.compare(String("hfllo")), -1); + // s is longer + EXPECT_EQ(s.compare("hell"), 1); + EXPECT_EQ(s.compare(String("hell")), 1); + // s is shorter + EXPECT_EQ(s.compare("hello world"), -1); + EXPECT_EQ(s.compare(String("helloworld")), -1); +} + // Check '\0' handling TEST(String, null_byte_handling) { using namespace std; @@ -369,4 +387,20 @@ TEST(String, CAPIAccessor) { EXPECT_EQ(arr->size, 5); EXPECT_EQ(std::string(arr->data, arr->size), "hello"); } + +TEST(String, BytesHash) { + std::vector data1(10); + std::vector data2(11); + for (size_t i = 0; i < data1.size(); ++i) { + data1[i] = i; + } + char* data1_ptr = reinterpret_cast(data1.data()); + char* data2_ptr = reinterpret_cast(data2.data()) + 1; + std::memcpy(data2_ptr, data1.data(), data1.size() * sizeof(int64_t)); + // has of aligned and unaligned data should be the same + uint64_t hash1 = details::StableHashBytes(data1_ptr, data1.size() * sizeof(int64_t)); + uint64_t hash2 = details::StableHashBytes(data2_ptr, data1.size() * sizeof(int64_t)); + EXPECT_EQ(hash1, hash2); +} + } // namespace