Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NODE][IR] Introduce StructuralHash for the Unified IR. #5160

Merged
merged 4 commits into from
Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/ir/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class ConstructorNode : public RelayExprNode {
equal(inputs, other->inputs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce(inputs);
}

static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
Expand Down Expand Up @@ -123,6 +128,12 @@ class TypeDataNode : public TypeNode {
equal(constructors, other->constructors);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(header);
hash_reduce.DefHash(type_vars);
hash_reduce(constructors);
}

static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
Expand Down
26 changes: 26 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class AttrFieldInfoNode : public Object {

static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};

Expand Down Expand Up @@ -281,6 +282,7 @@ class BaseAttrsNode : public Object {
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
Expand Down Expand Up @@ -309,6 +311,10 @@ class DictAttrsNode : public BaseAttrsNode {
return equal(dict, other->dict);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dict);
}

// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
Expand Down Expand Up @@ -452,6 +458,21 @@ class AttrsHashVisitor {
const AttrsHash& hasher_;
};

class AttrsSHashVisitor {
public:
explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
: hash_reducer_(hash_reducer) {}

template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
hash_reducer_(*value);
return AttrNopEntry();
}

private:
const SHashReducer& hash_reducer_;
};

// helper entry that does initialization, set default.
template<typename T>
struct AttrInitEntry {
Expand Down Expand Up @@ -858,6 +879,11 @@ class AttrsNode : public BaseAttrsNode {
return visitor.result_;
}

void SHashReduce(SHashReducer hash_reducer) const {
::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
self()->__VisitAttrs__(visitor);
}

Array<AttrFieldInfo> ListFieldInfo() const final {
::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
Expand Down
9 changes: 8 additions & 1 deletion include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,18 @@ class EnvFuncNode : public Object {
}

bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
return this == other;
// name uniquely identifies the env function.
return name == other->name;
}

void SHashReduce(SHashReducer hash_reduce) const {
// Name uniquely identifies the env function.
hash_reduce(name);
}

static constexpr const char* _type_key = "EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};

Expand Down
22 changes: 22 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

Expand Down Expand Up @@ -205,6 +206,11 @@ class GlobalVarNode : public RelayExprNode {
equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}

static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
Expand Down Expand Up @@ -240,6 +246,11 @@ class IntImmNode : public PrimExprNode {
return equal(dtype, other->dtype) && equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
}

static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
Expand Down Expand Up @@ -279,6 +290,11 @@ class FloatImmNode : public PrimExprNode {
return equal(dtype, other->dtype) && equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
}

static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
Expand Down Expand Up @@ -373,8 +389,14 @@ class RangeNode : public Object {
return equal(min, other->min) && equal(extent, other->extent);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(min);
hash_reduce(extent);
}

static constexpr const char* _type_key = "Range";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};

Expand Down
3 changes: 3 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class IRModuleNode : public Object {

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;

TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;

/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
Expand Down Expand Up @@ -238,6 +240,7 @@ class IRModuleNode : public Object {

static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);

private:
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class OpNode : public RelayExprNode {
return this == other;
}

void SHashReduce(SHashReducer hash_reduce) const {
// Name uniquely identifies an Op.
hash_reduce(name);
}

/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir/tensor_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class TensorTypeNode : public BaseTensorTypeNode {
equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(shape);
hash_reduce(dtype);
}

/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
Expand Down
38 changes: 38 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class TypeNode : public Object {

static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};

Expand Down Expand Up @@ -115,6 +116,10 @@ class PrimTypeNode : public TypeNode {
return equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
}

static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
Expand Down Expand Up @@ -161,6 +166,10 @@ class PointerTypeNode : public TypeNode {
return equal(element_type, other->element_type);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(element_type);
}

static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
Expand Down Expand Up @@ -233,6 +242,11 @@ class TypeVarNode : public TypeNode {
equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(kind);
hash_reduce.FreeVarHashImpl(this);
}

static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
Expand Down Expand Up @@ -280,6 +294,11 @@ class GlobalTypeVarNode : public TypeNode {
equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}

static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
Expand Down Expand Up @@ -320,6 +339,10 @@ class TupleTypeNode : public TypeNode {
return equal(fields, other->fields);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(fields);
}

static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
Expand Down Expand Up @@ -421,6 +444,13 @@ class FuncTypeNode : public TypeNode {
equal(type_constraints, other->type_constraints);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(type_params);
hash_reduce(arg_types);
hash_reduce(ret_type);
hash_reduce(type_constraints);
}

static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
Expand Down Expand Up @@ -471,6 +501,10 @@ class IncompleteTypeNode : public TypeNode {
return equal(kind, other->kind);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(kind);
}

static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
Expand Down Expand Up @@ -512,6 +546,10 @@ class RelayRefTypeNode : public TypeNode {
return equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(value);
}

// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class TypeCallNode : public TypeNode {
equal(args, other->args);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(args);
}

static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
Expand Down Expand Up @@ -209,6 +214,13 @@ class TypeRelationNode : public TypeConstraintNode {
equal(attrs, other->attrs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(args);
hash_reduce(num_inputs);
hash_reduce(attrs);
}

static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <tvm/node/repr_printer.h>
#include <tvm/node/container.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>

#include <string>
#include <vector>
Expand Down
Loading