diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index 8c2a9a0995f9..9e663b072810 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -228,15 +228,6 @@ Each ``Object`` subclass will override this to register its members. Here is an refl::ObjectDef().def_ro("value", &IntImmNode::value); } - bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(value); - } - static constexpr const char* _type_key = "ir.IntImm"; TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); }; diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 05b936ea9077..a49a9f170060 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -214,10 +214,6 @@ class Object { static constexpr int32_t _type_depth = 0; // the structural equality and hash kind of the type static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported; - // extra fields used by plug-ins for attribute visiting - // and structural information - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; // The following functions are provided by macro // TVM_FFI_DECLARE_BASE_OBJECT_INFO and TVM_DECLARE_FINAL_OBJECT_INFO /*! diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index 3d8b4b23ed7c..78ca008e1094 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -217,11 +217,9 @@ class TCustomFuncObj : public Object { bool SEqual(const TCustomFuncObj* other, ffi::TypedFunction cmp) const { if (!cmp(params, other->params, true, "params")) { - std::cout << "custom s_equal failed params" << std::endl; return false; } if (!cmp(body, other->body, false, "body")) { - std::cout << "custom s_equal failed body" << std::endl; return false; } return true; diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 54cbab258680..52e9e7209e89 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -94,10 +94,6 @@ class ConstIntBoundNode : public Object { .def_ro("max_value", &ConstIntBoundNode::max_value); } - bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const { - return equal(min_value, other->min_value) && equal(max_value, other->max_value); - } - /*! \brief Number to represent +inf */ static const constexpr int64_t kPosInf = std::numeric_limits::max(); /*! @@ -219,10 +215,6 @@ class ModularSetNode : public Object { .def_ro("base", &ModularSetNode::base); } - bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const { - return equal(coeff, other->coeff) && equal(base, other->base); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "arith.ModularSet"; TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object); diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 046019abd9ec..702edba1a462 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -57,7 +57,7 @@ enum SignType { kPositive, kNegative, kZero, kUnknown }; class IntSetNode : public Object { public: static constexpr const char* _type_key = "ir.IntSet"; - static constexpr bool _type_has_method_sequal_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object); }; diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index e2f384b696ac..6dfc2f0ecb88 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -71,20 +71,8 @@ class IntGroupBoundsNode : public Object { .def_ro("upper", &IntGroupBoundsNode::upper); } - bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const { - return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) && - eq(upper, other->upper); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(coef); - hash_reduce(lower); - hash_reduce(equal); - hash_reduce(upper); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const char* _type_key = "arith.IntGroupBounds"; TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object); }; @@ -163,19 +151,8 @@ class IntConstraintsNode : public Object { .def_ro("relations", &IntConstraintsNode::relations); } - bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const { - return equal(variables, other->variables) && equal(ranges, other->ranges) && - equal(relations, other->relations); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(variables); - hash_reduce(ranges); - hash_reduce(relations); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const char* _type_key = "arith.IntConstraints"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); }; @@ -228,20 +205,8 @@ class IntConstraintsTransformNode : public Object { .def_ro("dst_to_src", &IntConstraintsTransformNode::dst_to_src); } - bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const { - return equal(src, other->src) && equal(dst, other->dst) && - equal(src_to_dst, other->src_to_dst) && equal(dst_to_src, other->dst_to_src); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(src); - hash_reduce(dst); - hash_reduce(src_to_dst); - hash_reduce(dst_to_src); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const char* _type_key = "arith.IntConstraintsTransform"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); }; diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 3c666b430f13..25f8e14a7f7b 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -105,20 +105,8 @@ class IterMarkNode : public Object { .def_ro("extent", &IterMarkNode::extent); } - bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal(source, other->source) && equal(extent, other->extent); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); - hash_reduce(source); - hash_reduce(extent); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const char* _type_key = "arith.IterMark"; TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object); }; @@ -165,18 +153,6 @@ class IterSplitExprNode : public IterMapExprNode { .def_ro("scale", &IterSplitExprNode::scale); } - bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const { - return equal(source, other->source) && equal(lower_factor, other->lower_factor) && - equal(extent, other->extent) && equal(scale, other->scale); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(source); - hash_reduce(lower_factor); - hash_reduce(extent); - hash_reduce(scale); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "arith.IterSplitExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode); @@ -232,15 +208,6 @@ class IterSumExprNode : public IterMapExprNode { .def_ro("base", &IterSumExprNode::base); } - bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const { - return equal(args, other->args) && equal(base, other->base); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(args); - hash_reduce(base); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "arith.IterSumExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode); diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 6a43274cae46..2553116634a2 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -84,8 +84,7 @@ class AttrFieldInfoNode : public Object { static constexpr const char* _type_key = "ir.AttrFieldInfo"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr bool _type_has_method_sequal_reduce = false; - static constexpr bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); }; @@ -123,8 +122,7 @@ class BaseAttrsNode : public Object { bool allow_unknown = false) = 0; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const char* _type_key = "ir.Attrs"; TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); }; @@ -154,11 +152,6 @@ class DictAttrsNode : public BaseAttrsNode { rfl::ObjectDef().def_ro("__dict__", &DictAttrsNode::dict); } - bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const { - return equal(dict, other->dict); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); } void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final; // type info @@ -394,32 +387,6 @@ class AttrsNodeReflAdapter : public BaseAttrsNode { LOG(FATAL) << "`" << DerivedType::_type_key << "` uses new reflection mechanism for init"; } - bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const { - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex()); - bool success = true; - ffi::reflection::ForEachFieldInfoWithEarlyStop( - type_info, [&](const TVMFFIFieldInfo* field_info) { - ffi::reflection::FieldGetter field_getter(field_info); - ffi::Any field_value = field_getter(self()); - ffi::Any other_field_value = field_getter(other); - if (!equal.AnyEqual(field_value, other_field_value)) { - success = false; - return true; - } - return false; - }); - return success; - } - - void SHashReduce(SHashReducer hash_reducer) const { - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex()); - ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - ffi::reflection::FieldGetter field_getter(field_info); - ffi::Any field_value = field_getter(self()); - hash_reducer(field_value); - }); - } - private: DerivedType* self() const { return const_cast(static_cast(this)); diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index e1d7abbead15..9f4f5770aa60 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -74,11 +74,6 @@ class DiagnosticNode : public Object { .def_ro("message", &DiagnosticNode::message); } - bool SEqualReduce(const DiagnosticNode* other, SEqualReducer equal) const { - return equal(this->level, other->level) && equal(this->span, other->span) && - equal(this->message, other->message); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "Diagnostic"; TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticNode, Object); @@ -211,10 +206,6 @@ class DiagnosticContextNode : public Object { .def_ro("diagnostics", &DiagnosticContextNode::diagnostics); } - bool SEqualReduce(const DiagnosticContextNode* other, SEqualReducer equal) const { - return equal(module, other->module) && equal(diagnostics, other->diagnostics); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "DiagnosticContext"; TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object); diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index c1fdeb6d1c48..0fb48a6efaab 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -57,20 +57,9 @@ class EnvFuncNode : public Object { .def_ro("func", &EnvFuncNode::func, refl::AttachFieldFlag::SEqHashIgnore()); } - bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { - // name uniquely identifies the env function. - return name == other->name; - } - - void SHashReduce(SHashReducer hash_reduce) const { - // Name uniquely identifies the env function. - hash_reduce(name); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.EnvFunc"; - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object); }; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index cb62cbadf5bb..9b7645b56a46 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -66,8 +66,7 @@ class BaseExprNode : public Object { static constexpr const char* _type_key = "ir.BaseExpr"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 64; TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); }; @@ -468,16 +467,6 @@ class GlobalVarNode : public RelaxExprNode { refl::ObjectDef().def_ro("name_hint", &GlobalVarNode::name_hint); } - bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { - // name matters for global var. - return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name_hint); - hash_reduce.FreeVarHashImpl(this); - } - bool SEqual(const GlobalVarNode* other, ffi::TypedFunction equal) const { return equal(name_hint, other->name_hint, false, "name_hint"); @@ -519,15 +508,6 @@ class IntImmNode : public PrimExprNode { refl::ObjectDef().def_ro("value", &IntImmNode::value); } - bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(value); - } - static constexpr const char* _type_key = "ir.IntImm"; TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); }; @@ -565,15 +545,6 @@ class FloatImmNode : public PrimExprNode { refl::ObjectDef().def_ro("value", &FloatImmNode::value); } - bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(value); - } - static constexpr const char* _type_key = "ir.FloatImm"; TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); }; @@ -713,22 +684,13 @@ class RangeNode : public Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("min", &RangeNode::min) - .def_ro("extent", &RangeNode::extent); - } - - bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const { - return equal(min, other->min) && equal(extent, other->extent); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(min); - hash_reduce(extent); + .def_ro("extent", &RangeNode::extent) + .def_ro("span", &RangeNode::span, refl::AttachFieldFlag::SEqHashIgnore()); } static constexpr const char* _type_key = "ir.Range"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); }; diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 57eadf2b2992..e6ff10ad1bc4 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -45,8 +45,7 @@ class GlobalInfoNode : public Object { static constexpr const char* _type_key = "ir.GlobalInfo"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object); }; @@ -80,16 +79,6 @@ class VDeviceNode : public GlobalInfoNode { .def_ro("memory_scope", &VDeviceNode::memory_scope); } - TVM_DLL bool SEqualReduce(const VDeviceNode* other, SEqualReducer equal) const { - return equal(target, other->target) && equal(vdevice_id, other->vdevice_id) && - equal(memory_scope, other->memory_scope); - } - - TVM_DLL void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(target); - hash_reduce(vdevice_id); - hash_reduce(memory_scope); - } static constexpr const char* _type_key = "ir.VDevice"; TVM_DECLARE_FINAL_OBJECT_INFO(VDeviceNode, GlobalInfoNode); }; @@ -115,12 +104,6 @@ class DummyGlobalInfoNode : public GlobalInfoNode { } static constexpr const char* _type_key = "ir.DummyGlobalInfo"; - - TVM_DLL bool SEqualReduce(const DummyGlobalInfoNode* other, SEqualReducer equal) const { - return true; - } - - TVM_DLL void SHashReduce(SHashReducer hash_reduce) const {} TVM_DECLARE_FINAL_OBJECT_INFO(DummyGlobalInfoNode, GlobalInfoNode); }; diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h index b8dfed75a7aa..8ed8e5ed4c13 100644 --- a/include/tvm/ir/global_var_supply.h +++ b/include/tvm/ir/global_var_supply.h @@ -85,8 +85,7 @@ class GlobalVarSupplyNode : public Object { NameSupply name_supply_; static constexpr const char* _type_key = "ir.GlobalVarSupply"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object); private: diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 66c26b0629ba..6f7d6d2d130d 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -144,10 +144,6 @@ class IRModuleNode : public Object { .def("__s_hash__", &IRModuleNode::SHash); } - TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; - - TVM_DLL void SHashReduce(SHashReducer hash_reduce) const; - TVM_DLL bool SEqual(const IRModuleNode* other, ffi::TypedFunction equal) const; TVM_DLL uint64_t SHash(uint64_t init_hash, @@ -247,8 +243,7 @@ class IRModuleNode : public Object { static constexpr const char* _type_key = "ir.IRModule"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); private: diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index ff018157a176..6eefaefea793 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -85,8 +85,6 @@ class NameSupplyNode : public Object { std::string prefix_; static constexpr const char* _type_key = "ir.NameSupply"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object); private: diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 5903bed8d92e..5f40ff4d3a7b 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -103,16 +103,6 @@ class OpNode : public RelaxExprNode { .def_ro("support_level", &OpNode::support_level, refl::AttachFieldFlag::SEqHashIgnore()); } - bool SEqualReduce(const OpNode* other, SEqualReducer equal) const { - // pointer equality is fine as there is only one op with the same name. - return this == other; - } - - void SHashReduce(SHashReducer hash_reduce) const { - // Name uniquely identifies an Op. - hash_reduce(name); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUniqueInstance; static constexpr const char* _type_key = "ir.Op"; TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelaxExprNode); diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index d53c234690e2..c7fce1c5024c 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -53,12 +53,6 @@ class SourceNameNode : public Object { refl::ObjectDef().def_ro("name", &SourceNameNode::name); } - static constexpr bool _type_has_method_sequal_reduce = true; - - bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const { - return equal(name, other->name); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.SourceName"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); @@ -111,14 +105,6 @@ class SpanNode : public Object { .def_ro("end_column", &SpanNode::end_column); } - static constexpr bool _type_has_method_sequal_reduce = true; - - bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { - return equal(source_name, other->source_name) && equal(line, other->line) && - equal(column, other->column) && equal(end_line, other->end_line) && - equal(end_column, other->end_column); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.Span"; TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object); @@ -149,19 +135,6 @@ class SequentialSpanNode : public SpanNode { static constexpr const char* _type_key = "ir.SequentialSpan"; TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode); - - bool SEqualReduce(const SequentialSpanNode* other, SEqualReducer equal) const { - if (spans.size() != other->spans.size()) { - return false; - } - - for (size_t i = 0, e = spans.size(); i != e; ++i) { - if (!StructuralEqual()(spans[i], other->spans[i])) { - return false; - } - } - return true; - } }; /*! @@ -231,10 +204,6 @@ class SourceMapObj : public Object { refl::ObjectDef().def_ro("source_map", &SourceMapObj::source_map); } - bool SEqualReduce(const SourceMapObj* other, SEqualReducer equal) const { - return equal(source_map, other->source_map); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.SourceMap"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapObj, Object); diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index e92158b4e1dc..4f9004fba560 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -135,7 +135,7 @@ class PassContextNode : public Object { } static constexpr const char* _type_key = "transform.PassContext"; - static constexpr bool _type_has_method_sequal_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; @@ -324,7 +324,7 @@ class PassInfoNode : public Object { } static constexpr const char* _type_key = "transform.PassInfo"; - static constexpr bool _type_has_method_sequal_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index a2ab74a3aeb1..9d75e845f88f 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -89,8 +89,7 @@ class TypeNode : public Object { static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.Type"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 14; TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); }; @@ -124,12 +123,6 @@ class PrimTypeNode : public TypeNode { refl::ObjectDef().def_ro("dtype", &PrimTypeNode::dtype); } - bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } - static constexpr const char* _type_key = "ir.PrimType"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); }; @@ -178,19 +171,6 @@ class PointerTypeNode : public TypeNode { .def_ro("storage_scope", &PointerTypeNode::storage_scope); } - bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const { - // Make "global" equal to "" - String lhs_scope = storage_scope.empty() ? "global" : storage_scope; - String rhs_scope = other->storage_scope.empty() ? "global" : other->storage_scope; - return equal(element_type, other->element_type) && equal(lhs_scope, rhs_scope); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(element_type); - // Make "global" equal to "" - hash_reduce(storage_scope.empty() ? "global" : storage_scope); - } - static constexpr const char* _type_key = "ir.PointerType"; TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); }; @@ -229,12 +209,6 @@ class TupleTypeNode : public TypeNode { .def_ro("span", &TupleTypeNode::span); } - bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const { - return equal(fields, other->fields); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } - static constexpr const char* _type_key = "ir.TupleType"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); }; @@ -298,16 +272,6 @@ class FuncTypeNode : public TypeNode { .def_ro("span", &FuncTypeNode::span); } - bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const { - // type params first as they defines type vars. - return equal(arg_types, other->arg_types) && equal(ret_type, other->ret_type); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(arg_types); - hash_reduce(ret_type); - } - static constexpr const char* _type_key = "ir.FuncType"; TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); }; @@ -341,12 +305,6 @@ class TensorMapTypeNode : public TypeNode { refl::ObjectDef().def_ro("span", &TensorMapTypeNode::span); } - bool SEqualReduce(const TensorMapTypeNode* other, SEqualReducer equal) const { - return equal(span, other->span); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(span); } - static constexpr const char* _type_key = "ir.TensorMapType"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorMapTypeNode, TypeNode); }; diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 9b3186bac117..fef6c44fce5d 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -51,14 +51,6 @@ using runtime::ObjectRef; */ class ReflectionVTable { public: - /*! - * \brief Equality comparison function. - */ - typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal); - /*! - * \brief Structural hash reduction function. - */ - typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce); /*! * \brief creator function. * \param repr_bytes Repr bytes to create the object. @@ -80,20 +72,6 @@ class ReflectionVTable { * \return Whether repr bytes exists */ inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const; - /*! - * \brief Dispatch the SEqualReduce function. - * \param self The pointer to the object. - * \param other The pointer to another object to be compared. - * \param equal The equality comparator. - * \return the result. - */ - bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const; - /*! - * \brief Dispatch the SHashReduce function. - * \param self The pointer to the object. - * \param hash_reduce The hash reducer. - */ - void SHashReduce(const Object* self, SHashReducer hash_reduce) const; /*! * \brief Create an initial object using default constructor * by type_key and global key. @@ -138,14 +116,10 @@ class ReflectionVTable { TVM_DLL static ReflectionVTable* Global(); class Registry; - template + template inline Registry Register(); private: - /*! \brief Structural equal function. */ - std::vector fsequal_reduce_; - /*! \brief Structural hash function. */ - std::vector fshash_reduce_; /*! \brief Creation function. */ std::vector fcreate_; /*! \brief ReprBytes function. */ @@ -189,125 +163,30 @@ class ReflectionVTable::Registry { /*! * \brief Directly register reflection VTable. * \param TypeName The name of the type. - * \param TraitName A trait class that implements functions like SEqualReduce. - * - * \code - * - * // Example SEQualReduce traits for runtime StringObj. - * - * struct StringObjTrait { - * - * - * static void SHashReduce(const StringObj* key, SHashReducer hash_reduce) { - * hash_reduce->SHashReduceHashedValue(String::StableHashBytes(key->data, key->size)); - * } - * - * static bool SEqualReduce(const StringObj* lhs, - * const StringObj* rhs, - * SEqualReducer equal) { - * if (lhs == rhs) return true; - * if (lhs->size != rhs->size) return false; - * if (lhs->data != rhs->data) return true; - * return std::memcmp(lhs->data, rhs->data, lhs->size) != 0; - * } - * }; - * - * TVM_REGISTER_REFLECTION_VTABLE(StringObj, StringObjTrait); - * - * \endcode * * \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE. * And can be used to register the related reflection functions for runtime objects. */ -#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ +#define TVM_REGISTER_REFLECTION_VTABLE(TypeName) \ TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::ReflectionVTable::Global()->Register() + ::tvm::ReflectionVTable::Global()->Register() /*! * \brief Register a node type to object registry and reflection registry. * \param TypeName The name of the type. * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well. */ -#define TVM_REGISTER_NODE_TYPE(TypeName) \ - TVM_REGISTER_OBJECT_TYPE(TypeName); \ - TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait) \ - .set_creator([](const std::string&) -> ObjectPtr { \ - return ::tvm::ffi::make_object(); \ - }) - -// Implementation details -namespace detail { - -template -struct ImplSEqualReduce { - static constexpr const std::nullptr_t SEqualReduce = nullptr; -}; - -template -struct ImplSEqualReduce { - static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) { - return self->SEqualReduce(other, equal); - } -}; - -template -struct ImplSHashReduce { - static constexpr const std::nullptr_t SHashReduce = nullptr; -}; - -template -struct ImplSHashReduce { - static void SHashReduce(const T* self, SHashReducer hash_reduce) { - self->SHashReduce(hash_reduce); - } -}; +#define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_REFLECTION_VTABLE(TypeName).set_creator( \ + [](const std::string&) -> ObjectPtr { return ::tvm::ffi::make_object(); }) template -struct ReflectionTrait : public ImplSEqualReduce, public ImplSHashReduce {}; - -template ::value> -struct SelectSEqualReduce { - static constexpr const std::nullptr_t SEqualReduce = nullptr; -}; - -template -struct SelectSEqualReduce { - static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) { - return TraitName::SEqualReduce(static_cast(self), static_cast(other), - equal); - } -}; - -template ::value> -struct SelectSHashReduce { - static constexpr const std::nullptr_t SHashReduce = nullptr; -}; - -template -struct SelectSHashReduce { - static void SHashReduce(const Object* self, SHashReducer hash_reduce) { - return TraitName::SHashReduce(static_cast(self), hash_reduce); - } -}; - -} // namespace detail - -template inline ReflectionVTable::Registry ReflectionVTable::Register() { uint32_t tindex = T::RuntimeTypeIndex(); if (tindex >= fcreate_.size()) { fcreate_.resize(tindex + 1, nullptr); frepr_bytes_.resize(tindex + 1, nullptr); - fsequal_reduce_.resize(tindex + 1, nullptr); - fshash_reduce_.resize(tindex + 1, nullptr); } - // functor that implements the redirection. - fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce::SEqualReduce; - - fshash_reduce_[tindex] = ::tvm::detail::SelectSHashReduce::SHashReduce; - return Registry(this, tindex); } @@ -323,11 +202,5 @@ inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr } } -/*! - * \brief Given an object and an address of its attribute, return the key of the attribute. - * \return nullptr if no attribute with the given address exists. - */ -Optional GetAttrKeyByAddress(const Object* object, const void* attr_address); - } // namespace tvm #endif // TVM_NODE_REFLECTION_H_ diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 249c2dabb64e..0e7dc246a3e7 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -126,352 +126,8 @@ class StructuralEqual : public BaseValueEqual { * \param map_free_params Whether or not to map free variables. * \return The comparison result. */ - TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, + TVM_DLL bool operator()(const ffi::Any& lhs, const ffi::Any& rhs, const bool map_free_params = false) const; - - /*! - * \brief Compare any value via strutural equal. - * \param lhs The left operand. - * \param rhs The right operand. - * \param map_free_params Whether or not to map free variables. - * \return The comparison result. - */ - TVM_FFI_INLINE bool operator()(const ffi::Any& lhs, const ffi::Any& rhs, - bool map_free_params = false) const { - if (lhs.type_index() != rhs.type_index()) return false; - if (lhs.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - return operator()(lhs.cast(), rhs.cast(), map_free_params); - } - // POD value can always use v_int64 to get the hash value - return (ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(lhs)->v_uint64 == - ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(rhs)->v_uint64); - } }; - -/*! - * \brief A Reducer class to reduce the structural equality result of two objects. - * - * The reducer will call the SEqualReduce function of each objects recursively. - * Importantly, the reducer may not directly use recursive calls to resolve the - * equality checking. Instead, it can store the necessary equality conditions - * and check later via an internally managed stack. - */ -class SEqualReducer { - private: - struct PathTracingData; - - public: - /*! \brief Internal handler that defines custom behaviors.. */ - class Handler { - public: - /*! - * \brief Reduce condition to equality of lhs and rhs. - * - * \param lhs The left operand. - * \param rhs The right operand. - * \param map_free_vars Whether do we allow remap variables if possible. - * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability. - * - * \return false if there is an immediate failure, true otherwise. - * \note This function may save the equality condition of (lhs == rhs) in an internal - * stack and try to resolve later. - */ - virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const Optional& current_paths) = 0; - - /*! - * \brief Mark the comparison as failed, but don't fail immediately. - * - * This is useful for producing better error messages when comparing containers. - * For example, if two array sizes mismatch, it's better to mark the comparison as failed - * but compare array elements anyway, so that we could find the true first mismatch. - */ - virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0; - - /*! - * \brief Check if fail defferal is enabled. - * - * \return false if the fail deferral is not enabled, true otherwise. - */ - virtual bool IsFailDeferralEnabled() = 0; - - /*! - * \brief Lookup the graph node equal map for vars that are already mapped. - * - * This is an auxiliary method to check the Map equality. - * \param lhs an lhs value. - * - * \return The corresponding rhs value if any, nullptr if not available. - */ - virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0; - /*! - * \brief Mark current comparison as graph node equal comparison. - */ - virtual void MarkGraphNode() = 0; - - /*! - * \brief Map lhs to rhs. - * \param lhs The left operand. - * \return The corresponding rhs value if any, nullptr if not available. - */ - TVM_FFI_INLINE ffi::Any MapLhsToRhs(const ffi::Any& lhs) { - if (lhs.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - return MapLhsToRhs(lhs.cast()); - } else { - return lhs; - } - } - - protected: - using PathTracingData = SEqualReducer::PathTracingData; - }; - - /*! \brief default constructor */ - SEqualReducer() = default; - /*! - * \brief Constructor with a specific handler. - * \param handler The equal handler for objects. - * \param tracing_data Optional pointer to the path tracing data. - * \param map_free_vars Whether or not to map free variables. - */ - explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars) - : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {} - - /*! - * \brief Reduce condition to comparison of two attribute values. - * - * \param lhs The left operand. - * - * \param rhs The right operand. - * - * \param paths The paths to the LHS and RHS operands. If - * unspecified, will attempt to identify the attribute's address - * within the most recent ObjectRef. In general, the paths only - * require explicit handling for computed parameters - * (e.g. `array.size()`) - * - * \return the immediate check result. - */ - bool operator()(const double& lhs, const double& rhs, - Optional paths = std::nullopt) const; - bool operator()(const int64_t& lhs, const int64_t& rhs, - Optional paths = std::nullopt) const; - bool operator()(const uint64_t& lhs, const uint64_t& rhs, - Optional paths = std::nullopt) const; - bool operator()(const int& lhs, const int& rhs, - Optional paths = std::nullopt) const; - bool operator()(const bool& lhs, const bool& rhs, - Optional paths = std::nullopt) const; - bool operator()(const std::string& lhs, const std::string& rhs, - Optional paths = std::nullopt) const; - bool operator()(const DataType& lhs, const DataType& rhs, - Optional paths = std::nullopt) const; - bool operator()(const Optional& lhs, const Optional& rhs, - Optional paths = std::nullopt) const; - bool operator()(const Optional& lhs, const Optional& rhs, - Optional paths = std::nullopt) const; - template ::value>::type> - bool operator()(const ENum& lhs, const ENum& rhs, - Optional paths = std::nullopt) const { - using Underlying = typename std::underlying_type::type; - static_assert(std::is_same::value, - "Enum must have `int` as the underlying type"); - return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs, paths); - } - - template , ObjectPath>>> - bool operator()(const T& lhs, const T& rhs, const Callable& callable) { - if (IsPathTracingEnabled()) { - ObjectPathPair current_paths = GetCurrentObjectPaths(); - ObjectPathPair new_paths = {callable(current_paths->lhs_path), - callable(current_paths->rhs_path)}; - return (*this)(lhs, rhs, new_paths); - } else { - return (*this)(lhs, rhs); - } - } - - /*! - * \brief Reduce condition to comparison of two objects. - * \param lhs The left operand. - * \param rhs The right operand. - * \return the immediate check result. - */ - bool operator()(const ffi::ObjectRef& lhs, const ffi::ObjectRef& rhs) const; - - /*! - * \brief Reduce condition to comparison of two objects. - * - * Like `operator()`, but with an additional `paths` parameter that specifies explicit object - * paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container - * objects like Array and Map, or other custom objects that store nested objects that are not - * simply attributes. - * - * Can only be called when `IsPathTracingEnabled()` is `true`. - * - * \param lhs The left operand. - * \param rhs The right operand. - * \param paths Object paths for `lhs` and `rhs`. - * \return the immediate check result. - */ - bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const { - ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function"; - return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths); - } - - /* - * \brief Compare two Any values. - * \param lhs The left operand. - * \param rhs The right operand. - * \param paths Object paths for `lhs` and `rhs`. - * \return the immediate check result. - */ - bool AnyEqual(const ffi::Any& lhs, const ffi::Any& rhs, - Optional paths = std::nullopt) const; - - /*! - * \brief Reduce condition to comparison of two definitions, - * where free vars can be mapped. - * - * Call this function to compare definition points such as function params - * and var in a let-binding. - * - * \param lhs The left operand. - * \param rhs The right operand. - * \return the immediate check result. - */ - bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); - - /*! - * \brief Reduce condition to comparison of two arrays. - * \param lhs The left operand. - * \param rhs The right operand. - * \return the immediate check result. - */ - template - bool operator()(const Array& lhs, const Array& rhs) const { - if (tracing_data_ == nullptr) { - // quick specialization for Array to reduce amount of recursion - // depth as array comparison is pretty common. - if (lhs.size() != rhs.size()) return false; - for (size_t i = 0; i < lhs.size(); ++i) { - if constexpr (std::is_same_v) { - if (!(AnyEqual(lhs[i], rhs[i]))) return false; - } else { - if (!(operator()(lhs[i], rhs[i]))) return false; - } - } - return true; - } - - // If tracing is enabled, fall back to the regular path - const ObjectRef& lhs_obj = lhs; - const ObjectRef& rhs_obj = rhs; - return (*this)(lhs_obj, rhs_obj); - } - /*! - * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var). - * \param lhs The left operand. - * \param rhs The right operand. - * \return the result. - */ - bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const { - // var need to be remapped, so it belongs to graph node. - handler_->MarkGraphNode(); - // We only map free vars if they corresponds to the same address - // or map free_var option is set to be true. - return lhs == rhs || map_free_vars_; - } - - /*! \return Get the internal handler. */ - Handler* operator->() const { return handler_; } - - /*! \brief Check if this reducer is tracing paths to the first mismatch. */ - bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; } - - /*! - * \brief Get the paths of the currently compared objects. - * - * Can only be called when `IsPathTracingEnabled()` is true. - */ - const ObjectPathPair& GetCurrentObjectPaths() const; - - /*! - * \brief Specify the object paths of a detected mismatch. - * - * Can only be called when `IsPathTracingEnabled()` is true. - */ - void RecordMismatchPaths(const ObjectPathPair& paths) const; - - private: - bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address, - Optional paths = std::nullopt) const; - - bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const ObjectPathPair* paths) const; - - static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address, - const void* rhs_address, - const PathTracingData* tracing_data); - - template - static bool CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data, - Optional paths = std::nullopt); - - /*! \brief Internal class pointer. */ - Handler* handler_ = nullptr; - /*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */ - const PathTracingData* tracing_data_ = nullptr; - /*! \brief Whether or not to map free vars. */ - bool map_free_vars_ = false; -}; - -/*! \brief The default handler for equality testing. - * - * Users can derive from this class and override the DispatchSEqualReduce method, - * to customize equality testing. - */ -class SEqualHandlerDefault : public SEqualReducer::Handler { - public: - SEqualHandlerDefault(bool assert_mode, Optional* first_mismatch, - bool defer_fails); - virtual ~SEqualHandlerDefault(); - - bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const Optional& current_paths) override; - void DeferFail(const ObjectPathPair& mismatch_paths) override; - bool IsFailDeferralEnabled() override; - ObjectRef MapLhsToRhs(const ObjectRef& lhs) override; - void MarkGraphNode() override; - - /*! - * \brief The entry point for equality testing - * \param lhs The left operand. - * \param rhs The right operand. - * \param map_free_vars Whether or not to remap variables if possible. - * \return The equality result. - */ - virtual bool Equal(const ffi::Any& lhs, const ffi::Any& rhs, bool map_free_vars); - - protected: - /*! - * \brief The dispatcher for equality testing of intermediate objects - * \param lhs The left operand. - * \param rhs The right operand. - * \param map_free_vars Whether or not to remap variables if possible. - * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability. - * \return The equality result. - */ - virtual bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const Optional& current_paths); - - private: - class Impl; - Impl* impl; -}; - } // namespace tvm #endif // TVM_NODE_STRUCTURAL_EQUAL_H_ diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index c909b9dfc7e3..0aca92d0e28a 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -122,189 +122,7 @@ class StructuralHash : public BaseValueHash { * \param key The left operand. * \return The hash value. */ - TVM_DLL uint64_t operator()(const ffi::ObjectRef& key) const; - - /** - * \brief Compute structural hashing value for an Any object. - * \param key The Any object. - * \return The hash value. - */ - TVM_FFI_INLINE uint64_t operator()(const ffi::Any& key) const { - if (key.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - return operator()(key.cast()); - } - return HashPODValueInAny(key); - } -}; - -/*! - * \brief A Reducer class to reduce the structural hash value. - * - * The reducer will call the SEqualHash function of each objects recursively. - * - * A SEqualHash function will make a sequence of calls to the reducer to - * indicate a sequence of child hash values that the reducer need to combine - * inorder to obtain the hash value of the hash value of the parent object. - * - * Importantly, the reducer may not directly use recursive calls - * to compute the hash values of child objects directly. - * - * Instead, it can store the necessary hash computing task into a stack - * and reduce the result later. - */ -class SHashReducer { - public: - /*! \brief Internal handler that defines custom behaviors. */ - class Handler { - public: - /*! - * \brief Append hashed_value to the current sequence of hashes. - * - * \param hashed_value The hashed value - */ - virtual void SHashReduceHashedValue(uint64_t hashed_value) = 0; - /*! - * \brief Append hash value of key to the current sequence of hashes. - * - * \param key The object to compute hash from. - * \param map_free_vars Whether to map free variables by their occurrence number. - */ - virtual void SHashReduce(const ObjectRef& key, bool map_free_vars) = 0; - /*! - * \brief Append a hash value of free variable to the current sequence of hashes. - * - * \param var The var of interest. - * \param map_free_vars Whether to map free variables by their occurrence number. - * - * \note If map_free_vars is set to be true, - * internally the handler can maintain a counter to encode free variables - * by their order of occurrence. This helps to resolve variable - * mapping of function parameters and let binding variables. - * - * If map_free_vars is set to be false, the address of the variable will be used. - */ - virtual void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) = 0; - /*! - * \brief Lookup a hash value for key - * - * \param key The hash key. - * \param hashed_value the result hash value - * - * \return Whether there is already a pre-computed hash value. - */ - virtual bool LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) = 0; - /*! - * \brief Mark current comparison as graph node in hashing. - * Graph node hash will depends on the graph structure. - */ - virtual void MarkGraphNode() = 0; - }; - - /*! \brief default constructor */ - SHashReducer() = default; - /*! - * \brief Constructor with a specific handler. - * \param handler The equal handler for objects. - * \param map_free_vars Whether to map free variables. - */ - explicit SHashReducer(Handler* handler, bool map_free_vars) - : handler_(handler), map_free_vars_(map_free_vars) {} - /*! - * \brief Push hash of key to the current sequence of hash values. - * \param key The key to be hashed. - */ - template ::value>::type> - void operator()(const T& key) const { - // handle normal values. - handler_->SHashReduceHashedValue(BaseValueHash()(key)); - } - /** - * \brief Push hash of Any object to the current sequence of hash values. - * \param key The Any object. - */ - void operator()(const ffi::Any& key) const { - if (key.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - return operator()(key.cast()); - } - // POD value can always use v_int64 to get the hash value - handler_->SHashReduceHashedValue(BaseValueHash().HashPODValueInAny(key)); - } - /*! - * \brief Push hash of key to the current sequence of hash values. - * \param key The key to be hashed. - */ - void operator()(const ObjectRef& key) const { return handler_->SHashReduce(key, map_free_vars_); } - /*! - * \brief Push hash of key to the current sequence of hash values. - * \param key The key to be hashed. - * \note This function indicate key could contain var defintions. - */ - void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); } - /*! - * \brief Implementation for hash for a free var. - * \param var The variable. - */ - void FreeVarHashImpl(const runtime::Object* var) const { - handler_->SHashReduceFreeVar(var, map_free_vars_); - } - - /*! \return Get the internal handler. */ - Handler* operator->() const { return handler_; } - - private: - /*! \brief Internal class pointer. */ - Handler* handler_; - /*! - * \brief Whether or not to map free variables by their occurrence - * If the flag is false, then free variables will be mapped - * by their in-memory address. - */ - bool map_free_vars_; -}; - -/*! \brief The default handler for hash key computation - * - * Users can derive from this class and override the DispatchSHash method, - * to customize hashing. - */ -class SHashHandlerDefault : public SHashReducer::Handler { - public: - SHashHandlerDefault(); - virtual ~SHashHandlerDefault(); - - void SHashReduceHashedValue(uint64_t hashed_value) override; - void SHashReduce(const ObjectRef& key, bool map_free_vars) override; - void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) override; - bool LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) override; - void MarkGraphNode() override; - - /*! - * \brief The entry point for hashing - * \param object The object to be hashed. - * \param map_free_vars Whether or not to remap variables if possible. - * \return The hash result. - */ - virtual uint64_t Hash(const ffi::Any& object, bool map_free_vars); - - protected: - /*! - * \brief The dispatcher for hashing of intermediate objects - * \param object An intermediate object to be hashed. - * \param map_free_vars Whether or not to remap variables if possible. - */ - virtual void DispatchSHash(const ObjectRef& object, bool map_free_vars); - - private: - class Impl; - Impl* impl; -}; - -class SEqualReducer; -struct NDArrayContainerTrait { - static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce); - static bool SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, SEqualReducer equal); + TVM_DLL uint64_t operator()(const ffi::Any& key) const; }; } // namespace tvm diff --git a/include/tvm/relax/distributed/global_info.h b/include/tvm/relax/distributed/global_info.h index 89cdec14b0d8..5e0afc0dcaa7 100644 --- a/include/tvm/relax/distributed/global_info.h +++ b/include/tvm/relax/distributed/global_info.h @@ -54,26 +54,6 @@ class DeviceMeshNode : public GlobalInfoNode { } static constexpr const char* _type_key = "relax.distributed.DeviceMesh"; - - bool SEqualReduce(const DeviceMeshNode* other, SEqualReducer equal) const { - if (shape.size() != other->shape.size()) { - return false; - } - for (int i = 0; i < static_cast(shape.size()); i++) { - if (!equal(shape[i], other->shape[i])) { - return false; - } - } - return equal(device_ids, other->device_ids); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(device_ids); - for (int i = 0; i < static_cast(shape.size()); i++) { - hash_reduce(shape[i]); - } - } - TVM_DECLARE_FINAL_OBJECT_INFO(DeviceMeshNode, GlobalInfoNode); }; diff --git a/include/tvm/relax/distributed/struct_info.h b/include/tvm/relax/distributed/struct_info.h index 7f843a9f2c75..cd4c2e7daef2 100644 --- a/include/tvm/relax/distributed/struct_info.h +++ b/include/tvm/relax/distributed/struct_info.h @@ -51,19 +51,8 @@ class PlacementSpecNode : public Object { .def_ro("kind", &PlacementSpecNode::kind); } - bool SEqualReduce(const PlacementSpecNode* other, SEqualReducer equal) const { - return equal(axis, other->axis) && equal(kind, other->kind); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(axis); - hash_reduce(static_cast(kind)); - } - static constexpr const char* _type_key = "relax.distributed.PlacementSpec"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(PlacementSpecNode, Object); }; @@ -90,12 +79,6 @@ class ShardingNode : public PlacementSpecNode { refl::ObjectDef().def_ro("sharding_dim", &ShardingNode::sharding_dim); } - bool SEqualReduce(const ShardingNode* other, SEqualReducer equal) const { - return equal(sharding_dim, other->sharding_dim); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(sharding_dim); } - static constexpr const char* _type_key = "relax.distributed.Sharding"; TVM_DECLARE_FINAL_OBJECT_INFO(ShardingNode, PlacementSpecNode); }; @@ -112,14 +95,6 @@ class PlacementNode : public Object { refl::ObjectDef().def_ro("dim_specs", &PlacementNode::dim_specs); } - bool SEqualReduce(const PlacementNode* other, SEqualReducer equal) const { - return equal(dim_specs, other->dim_specs); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dim_specs); } - - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; static constexpr const char* _type_key = "relax.distributed.Placement"; TVM_DECLARE_FINAL_OBJECT_INFO(PlacementNode, Object); @@ -163,17 +138,6 @@ class DTensorStructInfoNode : public StructInfoNode { .def_ro("tensor_sinfo", &DTensorStructInfoNode::tensor_sinfo); } - bool SEqualReduce(const DTensorStructInfoNode* other, SEqualReducer equal) const { - return equal(tensor_sinfo, other->tensor_sinfo) && equal(device_mesh, other->device_mesh) && - equal(placement, other->placement); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(tensor_sinfo); - hash_reduce(device_mesh); - hash_reduce(placement); - } - static constexpr const char* _type_key = "relax.DTensorStructInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(DTensorStructInfoNode, StructInfoNode); }; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 06aba8618b66..22cda9e06635 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -61,16 +61,9 @@ class IdNode : public Object { refl::AttachFieldFlag::SEqHashIgnore()); } - bool SEqualReduce(const IdNode* other, SEqualReducer equal) const { - return equal.FreeVarEqualImpl(this, other); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; static constexpr const char* _type_key = "relax.Id"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); }; @@ -130,8 +123,7 @@ class StructInfoNode : public Object { static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "ir.StructInfo"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 7; TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object); }; @@ -182,20 +174,6 @@ class CallNode : public ExprNode { .def_ro("sinfo_args", &CallNode::sinfo_args); } - bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - // skip sinfo_args check for primitive ops. - return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && - equal(sinfo_args, other->sinfo_args) && equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(op); - hash_reduce(args); - hash_reduce(attrs); - hash_reduce(sinfo_args); - hash_reduce(struct_info_); - } - static constexpr const char* _type_key = "relax.expr.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); }; @@ -239,13 +217,6 @@ class TupleNode : public ExprNode { refl::ObjectDef().def_ro("fields", &TupleNode::fields); } - bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { - // struct info can be deterministically derived from fields. - return equal(fields, other->fields); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } - static constexpr const char* _type_key = "relax.expr.Tuple"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); }; @@ -304,16 +275,6 @@ class TupleGetItemNode : public ExprNode { .def_ro("index", &TupleGetItemNode::index); } - bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { - // struct info can be deterministically tuple and index. - return equal(tuple, other->tuple) && equal(index, other->index); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(tuple); - hash_reduce(index); - } - static constexpr const char* _type_key = "relax.expr.TupleGetItem"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); }; @@ -373,16 +334,7 @@ class ShapeExprNode : public LeafExprNode { refl::ObjectDef().def_ro("values", &ShapeExprNode::values); } - bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { - // struct info can be deterministically derived from values. - return equal(values, other->values); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(values); } - static constexpr const char* _type_key = "relax.expr.ShapeExpr"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, LeafExprNode); }; @@ -412,16 +364,6 @@ class VarNode : public LeafExprNode { .def("__s_hash__", &VarNode::SHash); } - bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(vid); - hash_reduce(struct_info_); - } - bool SEqual(const VarNode* other, ffi::TypedFunction equal) const { return equal(vid, other->vid, false, "vid") && @@ -438,8 +380,6 @@ class VarNode : public LeafExprNode { static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const char* _type_key = "relax.expr.Var"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 1; TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode); }; @@ -466,20 +406,8 @@ class DataflowVarNode : public VarNode { refl::ObjectDef(); } - bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(vid); - hash_reduce(struct_info_); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const char* _type_key = "relax.expr.DataflowVar"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode); }; @@ -517,16 +445,6 @@ class ConstantNode : public LeafExprNode { refl::ObjectDef().def_ro("data", &ConstantNode::data); } - bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { - // struct info can be deterministically derived from data. - return equal(data, other->data) && equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(data); - hash_reduce(struct_info_); - } - static constexpr const char* _type_key = "relax.expr.Constant"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode); }; @@ -563,13 +481,6 @@ class PrimValueNode : public LeafExprNode { refl::ObjectDef().def_ro("value", &PrimValueNode::value); } - bool SEqualReduce(const PrimValueNode* other, SEqualReducer equal) const { - // struct info can be deterministically derived from data. - return equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "relax.expr.PrimValue"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimValueNode, LeafExprNode); }; @@ -612,13 +523,6 @@ class StringImmNode : public LeafExprNode { refl::ObjectDef().def_ro("value", &StringImmNode::value); } - bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { - // struct info can be deterministically derived from data. - return equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "relax.expr.StringImm"; TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, LeafExprNode); }; @@ -653,13 +557,6 @@ class DataTypeImmNode : public LeafExprNode { refl::ObjectDef().def_ro("value", &DataTypeImmNode::value); } - bool SEqualReduce(const DataTypeImmNode* other, SEqualReducer equal) const { - // struct info can be deterministically derived from data. - return equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "relax.expr.DataTypeImm"; TVM_DECLARE_FINAL_OBJECT_INFO(DataTypeImmNode, LeafExprNode); }; @@ -697,8 +594,7 @@ class BindingNode : public Object { static constexpr const char* _type_key = "relax.expr.Binding"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object); }; @@ -735,12 +631,7 @@ class MatchCastNode : public BindingNode { .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef()); } - bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const; - void SHashReduce(SHashReducer hash_reduce) const; - static constexpr const char* _type_key = "relax.expr.MatchCast"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode); }; @@ -770,17 +661,13 @@ class VarBindingNode : public BindingNode { .def("__s_hash__", &VarBindingNode::SHash); } - bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const; - void SHashReduce(SHashReducer hash_reduce) const; - bool SEqual(const VarBindingNode* other, ffi::TypedFunction equal) const; uint64_t SHash(uint64_t init_hash, ffi::TypedFunction hash) const; static constexpr const char* _type_key = "relax.expr.VarBinding"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode); }; @@ -804,16 +691,9 @@ class BindingBlockNode : public Object { refl::DefaultValue(Span())); } - bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { - return equal(bindings, other->bindings); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "relax.expr.BindingBlock"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object); }; @@ -832,15 +712,8 @@ class DataflowBlockNode : public BindingBlockNode { refl::ObjectDef(); } - bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { - return equal(bindings, other->bindings); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } - static constexpr const char* _type_key = "relax.expr.DataflowBlock"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode); }; @@ -867,20 +740,8 @@ class SeqExprNode : public ExprNode { .def_ro("body", &SeqExprNode::body); } - bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const { - return equal(blocks, other->blocks) && equal(body, other->body) && - equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(blocks); - hash_reduce(body); - hash_reduce(struct_info_); - } - static constexpr const char* _type_key = "relax.expr.SeqExpr"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode); }; @@ -932,20 +793,6 @@ class IfNode : public ExprNode { .def_ro("false_branch", &IfNode::false_branch); } - bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal(cond, other->cond) && equal(true_branch, other->true_branch) && - equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); - hash_reduce(cond); - hash_reduce(true_branch); - hash_reduce(false_branch); - hash_reduce(struct_info_); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const char* _type_key = "relax.expr.If"; TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); @@ -1007,27 +854,8 @@ class FunctionNode : public BaseFuncNode { .def_ro("is_pure", &FunctionNode::is_pure); } - bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal.DefEqual(params, other->params) && equal(body, other->body) && - equal(ret_struct_info, other->ret_struct_info) && equal(is_pure, other->is_pure) && - equal(attrs, other->attrs) && equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); - hash_reduce.DefHash(params); - hash_reduce(body); - hash_reduce(ret_struct_info); - hash_reduce(is_pure); - hash_reduce(attrs); - hash_reduce(struct_info_); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; static constexpr const char* _type_key = "relax.expr.Function"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); }; @@ -1111,18 +939,7 @@ class ExternFuncNode : public BaseFuncNode { refl::ObjectDef().def_ro("global_symbol", &ExternFuncNode::global_symbol); } - bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { - return equal(global_symbol, other->global_symbol) && equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(global_symbol); - hash_reduce(struct_info_); - } - static constexpr const char* _type_key = "relax.expr.ExternFunc"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode); }; diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index cd9b05ab29f0..a897f031a289 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -40,10 +40,6 @@ class ObjectStructInfoNode : public StructInfoNode { refl::ObjectDef(); } - bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } - static constexpr const char* _type_key = "relax.ObjectStructInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode); }; @@ -77,15 +73,6 @@ class PrimStructInfoNode : public StructInfoNode { .def_ro("dtype", &PrimStructInfoNode::dtype); } - bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const { - return equal(value, other->value) && equal(dtype, other->dtype); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - hash_reduce(dtype); - } - static constexpr const char* _type_key = "relax.PrimStructInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode); }; @@ -128,15 +115,6 @@ class ShapeStructInfoNode : public StructInfoNode { .def_ro("ndim", &ShapeStructInfoNode::ndim); } - bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) const { - return equal(values, other->values) && equal(ndim, other->ndim); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(values); - hash_reduce(ndim); - } - static constexpr const char* _type_key = "relax.ShapeStructInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode); }; @@ -207,18 +185,6 @@ class TensorStructInfoNode : public StructInfoNode { .def_ro("ndim", &TensorStructInfoNode::ndim); } - bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) const { - return equal(shape, other->shape) && equal(ndim, other->ndim) && - equal(vdevice, other->vdevice) && equal(dtype, other->dtype); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(shape); - hash_reduce(dtype); - hash_reduce(vdevice); - hash_reduce(ndim); - } - static constexpr const char* _type_key = "relax.TensorStructInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode); }; @@ -267,12 +233,6 @@ class TupleStructInfoNode : public StructInfoNode { refl::ObjectDef().def_ro("fields", &TupleStructInfoNode::fields); } - bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) const { - return equal(fields, other->fields); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } - static constexpr const char* _type_key = "relax.TupleStructInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode); }; @@ -347,18 +307,6 @@ class FuncStructInfoNode : public StructInfoNode { .def_ro("purity", &FuncStructInfoNode::purity); } - bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const { - return equal.DefEqual(params, other->params) && equal(ret, other->ret) && - equal(purity, other->purity) && equal(derive_func, other->derive_func); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(params); - hash_reduce(ret); - hash_reduce(purity); - hash_reduce(derive_func); - } - static constexpr const char* _type_key = "relax.FuncStructInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode); }; diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 5bb7c202c1f7..18fd16af4d2b 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -49,12 +49,6 @@ class ShapeTypeNode : public TypeNode { refl::ObjectDef().def_ro("ndim", &ShapeTypeNode::ndim); } - bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { - return equal(ndim, other->ndim); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(ndim); } - static constexpr const char* _type_key = "relax.ShapeType"; TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); }; @@ -89,15 +83,6 @@ class TensorTypeNode : public TypeNode { .def_ro("dtype", &TensorTypeNode::dtype); } - bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const { - return equal(ndim, other->ndim) && equal(dtype, other->dtype); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(ndim); - hash_reduce(dtype); - } - inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; } inline bool IsUnknownDtype() const { return dtype.is_void(); } @@ -138,10 +123,6 @@ class ObjectTypeNode : public TypeNode { refl::ObjectDef(); } - bool SEqualReduce(const ObjectTypeNode* other, SEqualReducer equal) const { return true; } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } - static constexpr const char* _type_key = "relax.ObjectType"; TVM_DECLARE_FINAL_OBJECT_INFO(ObjectTypeNode, TypeNode); }; @@ -160,10 +141,6 @@ class PackedFuncTypeNode : public TypeNode { refl::ObjectDef(); } - bool SEqualReduce(const PackedFuncTypeNode* other, SEqualReducer equal) const { return true; } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } - static constexpr const char* _type_key = "relax.PackedFuncType"; TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncTypeNode, TypeNode); }; diff --git a/include/tvm/runtime/disco/cuda_ipc_memory.h b/include/tvm/runtime/disco/cuda_ipc_memory.h index ea272052626f..a77e06ccaef5 100644 --- a/include/tvm/runtime/disco/cuda_ipc_memory.h +++ b/include/tvm/runtime/disco/cuda_ipc_memory.h @@ -71,8 +71,6 @@ class CUDAIPCMemoryObj : public Object { int barrier_flag; static constexpr const char* _type_key = "tvm.runtime.disco.cuda_ipc_memory"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; TVM_DECLARE_BASE_OBJECT_INFO(CUDAIPCMemoryObj, Object); }; diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 9929791f31d3..678d36aeceda 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -173,13 +173,8 @@ class TargetNode : public Object { /*! \brief Get the keys for this target as an unordered_set of string */ TVM_DLL std::unordered_set GetLibs() const; - bool SEqualReduce(const TargetNode* other, SEqualReducer equal) const; - void SHashReduce(SHashReducer hash_reduce) const; - static constexpr const char* _type_key = "target.Target"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); private: diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index cb16d2912aa0..3cc988f49e38 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -126,28 +126,7 @@ class BufferNode : public Object { .def_ro("data_alignment", &BufferNode::data_alignment) .def_ro("offset_factor", &BufferNode::offset_factor) .def_ro("buffer_type", &BufferNode::buffer_type) - .def_ro("span", &BufferNode::span); - } - - bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { - // Use DefEqual as buffer can define variables in its semantics, - // skip name as name is not important. - return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && - equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && - equal.DefEqual(axis_separators, other->axis_separators) && - equal.DefEqual(elem_offset, other->elem_offset) && - equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(data); - hash_reduce(dtype); - hash_reduce.DefHash(shape); - hash_reduce.DefHash(strides); - hash_reduce.DefHash(elem_offset); - hash_reduce.DefHash(axis_separators); - hash_reduce(data_alignment); - hash_reduce(buffer_type); + .def_ro("span", &BufferNode::span, refl::AttachFieldFlag::SEqHashIgnore()); } /*! \return preferred index type for this buffer node */ @@ -165,8 +144,7 @@ class BufferNode : public Object { static constexpr const char* _type_key = "tir.Buffer"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); }; @@ -299,16 +277,7 @@ class DataProducerNode : public PrimExprConvertibleNode { */ virtual String GetNameHint() const = 0; - bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const { - // because buffer producer is opaque, we just do pointer equality. - return this == other; - } - - void SHashReduce(SHashReducer hash_reduce) const {} - static constexpr const char* _type_key = "tir.DataProducer"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, PrimExprConvertibleNode); }; diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 3e6a07a6cd6b..1b419b569311 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -60,12 +60,6 @@ class StringImmNode : public PrimExprNode { refl::ObjectDef().def_ro("value", &StringImmNode::value); } - bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { - return equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "tir.StringImm"; TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode); }; @@ -95,15 +89,6 @@ class CastNode : public PrimExprNode { refl::ObjectDef().def_ro("value", &CastNode::value); } - bool SEqualReduce(const CastNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(value); - } - static constexpr const char* _type_key = "tir.Cast"; TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode); }; @@ -136,16 +121,6 @@ class BinaryOpNode : public PrimExprNode { refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); } - bool SEqualReduce(const T* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(a); - hash_reduce(b); - } - TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); }; @@ -326,16 +301,6 @@ class CmpOpNode : public PrimExprNode { refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); } - bool SEqualReduce(const T* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(a); - hash_reduce(b); - } - TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); }; @@ -454,16 +419,6 @@ class AndNode : public PrimExprNode { refl::ObjectDef().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b); } - bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(a); - hash_reduce(b); - } - static constexpr const char* _type_key = "tir.And"; TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode); }; @@ -492,16 +447,6 @@ class OrNode : public PrimExprNode { refl::ObjectDef().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b); } - bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(a); - hash_reduce(b); - } - static constexpr const char* _type_key = "tir.Or"; TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode); }; @@ -528,15 +473,6 @@ class NotNode : public PrimExprNode { refl::ObjectDef().def_ro("a", &NotNode::a); } - bool SEqualReduce(const NotNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(a, other->a); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(a); - } - static constexpr const char* _type_key = "tir.Not"; TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode); }; @@ -576,18 +512,6 @@ class SelectNode : public PrimExprNode { .def_ro("false_value", &SelectNode::false_value); } - bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(condition, other->condition) && - equal(true_value, other->true_value) && equal(false_value, other->false_value); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(condition); - hash_reduce(true_value); - hash_reduce(false_value); - } - static constexpr const char* _type_key = "tir.Select"; TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); }; @@ -631,18 +555,6 @@ class BufferLoadNode : public PrimExprNode { .def_ro("predicate", &BufferLoadNode::predicate); } - bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(buffer, other->buffer) && - equal(indices, other->indices); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(buffer); - hash_reduce(indices); - hash_reduce(predicate); - } - static constexpr const char* _type_key = "tir.BufferLoad"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); @@ -698,17 +610,6 @@ class ProducerLoadNode : public PrimExprNode { .def_ro("indices", &ProducerLoadNode::indices); } - bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(producer, other->producer) && - equal(indices, other->indices); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(producer); - hash_reduce(indices); - } - static constexpr const char* _type_key = "tir.ProducerLoad"; TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode); }; @@ -751,18 +652,6 @@ class RampNode : public PrimExprNode { .def_ro("lanes", &RampNode::lanes); } - bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) && - equal(lanes, other->lanes); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(base); - hash_reduce(stride); - hash_reduce(lanes); - } - static constexpr const char* _type_key = "tir.Ramp"; TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode); }; @@ -793,16 +682,6 @@ class BroadcastNode : public PrimExprNode { .def_ro("lanes", &BroadcastNode::lanes); } - bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(value); - hash_reduce(lanes); - } - static constexpr const char* _type_key = "tir.Broadcast"; TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode); }; @@ -838,18 +717,6 @@ class LetNode : public PrimExprNode { .def_ro("body", &LetNode::body); } - bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) && - equal(value, other->value) && equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce.DefHash(var); - hash_reduce(value); - hash_reduce(body); - } - static constexpr const char* _type_key = "tir.Let"; TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode); }; @@ -886,16 +753,6 @@ class CallNode : public PrimExprNode { refl::ObjectDef().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args); } - bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(op); - hash_reduce(args); - } - static constexpr const char* _type_key = "tir.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); }; @@ -930,17 +787,6 @@ class ShuffleNode : public PrimExprNode { .def_ro("indices", &ShuffleNode::indices); } - bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(vectors, other->vectors) && - equal(indices, other->indices); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(vectors); - hash_reduce(indices); - } - static constexpr const char* _type_key = "tir.Shuffle"; TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode); }; @@ -996,22 +842,8 @@ class CommReducerNode : public Object { .def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore()); } - bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { - return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) && - equal(result, other->result) && equal(identity_element, other->identity_element); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(lhs); - hash_reduce.DefHash(rhs); - hash_reduce(result); - hash_reduce(identity_element); - } - static constexpr const char* _type_key = "tir.CommReducer"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); }; @@ -1057,24 +889,6 @@ class ReduceNode : public PrimExprNode { .def_ro("value_index", &ReduceNode::value_index); } - bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const { - // check axis first so IterVars can define the necessary variables. - return equal(dtype, other->dtype) && equal(axis, other->axis) && - equal(combiner, other->combiner) && equal(source, other->source) && - equal(init, other->init) && equal(condition, other->condition) && - equal(value_index, other->value_index); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(axis); - hash_reduce(combiner); - hash_reduce(source); - hash_reduce(init); - hash_reduce(condition); - hash_reduce(value_index); - } - static constexpr const char* _type_key = "tir.Reduce"; TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 2671f9879101..6ea50e9ae0f0 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -109,21 +109,6 @@ class PrimFuncNode : public BaseFuncNode { .def_ro("body", &PrimFuncNode::body); } - bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { - // visit params and buffer_map first as they contains defs. - return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && - equal(ret_type, other->ret_type) && equal(body, other->body) && - equal(attrs, other->attrs); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(params); - hash_reduce(buffer_map); - hash_reduce(ret_type); - hash_reduce(body); - hash_reduce(attrs); - } - /*! * \brief Return the derived function annotation of this function. * diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 55d083834dc9..518d7602f562 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -161,20 +161,8 @@ class IndexMapNode : public Object { refl::AttachFieldFlag::SEqHashIgnore()); } - bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const { - return equal.DefEqual(initial_indices, other->initial_indices) && - equal(final_indices, other->final_indices); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(initial_indices); - hash_reduce(final_indices); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "tir.IndexMap"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9d31d25c398d..250475c61d90 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -55,8 +55,7 @@ class StmtNode : public Object { static constexpr const char* _type_key = "tir.Stmt"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 15; TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); }; @@ -87,17 +86,6 @@ class LetStmtNode : public StmtNode { .def_ro("body", &LetStmtNode::body); } - bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const { - return equal.DefEqual(var, other->var) && equal(value, other->value) && - equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(var); - hash_reduce(value); - hash_reduce(body); - } - static constexpr const char* _type_key = "tir.LetStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); }; @@ -144,18 +132,6 @@ class AttrStmtNode : public StmtNode { .def_ro("body", &AttrStmtNode::body); } - bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const { - return equal(node, other->node) && equal(attr_key, other->attr_key) && - equal(value, other->value) && equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(node); - hash_reduce(attr_key); - hash_reduce(value); - hash_reduce(body); - } - static constexpr const char* _type_key = "tir.AttrStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); }; @@ -195,17 +171,6 @@ class AssertStmtNode : public StmtNode { .def_ro("body", &AssertStmtNode::body); } - bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const { - return equal(condition, other->condition) && equal(message, other->message) && - equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(condition); - hash_reduce(message); - hash_reduce(body); - } - static constexpr const char* _type_key = "tir.AssertStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); }; @@ -252,18 +217,6 @@ class BufferStoreNode : public StmtNode { .def_ro("predicate", &BufferStoreNode::predicate); } - bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { - return equal(buffer, other->buffer) && equal(value, other->value) && - equal(indices, other->indices); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer); - hash_reduce(value); - hash_reduce(indices); - hash_reduce(predicate); - } - static constexpr const char* _type_key = "tir.BufferStore"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); }; @@ -312,18 +265,6 @@ class BufferRealizeNode : public StmtNode { .def_ro("body", &BufferRealizeNode::body); } - bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const { - return equal(buffer, other->buffer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer); - hash_reduce(bounds); - hash_reduce(condition); - hash_reduce(body); - } - BufferRealizeNode() = default; BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, Span span = Span()) @@ -380,21 +321,6 @@ class AllocateNode : public StmtNode { .def_ro("annotations", &AllocateNode::annotations); } - bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { - return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && - equal(extents, other->extents) && equal(condition, other->condition) && - equal(body, other->body) && equal(annotations, other->annotations); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(buffer_var); - hash_reduce(dtype); - hash_reduce(extents); - hash_reduce(condition); - hash_reduce(body); - hash_reduce(annotations); - } - /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. @@ -410,8 +336,7 @@ class AllocateNode : public StmtNode { TVM_DLL static int64_t ConstantAllocationSize(const Array& extents); static constexpr const char* _type_key = "tir.Allocate"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); }; @@ -470,21 +395,6 @@ class AllocateConstNode : public StmtNode { .def_ro("annotations", &AllocateConstNode::annotations); } - bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const { - return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && - equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) && - equal(annotations, other->annotations); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(buffer_var); - hash_reduce(dtype); - hash_reduce(extents); - hash_reduce(body); - hash_reduce(annotations); - hash_reduce(data); - } - /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. @@ -500,8 +410,6 @@ class AllocateConstNode : public StmtNode { TVM_DLL static int64_t ConstantAllocationSize(const Array& extents); static constexpr const char* _type_key = "tir.AllocateConst"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode); }; @@ -538,15 +446,6 @@ class DeclBufferNode : public StmtNode { .def_ro("body", &DeclBufferNode::body); } - bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const { - return equal(buffer, other->buffer) && equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer); - hash_reduce(body); - } - static constexpr const char* _type_key = "tir.DeclBuffer"; TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode); }; @@ -580,12 +479,6 @@ class SeqStmtNode : public StmtNode { refl::ObjectDef().def_ro("seq", &SeqStmtNode::seq); } - bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const { - return equal(seq, other->seq); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); } - static constexpr const char* _type_key = "tir.SeqStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); }; @@ -606,12 +499,6 @@ class EvaluateNode : public StmtNode { refl::ObjectDef().def_ro("value", &EvaluateNode::value); } - bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { - return equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "tir.Evaluate"; TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); }; @@ -801,17 +688,6 @@ class IfThenElseNode : public StmtNode { .def_ro("else_case", &IfThenElseNode::else_case); } - bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const { - return equal(condition, other->condition) && equal(then_case, other->then_case) && - equal(else_case, other->else_case); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(condition); - hash_reduce(then_case); - hash_reduce(else_case); - } - static constexpr const char* _type_key = "tir.IfThenElse"; TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); }; @@ -906,22 +782,6 @@ class ForNode : public StmtNode { .def_ro("annotations", &ForNode::annotations); } - bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { - return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && - equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) && - equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(loop_var); - hash_reduce(min); - hash_reduce(extent); - hash_reduce(kind); - hash_reduce(body); - hash_reduce(thread_binding); - hash_reduce(annotations); - } - static constexpr const char* _type_key = "tir.For"; TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); }; @@ -964,15 +824,6 @@ class WhileNode : public StmtNode { .def_ro("body", &WhileNode::body); } - bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const { - return equal(condition, other->condition) && equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(condition); - hash_reduce(body); - } - static constexpr const char* _type_key = "tir.While"; TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode); }; @@ -1006,21 +857,10 @@ class BufferRegionNode : public PrimExprConvertibleNode { .def_ro("region", &BufferRegionNode::region); } - bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const { - return equal(buffer, other->buffer) && equal(region, other->region); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer); - hash_reduce(region); - } - TVM_DLL PrimExpr ToPrimExpr() const final; static constexpr const char* _type_key = "tir.BufferRegion"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, PrimExprConvertibleNode); }; @@ -1074,19 +914,8 @@ class MatchBufferRegionNode : public Object { .def_ro("source", &MatchBufferRegionNode::source); } - bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const { - return equal(buffer, other->buffer) && equal(source, other->source); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer); - hash_reduce(source); - } - static constexpr const char* _type_key = "tir.MatchBufferRegion"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object); }; @@ -1164,26 +993,6 @@ class BlockNode : public StmtNode { .def_ro("body", &BlockNode::body); } - bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const { - // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars - return equal.DefEqual(iter_vars, other->iter_vars) && - equal(alloc_buffers, other->alloc_buffers) && - equal(match_buffers, other->match_buffers) && equal(reads, other->reads) && - equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) && - equal(annotations, other->annotations); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(iter_vars); - hash_reduce(alloc_buffers); - hash_reduce(match_buffers); - hash_reduce(reads); - hash_reduce(writes); - hash_reduce(body); - hash_reduce(init); - hash_reduce(annotations); - } - static constexpr const char* _type_key = "tir.Block"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode); }; @@ -1229,17 +1038,6 @@ class BlockRealizeNode : public StmtNode { .def_ro("block", &BlockRealizeNode::block); } - bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const { - return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) && - equal(block, other->block); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(iter_values); - hash_reduce(predicate); - hash_reduce(block); - } - static constexpr const char* _type_key = "tir.BlockRealize"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode); }; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 021b6c301a68..7bf29265ceea 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -68,18 +68,6 @@ class VarNode : public PrimExprNode { .def_ro("type_annotation", &VarNode::type_annotation); } - bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { - if (!equal(dtype, other->dtype)) return false; - if (!equal(type_annotation, other->type_annotation)) return false; - return equal.FreeVarEqualImpl(this, other); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(type_annotation); - hash_reduce.FreeVarHashImpl(this); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; static constexpr const char* _type_key = "tir.Var"; static constexpr const uint32_t _type_child_slots = 1; @@ -296,22 +284,8 @@ class IterVarNode : public PrimExprConvertibleNode { .def_ro("thread_tag", &IterVarNode::thread_tag); } - bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { - return equal(dom, other->dom) && equal.DefEqual(var, other->var) && - equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dom); - hash_reduce.DefHash(var); - hash_reduce(iter_type); - hash_reduce(thread_tag); - } - static constexpr const char* _type_key = "tir.IterVar"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, PrimExprConvertibleNode); }; diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index 1d6b1046d9e1..a8587a2e5ed8 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -387,19 +387,6 @@ class MSCTensorNode : public Object { .def_ro("prims", &MSCTensorNode::prims); } - bool SEqualReduce(const MSCTensorNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(dtype, other->dtype) && equal(shape, other->shape) && - equal(layout, other->layout) && equal(prims, other->prims); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(dtype); - hash_reduce(shape); - hash_reduce(layout); - hash_reduce(prims); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.MSCTensor"; TVM_DECLARE_FINAL_OBJECT_INFO(MSCTensorNode, Object); @@ -501,24 +488,8 @@ class BaseJointNode : public Object { .def_ro("children", &BaseJointNode::children); } - bool SEqualReduce(const BaseJointNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(shared_ref, other->shared_ref) && - equal(attrs, other->attrs) && equal(parents, other->parents) && - equal(children, other->children); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(shared_ref); - hash_reduce(attrs); - hash_reduce(parents); - hash_reduce(children); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.BaseJoint"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseJointNode, Object); }; @@ -587,21 +558,6 @@ class MSCJointNode : public BaseJointNode { .def_ro("weights", &MSCJointNode::weights); } - bool SEqualReduce(const MSCJointNode* other, SEqualReducer equal) const { - return BaseJointNode::SEqualReduce(other, equal) && equal(optype, other->optype) && - equal(scope, other->scope) && equal(inputs, other->inputs) && - equal(outputs, other->outputs) && equal(weights, other->weights); - } - - void SHashReduce(SHashReducer hash_reduce) const { - BaseJointNode::SHashReduce(hash_reduce); - hash_reduce(optype); - hash_reduce(scope); - hash_reduce(inputs); - hash_reduce(outputs); - hash_reduce(weights); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.MSCJoint"; TVM_DECLARE_FINAL_OBJECT_INFO(MSCJointNode, BaseJointNode); @@ -672,15 +628,6 @@ class MSCPrimNode : public BaseJointNode { refl::ObjectDef().def_ro("optype", &MSCPrimNode::optype); } - bool SEqualReduce(const MSCPrimNode* other, SEqualReducer equal) const { - return BaseJointNode::SEqualReduce(other, equal) && equal(optype, other->optype); - } - - void SHashReduce(SHashReducer hash_reduce) const { - BaseJointNode::SHashReduce(hash_reduce); - hash_reduce(optype); - } - static constexpr const char* _type_key = "msc.core.MSCPrim"; TVM_DECLARE_FINAL_OBJECT_INFO(MSCPrimNode, BaseJointNode); }; @@ -749,18 +696,6 @@ class WeightJointNode : public BaseJointNode { .def_ro("friends", &WeightJointNode::friends); } - bool SEqualReduce(const WeightJointNode* other, SEqualReducer equal) const { - return BaseJointNode::SEqualReduce(other, equal) && equal(weight_type, other->weight_type) && - equal(weight, other->weight) && equal(friends, other->friends); - } - - void SHashReduce(SHashReducer hash_reduce) const { - BaseJointNode::SHashReduce(hash_reduce); - hash_reduce(weight_type); - hash_reduce(weight); - hash_reduce(friends); - } - static constexpr const char* _type_key = "msc.core.WeightJoint"; TVM_DECLARE_FINAL_OBJECT_INFO(WeightJointNode, BaseJointNode); }; @@ -825,21 +760,9 @@ class BaseGraphNode : public Object { .def_ro("node_names", &BaseGraphNode::node_names); } - bool SEqualReduce(const BaseGraphNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(nodes, other->nodes) && - equal(node_names, other->node_names); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(nodes); - hash_reduce(node_names); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.BaseGraph"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseGraphNode, Object); }; @@ -929,21 +852,6 @@ class MSCGraphNode : public BaseGraphNode { .def_ro("weight_holders", &MSCGraphNode::weight_holders); } - bool SEqualReduce(const MSCGraphNode* other, SEqualReducer equal) const { - return BaseGraphNode::SEqualReduce(other, equal) && equal(prims, other->prims) && - equal(prim_names, other->prim_names) && equal(input_names, other->input_names) && - equal(output_names, other->output_names) && equal(weight_holders, other->weight_holders); - } - - void SHashReduce(SHashReducer hash_reduce) const { - BaseGraphNode::SHashReduce(hash_reduce); - hash_reduce(prims); - hash_reduce(prim_names); - hash_reduce(input_names); - hash_reduce(output_names); - hash_reduce(weight_holders); - } - static constexpr const char* _type_key = "msc.core.MSCGraph"; TVM_DECLARE_FINAL_OBJECT_INFO(MSCGraphNode, BaseGraphNode); }; @@ -1005,12 +913,6 @@ class WeightGraphNode : public BaseGraphNode { refl::ObjectDef(); } - bool SEqualReduce(const WeightGraphNode* other, SEqualReducer equal) const { - return BaseGraphNode::SEqualReduce(other, equal); - } - - void SHashReduce(SHashReducer hash_reduce) const { BaseGraphNode::SHashReduce(hash_reduce); } - static constexpr const char* _type_key = "msc.core.WeightGraph"; TVM_DECLARE_FINAL_OBJECT_INFO(WeightGraphNode, BaseGraphNode); }; diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h index 291a0e196a24..f0a5dc9937b8 100644 --- a/src/contrib/msc/core/ir/plugin.h +++ b/src/contrib/msc/core/ir/plugin.h @@ -278,18 +278,6 @@ class PluginAttrNode : public Object { .def_ro("describe", &PluginAttrNode::describe); } - bool SEqualReduce(const PluginAttrNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(type, other->type) && - equal(default_value, other->default_value) && equal(describe, other->describe); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(type); - hash_reduce(default_value); - hash_reduce(describe); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.PluginAttr"; TVM_DECLARE_FINAL_OBJECT_INFO(PluginAttrNode, Object); @@ -359,19 +347,6 @@ class PluginTensorNode : public Object { .def_ro("describe", &PluginTensorNode::describe); } - bool SEqualReduce(const PluginTensorNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(dtype, other->dtype) && equal(ndim, other->ndim) && - equal(device, other->device) && equal(describe, other->describe); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(dtype); - hash_reduce(ndim); - hash_reduce(device); - hash_reduce(describe); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.PluginTensor"; TVM_DECLARE_FINAL_OBJECT_INFO(PluginTensorNode, Object); @@ -442,20 +417,6 @@ class PluginExternNode : public Object { .def_ro("describe", &PluginExternNode::describe); } - bool SEqualReduce(const PluginExternNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(header, other->header) && - equal(source, other->source) && equal(lib, other->lib) && - equal(describe, other->describe); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(header); - hash_reduce(source); - hash_reduce(lib); - hash_reduce(describe); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.PluginExtern"; TVM_DECLARE_FINAL_OBJECT_INFO(PluginExternNode, Object); @@ -546,28 +507,6 @@ class PluginNode : public Object { .def_ro("options", &PluginNode::options); } - bool SEqualReduce(const PluginNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(version, other->version) && - equal(describe, other->describe) && equal(attrs, other->attrs) && - equal(inputs, other->inputs) && equal(outputs, other->outputs) && - equal(buffers, other->buffers) && equal(externs, other->externs) && - equal(support_dtypes, other->support_dtypes) && equal(options, other->options); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(version); - hash_reduce(describe); - hash_reduce(attrs); - hash_reduce(inputs); - hash_reduce(outputs); - hash_reduce(buffers); - hash_reduce(externs); - hash_reduce(externs); - hash_reduce(support_dtypes); - hash_reduce(options); - } - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "msc.core.Plugin"; TVM_DECLARE_FINAL_OBJECT_INFO(PluginNode, Object); diff --git a/src/ir/module.cc b/src/ir/module.cc index f17874724676..6ff17c78e618 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -57,45 +57,6 @@ IRModule::IRModule(tvm::Map functions, SourceMap source_map data_ = std::move(n); } -bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (!equal(this->attrs, other->attrs, [](const auto& path) { return path->Attr("attrs"); })) { - return false; - } - - if (this->global_infos.size() != other->global_infos.size()) return false; - for (const auto& kv : this->global_infos) { - if (!equal(kv.second, other->global_infos[kv.first])) return false; - } - - if (functions.size() != other->functions.size()) return false; - // Update GlobalVar remap - if (equal.IsPathTracingEnabled()) { - if (functions.size() != other->functions.size()) { - return false; - } - } - - // Define remaps for GlobalVar and GlobalTypeVar based on their - // string name. Early bail-out is only performed when path-tracing - // is disabled, as the later equality checks on the member variables - // will provide better error messages. - for (const auto& gv : this->GetGlobalVars()) { - if (other->ContainGlobalVar(gv->name_hint)) { - if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; - } else if (!equal.IsPathTracingEnabled()) { - return false; - } - } - - // Checking functions and type definitions - if (!equal(this->functions, other->functions, - [](const auto& path) { return path->Attr("functions"); })) { - return false; - } - - return true; -} - bool IRModuleNode::SEqual(const IRModuleNode* other, ffi::TypedFunction equal) const { if (!equal(this->attrs, other->attrs, false, "attrs")) { @@ -120,37 +81,6 @@ bool IRModuleNode::SEqual(const IRModuleNode* other, return true; } -void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { - using KV = std::tuple; - // hash the functions. - std::vector temp; - - auto reduce_temp = [&]() { - // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), - [](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) < std::get<0>(rhs); }); - - hash_reduce(static_cast(temp.size())); - // Defhash the GlobalVar/GlobalTypeVar - for (size_t i = 0; i < temp.size(); ++i) { - hash_reduce.DefHash(std::get<1>(temp[i])); - } - // hash the name and content - for (size_t i = 0; i < temp.size(); ++i) { - hash_reduce(std::get<0>(temp[i])); - hash_reduce(std::get<2>(temp[i])); - } - }; - - for (const auto& kv : this->functions) { - temp.emplace_back(kv.first->name_hint, kv.first, kv.second); - } - reduce_temp(); - - hash_reduce(this->attrs); - hash_reduce(this->global_infos); -} - uint64_t IRModuleNode::SHash(uint64_t init_hash, ffi::TypedFunction hash) const { uint64_t hash_value = init_hash; diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 8901d5fd8d57..501d55b8efa6 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -27,8 +27,6 @@ #include -#include "../node/ndarray_hash_equal.h" - namespace tvm { namespace meta_schedule { @@ -58,7 +56,9 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { auto anchor_block = tir::FindAnchorBlock(mod); if (anchor_block) { - return SHashHandlerIgnoreNDArray().Hash(GetRef(anchor_block), false); + return ffi::reflection::StructuralHash::Hash(GetRef(anchor_block), + /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } return ModuleEqualityIgnoreNDArray().Hash(mod); } diff --git a/src/node/ndarray_hash_equal.h b/src/node/ndarray_hash_equal.h deleted file mode 100644 index b5639f524b2b..000000000000 --- a/src/node/ndarray_hash_equal.h +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_NODE_NDARRAY_HASH_EQUAL_H_ -#define TVM_NODE_NDARRAY_HASH_EQUAL_H_ - -#include -#include - -namespace tvm { - -class SEqualReducer; -class SHashReducer; - -/*! \brief A custom hash handler that ignores NDArray raw data. */ -class SHashHandlerIgnoreNDArray : public SHashHandlerDefault { - protected: - void DispatchSHash(const ObjectRef& object, bool map_free_vars) override; -}; - -/*! - * \brief Test two NDArrays for equality. - * \param lhs The left operand. - * \param rhs The right operand. - * \param equal A Reducer class to reduce the structural equality result of two objects. - * See tvm/node/structural_equal.h. - * \param compare_data Whether or not to consider ndarray raw data in the equality testing. - * \return The equality testing result. - */ -bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, - SEqualReducer equal, bool compare_data); - -/*! - * \brief Hash NDArray. - * \param arr The NDArray to compute the hash for. - * \param hash_reduce A Reducer class to reduce the structural hash value. - * See tvm/node/structural_hash.h. - * \param hash_data Whether or not to hash ndarray raw data. - */ -void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, bool hash_data); - -} // namespace tvm - -#endif // TVM_NODE_NDARRAY_HASH_EQUAL_H_ diff --git a/src/node/reflection.cc b/src/node/reflection.cc index fe63d28426ba..ffe15ca2abb3 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -184,23 +184,4 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("node.MakeNode", MakeNode); }); -Optional GetAttrKeyByAddress(const Object* object, const void* attr_address) { - const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(object->type_index()); - if (tinfo->metadata != nullptr) { - Optional result; - // visit fields with the new reflection - ffi::reflection::ForEachFieldInfoWithEarlyStop(tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = ffi::reflection::FieldGetter(field_info)(object); - const void* field_addr = reinterpret_cast(object) + field_info->offset; - if (field_addr == attr_address) { - result = String(field_info->name); - return true; - } - return false; - }); - return result; - } - return std::nullopt; -} - } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 5987692a0f78..186f50947230 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -33,8 +33,6 @@ #include #include -#include "ndarray_hash_equal.h" - namespace tvm { TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode); @@ -55,552 +53,6 @@ ObjectPathPair::ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path) { data_ = make_object(std::move(lhs_path), std::move(rhs_path)); } -// Define the dispatch function here since primary user is in this file. -bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, - SEqualReducer equal) const { - uint32_t tindex = self->type_index(); - if (tindex >= fsequal_reduce_.size() || fsequal_reduce_[tindex] == nullptr) { - LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey() - << " is not registered via TVM_REGISTER_NODE_TYPE." - << " Did you forget to set _type_has_method_sequal_reduce=true?"; - } - return fsequal_reduce_[tindex](self, other, equal); -} - -namespace { -ObjectPath GetAttrPath(const ObjectRef& obj, const void* attr_address, const ObjectPath& path) { - Optional attr_key = GetAttrKeyByAddress(obj.get(), attr_address); - return path->Attr(attr_key); -} -} // namespace - -struct SEqualReducer::PathTracingData { - ObjectPathPair current_paths; - ObjectRef lhs_object; - ObjectRef rhs_object; - Optional* first_mismatch; - - ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const { - ObjectPath lhs_attr_path = GetAttrPath(lhs_object, &lhs, current_paths->lhs_path); - ObjectPath rhs_attr_path = GetAttrPath(rhs_object, &rhs, current_paths->rhs_path); - return ObjectPathPair(lhs_attr_path, rhs_attr_path); - } -}; - -bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { - if (tracing_data_ == nullptr) { - // Fast path: no tracing - return handler_->SEqualReduce(lhs, rhs, map_free_vars_, std::nullopt); - } - return ObjectAttrsEqual(lhs, rhs, map_free_vars_, nullptr); -} - -bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { - if (tracing_data_ == nullptr) { - // Fast path: no tracing - return handler_->SEqualReduce(lhs, rhs, true, std::nullopt); - } - return ObjectAttrsEqual(lhs, rhs, true, nullptr); -} - -/* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch( - const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) { - if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) { - ObjectPath lhs_attr_path = - GetAttrPath(tracing_data->lhs_object, lhs_address, tracing_data->current_paths->lhs_path); - ObjectPath rhs_attr_path = - GetAttrPath(tracing_data->rhs_object, rhs_address, tracing_data->current_paths->rhs_path); - - *tracing_data->first_mismatch = ObjectPathPair(lhs_attr_path, rhs_attr_path); - } -} - -template -/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data, - Optional paths) { - if (BaseValueEqual()(lhs, rhs)) { - return true; - } - - if (tracing_data && !tracing_data->first_mismatch->defined()) { - if (paths) { - *tracing_data->first_mismatch = paths.value(); - } else { - GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); - } - } - return false; -} - -bool SEqualReducer::operator()(const double& lhs, const double& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::operator()(const Optional& lhs, const Optional& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::operator()(const Optional& lhs, const Optional& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::operator()(const int& lhs, const int& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::operator()(const bool& lhs, const bool& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs, - Optional paths) const { - return CompareAttributeValues(lhs, rhs, tracing_data_, paths); -} - -bool SEqualReducer::AnyEqual(const ffi::Any& lhs, const ffi::Any& rhs, - Optional paths) const { - auto record_mismatch = [&]() { - if (tracing_data_ && !tracing_data_->first_mismatch->defined()) { - if (paths) { - *tracing_data_->first_mismatch = paths.value(); - } - } - }; - if (lhs.type_index() != rhs.type_index()) { - record_mismatch(); - return false; - } - if (lhs.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - if (paths) { - return operator()(lhs.cast(), rhs.cast(), paths.value()); - } else { - ObjectRef lhs_obj = lhs.cast(); - ObjectRef rhs_obj = rhs.cast(); - bool result = operator()(lhs_obj, rhs_obj); - return result; - } - } - - if (ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(lhs)->v_uint64 == - ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(rhs)->v_uint64) { - return true; - } - record_mismatch(); - return false; -} - -bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, - const void* rhs_address, Optional paths) const { - if (lhs == rhs) { - return true; - } - - if (tracing_data_ && !tracing_data_->first_mismatch->defined()) { - if (paths) { - *tracing_data_->first_mismatch = paths.value(); - } else { - GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data_); - } - } - - return false; -} - -const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const { - ICHECK(tracing_data_ != nullptr) - << "GetCurrentObjectPaths() can only be called when path tracing is enabled"; - return tracing_data_->current_paths; -} - -void SEqualReducer::RecordMismatchPaths(const ObjectPathPair& paths) const { - ICHECK(tracing_data_ != nullptr) - << "RecordMismatchPaths() can only be called when path tracing is enabled"; - if (!tracing_data_->first_mismatch->defined()) { - *tracing_data_->first_mismatch = paths; - } -} - -bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const ObjectPathPair* paths) const { - if (tracing_data_ == nullptr) { - // Fast path: no tracing - return handler_->SEqualReduce(lhs, rhs, map_free_vars, std::nullopt); - } - - // Slow path: tracing object paths for better error reporting - ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths; - - if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) { - return true; - } else { - if (!tracing_data_->first_mismatch->defined()) { - *tracing_data_->first_mismatch = new_paths; - } - return false; - } -} - -/*! - * \brief A non recursive stack based SEqual handler that can remaps vars. - * - * This handler pushs the Object equality cases into a stack, and - * traverses the stack to expand the necessary children that need to be checked. - * - * The order of SEqual being called is the same as the order as if we - * eagerly do recursive calls in SEqualReduce. - */ -class SEqualHandlerDefault::Impl { - public: - Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional* first_mismatch, - bool defer_fails) - : parent_(parent), - assert_mode_(assert_mode), - first_mismatch_(first_mismatch), - defer_fails_(defer_fails) {} - - bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const Optional& current_paths) { - // We cannot use check lhs.same_as(rhs) to check equality. - // if we choose to enable var remapping. - // - // Counter example below (%x, %y) are shared vars - // between the two functions(possibly before/after rewriting). - // - // - function0: fn (%x, %y) { %x + %y } - // - function1. fn (%y, %x) { %x + %y } - // - // Because we choose to enable var remapping, - // %x is mapped to %y, and %y is mapped to %x, - // the body of the function no longer means the same thing. - // - // Take away: We can either choose only compare Var by address, - // in which case we can use same_as for quick checking, - // or we have to run deep comparison and avoid to use same_as checks. - auto run = [=]() { - std::optional early_result = [&]() -> std::optional { - if (!lhs.defined() && !rhs.defined()) return true; - if (!lhs.defined() && rhs.defined()) return false; - if (!rhs.defined() && lhs.defined()) return false; - if (lhs->type_index() != rhs->type_index()) return false; - auto it = equal_map_lhs_.find(lhs); - if (it != equal_map_lhs_.end()) { - return it->second.same_as(rhs); - } - if (equal_map_rhs_.count(rhs)) return false; - - return std::nullopt; - }(); - - if (early_result.has_value()) { - if (early_result.value()) { - return true; - } else if (IsPathTracingEnabled() && IsFailDeferralEnabled() && current_paths.defined()) { - DeferFail(current_paths.value()); - return true; - } else { - return false; - } - } - - // need to push to pending tasks in this case - pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths); - return true; - }; - return CheckResult(run(), lhs, rhs, current_paths); - } - - void DeferFail(const ObjectPathPair& mismatch_paths) { - pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths); - } - - bool IsFailDeferralEnabled() { return defer_fails_; } - - void MarkGraphNode() { - // need to push to pending tasks in this case - ICHECK(!allow_push_to_stack_ && !task_stack_.empty()); - task_stack_.back().graph_equal = true; - } - - ObjectRef MapLhsToRhs(const ObjectRef& lhs) { - auto it = equal_map_lhs_.find(lhs); - if (it != equal_map_lhs_.end()) return it->second; - return lhs; - } - - // Function that implements actual equality check. - bool Equal(const ffi::Any& lhs, const ffi::Any& rhs, bool map_free_vars) { - task_stack_.clear(); - pending_tasks_.clear(); - equal_map_lhs_.clear(); - equal_map_rhs_.clear(); - root_lhs_ = lhs; - root_rhs_ = rhs; - Optional current_paths; - if (IsPathTracingEnabled()) { - auto root_path = ObjectPath::Root(); - current_paths = ObjectPathPair(root_path, root_path); - } - if (lhs.type_index() != rhs.type_index()) { - return CheckResult(false, lhs, rhs, current_paths); - } - - if (lhs.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - if (ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(lhs)->v_uint64 == - ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(rhs)->v_uint64) { - return true; - } - return CheckResult(false, lhs, rhs, current_paths); - } - - // normal object ref path - if (!SEqualReduce(lhs.cast(), rhs.cast(), map_free_vars, current_paths)) { - return false; - } - - ICHECK_EQ(pending_tasks_.size(), 1U); - ICHECK(allow_push_to_stack_); - task_stack_.emplace_back(std::move(pending_tasks_.back())); - pending_tasks_.clear(); - return RunTasks(); - } - - // The default equal as registered in the structural equal vtable. - bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const Optional& current_paths) { - auto compute = [=]() { - ICHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index()); - // skip entries that already have equality maps. - auto it = equal_map_lhs_.find(lhs); - if (it != equal_map_lhs_.end()) { - return it->second.same_as(rhs); - } - if (equal_map_rhs_.count(rhs)) return false; - - if (!IsPathTracingEnabled()) { - return vtable_->SEqualReduce(lhs.get(), rhs.get(), - SEqualReducer(parent_, nullptr, map_free_vars)); - } else { - PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_}; - return vtable_->SEqualReduce(lhs.get(), rhs.get(), - SEqualReducer(parent_, &tracing_data, map_free_vars)); - } - }; - return CheckResult(compute(), lhs, rhs, current_paths); - } - - protected: - // Check the result. - bool CheckResult(bool result, const Any& lhs, const Any& rhs, - const Optional& current_paths) { - if (IsPathTracingEnabled() && !result && !first_mismatch_->defined()) { - *first_mismatch_ = current_paths; - } - if (assert_mode_ && !result) { - std::ostringstream oss; - oss << "ValueError: StructuralEqual check failed, caused by lhs"; - if (first_mismatch_->defined()) { - oss << " at " << first_mismatch_->value()->lhs_path; - if (root_lhs_.has_value()) { - PrinterConfig cfg; - cfg->syntax_sugar = false; - cfg->path_to_underline.push_back(first_mismatch_->value()->lhs_path); - // The TVMScriptPrinter::Script will fallback to Repr printer, - // if the root node to print is not supported yet, - // e.g. Relax nodes, ArrayObj, MapObj, etc. - oss << ":" << std::endl - << TVMScriptPrinter::Script(root_lhs_.value().cast(), cfg); - } - } else { - oss << ":" << std::endl << lhs; - } - oss << std::endl << "and rhs"; - if (first_mismatch_->defined()) { - oss << " at " << first_mismatch_->value()->rhs_path; - if (root_rhs_.has_value()) { - PrinterConfig cfg; - cfg->syntax_sugar = false; - cfg->path_to_underline.push_back(first_mismatch_->value()->rhs_path); - // The TVMScriptPrinter::Script will fallback to Repr printer, - // if the root node to print is not supported yet, - // e.g. Relax nodes, ArrayObj, MapObj, etc. - oss << ":" << std::endl - << TVMScriptPrinter::Script(root_rhs_.value().cast(), cfg); - } - } else { - oss << ":" << std::endl << rhs; - } - LOG(FATAL) << oss.str(); - } - return result; - } - /*! - * \brief Run tasks until the stack reaches the stack begin - * \param stack_begin The expected beginning of the stack. - * \return The checks we encountered throughout the process. - */ - bool RunTasks() { - while (task_stack_.size() != 0) { - // Caution: entry becomes invalid when the stack changes - auto& entry = task_stack_.back(); - - if (entry.force_fail) { - return CheckResult(false, entry.lhs, entry.rhs, entry.current_paths); - } - - if (entry.children_expanded) { - // When all the children has expanded and visited. - // This means all the condition checks for - // the current entry has been passed - // We can safely mark lhs and rhs as equal to each other. - auto it = equal_map_lhs_.find(entry.lhs); - if (it != equal_map_lhs_.end()) { - ICHECK(it->second.same_as(entry.rhs)); - } - // create the map if the quality is graph equal. - if (entry.graph_equal) { - equal_map_lhs_[entry.lhs] = entry.rhs; - equal_map_rhs_[entry.rhs] = entry.lhs; - } - task_stack_.pop_back(); - } else { - // mark before expand - // Important: because entry becomes invalid when stack changes. - entry.children_expanded = true; - // Expand the objects - // The SEqual of the object can call into this->SEqualReduce - // which populates the pending tasks. - ICHECK_EQ(pending_tasks_.size(), 0U); - allow_push_to_stack_ = false; - if (!parent_->DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars, - entry.current_paths)) - return false; - allow_push_to_stack_ = true; - // Push pending tasks in reverse order, so earlier tasks get to - // expand first in the stack - while (pending_tasks_.size() != 0) { - task_stack_.emplace_back(std::move(pending_tasks_.back())); - pending_tasks_.pop_back(); - } - } - } - return true; - } - - private: - /*! \brief Pending reduce tasks. */ - struct Task { - /*! \brief The lhs operand to be compared. */ - ObjectRef lhs; - /*! \brief The rhs operand to be compared. */ - ObjectRef rhs; - /*! \brief If path tracing is enabled, paths taken so far from the root to `lhs` and `rhs` - * objects. */ - Optional current_paths; - /*! \brief The map free var argument. */ - bool map_free_vars; - /*! \brief Whether the children has been expanded via SEqualReduce */ - bool children_expanded{false}; - /*! \brief whether the task is about graph equality(need remap). */ - bool graph_equal{false}; - /*! \brief whether the task should return "false" without actually comparing anything */ - bool force_fail{false}; - - Task() = default; - Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars, Optional current_paths) - : lhs(lhs), - rhs(rhs), - current_paths(std::move(current_paths)), - map_free_vars(map_free_vars) {} - - struct ForceFailTag {}; // dispatch tag for the constructor below - Task(ForceFailTag, const ObjectPathPair& current_paths) - : current_paths(current_paths), force_fail(true) {} - }; - - bool IsPathTracingEnabled() const { return first_mismatch_ != nullptr; } - - // The owner of this impl - SEqualHandlerDefault* parent_; - // list of pending tasks to be pushed to the stack. - std::vector pending_tasks_; - // Internal task stack to executed the task. - std::vector task_stack_; - // Whether we allow push to stack. - bool allow_push_to_stack_{true}; - // If in assert mode, must return true, and will throw error otherwise. - bool assert_mode_{false}; - // Location to store the paths to the first detected mismatch, or nullptr to disable path - // tracing. - Optional* first_mismatch_; - // reflection vtable - ReflectionVTable* vtable_ = ReflectionVTable::Global(); - // map from lhs to rhs - std::unordered_map equal_map_lhs_; - // map from rhs to lhs - std::unordered_map equal_map_rhs_; - // root lhs for result printing - Optional root_lhs_; - // root rhs for result printing - Optional root_rhs_; - // whether to defer fails - bool defer_fails_; -}; - -SEqualHandlerDefault::SEqualHandlerDefault(bool assert_mode, - Optional* first_mismatch, - bool defer_fails) { - impl = new Impl(this, assert_mode, first_mismatch, defer_fails); -} - -SEqualHandlerDefault::~SEqualHandlerDefault() { delete impl; } - -bool SEqualHandlerDefault::SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, - bool map_free_vars, - const Optional& current_paths) { - return impl->SEqualReduce(lhs, rhs, map_free_vars, current_paths); -} - -void SEqualHandlerDefault::DeferFail(const ObjectPathPair& mismatch_paths) { - impl->DeferFail(mismatch_paths); -} - -bool SEqualHandlerDefault::IsFailDeferralEnabled() { return impl->IsFailDeferralEnabled(); } - -ObjectRef SEqualHandlerDefault::MapLhsToRhs(const ObjectRef& lhs) { return impl->MapLhsToRhs(lhs); } - -void SEqualHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); } - -bool SEqualHandlerDefault::Equal(const Any& lhs, const Any& rhs, bool map_free_vars) { - return impl->Equal(lhs, rhs, map_free_vars); -} - -bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, - bool map_free_vars, - const Optional& current_paths) { - return impl->DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); -} - Optional ObjectPathPairFromAccessPathPair( Optional src) { if (!src.has_value()) return std::nullopt; @@ -699,42 +151,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs, +bool StructuralEqual::operator()(const ffi::Any& lhs, const ffi::Any& rhs, bool map_free_params) const { return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_params); } - -bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, - SEqualReducer equal, bool compare_data) { - if (lhs == rhs) return true; - - auto ldt = lhs->dtype; - auto rdt = rhs->dtype; - ICHECK_EQ(lhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK_EQ(rhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(*lhs)) << "Can only compare contiguous tensor"; - ICHECK(runtime::IsContiguous(*rhs)) << "Can only compare contiguous tensor"; - - if (lhs->ndim != rhs->ndim) return false; - for (int i = 0; i < lhs->ndim; ++i) { - if (!equal(lhs->shape[i], rhs->shape[i])) return false; - } - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t data_size = runtime::GetDataSize(*lhs); - if (compare_data) { - return std::memcmp(lhs->data, rhs->data, data_size) == 0; - } else { - return true; - } - } else { - return false; - } -} - -bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, - SEqualReducer equal) { - return NDArrayEqual(lhs, rhs, equal, true); -} - } // namespace tvm diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 6fb8d3678454..3a5f1de04165 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -37,262 +37,9 @@ #include "../support/base64.h" #include "../support/str_escape.h" #include "../support/utils.h" -#include "ndarray_hash_equal.h" namespace tvm { -// Define the dispatch function here since primary user is in this file. -void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) const { - uint32_t tindex = self->type_index(); - if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) { - LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey() - << " is not registered via TVM_REGISTER_NODE_TYPE"; - } - fshash_reduce_[tindex](self, reducer); -} - -// Hash handler that handles free vars -// by assigning an unique counter in the order of their occurrence. -// -// This algorithm depends on the determinism of the traversal of SHash function. -// In particular, when we traverse unordered_map, we should first sort -// the entries by keys(or hash of keys) before traversing. - -class SHashHandlerDefault::Impl { - public: - explicit Impl(SHashHandlerDefault* parent) : parent_(parent) {} - - /*! \brief Pending reduce tasks. */ - struct Task { - /*! - * \brief The object operand to be hashed. - * If the object is nullptr, then the reduced hash is already set - * the correct value. - */ - ObjectRef object; - /*! \brief The partially reduce hash value.*/ - uint64_t reduced_hash; - /*! \brief The expected location in the result stack. */ - uint64_t result_stack_index = std::numeric_limits::max(); - /*! \brief Whether the children has been expanded via SEqualReduce */ - bool children_expanded{false}; - /*! \brief Whether the node is graph node. */ - bool graph_node_hash{false}; - /*! \brief whether to map the free variables. */ - bool map_free_vars; - - Task() = default; - explicit Task(ObjectRef object, uint64_t reduced_hash, bool map_free_vars) - : object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {} - }; - - void MarkGraphNode() { - // need to push to pending tasks in this case - ICHECK(!allow_push_to_stack_ && !task_stack_.empty()); - task_stack_.back().graph_node_hash = true; - } - - bool LookupHashedValue(const ObjectRef& key, uint64_t* hash_value) { - auto it = hash_memo_.find(key); - if (it != hash_memo_.end()) { - hash_value[0] = it->second; - return true; - } - return false; - } - - void SHashReduceHashedValue(uint64_t hashed_value) { - pending_tasks_.emplace_back(Task(ObjectRef(nullptr), hashed_value, false)); - } - - void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) { - ICHECK(!hash_memo_.count(GetRef(var))); - if (map_free_vars) { - // use counter value. - uint64_t value = std::hash()(free_var_counter_++); - pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); - } else { - // use pointer hash - uint64_t value = std::hash()(var); - pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); - } - } - - void SHashReduce(const ObjectRef& object, bool map_free_vars) { - // Directly push the result - // Note: it is still important to push the result to pending tasks - // so that the reduction order of hash values stays the same. - if (!object.defined()) { - pending_tasks_.emplace_back(Task(ObjectRef(nullptr), 0, false)); - return; - } - auto it = hash_memo_.find(object); - if (it != hash_memo_.end()) { - pending_tasks_.emplace_back(Task(ObjectRef(nullptr), it->second, false)); - } else { - // Push a pending task with initial value. - pending_tasks_.emplace_back(Task(object, object->GetTypeKeyHash(), map_free_vars)); - } - } - - uint64_t Hash(const Any& value, bool map_free_vars) { - if (value.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - return BaseValueHash().HashPODValueInAny(value); - } - ObjectRef object = value.cast(); - ICHECK_EQ(task_stack_.size(), 0U); - ICHECK_EQ(pending_tasks_.size(), 0U); - ICHECK_EQ(result_stack_.size(), 0U); - - this->SHashReduce(object, map_free_vars); - ICHECK_EQ(pending_tasks_.size(), 1U); - ICHECK(allow_push_to_stack_); - task_stack_.emplace_back(std::move(pending_tasks_.back())); - pending_tasks_.clear(); - - this->RunTasks(); - - ICHECK_EQ(result_stack_.size(), 1U); - uint64_t ret = result_stack_.back(); - result_stack_.pop_back(); - return ret; - } - - void DispatchSHash(const ObjectRef& object, bool map_free_vars) { - ICHECK(object.defined()); - vtable_->SHashReduce(object.get(), SHashReducer(parent_, map_free_vars)); - } - - protected: - /*! - * \brief Pop the top entry of the task stack and push the hash into the result stack. - */ - void PopTaskStack() { - const auto& entry = task_stack_.back(); - result_stack_.push_back(entry.reduced_hash); - task_stack_.pop_back(); - } - /*! - * \brief Compute the reduced hash value for the task. - * \param task The indicated task. - */ - uint64_t ReduceHash(const Task& task) { - uint64_t stack_begin = task.result_stack_index; - ICHECK_LE(stack_begin, result_stack_.size()); - - // combine in the reverse order of the stack. - uint64_t reduced_hash = task.reduced_hash; - for (uint32_t i = result_stack_.size(); i != stack_begin; --i) { - reduced_hash = support::HashCombine(reduced_hash, result_stack_[i - 1]); - } - result_stack_.resize(stack_begin); - return reduced_hash; - } - // run the tasks. - void RunTasks() { - while (task_stack_.size() != 0) { - // Caution: entry becomes invalid when the stack changes - auto& entry = task_stack_.back(); - if (entry.children_expanded) { - // reduce hash - entry.reduced_hash = ReduceHash(entry); - // When all the children has expanded and visited. - // entry.reduced_hash contains the reduced hash result. - auto it = hash_memo_.find(entry.object); - if (it != hash_memo_.end()) { - // use the pre-computed hash for the object. - entry.reduced_hash = it->second; - } else { - // Append the graph node counter to the hash - // so that we can distinguish DAG from trees. - if (entry.graph_node_hash) { - entry.reduced_hash = support::HashCombine(entry.reduced_hash, - std::hash()(graph_node_counter_++)); - } - hash_memo_[entry.object] = entry.reduced_hash; - } - // send value to parent. - this->PopTaskStack(); - } else if (!entry.object.defined()) { - // Directly send value to parent - this->PopTaskStack(); - } else { - // check if there are already hash for object. - auto it = hash_memo_.find(entry.object); - if (it != hash_memo_.end()) { - entry.reduced_hash = it->second; - this->PopTaskStack(); - } else { - // NOTE: important to modify entry before visit. - // as entry becomes invalid after we change the stack. - entry.children_expanded = true; - entry.result_stack_index = result_stack_.size(); - - ICHECK_EQ(pending_tasks_.size(), 0U); - allow_push_to_stack_ = false; - // dispatch hash, reduce to the current slot. - parent_->DispatchSHash(entry.object, entry.map_free_vars); - allow_push_to_stack_ = true; - // Move pending tasks to the stack until the marked point. - while (pending_tasks_.size() != 0) { - task_stack_.emplace_back(std::move(pending_tasks_.back())); - pending_tasks_.pop_back(); - } - } - } - } - } - - private: - // The owner of this impl - SHashHandlerDefault* parent_; - // free var counter. - uint32_t free_var_counter_{0}; - // graph node counter. - uint32_t graph_node_counter_{0}; - // record current stack top - bool allow_push_to_stack_{true}; - // list of pending tasks to be pushed to the stack. - std::vector pending_tasks_; - // Internal task stack to executed the task - std::vector task_stack_; - // Internal stack to store the result popped from the task stack. - std::vector result_stack_; - // reflection vtable - ReflectionVTable* vtable_ = ReflectionVTable::Global(); - // map from lhs to rhs - std::unordered_map hash_memo_; -}; - -SHashHandlerDefault::SHashHandlerDefault() { impl = new Impl(this); } -SHashHandlerDefault::~SHashHandlerDefault() { delete impl; } - -void SHashHandlerDefault::SHashReduceHashedValue(uint64_t hashed_value) { - return impl->SHashReduceHashedValue(hashed_value); -} - -void SHashHandlerDefault::SHashReduce(const ObjectRef& key, bool map_free_vars) { - impl->SHashReduce(key, map_free_vars); -} - -void SHashHandlerDefault::SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) { - impl->SHashReduceFreeVar(var, map_free_vars); -} - -bool SHashHandlerDefault::LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) { - return impl->LookupHashedValue(key, hashed_value); -} - -void SHashHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); } - -uint64_t SHashHandlerDefault::Hash(const Any& object, bool map_free_vars) { - return impl->Hash(object, map_free_vars); -} - -void SHashHandlerDefault::DispatchSHash(const ObjectRef& key, bool map_free_vars) { - impl->DispatchSHash(key, map_free_vars); -} - TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.StructuralHash", @@ -301,56 +48,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -uint64_t StructuralHash::operator()(const ObjectRef& object) const { +uint64_t StructuralHash::operator()(const ffi::Any& object) const { return ffi::reflection::StructuralHash::Hash(object, false); } -void SHashHandlerIgnoreNDArray::DispatchSHash(const ObjectRef& object, bool map_free_vars) { - ICHECK(object.defined()); - if (auto ndarray = object.as()) { - SHashReducer hash_reduce(this, map_free_vars); - NDArrayHash(ndarray, &hash_reduce, false); - } else { - SHashHandlerDefault::DispatchSHash(object, map_free_vars); - } -} - -// SEQualReduce traits for runtime containers. -struct StringObjTrait { - static void SHashReduce(const ffi::StringObj* key, SHashReducer hash_reduce) { - hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, key->size)); - } - - static bool SEqualReduce(const ffi::StringObj* lhs, const ffi::StringObj* rhs, - SEqualReducer equal) { - if (lhs == rhs) return true; - if (lhs->size != rhs->size) return false; - if (lhs->data == rhs->data) return true; - return std::memcmp(lhs->data, rhs->data, lhs->size) == 0; - } -}; - -struct BytesObjTrait { - static void SHashReduce(const ffi::BytesObj* key, SHashReducer hash_reduce) { - hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, key->size)); - } - - static bool SEqualReduce(const ffi::BytesObj* lhs, const ffi::BytesObj* rhs, - SEqualReducer equal) { - if (lhs == rhs) return true; - if (lhs->size != rhs->size) return false; - if (lhs->data == rhs->data) return true; - return std::memcmp(lhs->data, rhs->data, lhs->size) == 0; - } -}; - struct RefToObjectPtr : public ObjectRef { static ObjectPtr Get(const ObjectRef& ref) { return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(ref); } }; -TVM_REGISTER_REFLECTION_VTABLE(ffi::StringObj, StringObjTrait) +TVM_REGISTER_REFLECTION_VTABLE(ffi::StringObj) .set_creator([](const std::string& bytes) { return RefToObjectPtr::Get(String(bytes)); }) .set_repr_bytes([](const Object* n) -> std::string { return GetRef(static_cast(n)).operator std::string(); @@ -362,7 +70,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; }); -TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj, BytesObjTrait) +TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj) .set_creator([](const std::string& bytes) { return RefToObjectPtr::Get(String(bytes)); }) .set_repr_bytes([](const Object* n) -> std::string { return GetRef(static_cast(n)).operator std::string(); @@ -374,12 +82,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "b\"" << support::StrEscape(op->data, op->size) << '"'; }); -struct ModuleNodeTrait { - static constexpr const std::nullptr_t SHashReduce = nullptr; - static constexpr const std::nullptr_t SEqualReduce = nullptr; -}; - -TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode, ModuleNodeTrait) +TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode) .set_creator([](const std::string& blob) { runtime::Module rtmod = codegen::DeserializeModuleFromBytes(blob); return RefToObjectPtr::Get(rtmod); @@ -389,28 +92,7 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode, ModuleNodeTrait) return codegen::SerializeModuleToBytes(GetRef(rtmod), /*export_dso*/ false); }); -void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, - bool hash_data) { - ICHECK_EQ(arr->device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(*arr)) << "Can only hash contiguous tensor"; - (*hash_reduce)(runtime::DataType(arr->dtype)); - (*hash_reduce)(arr->ndim); - for (int i = 0; i < arr->ndim; ++i) { - (*hash_reduce)(arr->shape[i]); - } - if (hash_data) { - (*hash_reduce) - ->SHashReduceHashedValue(ffi::details::StableHashBytes(static_cast(arr->data), - runtime::GetDataSize(*arr))); - } -} - -void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key, - SHashReducer hash_reduce) { - NDArrayHash(key, &hash_reduce, /*bool hash_data*/ true); -} - -TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait) +TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container) .set_creator([](const std::string& blob) { dmlc::MemoryStringStream mstrm(const_cast(&blob)); support::Base64InStream b64strm(&mstrm); @@ -429,110 +111,12 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrai return blob; }); -struct ArrayObjTrait { - static void SHashReduce(const ffi::ArrayObj* key, SHashReducer hash_reduce) { - hash_reduce(static_cast(key->size())); - for (uint32_t i = 0; i < key->size(); ++i) { - hash_reduce(key->at(i)); - } - } - - static bool SEqualReduce(const ffi::ArrayObj* lhs, const ffi::ArrayObj* rhs, - SEqualReducer equal) { - if (equal.IsPathTracingEnabled()) { - return SEqualReduceTraced(lhs, rhs, equal); - } - - if (lhs->size() != rhs->size()) return false; - for (uint32_t i = 0; i < lhs->size(); ++i) { - if (!equal.AnyEqual(lhs->at(i), rhs->at(i))) return false; - } - return true; - } - - private: - static bool SEqualReduceTraced(const ffi::ArrayObj* lhs, const ffi::ArrayObj* rhs, - const SEqualReducer& equal) { - uint32_t min_size = std::min(lhs->size(), rhs->size()); - const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths(); - - for (uint32_t index = 0; index < min_size; ++index) { - ObjectPathPair element_paths = {array_paths->lhs_path->ArrayIndex(index), - array_paths->rhs_path->ArrayIndex(index)}; - if (!equal.AnyEqual(lhs->at(index), rhs->at(index), element_paths)) { - return false; - } - } - - if (lhs->size() == rhs->size()) { - return true; - } - - // If the array length is mismatched, don't report it immediately. - // Instead, defer the failure until we visit all children. - // - // This is for human readability. For example, say we have two sequences - // - // (1) a b c d e f g h i j k l m - // (2) a b c d e g h i j k l m - // - // If we directly report a mismatch at the end of the array right now, - // the user will see that array (1) has an element `m` at index 12 but array (2) - // has no index 12 because it's too short: - // - // (1) a b c d e f g h i j k l m - // ^error here - // (2) a b c d e g h i j k l m - // ^ error here - // - // This is not very helpful. Instead, if we defer reporting this mismatch until all elements - // are fully visited, we can be much more helpful with pointing out the location: - // - // (1) a b c d e f g h i j k l m - // ^ - // error here - // - // (2) a b c d e g h i j k l m - // ^ - // error here - if (equal->IsFailDeferralEnabled()) { - if (lhs->size() > min_size) { - equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size), - array_paths->rhs_path->MissingArrayElement(min_size)}); - } else { - equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size), - array_paths->rhs_path->ArrayIndex(min_size)}); - } - // Can return `true` pretending that everything is good since we have deferred the failure. - return true; - } - return false; - } -}; -TVM_REGISTER_REFLECTION_VTABLE(ffi::ArrayObj, ArrayObjTrait) +TVM_REGISTER_REFLECTION_VTABLE(ffi::ArrayObj) .set_creator([](const std::string&) -> ObjectPtr { return ffi::make_object(); }); -struct ShapeObjTrait { - static void SHashReduce(const ffi::ShapeObj* self, SHashReducer hash_reduce) { - hash_reduce(static_cast(self->size)); - for (uint32_t i = 0; i < self->size; ++i) { - hash_reduce(self->data[i]); - } - } - - static bool SEqualReduce(const ffi::ShapeObj* lhs, const ffi::ShapeObj* rhs, - SEqualReducer equal) { - if (lhs->size != rhs->size) return false; - for (uint32_t i = 0; i < lhs->size; ++i) { - if (!equal(lhs->data[i], rhs->data[i])) return false; - } - return true; - } -}; - -TVM_REGISTER_REFLECTION_VTABLE(ffi::ShapeObj, ShapeObjTrait) +TVM_REGISTER_REFLECTION_VTABLE(ffi::ShapeObj) .set_creator([](const std::string& blob) { // Store shape tuple in blob to avoid large integer overflow in JSON. dmlc::MemoryStringStream mstrm(const_cast(&blob)); @@ -556,139 +140,7 @@ TVM_REGISTER_REFLECTION_VTABLE(ffi::ShapeObj, ShapeObjTrait) return blob; }); -struct MapObjTrait { - static void SHashReduceForOMap(const ffi::MapObj* key, SHashReducer hash_reduce) { - // SHash's var handling depends on the determinism of traversal. - // NOTE: only book-keep the mapped hash keys. - // This resolves common use cases where we want to store - // Map where Var is defined in the function - // parameters. - using KV = std::pair; - std::vector temp; - for (const auto& kv : *key) { - uint64_t hashed_value; - // skip non-object keys - if (kv.first.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - if (hash_reduce->LookupHashedValue(kv.first.cast(), &hashed_value)) { - temp.emplace_back(hashed_value, kv.second); - } - } - } - // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), - [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); - // add size to the hash - hash_reduce(static_cast(key->size())); - // hash the content - for (uint32_t i = 0; i < temp.size();) { - uint32_t k = i + 1; - for (; k < temp.size() && temp[k].first == temp[i].first; ++k) { - } - // ties are rare, but we need to skip them to make the hash deterministic - if (k == i + 1) { - hash_reduce->SHashReduceHashedValue(temp[i].first); - hash_reduce(temp[i].second); - } - i = k; - } - } - - static void SHashReduceForSMap(const ffi::MapObj* key, SHashReducer hash_reduce) { - // NOTE: only book-keep the mapped hash keys. - // This resolves common use cases where we want to store - // Map where Var is defined in the function - // parameters. - using KV = std::pair; - std::vector temp; - for (const auto& kv : *key) { - temp.push_back(std::make_pair(kv.first.cast(), kv.second)); - } - // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), - [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); - // NOTE: we won't have ties - // add size to the hash after sorting. - hash_reduce(static_cast(key->size())); - // hash the content - for (uint32_t i = 0; i < temp.size(); ++i) { - hash_reduce(temp[i].first); - hash_reduce(temp[i].second); - } - } - - static void SHashReduce(const ffi::MapObj* key, SHashReducer hash_reduce) { - bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) { - return v.first.template as(); - }); - if (is_str_map) { - SHashReduceForSMap(key, hash_reduce); - } else { - SHashReduceForOMap(key, hash_reduce); - } - } - - static bool SEqualReduceTraced(const ffi::MapObj* lhs, const ffi::MapObj* rhs, - const SEqualReducer& equal) { - const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths(); - // First, check that every key from `lhs` is also in `rhs`, - // and their values are mapped to each other. - for (const auto& kv : *lhs) { - ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first); - Any rhs_key = equal->MapLhsToRhs(kv.first); - auto it = rhs->find(rhs_key); - if (it == rhs->end()) { - equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()}); - return false; - } - - if (!equal.AnyEqual(kv.second, it->second, - ObjectPathPair({lhs_path, map_paths->rhs_path->MapValue(it->first)}))) { - return false; - } - } - // fast path, lhs equals rhs - if (lhs->size() == rhs->size()) return true; - // slow path check what rhs keys are missing in lhs - std::unordered_set seen_rhs_keys; - for (const auto& kv : *lhs) { - ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first); - Any rhs_key = equal->MapLhsToRhs(kv.first); - seen_rhs_keys.insert(rhs_key); - } - // Second, check that we have visited every `rhs` key when iterating over `lhs`. - for (const auto& kv : *rhs) { - if (!seen_rhs_keys.count(kv.first)) { - equal.RecordMismatchPaths( - {map_paths->lhs_path->MissingMapEntry(), map_paths->rhs_path->MapValue(kv.first)}); - return false; - } - } - LOG(FATAL) << "not reached"; - TVM_FFI_UNREACHABLE(); - } - - static bool SEqualReduce(const ffi::MapObj* lhs, const ffi::MapObj* rhs, SEqualReducer equal) { - if (equal.IsPathTracingEnabled()) { - return SEqualReduceTraced(lhs, rhs, equal); - } - - if (rhs->size() != lhs->size()) return false; - if (rhs->size() == 0) return true; - - for (const auto& kv : *lhs) { - // Only allow equal checking if the keys are already mapped - // This resolves common use cases where we want to store - // Map where Var is defined in the function - // parameters. - Any rhs_key = equal->MapLhsToRhs(kv.first); - auto it = rhs->find(rhs_key); - if (it == rhs->end()) return false; - if (!equal.AnyEqual(kv.second, it->second)) return false; - } - return true; - } -}; -TVM_REGISTER_REFLECTION_VTABLE(ffi::MapObj, MapObjTrait) +TVM_REGISTER_REFLECTION_VTABLE(ffi::MapObj) .set_creator([](const std::string&) -> ObjectPtr { return ffi::MapObj::Empty(); }); struct ReportNodeTrait { @@ -699,12 +151,10 @@ struct ReportNodeTrait { .def_ro("device_metrics", &runtime::profiling::ReportNode::device_metrics) .def_ro("configuration", &runtime::profiling::ReportNode::configuration); } - - static constexpr std::nullptr_t SEqualReduce = nullptr; - static constexpr std::nullptr_t SHashReduce = nullptr; }; + TVM_FFI_STATIC_INIT_BLOCK({ ReportNodeTrait::RegisterReflection(); }); -TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::ReportNode, ReportNodeTrait); +TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::ReportNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -718,14 +168,11 @@ struct CountNodeTrait { refl::ObjectDef().def_ro("value", &runtime::profiling::CountNode::value); } - - static constexpr std::nullptr_t SEqualReduce = nullptr; - static constexpr std::nullptr_t SHashReduce = nullptr; }; TVM_FFI_STATIC_INIT_BLOCK({ CountNodeTrait::RegisterReflection(); }); -TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::CountNode, CountNodeTrait); +TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::CountNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -738,9 +185,6 @@ struct DurationNodeTrait { refl::ObjectDef().def_ro( "microseconds", &runtime::profiling::DurationNode::microseconds); } - - static constexpr std::nullptr_t SEqualReduce = nullptr; - static constexpr std::nullptr_t SHashReduce = nullptr; }; TVM_FFI_STATIC_INIT_BLOCK({ DurationNodeTrait::RegisterReflection(); }); @@ -750,7 +194,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto* op = static_cast(node.get()); p->stream << op->GetTypeKey() << "(" << op->microseconds << ")"; }); -TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::DurationNode, DurationNodeTrait); +TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::DurationNode); struct PercentNodeTrait { static void RegisterReflection() { @@ -758,14 +202,11 @@ struct PercentNodeTrait { refl::ObjectDef().def_ro( "percent", &runtime::profiling::PercentNode::percent); } - - static constexpr std::nullptr_t SEqualReduce = nullptr; - static constexpr std::nullptr_t SHashReduce = nullptr; }; TVM_FFI_STATIC_INIT_BLOCK({ PercentNodeTrait::RegisterReflection(); }); -TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::PercentNode, PercentNodeTrait); +TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::PercentNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -778,14 +219,11 @@ struct RatioNodeTrait { refl::ObjectDef().def_ro("ratio", &runtime::profiling::RatioNode::ratio); } - - static constexpr std::nullptr_t SEqualReduce = nullptr; - static constexpr std::nullptr_t SHashReduce = nullptr; }; TVM_FFI_STATIC_INIT_BLOCK({ RatioNodeTrait::RegisterReflection(); }); -TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::RatioNode, RatioNodeTrait); +TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::RatioNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 6a972f518f05..037a9f3021fb 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -38,8 +39,6 @@ #include #include -#include "../../node/ndarray_hash_equal.h" - // Block builder have three categories of logics that are interdependent with each other. // // The logics are somewhat interdependent with each other. @@ -429,12 +428,11 @@ class BlockBuilderImpl : public BlockBuilderNode { } /*! \brief A custom structural hashing that ignores NDArray raw data. */ - class StructuralHashIgnoreNDarray : public BaseValueHash { + class StructuralHashIgnoreNDarray { public: - using BaseValueHash::operator(); - uint64_t operator()(const ObjectRef& key) const { - return SHashHandlerIgnoreNDArray().Hash(key, false); + return ffi::reflection::StructuralHash::Hash(key, /*map_free_vars=*/false, + /*skip_ndarray_content=*/true); } }; diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index c905b8730571..0bd3606cd216 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -450,33 +450,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -bool MatchCastNode::SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { - if (value->IsInstance()) { - // Recursive function definitions may reference the bound variable - // within the value being bound. In these cases, the - // `DefEqual(var, other->var)` must occur first, to ensure it is - // defined at point of use. - return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && - equal(value, other->value); - } else { - // In all other cases, visit the bound value before the variable - // it is bound to, in order to provide better error messages. - return equal(value, other->value) && equal.DefEqual(struct_info, other->struct_info) && - equal.DefEqual(var, other->var); - } -} -void MatchCastNode::SHashReduce(SHashReducer hash_reduce) const { - if (value->IsInstance()) { - hash_reduce.DefHash(var); - hash_reduce.DefHash(struct_info); - hash_reduce(value); - } else { - hash_reduce(value); - hash_reduce.DefHash(struct_info); - hash_reduce.DefHash(var); - } -} - TVM_REGISTER_NODE_TYPE(VarBindingNode); VarBinding::VarBinding(Var var, Expr value, Span span) { @@ -494,29 +467,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -bool VarBindingNode::SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { - if (value->IsInstance()) { - // Recursive function definitions may reference the bound variable - // within the value being bound. In these cases, the - // `DefEqual(var, other->var)` must occur first, to ensure it is - // defined at point of use. - return equal.DefEqual(var, other->var) && equal(value, other->value); - } else { - // In all other cases, visit the bound value before the variable - // it is bound to, in order to provide better error messages. - return equal(value, other->value) && equal.DefEqual(var, other->var); - } -} -void VarBindingNode::SHashReduce(SHashReducer hash_reduce) const { - if (value->IsInstance()) { - hash_reduce.DefHash(var); - hash_reduce(value); - } else { - hash_reduce(value); - hash_reduce.DefHash(var); - } -} - bool VarBindingNode::SEqual(const VarBindingNode* other, ffi::TypedFunction equal) const { if (value->IsInstance()) { diff --git a/src/target/target.cc b/src/target/target.cc index a0e86330d7a6..ac82c51b1b78 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -710,19 +710,6 @@ String TargetNode::ToDebugString() const { return os.str(); } -bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const { - return equal(kind.get(), other->kind.get()) && equal(host, other->host) && - equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs); -} - -void TargetNode::SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(kind.get()); - hash_reduce(host); - hash_reduce(tag); - hash_reduce(keys); - hash_reduce(attrs); -} - /*! \brief Entry to hold the Target context stack. */ struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 75451bc0495f..301c6c13b9f0 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include "ir_utils.h" @@ -40,12 +41,9 @@ class Applicator : public tir::StmtMutator { protected: // returns index of the a in constant_array_, if not found - appends size_t DeDup(const runtime::NDArray& a) { - tvm::SEqualReducer eql; - auto it = std::find_if( - constant_array_.begin(), constant_array_.end(), [&eql, a](const runtime::NDArray& v) { - return NDArrayContainerTrait::SEqualReduce(a.as(), - v.as(), eql); - }); + tvm::StructuralEqual eql; + auto it = std::find_if(constant_array_.begin(), constant_array_.end(), + [&eql, a](const runtime::NDArray& v) { return eql(a, v); }); if (it != constant_array_.end()) { return it - constant_array_.begin(); } diff --git a/tests/python/ir/test_container_structural_equal.py b/tests/python/ir/test_container_structural_equal.py index 238a77b4ef4b..84556aab6b27 100644 --- a/tests/python/ir/test_container_structural_equal.py +++ b/tests/python/ir/test_container_structural_equal.py @@ -174,10 +174,5 @@ def test_string_structural_equal_to_self(contents): assert get_first_mismatch_ensure_symmetry(a, b) is None -# The behavior of structural equality for maps with non-string keys is fairly specific -# to IR variables because it assumes that map keys have been "mapped" using -# `SEqualReducer::FreeVarEqualImpl()`. So we leave this case to TIR tests. - - if __name__ == "__main__": tvm.testing.main()