diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 14b3cd91701c6..a02d74ae0832c 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -486,6 +486,37 @@ std::string RelayPrint( const NodeRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); + +/*! + * \brief User defined type relation, is an input-output relation on types. + */ +class TypeOf; +/*! + * \brief TypeRelation container. + * \note This node is not directly serializable. + * The type function need to be lookedup in the module. + */ +class TypeOfNode : public TypeNode { + public: + /*! + * \brief The function on input and output variables which + * this is not directly serializable, + * need to be looked-up in the module. + */ + relay::Expr expr; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("expr", &expr); + } + + TVM_DLL static TypeOf make(relay::Expr expr); + + static constexpr const char* _type_key = "relay.TypeOf"; + TVM_DECLARE_NODE_TYPE_INFO(TypeOfNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeOf, TypeOfNode, Type); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 13526b9ce3634..90b4663f7c026 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -48,6 +48,7 @@ scalar_type = ty.scalar_type GlobalTypeVar = ty.GlobalTypeVar TypeCall = ty.TypeCall +TypeOf = ty.TypeOf # Expr Expr = expr.Expr diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 3aef57cff6dc7..7df07cf72714b 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -6,7 +6,7 @@ from .op import register_gradient from .op import schedule_injective, OpPattern from .transform import collapse_sum_like -from .tensor import negative +from .tensor import negative, zeros_like, ones_like def add_grad(orig, grad): @@ -29,6 +29,12 @@ def multiply_grad(orig, grad): register_gradient("multiply", multiply_grad) +def take_grad(orig, grad): + return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] + + +register_gradient("take", take_grad) + schedule_broadcast = schedule_injective schedule_elemwise = schedule_injective diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 619df4b966651..24eeacf55b443 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -270,6 +270,10 @@ def __init__(self, func, args, num_inputs, attrs): self.__init_handle_by_constructor__(_make.TypeRelation, func, args, num_inputs, attrs) +@register_relay_node +class TypeOf(Type): + def __init__(self, expr): + self.__init_handle_by_constructor__(_make.TypeOf, expr) def scalar_type(dtype): """Creates a scalar type. diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc index c8a8894205c09..6d54c9b7008ca 100644 --- a/src/relay/ir/error.cc +++ b/src/relay/ir/error.cc @@ -87,9 +87,9 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { std::endl << rang::style::reset; - // for (auto pair : err_map) { - // std::cout << "Key: " << pair.first << " Value: " << pair.second << std::endl; - // } + for (auto pair : err_map) { + std::cout << "Key: " << pair.first << std::endl << " Value: " << pair.second << std::endl; + } // We then call into the Relay printer to generate the program. // diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3af619c70643f..92f799b21f93f 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -283,6 +283,24 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") *ret = temp->Realize(); }); +TypeOf TypeOfNode::make(relay::Expr expr) { + NodePtr n = make_node(); + n->expr = std::move(expr); + return TypeOf(n); +} + +TVM_REGISTER_NODE_TYPE(TypeOfNode); + +TVM_REGISTER_API("relay._make.TypeOf") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TypeOfNode::make(args[0]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeOfNode* node, + tvm::IRPrinter* p) { + p->stream << "TypeOf(" << node->expr << ")"; +}); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index bde9e53dc3297..32bc8a0d16127 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -207,5 +207,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TupleTypeNode(" << node->fields << ")"; }); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 9c47a2fb58de8..a6ea13cef1d2a 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -68,6 +68,8 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { } } +void TypeVisitor::VisitType_(const TypeOfNode* op) {} + // Type Mutator. Array TypeMutator::MutateArray(Array arr) { // The array will do copy on write @@ -172,6 +174,10 @@ Type TypeMutator::VisitType_(const TypeDataNode* op) { return GetRef(op); } +Type TypeMutator::VisitType_(const TypeOfNode* op) { + return GetRef(op); +} + // Implements bind. class TypeBinder : public TypeMutator { public: diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 9a6f18a09cc7f..ece6f75a64355 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -72,6 +72,7 @@ class TypeFunctor { virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeOfNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); @@ -93,6 +94,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeOfNode); return vtable; } }; @@ -111,6 +113,7 @@ class TypeVisitor : public TypeFunctor { void VisitType_(const GlobalTypeVarNode* op) override; void VisitType_(const TypeCallNode* op) override; void VisitType_(const TypeDataNode* op) override; + void VisitType_(const TypeOfNode* op) override; }; // Mutator that transform a type to another one. @@ -125,6 +128,7 @@ class TypeMutator : public TypeFunctor { Type VisitType_(const GlobalTypeVarNode* op) override; Type VisitType_(const TypeCallNode* op) override; Type VisitType_(const TypeDataNode* op) override; + Type VisitType_(const TypeOfNode* op) override; private: Array MutateArray(Array arr); diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 9c9b434fb204d..8b7ba4515e1ae 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -156,6 +156,10 @@ struct KindChecker : TypeFunctor { return Kind::kTypeData; } + Kind VisitType_(const TypeOfNode* op) override { + return kType; + } + Kind Check(const Type& t) { return this->VisitType(t); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 78766a0fc77c3..490727aedfb54 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -74,6 +74,9 @@ struct ResolvedTypeInfo { Array type_args = Array(NodePtr(nullptr)); }; +class TypeInferencer; +Type ExpandTypeOf(TypeInferencer* infer, const Type& t); + // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -83,6 +86,7 @@ struct ResolvedTypeInfo { // class TypeInferencer : private ExprFunctor, private PatternFunctor { + friend struct ExpandTypeOfMutator; public: // constructors @@ -116,12 +120,17 @@ class TypeInferencer : private ExprFunctor, TypeRelationFn tuple_getitem_rel_; TypeRelationFn make_tuple_rel_; + inline Type ExpandTypeOf(const Type& t) { + return ::tvm::relay::ExpandTypeOf(this, t); + } + // Perform unification on two types and report the error at the expression // or the span of the expression. Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { // TODO(tqchen, jroesch): propagate span to solver + try { - return solver_.Unify(t1, t2, expr); + return solver_.Unify(ExpandTypeOf(t1), ExpandTypeOf(t2), expr); } catch (const dmlc::Error &e) { this->ReportFatalError( expr, @@ -157,7 +166,7 @@ class TypeInferencer : private ExprFunctor, // Visitor Logic Type VisitExpr_(const VarNode* op) final { if (op->type_annotation.defined()) { - return op->type_annotation; + return ExpandTypeOf(op->type_annotation); } else { return IncompleteTypeNode::make(Kind::kType); } @@ -718,6 +727,19 @@ Function InferType(const Function& func, return Downcast(func_ret); } +struct ExpandTypeOfMutator : TypeMutator { + TypeInferencer* infer; + Type VisitType_(const TypeOfNode* type_of) final { + return this->infer->GetType(type_of->expr); + } + ExpandTypeOfMutator(TypeInferencer* infer) : infer(infer) {} +}; + +inline Type ExpandTypeOf(TypeInferencer* infer, const Type& t) { + auto expand = ExpandTypeOfMutator(infer); + return expand.VisitType(t); +} + TVM_REGISTER_API("relay._ir_pass.infer_type") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = InferType(args[0], args[1]); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index cebdd8fca82de..55494f4348cf8 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -194,7 +194,6 @@ class TypeSolver::Unifier : public TypeFunctor { } return TypeCallNode::make(func, args); } - private: TypeSolver* solver_; };