diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index d6c56c8112b0..76b2901c7aab 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -59,13 +59,13 @@ set(tvm_ffi_objs_sources "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc" ) if (TVM_FFI_USE_EXTRA_CXX_API) list(APPEND tvm_ffi_objs_sources - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_equal.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc" ) endif() diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 60743b82c67e..d99832af01f3 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -56,27 +56,6 @@ #define TVM_FFI_DLL_EXPORT __attribute__((visibility("default"))) #endif -/*! - * \brief Marks the API as extra c++ api that is defined in cc files. - * - * These APIs are extra features that depend on, but are not required to - * support essential core functionality, such as function calling and object - * access. - * - * They are implemented in cc files to reduce compile-time overhead. - * The input/output only uses POD/Any/ObjectRef for ABI stability. - * However, these extra APIs may have an issue across MSVC/Itanium ABI, - * - * Related features are also available through reflection based function - * that is fully based on C API - * - * The project aims to minimize the number of extra C++ APIs and only - * restrict the use to non-core functionalities. - */ -#ifndef TVM_FFI_EXTRA_CXX_API -#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL -#endif - #ifdef __cplusplus extern "C" { #endif diff --git a/ffi/include/tvm/ffi/extra/base.h b/ffi/include/tvm/ffi/extra/base.h new file mode 100644 index 000000000000..b09b3540a83e --- /dev/null +++ b/ffi/include/tvm/ffi/extra/base.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/base.h + * \brief Base header for Extra API. + * + * The extra APIs contains a minmal set of extra APIs that are not + * required to support essential core functionality. + */ +#ifndef TVM_FFI_EXTRA_BASE_H_ +#define TVM_FFI_EXTRA_BASE_H_ + +#include + +/*! + * \brief Marks the API as extra c++ api that is defined in cc files. + * + * They are implemented in cc files to reduce compile-time overhead. + * The input/output only uses POD/Any/ObjectRef for ABI stability. + * However, these extra APIs may have an issue across MSVC/Itanium ABI, + * + * Related features are also available through reflection based function + * that is fully based on C API + * + * The project aims to minimize the number of extra C++ APIs to keep things + * lightweight and restrict the use to non-core functionalities. + */ +#ifndef TVM_FFI_EXTRA_CXX_API +#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL +#endif + +#endif // TVM_FFI_EXTRA_BASE_H_ diff --git a/ffi/include/tvm/ffi/reflection/structural_equal.h b/ffi/include/tvm/ffi/extra/structural_equal.h similarity index 90% rename from ffi/include/tvm/ffi/reflection/structural_equal.h rename to ffi/include/tvm/ffi/extra/structural_equal.h index 860222644c95..9727940297ed 100644 --- a/ffi/include/tvm/ffi/reflection/structural_equal.h +++ b/ffi/include/tvm/ffi/extra/structural_equal.h @@ -17,19 +17,19 @@ * under the License. */ /*! - * \file tvm/ffi/reflection/structural_equal.h + * \file tvm/ffi/extra/structural_equal.h * \brief Structural equal implementation */ -#ifndef TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_ -#define TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_ +#ifndef TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ +#define TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ #include +#include #include #include namespace tvm { namespace ffi { -namespace reflection { /* * \brief Structural equality comparators */ @@ -59,7 +59,7 @@ class StructuralEqual { * \return If comparison fails, return the first mismatch AccessPath pair, * otherwise return std::nullopt. */ - TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( + TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( const Any& lhs, const Any& rhs, bool map_free_vars = false, bool skip_ndarray_content = false); @@ -74,7 +74,6 @@ class StructuralEqual { } }; -} // namespace reflection } // namespace ffi } // namespace tvm -#endif // TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_ +#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ diff --git a/ffi/include/tvm/ffi/reflection/structural_hash.h b/ffi/include/tvm/ffi/extra/structural_hash.h similarity index 87% rename from ffi/include/tvm/ffi/reflection/structural_hash.h rename to ffi/include/tvm/ffi/extra/structural_hash.h index b0d17cf8bfbc..9cb08a1c0fc8 100644 --- a/ffi/include/tvm/ffi/reflection/structural_hash.h +++ b/ffi/include/tvm/ffi/extra/structural_hash.h @@ -17,17 +17,17 @@ * under the License. */ /*! - * \file tvm/ffi/reflection/structural_hash.h + * \file tvm/ffi/extra/structural_hash.h * \brief Structural hash */ -#ifndef TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_ -#define TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_ +#ifndef TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ +#define TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ #include +#include namespace tvm { namespace ffi { -namespace reflection { /* * \brief Structural hash @@ -52,7 +52,6 @@ class StructuralHash { TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return Hash(value); } }; -} // namespace reflection } // namespace ffi } // namespace tvm -#endif // TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_ +#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ diff --git a/ffi/include/tvm/ffi/reflection/access_path.h b/ffi/include/tvm/ffi/reflection/access_path.h index e37b3f410cbc..a4f40f485ebd 100644 --- a/ffi/include/tvm/ffi/reflection/access_path.h +++ b/ffi/include/tvm/ffi/reflection/access_path.h @@ -35,12 +35,12 @@ namespace reflection { enum class AccessKind : int32_t { kObjectField = 0, - kArrayIndex = 1, - kMapKey = 2, + kArrayItem = 1, + kMapItem = 2, // the following two are used for error reporting when // the supposed access field is not available - kArrayIndexMissing = 3, - kMapKeyMissing = 4, + kArrayItemMissing = 3, + kMapItemMissing = 4, }; /*! @@ -86,15 +86,15 @@ class AccessStep : public ObjectRef { return AccessStep(AccessKind::kObjectField, field_name); } - static AccessStep ArrayIndex(int64_t index) { return AccessStep(AccessKind::kArrayIndex, index); } + static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); } - static AccessStep ArrayIndexMissing(int64_t index) { - return AccessStep(AccessKind::kArrayIndexMissing, index); + static AccessStep ArrayItemMissing(int64_t index) { + return AccessStep(AccessKind::kArrayItemMissing, index); } - static AccessStep MapKey(Any key) { return AccessStep(AccessKind::kMapKey, key); } + static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, key); } - static AccessStep MapKeyMissing(Any key) { return AccessStep(AccessKind::kMapKeyMissing, key); } + static AccessStep MapItemMissing(Any key) { return AccessStep(AccessKind::kMapItemMissing, key); } TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj); }; diff --git a/ffi/src/ffi/reflection/structural_equal.cc b/ffi/src/ffi/extra/structural_equal.cc similarity index 83% rename from ffi/src/ffi/reflection/structural_equal.cc rename to ffi/src/ffi/extra/structural_equal.cc index e44a0c3256f3..a73c07713f1e 100644 --- a/ffi/src/ffi/reflection/structural_equal.cc +++ b/ffi/src/ffi/extra/structural_equal.cc @@ -25,8 +25,8 @@ #include #include #include +#include #include -#include #include #include @@ -34,7 +34,6 @@ namespace tvm { namespace ffi { -namespace reflection { /** * \brief Internal Handler class for structural equal comparison. @@ -135,11 +134,11 @@ class StructEqualHandler { bool success = true; if (custom_s_equal[type_info->type_index] == nullptr) { // We recursively compare the fields the object - ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) { + reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) { // skip fields that are marked as structural eq hash ignore if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) return false; // get the field value from both side - FieldGetter getter(field_info); + reflection::FieldGetter getter(field_info); Any lhs_value = getter(lhs); Any rhs_value = getter(rhs); // field is in def region, enable free var mapping @@ -155,9 +154,9 @@ class StructEqualHandler { // record the first mismatching field if we sub-rountine compare failed if (mismatch_lhs_reverse_path_ != nullptr) { mismatch_lhs_reverse_path_->emplace_back( - AccessStep::ObjectField(String(field_info->name))); + reflection::AccessStep::ObjectField(String(field_info->name))); mismatch_rhs_reverse_path_->emplace_back( - AccessStep::ObjectField(String(field_info->name))); + reflection::AccessStep::ObjectField(String(field_info->name))); } // return true to indicate early stop return true; @@ -185,8 +184,10 @@ class StructEqualHandler { if (!success) { if (mismatch_lhs_reverse_path_ != nullptr) { String field_name_str = field_name.cast(); - mismatch_lhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str)); - mismatch_rhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str)); + mismatch_lhs_reverse_path_->emplace_back( + reflection::AccessStep::ObjectField(field_name_str)); + mismatch_rhs_reverse_path_->emplace_back( + reflection::AccessStep::ObjectField(field_name_str)); } } return success; @@ -235,16 +236,16 @@ class StructEqualHandler { auto it = rhs.find(rhs_key); if (it == rhs.end()) { if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first)); - mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(rhs_key)); + mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); + mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(rhs_key)); } return false; } // now recursively compare value if (!CompareAny(kv.second, (*it).second)) { if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first)); - mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(rhs_key)); + mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); + mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(rhs_key)); } return false; } @@ -258,8 +259,8 @@ class StructEqualHandler { auto it = lhs.find(lhs_key); if (it == lhs.end()) { if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(lhs_key)); - mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first)); + mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(lhs_key)); + mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); } return false; } @@ -276,8 +277,8 @@ class StructEqualHandler { for (size_t i = 0; i < std::min(lhs.size(), rhs.size()); ++i) { if (!CompareAny(lhs[i], rhs[i])) { if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i)); - mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i)); + mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i)); + mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i)); } return false; } @@ -285,11 +286,13 @@ class StructEqualHandler { if (lhs.size() == rhs.size()) return true; if (mismatch_lhs_reverse_path_ != nullptr) { if (lhs.size() > rhs.size()) { - mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(rhs.size())); - mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(rhs.size())); + mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(rhs.size())); + mismatch_rhs_reverse_path_->emplace_back( + reflection::AccessStep::ArrayItemMissing(rhs.size())); } else { - mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(lhs.size())); - mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(lhs.size())); + mismatch_lhs_reverse_path_->emplace_back( + reflection::AccessStep::ArrayItemMissing(lhs.size())); + mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(lhs.size())); } } return false; @@ -354,8 +357,8 @@ class StructEqualHandler { // whether we compare ndarray data bool skip_ndarray_content_{false}; // the root lhs for result printing - std::vector* mismatch_lhs_reverse_path_ = nullptr; - std::vector* mismatch_rhs_reverse_path_ = nullptr; + std::vector* mismatch_lhs_reverse_path_ = nullptr; + std::vector* mismatch_rhs_reverse_path_ = nullptr; // lazily initialize custom equal function ffi::Function s_equal_callback_ = nullptr; // map from lhs to rhs @@ -372,32 +375,31 @@ bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars, return handler.CompareAny(lhs, rhs); } -Optional StructuralEqual::GetFirstMismatch(const Any& lhs, const Any& rhs, - bool map_free_vars, - bool skip_ndarray_content) { +Optional StructuralEqual::GetFirstMismatch(const Any& lhs, + const Any& rhs, + bool map_free_vars, + bool skip_ndarray_content) { StructEqualHandler handler; handler.map_free_vars_ = map_free_vars; handler.skip_ndarray_content_ = skip_ndarray_content; - std::vector lhs_reverse_path; - std::vector rhs_reverse_path; + std::vector lhs_reverse_path; + std::vector rhs_reverse_path; handler.mismatch_lhs_reverse_path_ = &lhs_reverse_path; handler.mismatch_rhs_reverse_path_ = &rhs_reverse_path; if (handler.CompareAny(lhs, rhs)) { return std::nullopt; } - AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend()); - AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend()); - return AccessPathPair(lhs_path, rhs_path); + reflection::AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend()); + reflection::AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend()); + return reflection::AccessPathPair(lhs_path, rhs_path); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.reflection.GetFirstStructuralMismatch", - StructuralEqual::GetFirstMismatch); + refl::GlobalDef().def("ffi.GetFirstStructuralMismatch", StructuralEqual::GetFirstMismatch); // ensure the type attribute column is presented in the system even if it is empty. refl::EnsureTypeAttrColumn("__s_equal__"); }); -} // namespace reflection } // namespace ffi } // namespace tvm diff --git a/ffi/src/ffi/reflection/structural_hash.cc b/ffi/src/ffi/extra/structural_hash.cc similarity index 97% rename from ffi/src/ffi/reflection/structural_hash.cc rename to ffi/src/ffi/extra/structural_hash.cc index e8ffcf6d2a72..e47fbbacc806 100644 --- a/ffi/src/ffi/reflection/structural_hash.cc +++ b/ffi/src/ffi/extra/structural_hash.cc @@ -25,9 +25,9 @@ #include #include #include +#include #include #include -#include #include #include @@ -37,7 +37,6 @@ namespace tvm { namespace ffi { -namespace reflection { /** * \brief Internal Handler class for structural hash. */ @@ -119,11 +118,11 @@ class StructuralHashHandler { uint64_t hash_value = obj->GetTypeKeyHash(); if (custom_s_hash[type_info->type_index] == nullptr) { // go over the content and hash the fields - ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { // skip fields that are marked as structural eq hash ignore if (!(field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore)) { // get the field value from both side - FieldGetter getter(field_info); + reflection::FieldGetter getter(field_info); Any field_value = getter(obj); // field is in def region, enable free var mapping if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { @@ -297,10 +296,9 @@ uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool skip_nd TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.reflection.StructuralHash", StructuralHash::Hash); + refl::GlobalDef().def("ffi.StructuralHash", StructuralHash::Hash); refl::EnsureTypeAttrColumn("__s_hash__"); }); -} // namespace reflection } // namespace ffi } // namespace tvm diff --git a/ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc b/ffi/tests/cpp/extra/test_structural_equal_hash.cc similarity index 70% rename from ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc rename to ffi/tests/cpp/extra/test_structural_equal_hash.cc index d3353b782d33..76c485d9062e 100644 --- a/ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc +++ b/ffi/tests/cpp/extra/test_structural_equal_hash.cc @@ -20,10 +20,10 @@ #include #include #include +#include +#include #include #include -#include -#include #include #include "../testing_object.h" @@ -37,18 +37,18 @@ namespace refl = tvm::ffi::reflection; TEST(StructuralEqualHash, Array) { Array a = {1, 2, 3}; Array b = {1, 2, 3}; - EXPECT_TRUE(refl::StructuralEqual()(a, b)); - EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b)); + EXPECT_TRUE(StructuralEqual()(a, b)); + EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); Array c = {1, 3}; - EXPECT_FALSE(refl::StructuralEqual()(a, c)); - EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c)); - auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c); + EXPECT_FALSE(StructuralEqual()(a, c)); + EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); + auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); // first directly interepret diff, EXPECT_TRUE(diff_a_c.has_value()); - EXPECT_EQ((*diff_a_c).get<0>()[0]->kind, refl::AccessKind::kArrayIndex); - EXPECT_EQ((*diff_a_c).get<1>()[0]->kind, refl::AccessKind::kArrayIndex); + EXPECT_EQ((*diff_a_c).get<0>()[0]->kind, refl::AccessKind::kArrayItem); + EXPECT_EQ((*diff_a_c).get<1>()[0]->kind, refl::AccessKind::kArrayItem); EXPECT_EQ((*diff_a_c).get<0>()[0]->key.cast(), 1); EXPECT_EQ((*diff_a_c).get<1>()[0]->key.cast(), 1); EXPECT_EQ((*diff_a_c).get<0>().size(), 1); @@ -57,90 +57,90 @@ TEST(StructuralEqualHash, Array) { // use structural equal for checking in future parts // given we have done some basic checks above by directly interepret diff, Array d = {1, 2}; - auto diff_a_d = refl::StructuralEqual::GetFirstMismatch(a, d); + auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d); auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::ArrayIndex(2), + refl::AccessStep::ArrayItem(2), }), refl::AccessPath({ - refl::AccessStep::ArrayIndexMissing(2), + refl::AccessStep::ArrayItemMissing(2), })); // then use structural equal to check it - EXPECT_TRUE(refl::StructuralEqual()(diff_a_d, expected_diff_a_d)); + EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d)); } TEST(StructuralEqualHash, Map) { // same map but different insertion order Map a = {{"a", 1}, {"b", 2}, {"c", 3}}; Map b = {{"b", 2}, {"c", 3}, {"a", 1}}; - EXPECT_TRUE(refl::StructuralEqual()(a, b)); - EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b)); + EXPECT_TRUE(StructuralEqual()(a, b)); + EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); Map c = {{"a", 1}, {"b", 2}, {"c", 4}}; - EXPECT_FALSE(refl::StructuralEqual()(a, c)); - EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c)); + EXPECT_FALSE(StructuralEqual()(a, c)); + EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); - auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c); + auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::MapKey("c"), + refl::AccessStep::MapItem("c"), }), refl::AccessPath({ - refl::AccessStep::MapKey("c"), + refl::AccessStep::MapItem("c"), })); EXPECT_TRUE(diff_a_c.has_value()); - EXPECT_TRUE(refl::StructuralEqual()(diff_a_c, expected_diff_a_c)); + EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c)); } TEST(StructuralEqualHash, NestedMapArray) { Map> a = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}}; Map> b = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}}; - EXPECT_TRUE(refl::StructuralEqual()(a, b)); - EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b)); + EXPECT_TRUE(StructuralEqual()(a, b)); + EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); Map> c = {{"a", {1, 2, 3}}, {"b", {4, "world", 6}}}; - EXPECT_FALSE(refl::StructuralEqual()(a, c)); - EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c)); + EXPECT_FALSE(StructuralEqual()(a, c)); + EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); - auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c); + auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::MapKey("b"), - refl::AccessStep::ArrayIndex(1), + refl::AccessStep::MapItem("b"), + refl::AccessStep::ArrayItem(1), }), refl::AccessPath({ - refl::AccessStep::MapKey("b"), - refl::AccessStep::ArrayIndex(1), + refl::AccessStep::MapItem("b"), + refl::AccessStep::ArrayItem(1), })); EXPECT_TRUE(diff_a_c.has_value()); - EXPECT_TRUE(refl::StructuralEqual()(diff_a_c, expected_diff_a_c)); + EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c)); Map> d = {{"a", {1, 2, 3}}}; - auto diff_a_d = refl::StructuralEqual::GetFirstMismatch(a, d); + auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d); auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::MapKey("b"), + refl::AccessStep::MapItem("b"), }), refl::AccessPath({ - refl::AccessStep::MapKeyMissing("b"), + refl::AccessStep::MapItemMissing("b"), })); EXPECT_TRUE(diff_a_d.has_value()); - EXPECT_TRUE(refl::StructuralEqual()(diff_a_d, expected_diff_a_d)); + EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d)); - auto diff_d_a = refl::StructuralEqual::GetFirstMismatch(d, a); + auto diff_d_a = StructuralEqual::GetFirstMismatch(d, a); auto expected_diff_d_a = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::MapKeyMissing("b"), + refl::AccessStep::MapItemMissing("b"), }), refl::AccessPath({ - refl::AccessStep::MapKey("b"), + refl::AccessStep::MapItem("b"), })); } TEST(StructuralEqualHash, FreeVar) { TVar a = TVar("a"); TVar b = TVar("b"); - EXPECT_TRUE(refl::StructuralEqual::Equal(a, b, /*map_free_vars=*/true)); - EXPECT_FALSE(refl::StructuralEqual::Equal(a, b)); + EXPECT_TRUE(StructuralEqual::Equal(a, b, /*map_free_vars=*/true)); + EXPECT_FALSE(StructuralEqual::Equal(a, b)); - EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(b)); - EXPECT_EQ(refl::StructuralHash::Hash(a, /*map_free_vars=*/true), - refl::StructuralHash::Hash(b, /*map_free_vars=*/true)); + EXPECT_NE(StructuralHash()(a), StructuralHash()(b)); + EXPECT_EQ(StructuralHash::Hash(a, /*map_free_vars=*/true), + StructuralHash::Hash(b, /*map_free_vars=*/true)); } TEST(StructuralEqualHash, FuncDefAndIgnoreField) { @@ -152,21 +152,21 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) { TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, "comment c"); - EXPECT_TRUE(refl::StructuralEqual()(fa, fb)); - EXPECT_EQ(refl::StructuralHash()(fa), refl::StructuralHash()(fb)); + EXPECT_TRUE(StructuralEqual()(fa, fb)); + EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb)); - EXPECT_FALSE(refl::StructuralEqual()(fa, fc)); - auto diff_fa_fc = refl::StructuralEqual::GetFirstMismatch(fa, fc); + EXPECT_FALSE(StructuralEqual()(fa, fc)); + auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc); auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({ refl::AccessStep::ObjectField("body"), - refl::AccessStep::ArrayIndex(1), + refl::AccessStep::ArrayItem(1), }), refl::AccessPath({ refl::AccessStep::ObjectField("body"), - refl::AccessStep::ArrayIndex(1), + refl::AccessStep::ArrayItem(1), })); EXPECT_TRUE(diff_fa_fc.has_value()); - EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); + EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); } TEST(StructuralEqualHash, CustomTreeNode) { @@ -178,21 +178,21 @@ TEST(StructuralEqualHash, CustomTreeNode) { TCustomFunc fc = TCustomFunc({x}, {TInt(1), TInt(2)}, "comment c"); - EXPECT_TRUE(refl::StructuralEqual()(fa, fb)); - EXPECT_EQ(refl::StructuralHash()(fa), refl::StructuralHash()(fb)); + EXPECT_TRUE(StructuralEqual()(fa, fb)); + EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb)); - EXPECT_FALSE(refl::StructuralEqual()(fa, fc)); - auto diff_fa_fc = refl::StructuralEqual::GetFirstMismatch(fa, fc); + EXPECT_FALSE(StructuralEqual()(fa, fc)); + auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc); auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({ refl::AccessStep::ObjectField("body"), - refl::AccessStep::ArrayIndex(1), + refl::AccessStep::ArrayItem(1), }), refl::AccessPath({ refl::AccessStep::ObjectField("body"), - refl::AccessStep::ArrayIndex(1), + refl::AccessStep::ArrayItem(1), })); EXPECT_TRUE(diff_fa_fc.has_value()); - EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); + EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); } } // namespace diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 501d55b8efa6..df8c45b5e697 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -18,8 +18,8 @@ */ #include "module_equality.h" -#include -#include +#include +#include #include #include #include @@ -40,12 +40,12 @@ class ModuleEqualityStructural : public ModuleEquality { class ModuleEqualityIgnoreNDArray : public ModuleEquality { public: size_t Hash(IRModule mod) const { - return tvm::ffi::reflection::StructuralHash::Hash(mod, /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + return tvm::ffi::StructuralHash::Hash(mod, /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } bool Equal(IRModule lhs, IRModule rhs) const { - return tvm::ffi::reflection::StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + return tvm::ffi::StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } String GetName() const { return "ignore-ndarray"; } }; @@ -56,9 +56,9 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { auto anchor_block = tir::FindAnchorBlock(mod); if (anchor_block) { - return ffi::reflection::StructuralHash::Hash(GetRef(anchor_block), - /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + return ffi::StructuralHash::Hash(GetRef(anchor_block), + /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } return ModuleEqualityIgnoreNDArray().Hash(mod); } @@ -66,10 +66,10 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { auto anchor_block_lhs = tir::FindAnchorBlock(lhs); auto anchor_block_rhs = tir::FindAnchorBlock(rhs); if (anchor_block_lhs && anchor_block_rhs) { - return tvm::ffi::reflection::StructuralEqual::Equal(GetRef(anchor_block_lhs), - GetRef(anchor_block_rhs), - /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + return tvm::ffi::StructuralEqual::Equal(GetRef(anchor_block_lhs), + GetRef(anchor_block_rhs), + /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs); } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 186f50947230..5474be667655 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,10 +19,10 @@ /*! * \file src/node/structural_equal.cc */ +#include #include #include #include -#include #include #include #include @@ -64,19 +64,19 @@ Optional ObjectPathPairFromAccessPathPair( result = result->Attr(step->key.cast()); break; } - case ffi::reflection::AccessKind::kArrayIndex: { + case ffi::reflection::AccessKind::kArrayItem: { result = result->ArrayIndex(step->key.cast()); break; } - case ffi::reflection::AccessKind::kMapKey: { + case ffi::reflection::AccessKind::kMapItem: { result = result->MapValue(step->key); break; } - case ffi::reflection::AccessKind::kArrayIndexMissing: { + case ffi::reflection::AccessKind::kArrayItemMissing: { result = result->MissingArrayElement(step->key.cast()); break; } - case ffi::reflection::AccessKind::kMapKeyMissing: { + case ffi::reflection::AccessKind::kMapItemMissing: { result = result->MissingMapEntry(); break; } @@ -96,7 +96,7 @@ bool NodeStructuralEqualAdapter(const Any& lhs, const Any& rhs, bool assert_mode bool map_free_vars) { if (assert_mode) { auto first_mismatch = ObjectPathPairFromAccessPathPair( - ffi::reflection::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars)); + ffi::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars)); if (first_mismatch.has_value()) { std::ostringstream oss; oss << "StructuralEqual check failed, caused by lhs"; @@ -129,7 +129,7 @@ bool NodeStructuralEqualAdapter(const Any& lhs, const Any& rhs, bool assert_mode } return true; } else { - return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_vars); + return ffi::StructuralEqual::Equal(lhs, rhs, map_free_vars); } } @@ -147,12 +147,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ return first_mismatch; */ return ObjectPathPairFromAccessPathPair( - ffi::reflection::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars)); + ffi::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars)); }); }); bool StructuralEqual::operator()(const ffi::Any& lhs, const ffi::Any& rhs, bool map_free_params) const { - return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_params); + return ffi::StructuralEqual::Equal(lhs, rhs, map_free_params); } } // namespace tvm diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 3a5f1de04165..383f344facae 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -20,9 +20,9 @@ * \file src/node/structural_hash.cc */ #include +#include #include #include -#include #include #include #include @@ -44,12 +44,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.StructuralHash", [](const Any& object, bool map_free_vars) -> int64_t { - return ffi::reflection::StructuralHash::Hash(object, map_free_vars); + return ffi::StructuralHash::Hash(object, map_free_vars); }); }); uint64_t StructuralHash::operator()(const ffi::Any& object) const { - return ffi::reflection::StructuralHash::Hash(object, false); + return ffi::StructuralHash::Hash(object, false); } struct RefToObjectPtr : public ObjectRef { diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 037a9f3021fb..f87f531a875b 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -21,9 +21,9 @@ * \file src/relax/block_builder.cc */ #include +#include #include #include -#include #include #include #include @@ -431,8 +431,8 @@ class BlockBuilderImpl : public BlockBuilderNode { class StructuralHashIgnoreNDarray { public: uint64_t operator()(const ObjectRef& key) const { - return ffi::reflection::StructuralHash::Hash(key, /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + return ffi::StructuralHash::Hash(key, /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } }; diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 83d978f27d0c..40a1c307cee5 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -22,8 +22,8 @@ * \brief Lift local functions into global functions. */ +#include #include -#include #include #include #include @@ -541,8 +541,8 @@ class ParamRemapper : private ExprFunctor { } else { var_remap_.Set(GetRef(lhs_var), rhs_var); } - CHECK(tvm::ffi::reflection::StructuralEqual::Equal(lhs_var->struct_info_, rhs_var->struct_info_, - /*map_free_vars=*/true)) + CHECK(tvm::ffi::StructuralEqual::Equal(lhs_var->struct_info_, rhs_var->struct_info_, + /*map_free_vars=*/true)) << "The struct info of the parameters should be the same for all target functions"; auto lhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(GetRef(lhs_var))); auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr));