diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index 55fbd1c1bcf4..af9943476e3d 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -59,7 +59,6 @@ set(tvm_ffi_objs_sources "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc" ) if (TVM_FFI_USE_EXTRA_CXX_API) @@ -69,6 +68,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API) "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/reflection_extra.cc" ) endif() diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 11080a21f0b8..c8d46d455227 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -896,7 +896,7 @@ TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index); #endif //--------------------------------------------------------------- -// The following API defines static object field accessors +// The following API defines static object attribute accessors // for language bindings. // // They are defined in C++ inline functions for cleaner code. diff --git a/ffi/include/tvm/ffi/reflection/access_path.h b/ffi/include/tvm/ffi/reflection/access_path.h index a4f40f485ebd..267cb76fc1fe 100644 --- a/ffi/include/tvm/ffi/reflection/access_path.h +++ b/ffi/include/tvm/ffi/reflection/access_path.h @@ -25,24 +25,31 @@ #include #include +#include #include #include +#include #include +#include + namespace tvm { namespace ffi { namespace reflection { enum class AccessKind : int32_t { - kObjectField = 0, + kAttr = 0, kArrayItem = 1, kMapItem = 2, // the following two are used for error reporting when // the supposed access field is not available - kArrayItemMissing = 3, - kMapItemMissing = 4, + kAttrMissing = 3, + kArrayItemMissing = 4, + kMapItemMissing = 5, }; +class AccessStep; + /*! * \brief Represent a single step in object field, map key, array index access. */ @@ -59,16 +66,18 @@ class AccessStepObj : public Object { */ Any key; + // default constructor to enable auto-serialization + AccessStepObj() = default; AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {} - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("kind", &AccessStepObj::kind) - .def_ro("key", &AccessStepObj::key); - } + /*! + * \brief Deep check if two steps are equal. + * \param other The other step to compare with. + * \return True if the two steps are equal, false otherwise. + */ + inline bool StepEqual(const AccessStep& other) const; - static constexpr const char* _type_key = "tvm.ffi.reflection.AccessStep"; + static constexpr const char* _type_key = "ffi.reflection.AccessStep"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessStepObj, Object); }; @@ -82,8 +91,10 @@ class AccessStep : public ObjectRef { public: AccessStep(AccessKind kind, Any key) : ObjectRef(make_object(kind, key)) {} - static AccessStep ObjectField(String field_name) { - return AccessStep(AccessKind::kObjectField, field_name); + static AccessStep Attr(String field_name) { return AccessStep(AccessKind::kAttr, field_name); } + + static AccessStep AttrMissing(String field_name) { + return AccessStep(AccessKind::kAttrMissing, field_name); } static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); } @@ -94,15 +105,273 @@ class AccessStep : public ObjectRef { static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, key); } - static AccessStep MapItemMissing(Any key) { return AccessStep(AccessKind::kMapItemMissing, key); } + static AccessStep MapItemMissing(Any key = nullptr) { + return AccessStep(AccessKind::kMapItemMissing, key); + } TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj); }; -using AccessPath = Array; +inline bool AccessStepObj::StepEqual(const AccessStep& other) const { + return this->kind == other->kind && AnyEqual()(this->key, other->key); +} + +// forward declaration +class AccessPath; + +/*! + * \brief ObjectRef class of AccessPathObj. + * + * \sa AccessPathObj + */ +class AccessPathObj : public Object { + public: + /*! + * \brief The parent of the access path. + * + * This parent-pointing tree structure is more space efficient when + * representing multiple paths that share a common prefix. + * + * \note Empty for root. + */ + Optional parent; + /*! + * \brief The current of the access path. + * \note Empty for root. + */ + Optional step; + /*! + * \brief The current depth of the access path, 0 for root + */ + int32_t depth; + + // default constructor to enable auto-serialization + AccessPathObj() = default; + /*! + * \brief Constructor for the access path. + * \param parent The parent of the access path. + * \param step The current step of the access path. + * \param depth The current depth of the access path. + */ + AccessPathObj(Optional parent, Optional step, int32_t depth) + : parent(parent), step(step), depth(depth) {} + + /*! + * \brief Get the parent of the access path. + * \return The parent of the access path. + */ + inline Optional GetParent() const; + + /*! + * \brief Extend the access path with a new step. + * \param step The step to extend the access path with. + * \return The extended access path. + */ + inline AccessPath Extend(AccessStep step) const; + + /*! + * \brief Extend the access path with an object attribute access. + * \param field_name The name of the field to access. + * \return The extended access path. + */ + inline AccessPath Attr(String field_name) const; + + /*! + * \brief Extend the access path with an object attribute missing access. + * \param field_name The name of the field to access. + * \return The extended access path. + */ + inline AccessPath AttrMissing(String field_name) const; + + /*! + * \brief Extend the access path with an array item access. + * \param index The index of the array item to access. + * \return The extended access path. + */ + inline AccessPath ArrayItem(int64_t index) const; + + /*! + * \brief Extend the access path with an array item missing access. + * \param index The index of the array item to access. + * \return The extended access path. + */ + inline AccessPath ArrayItemMissing(int64_t index) const; + + /*! + * \brief Extend the access path with a map item access. + * \param key The key of the map item to access. + * \return The extended access path. + */ + inline AccessPath MapItem(Any key) const; + + /*! + * \brief Extend the access path with a map item missing access. + * \param key The key of the map item to access. + * \return The extended access path. + */ + inline AccessPath MapItemMissing(Any key) const; + + /*! + * \brief Get the array of steps that corresponds to the access path. + * \return The array of steps that corresponds to the access path. + */ + inline Array ToSteps() const; + + /*! + * \brief Check if two paths are equal by deep comparing the steps. + * \param other The other path to compare with. + * \return True if the two paths are equal, false otherwise. + */ + inline bool PathEqual(const AccessPath& other) const; + + /*! + * \brief Check if this path is a prefix of another path. + * \param other The other path to compare with. + * \return True if this path is a prefix of the other path, false otherwise. + */ + inline bool IsPrefixOf(const AccessPath& other) const; + + static constexpr const char* _type_key = "ffi.reflection.AccessPath"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessPathObj, Object); + + private: + static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) { + // fast path for same pointer + if (lhs == rhs) return true; + if (lhs->depth != rhs->depth) return false; + // do deep equality checks + while (lhs->parent.has_value()) { + TVM_FFI_ICHECK(rhs->parent.has_value()); + TVM_FFI_ICHECK(lhs->step.has_value()); + TVM_FFI_ICHECK(rhs->step.has_value()); + if (!(*lhs->step)->StepEqual(*(rhs->step))) { + return false; + } + lhs = static_cast(lhs->parent.get()); + rhs = static_cast(rhs->parent.get()); + // fast path for same pointer + if (lhs == rhs) return true; + TVM_FFI_ICHECK(lhs != nullptr); + TVM_FFI_ICHECK(rhs != nullptr); + } + return true; + } +}; + +/*! + * \brief ObjectRef class of AccessPath. + * + * \sa AccessPathObj + */ +class AccessPath : public ObjectRef { + public: + /*! + * \brief Create an access path from an iterator range of steps. + * \param begin The beginning of the iterator range. + * \param end The end of the iterator range. + * \return The access path. + */ + template + static AccessPath FromSteps(Iter begin, Iter end) { + AccessPath path = AccessPath::Root(); + for (Iter it = begin; it != end; ++it) { + path = path->Extend(*it); + } + return path; + } + /*! + * \brief Create an access path from an array of steps. + * \param steps The array of steps. + * \return The access path. + */ + static AccessPath FromSteps(Array steps) { + AccessPath path = AccessPath::Root(); + for (AccessStep step : steps) { + path = path->Extend(step); + } + return path; + } + + /*! + * \brief Create a root access path. + * \return The root access path. + */ + static AccessPath Root() { + return AccessPath(make_object(std::nullopt, std::nullopt, 0)); + } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessPath, ObjectRef, AccessPathObj); +}; + using AccessPathPair = Tuple; +inline Optional AccessPathObj::GetParent() const { + if (auto opt_parent = this->parent.as()) { + return opt_parent; + } + return std::nullopt; +} + +inline AccessPath AccessPathObj::Extend(AccessStep step) const { + return AccessPath(make_object(GetRef(this), step, this->depth + 1)); +} + +inline AccessPath AccessPathObj::Attr(String field_name) const { + return this->Extend(AccessStep::Attr(field_name)); +} + +inline AccessPath AccessPathObj::AttrMissing(String field_name) const { + return this->Extend(AccessStep::AttrMissing(field_name)); +} + +inline AccessPath AccessPathObj::ArrayItem(int64_t index) const { + return this->Extend(AccessStep::ArrayItem(index)); +} + +inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const { + return this->Extend(AccessStep::ArrayItemMissing(index)); +} + +inline AccessPath AccessPathObj::MapItem(Any key) const { + return this->Extend(AccessStep::MapItem(key)); +} + +inline AccessPath AccessPathObj::MapItemMissing(Any key) const { + return this->Extend(AccessStep::MapItemMissing(key)); +} + +inline Array AccessPathObj::ToSteps() const { + std::vector reverse_steps; + reverse_steps.reserve(this->depth); + const AccessPathObj* current = this; + while (current->parent.has_value()) { + TVM_FFI_ICHECK(current->step.has_value()); + reverse_steps.push_back(*(current->step)); + current = static_cast(current->parent.get()); + TVM_FFI_ICHECK(current != nullptr); + } + return Array(reverse_steps.rbegin(), reverse_steps.rend()); +} + +inline bool AccessPathObj::PathEqual(const AccessPath& other) const { + return PathEqual(this, other.get()); +} + +inline bool AccessPathObj::IsPrefixOf(const AccessPath& other) const { + if (this->depth > other->depth) { + return false; + } + const AccessPathObj* rhs_path = other.get(); + while (rhs_path->depth > this->depth) { + TVM_FFI_ICHECK(rhs_path->parent.has_value()); + rhs_path = static_cast(rhs_path->parent.get()); + } + return PathEqual(this, rhs_path); +} + } // namespace reflection } // namespace ffi } // namespace tvm + #endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_ diff --git a/ffi/include/tvm/ffi/reflection/registry.h b/ffi/include/tvm/ffi/reflection/registry.h index 14b49395d743..107a6e77592b 100644 --- a/ffi/include/tvm/ffi/reflection/registry.h +++ b/ffi/include/tvm/ffi/reflection/registry.h @@ -198,7 +198,7 @@ class ReflectionDefBase { } } - template + template TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) { return ffi::Function::FromTyped(std::forward(func), name); } @@ -258,27 +258,12 @@ class GlobalDef : public ReflectionDefBase { */ template GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) { - RegisterFunc(name, GetMethod_(std::string(name), std::forward(func)), + RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), std::forward(extra)...); return *this; } private: - template - TVM_FFI_INLINE static Function GetMethod_(std::string name, Func&& func) { - return ffi::Function::FromTyped(std::forward(func), name); - } - - template - TVM_FFI_INLINE static Function GetMethod_(std::string name, R (Class::*func)(Args...) const) { - return GetMethod(std::string(name), func); - } - - template - TVM_FFI_INLINE static Function GetMethod_(std::string name, R (Class::*func)(Args...)) { - return GetMethod(std::string(name), func); - } - template void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) { TVMFFIMethodInfo info; @@ -434,8 +419,7 @@ class ObjectDef : public ReflectionDefBase { info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; } // obtain the method function - Function method = - GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); info.method = AnyView(method).CopyToTVMFFIAny(); // apply method info traits ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); @@ -467,7 +451,7 @@ class TypeAttrDef : public ReflectionDefBase { TypeAttrDef& def(const char* name, Func&& func) { TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; ffi::Function ffi_func = - GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + GetMethod(std::string(type_key_) + "." + name, std::forward(func)); TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny(); TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); return *this; diff --git a/ffi/src/ffi/extra/reflection_extra.cc b/ffi/src/ffi/extra/reflection_extra.cc new file mode 100644 index 000000000000..698be6337698 --- /dev/null +++ b/ffi/src/ffi/extra/reflection_extra.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/reflection_extra.cc + * + * \brief Extra reflection registrations. * + */ +#include +#include + +namespace tvm { +namespace ffi { +namespace reflection { + +void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { + int32_t type_index; + if (auto opt_type_index = args[0].try_cast()) { + type_index = *opt_type_index; + } else { + String type_key = args[0].cast(); + TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + } + + TVM_FFI_ICHECK(args.size() % 2 == 1); + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); + + if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not support reflection creation"; + } + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); + ObjectPtr ptr = + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + + std::vector keys; + std::vector keys_found; + + for (int i = 1; i < args.size(); i += 2) { + keys.push_back(args[i].cast()); + } + keys_found.resize(keys.size(), false); + + auto search_field = [&](const TVMFFIByteArray& field_name) { + for (size_t i = 0; i < keys.size(); ++i) { + if (keys_found[i]) continue; + if (keys[i].compare(field_name) == 0) { + return i; + } + } + return keys.size(); + }; + + auto update_fields = [&](const TVMFFITypeInfo* tinfo) { + for (int i = 0; i < tinfo->num_fields; ++i) { + const TVMFFIFieldInfo* field_info = tinfo->fields + i; + size_t arg_index = search_field(field_info->name); + void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; + if (arg_index < keys.size()) { + AnyView field_value = args[arg_index * 2 + 2]; + field_info->setter(field_addr, reinterpret_cast(&field_value)); + keys_found[arg_index] = true; + } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { + field_info->setter(field_addr, &(field_info->default_value)); + } else { + TVM_FFI_THROW(TypeError) << "Required field `" + << String(field_info->name.data, field_info->name.size) + << "` not set in type `" << TypeIndexToTypeKey(type_index) << "`"; + } + } + }; + + // iterate through acenstors in parent to child order + // skip the first one since it is always the root object + for (int i = 1; i < type_info->type_depth; ++i) { + update_fields(type_info->type_acenstors[i]); + } + update_fields(type_info); + + for (size_t i = 0; i < keys.size(); ++i) { + if (!keys_found[i]) { + TVM_FFI_THROW(TypeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not have field `" << keys[i] << "`"; + } + } + *ret = ObjectRef(ptr); +} + +inline void AccessStepRegisterReflection() { + // register access step reflection here since it is only needed for bindings + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("kind", &AccessStepObj::kind) + .def_ro("key", &AccessStepObj::key); +} + +inline void AccessPathRegisterReflection() { + // register access path reflection here since it is only needed for bindings + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("parent", &AccessPathObj::parent) + .def_ro("step", &AccessPathObj::step) + .def_ro("depth", &AccessPathObj::depth) + .def_static("_root", &AccessPath::Root) + .def("_extend", &AccessPathObj::Extend) + .def("_attr", &AccessPathObj::Attr) + .def("_array_item", &AccessPathObj::ArrayItem) + .def("_map_item", &AccessPathObj::MapItem) + .def("_attr_missing", &AccessPathObj::AttrMissing) + .def("_array_item_missing", &AccessPathObj::ArrayItemMissing) + .def("_map_item_missing", &AccessPathObj::MapItemMissing) + .def("_is_prefix_of", &AccessPathObj::IsPrefixOf) + .def("_to_steps", &AccessPathObj::ToSteps) + .def("_path_equal", + [](const AccessPath& self, const AccessPath& other) { return self->PathEqual(other); }); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + AccessStepRegisterReflection(); + AccessPathRegisterReflection(); + refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs", MakeObjectFromPackedArgs); +}); + +} // namespace reflection +} // namespace ffi +} // namespace tvm diff --git a/ffi/src/ffi/extra/structural_equal.cc b/ffi/src/ffi/extra/structural_equal.cc index 97ebbf4072cd..171fa2f750a0 100644 --- a/ffi/src/ffi/extra/structural_equal.cc +++ b/ffi/src/ffi/extra/structural_equal.cc @@ -185,9 +185,9 @@ class StructEqualHandler { // record the first mismatching field if we sub-rountine compare failed if (mismatch_lhs_reverse_path_ != nullptr) { mismatch_lhs_reverse_path_->emplace_back( - reflection::AccessStep::ObjectField(String(field_info->name))); + reflection::AccessStep::Attr(String(field_info->name))); mismatch_rhs_reverse_path_->emplace_back( - reflection::AccessStep::ObjectField(String(field_info->name))); + reflection::AccessStep::Attr(String(field_info->name))); } // return true to indicate early stop return true; @@ -216,9 +216,9 @@ class StructEqualHandler { if (mismatch_lhs_reverse_path_ != nullptr) { String field_name_str = field_name.cast(); mismatch_lhs_reverse_path_->emplace_back( - reflection::AccessStep::ObjectField(field_name_str)); + reflection::AccessStep::Attr(field_name_str)); mismatch_rhs_reverse_path_->emplace_back( - reflection::AccessStep::ObjectField(field_name_str)); + reflection::AccessStep::Attr(field_name_str)); } } return success; @@ -420,8 +420,11 @@ Optional StructuralEqual::GetFirstMismatch(const Any if (handler.CompareAny(lhs, rhs)) { return std::nullopt; } - reflection::AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend()); - reflection::AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend()); + using reflection::AccessPath; + reflection::AccessPath lhs_path = + AccessPath::FromSteps(lhs_reverse_path.rbegin(), lhs_reverse_path.rend()); + reflection::AccessPath rhs_path = + AccessPath::FromSteps(rhs_reverse_path.rbegin(), rhs_reverse_path.rend()); return reflection::AccessPathPair(lhs_path, rhs_path); } diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 374c0c7c4eeb..61107cb63ff7 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -385,86 +385,6 @@ class TypeTable { Map type_attr_name_to_column_index_; }; -void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { - int32_t type_index; - if (auto opt_type_index = args[0].try_cast()) { - type_index = *opt_type_index; - } else { - String type_key = args[0].cast(); - TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - } - - TVM_FFI_ICHECK(args.size() % 2 == 1); - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support reflection creation"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - - std::vector keys; - std::vector keys_found; - - for (int i = 1; i < args.size(); i += 2) { - keys.push_back(args[i].cast()); - } - keys_found.resize(keys.size(), false); - - auto search_field = [&](const TVMFFIByteArray& field_name) { - for (size_t i = 0; i < keys.size(); ++i) { - if (keys_found[i]) continue; - if (keys[i].compare(field_name) == 0) { - return i; - } - } - return keys.size(); - }; - - auto update_fields = [&](const TVMFFITypeInfo* tinfo) { - for (int i = 0; i < tinfo->num_fields; ++i) { - const TVMFFIFieldInfo* field_info = tinfo->fields + i; - size_t arg_index = search_field(field_info->name); - void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (arg_index < keys.size()) { - AnyView field_value = args[arg_index * 2 + 2]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); - keys_found[arg_index] = true; - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" << TypeIndexToTypeKey(type_index) << "`"; - } - } - }; - - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - update_fields(type_info->type_acenstors[i]); - } - update_fields(type_info); - - for (size_t i = 0; i < keys.size(); ++i) { - if (!keys_found[i]) { - TVM_FFI_THROW(TypeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have field `" << keys[i] << "`"; - } - } - *ret = ObjectRef(ptr); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs", MakeObjectFromPackedArgs); -}); - } // namespace ffi } // namespace tvm diff --git a/ffi/src/ffi/reflection/access_path.cc b/ffi/src/ffi/reflection/access_path.cc deleted file mode 100644 index 17b8abb062ff..000000000000 --- a/ffi/src/ffi/reflection/access_path.cc +++ /dev/null @@ -1,34 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/reflection/access_path.cc - */ - -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -TVM_FFI_STATIC_INIT_BLOCK({ AccessStepObj::RegisterReflection(); }); - -} // namespace reflection -} // namespace ffi -} // namespace tvm diff --git a/ffi/tests/cpp/extra/test_structural_equal_hash.cc b/ffi/tests/cpp/extra/test_structural_equal_hash.cc index 8a377f483713..a05c50cc2617 100644 --- a/ffi/tests/cpp/extra/test_structural_equal_hash.cc +++ b/ffi/tests/cpp/extra/test_structural_equal_hash.cc @@ -47,21 +47,23 @@ TEST(StructuralEqualHash, Array) { // first directly interepret diff, EXPECT_TRUE(diff_a_c.has_value()); - EXPECT_EQ((*diff_a_c).get<0>()[0]->kind, refl::AccessKind::kArrayItem); - EXPECT_EQ((*diff_a_c).get<1>()[0]->kind, refl::AccessKind::kArrayItem); - EXPECT_EQ((*diff_a_c).get<0>()[0]->key.cast(), 1); - EXPECT_EQ((*diff_a_c).get<1>()[0]->key.cast(), 1); - EXPECT_EQ((*diff_a_c).get<0>().size(), 1); - EXPECT_EQ((*diff_a_c).get<1>().size(), 1); + auto lhs_steps = (*diff_a_c).get<0>()->ToSteps(); + auto rhs_steps = (*diff_a_c).get<1>()->ToSteps(); + EXPECT_EQ(lhs_steps[0]->kind, refl::AccessKind::kArrayItem); + EXPECT_EQ(rhs_steps[0]->kind, refl::AccessKind::kArrayItem); + EXPECT_EQ(lhs_steps[0]->key.cast(), 1); + EXPECT_EQ(rhs_steps[0]->key.cast(), 1); + EXPECT_EQ(lhs_steps.size(), 1); + EXPECT_EQ(rhs_steps.size(), 1); // use structural equal for checking in future parts // given we have done some basic checks above by directly interepret diff, Array d = {1, 2}; auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d); - auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({ + auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath::FromSteps({ refl::AccessStep::ArrayItem(2), }), - refl::AccessPath({ + refl::AccessPath::FromSteps({ refl::AccessStep::ArrayItemMissing(2), })); // then use structural equal to check it @@ -80,12 +82,8 @@ TEST(StructuralEqualHash, Map) { EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); - auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::MapItem("c"), - }), - refl::AccessPath({ - refl::AccessStep::MapItem("c"), - })); + auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath::Root()->MapItem("c"), + refl::AccessPath::Root()->MapItem("c")); EXPECT_TRUE(diff_a_c.has_value()); EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c)); } @@ -101,35 +99,22 @@ TEST(StructuralEqualHash, NestedMapArray) { EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); - auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::MapItem("b"), - refl::AccessStep::ArrayItem(1), - }), - refl::AccessPath({ - refl::AccessStep::MapItem("b"), - refl::AccessStep::ArrayItem(1), - })); + auto expected_diff_a_c = + refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b")->ArrayItem(1), + refl::AccessPath::Root()->MapItem("b")->ArrayItem(1)); EXPECT_TRUE(diff_a_c.has_value()); EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c)); Map> d = {{"a", {1, 2, 3}}}; auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d); - auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::MapItem("b"), - }), - refl::AccessPath({ - refl::AccessStep::MapItemMissing("b"), - })); + auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b"), + refl::AccessPath::Root()->MapItemMissing("b")); EXPECT_TRUE(diff_a_d.has_value()); EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d)); auto diff_d_a = StructuralEqual::GetFirstMismatch(d, a); - auto expected_diff_d_a = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::MapItemMissing("b"), - }), - refl::AccessPath({ - refl::AccessStep::MapItem("b"), - })); + auto expected_diff_d_a = refl::AccessPathPair(refl::AccessPath::Root()->MapItemMissing("b"), + refl::AccessPath::Root()->MapItem("b")); } TEST(StructuralEqualHash, FreeVar) { @@ -157,12 +142,12 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) { EXPECT_FALSE(StructuralEqual()(fa, fc)); auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc); - auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::ObjectField("body"), + auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath::FromSteps({ + refl::AccessStep::Attr("body"), refl::AccessStep::ArrayItem(1), }), - refl::AccessPath({ - refl::AccessStep::ObjectField("body"), + refl::AccessPath::FromSteps({ + refl::AccessStep::Attr("body"), refl::AccessStep::ArrayItem(1), })); EXPECT_TRUE(diff_fa_fc.has_value()); @@ -183,14 +168,9 @@ TEST(StructuralEqualHash, CustomTreeNode) { EXPECT_FALSE(StructuralEqual()(fa, fc)); auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc); - auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({ - refl::AccessStep::ObjectField("body"), - refl::AccessStep::ArrayItem(1), - }), - refl::AccessPath({ - refl::AccessStep::ObjectField("body"), - refl::AccessStep::ArrayItem(1), - })); + auto expected_diff_fa_fc = + refl::AccessPathPair(refl::AccessPath::Root()->Attr("body")->ArrayItem(1), + refl::AccessPath::Root()->Attr("body")->ArrayItem(1)); EXPECT_TRUE(diff_fa_fc.has_value()); EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); } diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc index 98915c54e19c..85da00c1321d 100644 --- a/ffi/tests/cpp/test_reflection.cc +++ b/ffi/tests/cpp/test_reflection.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -165,4 +166,107 @@ TEST(Reflection, ObjectCreator) { refl::ObjectCreator creator("test.Int"); EXPECT_EQ(creator(Map({{"value", 1}})).cast()->value, 1); } + +TEST(Reflection, AccessPath) { + namespace refl = tvm::ffi::reflection; + + // Test basic path construction and ToSteps() + refl::AccessPath path = refl::AccessPath::Root()->Attr("body")->ArrayItem(1); + auto steps = path->ToSteps(); + EXPECT_EQ(steps.size(), 2); + EXPECT_EQ(steps[0]->kind, refl::AccessKind::kAttr); + EXPECT_EQ(steps[1]->kind, refl::AccessKind::kArrayItem); + EXPECT_EQ(steps[0]->key.cast(), "body"); + EXPECT_EQ(steps[1]->key.cast(), 1); + + // Test PathEqual with identical paths + refl::AccessPath path2 = refl::AccessPath::Root()->Attr("body")->ArrayItem(1); + EXPECT_TRUE(path->PathEqual(path2)); + EXPECT_TRUE(path->IsPrefixOf(path2)); + + // Test PathEqual with different paths + refl::AccessPath path3 = refl::AccessPath::Root()->Attr("body")->ArrayItem(2); + EXPECT_FALSE(path->PathEqual(path3)); + EXPECT_FALSE(path->IsPrefixOf(path3)); + + // Test prefix relationship - path4 extends path, so path should be prefix of path4 + refl::AccessPath path4 = refl::AccessPath::Root()->Attr("body")->ArrayItem(1)->Attr("body"); + EXPECT_FALSE(path->PathEqual(path4)); // Not equal (different lengths) + EXPECT_TRUE(path->IsPrefixOf(path4)); // But path is a prefix of path4 + + // Test completely different paths + refl::AccessPath path5 = refl::AccessPath::Root()->ArrayItem(0)->ArrayItem(1)->Attr("body"); + EXPECT_FALSE(path->PathEqual(path5)); + EXPECT_FALSE(path->IsPrefixOf(path5)); + + // Test Root path + refl::AccessPath root = refl::AccessPath::Root(); + auto root_steps = root->ToSteps(); + EXPECT_EQ(root_steps.size(), 0); + EXPECT_EQ(root->depth, 0); + EXPECT_TRUE(root->IsPrefixOf(path)); + EXPECT_TRUE(root->IsPrefixOf(root)); + EXPECT_TRUE(root->PathEqual(refl::AccessPath::Root())); + + // Test depth calculations + EXPECT_EQ(path->depth, 2); + EXPECT_EQ(path4->depth, 3); + EXPECT_EQ(root->depth, 0); + + // Test MapItem access + refl::AccessPath map_path = refl::AccessPath::Root()->Attr("data")->MapItem("key1"); + auto map_steps = map_path->ToSteps(); + EXPECT_EQ(map_steps.size(), 2); + EXPECT_EQ(map_steps[0]->kind, refl::AccessKind::kAttr); + EXPECT_EQ(map_steps[1]->kind, refl::AccessKind::kMapItem); + EXPECT_EQ(map_steps[0]->key.cast(), "data"); + EXPECT_EQ(map_steps[1]->key.cast(), "key1"); + + // Test MapItemMissing access + refl::AccessPath map_missing_path = refl::AccessPath::Root()->MapItemMissing(42); + auto map_missing_steps = map_missing_path->ToSteps(); + EXPECT_EQ(map_missing_steps.size(), 1); + EXPECT_EQ(map_missing_steps[0]->kind, refl::AccessKind::kMapItemMissing); + EXPECT_EQ(map_missing_steps[0]->key.cast(), 42); + + // Test ArrayItemMissing access + refl::AccessPath array_missing_path = refl::AccessPath::Root()->ArrayItemMissing(5); + auto array_missing_steps = array_missing_path->ToSteps(); + EXPECT_EQ(array_missing_steps.size(), 1); + EXPECT_EQ(array_missing_steps[0]->kind, refl::AccessKind::kArrayItemMissing); + EXPECT_EQ(array_missing_steps[0]->key.cast(), 5); + + // Test FromSteps static method - round trip conversion + auto original_steps = path->ToSteps(); + refl::AccessPath reconstructed = refl::AccessPath::FromSteps(original_steps); + EXPECT_TRUE(path->PathEqual(reconstructed)); + EXPECT_EQ(path->depth, reconstructed->depth); + + // Test complex prefix relationships + refl::AccessPath short_path = refl::AccessPath::Root()->Attr("x"); + refl::AccessPath medium_path = refl::AccessPath::Root()->Attr("x")->ArrayItem(0); + refl::AccessPath long_path = refl::AccessPath::Root()->Attr("x")->ArrayItem(0)->MapItem("z"); + + EXPECT_TRUE(short_path->IsPrefixOf(medium_path)); + EXPECT_TRUE(short_path->IsPrefixOf(long_path)); + EXPECT_TRUE(medium_path->IsPrefixOf(long_path)); + EXPECT_FALSE(medium_path->IsPrefixOf(short_path)); + EXPECT_FALSE(long_path->IsPrefixOf(medium_path)); + EXPECT_FALSE(long_path->IsPrefixOf(short_path)); + + // Test non-prefix relationships + refl::AccessPath branch1 = refl::AccessPath::Root()->Attr("x")->ArrayItem(0); + refl::AccessPath branch2 = refl::AccessPath::Root()->Attr("x")->ArrayItem(1); + EXPECT_FALSE(branch1->IsPrefixOf(branch2)); + EXPECT_FALSE(branch2->IsPrefixOf(branch1)); + EXPECT_FALSE(branch1->PathEqual(branch2)); + + // Test GetParent functionality + auto parent = path4->GetParent(); + EXPECT_TRUE(parent.has_value()); + EXPECT_TRUE(parent.value()->PathEqual(path)); + + auto root_parent = root->GetParent(); + EXPECT_FALSE(root_parent.has_value()); +} } // namespace diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py index 43a20e751c29..e615e22a0cbc 100644 --- a/python/tvm/ffi/__init__.py +++ b/python/tvm/ffi/__init__.py @@ -31,6 +31,7 @@ from .ndarray import from_dlpack, NDArray, Shape from .container import Array, Map from . import serialization +from . import access_path from . import testing @@ -67,4 +68,7 @@ "Shape", "Array", "Map", + "testing", + "access_path", + "serialization", ] diff --git a/python/tvm/ffi/access_path.py b/python/tvm/ffi/access_path.py new file mode 100644 index 000000000000..c4822074ebb8 --- /dev/null +++ b/python/tvm/ffi/access_path.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Access path classes.""" + +from enum import IntEnum +from typing import List, Any +from . import core +from .registry import register_object + + +class AccessKind(IntEnum): + ATTR = 0 + ARRAY_ITEM = 1 + MAP_ITEM = 2 + ATTR_MISSING = 3 + ARRAY_ITEM_MISSING = 4 + MAP_ITEM_MISSING = 5 + + +@register_object("ffi.reflection.AccessStep") +class AccessStep(core.Object): + """Access step container""" + + +@register_object("ffi.reflection.AccessPath") +class AccessPath(core.Object): + """Access path container""" + + def __init__(self) -> None: + super().__init__() + raise ValueError( + "AccessPath can't be initialized directly. " + "Use AccessPath.root() to create a path to the root object" + ) + + @staticmethod + def root() -> "AccessPath": + """Create a root access path""" + return AccessPath._root() + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, AccessPath): + return False + return self._path_equal(other) + + def __ne__(self, other: Any) -> bool: + if not isinstance(other, AccessPath): + return True + return not self._path_equal(other) + + def is_prefix_of(self, other: "AccessPath") -> bool: + """Check if this access path is a prefix of another access path + + Parameters + ---------- + other : AccessPath + The access path to check if it is a prefix of this access path + + Returns + ------- + bool + True if this access path is a prefix of the other access path, False otherwise + """ + return self._is_prefix_of(other) + + def attr(self, attr_key: str) -> "AccessPath": + """Create an access path to the attribute of the current object + + Parameters + ---------- + attr_key : str + The key of the attribute to access + + Returns + ------- + AccessPath + The extended access path + """ + return self._attr(attr_key) + + def attr_missing(self, attr_key: str) -> "AccessPath": + """Create an access path that indicate an attribute is missing + + Parameters + ---------- + attr_key : str + The key of the attribute to access + + Returns + ------- + AccessPath + The extended access path + """ + return self._attr_missing(attr_key) + + def array_item(self, index: int) -> "AccessPath": + """Create an access path to the item of the current array + + Parameters + ---------- + index : int + The index of the item to access + + Returns + ------- + AccessPath + The extended access path + """ + return self._array_item(index) + + def array_item_missing(self, index: int) -> "AccessPath": + """Create an access path that indicate an array item is missing + + Parameters + ---------- + index : int + The index of the item to access + + Returns + ------- + AccessPath + The extended access path + """ + return self._array_item_missing(index) + + def map_item(self, key: Any) -> "AccessPath": + """Create an access path to the item of the current map + + Parameters + ---------- + key : Any + The key of the item to access + + Returns + ------- + AccessPath + The extended access path + """ + return self._map_item(key) + + def map_item_missing(self, key: Any) -> "AccessPath": + """Create an access path that indicate a map item is missing + + Parameters + ---------- + key : Any + The key of the item to access + + Returns + ------- + AccessPath + The extended access path + """ + return self._map_item_missing(key) + + def to_steps(self) -> List["AccessStep"]: + """Convert the access path to a list of access steps + + Returns + ------- + List[AccessStep] + The list of access steps + """ + return self._to_steps() diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index 8c9df19642b0..999c2e1338b5 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -291,6 +291,12 @@ cdef _get_method_from_method_info(const TVMFFIMethodInfo* method): return make_ret(result) +def _member_method_wrapper(method_func): + def wrapper(self, *args): + return method_func(self, *args) + return wrapper + + def _add_class_attrs_by_reflection(int type_index, object cls): """Decorate the class attrs by reflection""" cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index) @@ -335,8 +341,10 @@ def _add_class_attrs_by_reflection(int type_index, object cls): if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod: method_pyfunc = staticmethod(method_func) else: - def method_pyfunc(self, *args): - return method_func(self, *args) + # must call into another method instead of direct capture + # to avoid the same method_func variable being used + # across multiple loop iterations + method_pyfunc = _member_method_wrapper(method_func) if doc is not None: method_pyfunc.__doc__ = doc @@ -345,7 +353,6 @@ def _add_class_attrs_by_reflection(int type_index, object cls): if hasattr(cls, name): # skip already defined attributes continue - setattr(cls, name, method_pyfunc) return cls diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi index 7df5f7a19aff..dad6bee51b34 100644 --- a/python/tvm/ffi/cython/object.pxi +++ b/python/tvm/ffi/cython/object.pxi @@ -31,7 +31,7 @@ def _set_func_convert_to_object(func): def __object_repr__(obj): """Object repr function that can be overridden by assigning to it""" - return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" + return type(obj).__name__ + "(" + str(obj.__ctypes_handle__().value) + ")" def _new_object(cls): diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index d1954413dc92..c6875d3fca4b 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -55,9 +55,9 @@ Optional ObjectPathPairFromAccessPathPair( if (!src.has_value()) return std::nullopt; auto translate_path = [](ffi::reflection::AccessPath path) { ObjectPath result = ObjectPath::Root(); - for (const auto& step : path) { + for (const auto& step : path->ToSteps()) { switch (step->kind) { - case ffi::reflection::AccessKind::kObjectField: { + case ffi::reflection::AccessKind::kAttr: { result = result->Attr(step->key.cast()); break; } diff --git a/tests/python/ffi/test_access_path.py b/tests/python/ffi/test_access_path.py new file mode 100644 index 000000000000..06fbb64ff217 --- /dev/null +++ b/tests/python/ffi/test_access_path.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from tvm.ffi.access_path import AccessPath, AccessKind + + +def test_root_path(): + root = AccessPath.root() + assert isinstance(root, AccessPath) + steps = root.to_steps() + assert len(steps) == 0 + assert root == AccessPath.root() + + +def test_path_attr(): + path = AccessPath.root().attr("foo") + assert isinstance(path, AccessPath) + steps = path.to_steps() + assert len(steps) == 1 + assert steps[0].kind == AccessKind.ATTR + assert steps[0].key == "foo" + assert path.parent == AccessPath.root() + + +def test_path_array_item(): + path = AccessPath.root().array_item(2) + assert isinstance(path, AccessPath) + steps = path.to_steps() + assert len(steps) == 1 + assert steps[0].kind == AccessKind.ARRAY_ITEM + assert steps[0].key == 2 + assert path.parent == AccessPath.root() + + +def test_path_missing_array_element(): + path = AccessPath.root().array_item_missing(2) + assert isinstance(path, AccessPath) + steps = path.to_steps() + assert len(steps) == 1 + assert steps[0].kind == AccessKind.ARRAY_ITEM_MISSING + assert steps[0].key == 2 + assert path.parent == AccessPath.root() + + +def test_path_map_item(): + path = AccessPath.root().map_item("foo") + assert isinstance(path, AccessPath) + steps = path.to_steps() + assert len(steps) == 1 + assert steps[0].kind == AccessKind.MAP_ITEM + assert steps[0].key == "foo" + assert path.parent == AccessPath.root() + + +def test_path_missing_map_item(): + path = AccessPath.root().map_item_missing("foo") + assert isinstance(path, AccessPath) + steps = path.to_steps() + assert len(steps) == 1 + assert steps[0].kind == AccessKind.MAP_ITEM_MISSING + assert steps[0].key == "foo" + assert path.parent == AccessPath.root() + + +def test_path_is_prefix_of(): + # Root is prefix of root + assert AccessPath.root().is_prefix_of(AccessPath.root()) + + # Root is prefix of any path + assert AccessPath.root().is_prefix_of(AccessPath.root().attr("foo")) + + # Non-root is not prefix of root + assert not AccessPath.root().attr("foo").is_prefix_of(AccessPath.root()) + + # Path is prefix of itself + assert AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo")) + + # Different attrs are not prefixes of each other + assert not AccessPath.root().attr("bar").is_prefix_of(AccessPath.root().attr("foo")) + + # Shorter path is prefix of longer path with same start + assert AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo").array_item(2)) + + # Longer path is not prefix of shorter path + assert ( + not AccessPath.root().attr("foo").array_item(2).is_prefix_of(AccessPath.root().attr("foo")) + ) + + # Different paths are not prefixes + assert ( + not AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("bar").array_item(2)) + ) + + +def test_path_equal(): + # Root equals root + assert AccessPath.root() == AccessPath.root() + + # Root does not equal non-root paths + assert not (AccessPath.root() == AccessPath.root().attr("foo")) + + # Non-root does not equal root + assert not (AccessPath.root().attr("foo") == AccessPath.root()) + + # Path equals itself + assert AccessPath.root().attr("foo") == AccessPath.root().attr("foo") + + # Different attrs are not equal + assert not (AccessPath.root().attr("bar") == AccessPath.root().attr("foo")) + + # Shorter path does not equal longer path + assert not (AccessPath.root().attr("foo") == AccessPath.root().attr("foo").array_item(2)) + + # Longer path does not equal shorter path + assert not (AccessPath.root().attr("foo").array_item(2) == AccessPath.root().attr("foo")) + + # Different paths are not equal + assert not (AccessPath.root().attr("foo") == AccessPath.root().attr("bar").array_item(2))